From bd92659f7d51bfeb9a05e20921088af81888a80e Mon Sep 17 00:00:00 2001 From: smoczy123 Date: Mon, 6 Jan 2025 23:34:53 +0100 Subject: [PATCH] Added vector type deserialization --- scylla-cql/src/frame/frame_errors.rs | 4 + scylla-cql/src/frame/response/mod.rs | 1 + scylla-cql/src/frame/response/result.rs | 83 ++++- scylla-cql/src/frame/response/type_parser.rs | 267 ++++++++++++++ scylla-cql/src/frame/types.rs | 2 +- scylla-cql/src/frame/value.rs | 4 + .../src/types/deserialize/frame_slice.rs | 29 ++ scylla-cql/src/types/deserialize/value.rs | 341 +++++++++++++++++- 8 files changed, 712 insertions(+), 19 deletions(-) create mode 100644 scylla-cql/src/frame/response/type_parser.rs diff --git a/scylla-cql/src/frame/frame_errors.rs b/scylla-cql/src/frame/frame_errors.rs index 2eac4c50f6..6540603af3 100644 --- a/scylla-cql/src/frame/frame_errors.rs +++ b/scylla-cql/src/frame/frame_errors.rs @@ -454,6 +454,8 @@ pub enum CqlTypeParseError { TupleLengthParseError(LowLevelDeserializationError), #[error("CQL Type not yet implemented, id: {0}")] TypeNotImplemented(u16), + #[error("Failed to parse abstract type")] + AbstractTypeParseError(), } /// A low level deserialization error. @@ -485,6 +487,8 @@ pub enum LowLevelDeserializationError { InvalidInetLength(u8), #[error("UTF8 deserialization failed: {0}")] UTF8DeserializationError(#[from] std::str::Utf8Error), + #[error(transparent)] + ParseIntError(#[from] std::num::ParseIntError), } impl From for LowLevelDeserializationError { diff --git a/scylla-cql/src/frame/response/mod.rs b/scylla-cql/src/frame/response/mod.rs index 08ade35c30..cdb25065e0 100644 --- a/scylla-cql/src/frame/response/mod.rs +++ b/scylla-cql/src/frame/response/mod.rs @@ -4,6 +4,7 @@ pub mod error; pub mod event; pub mod result; pub mod supported; +pub mod type_parser; use std::sync::Arc; diff --git a/scylla-cql/src/frame/response/result.rs b/scylla-cql/src/frame/response/result.rs index 02620de536..72e3e2f780 100644 --- a/scylla-cql/src/frame/response/result.rs +++ b/scylla-cql/src/frame/response/result.rs @@ -1,3 +1,4 @@ +use super::type_parser; #[allow(deprecated)] use crate::cql_to_rust::{FromRow, FromRowError}; use crate::frame::frame_errors::{ @@ -16,7 +17,8 @@ use crate::frame::value::{ use crate::types::deserialize::result::{RawRowIterator, TypedRowIterator}; use crate::types::deserialize::row::DeserializeRow; use crate::types::deserialize::value::{ - mk_deser_err, BuiltinDeserializationErrorKind, DeserializeValue, MapIterator, UdtIterator, + mk_deser_err, BuiltinDeserializationErrorKind, ConstLengthVectorIterator, DeserializeValue, + MapIterator, UdtIterator, VariableLengthVectorIterator, }; use crate::types::deserialize::{DeserializationError, FrameSlice, TypeCheckError}; use bytes::{Buf, Bytes}; @@ -81,6 +83,7 @@ pub enum ColumnType<'frame> { Tuple(Vec>), Uuid, Varint, + Vector(Box>, u32), } impl ColumnType<'_> { @@ -128,6 +131,43 @@ impl ColumnType<'_> { } ColumnType::Uuid => ColumnType::Uuid, ColumnType::Varint => ColumnType::Varint, + ColumnType::Vector(elem_type, dim) => { + ColumnType::Vector(Box::new(elem_type.into_owned()), dim) + } + } + } + + pub fn type_size(&self) -> Option { + match self { + ColumnType::Custom(_) => None, + ColumnType::Ascii => None, + ColumnType::Boolean => Some(1), + ColumnType::Blob => None, + ColumnType::Counter => None, + ColumnType::Date => Some(8), + ColumnType::Decimal => None, + ColumnType::Double => Some(8), + ColumnType::Duration => None, + ColumnType::Float => Some(4), + ColumnType::Int => Some(4), + ColumnType::BigInt => Some(8), + ColumnType::Text => None, + ColumnType::Timestamp => Some(8), + ColumnType::Inet => None, + ColumnType::List(_) => None, + ColumnType::Map(_, _) => None, + ColumnType::Set(_) => None, + ColumnType::UserDefinedType { .. } => None, + // Note that although SmallInt and TinyInt is of a fixed size, + // Cassandra (erroneously) treats it as a variable-size + ColumnType::SmallInt => None, + ColumnType::TinyInt => None, + ColumnType::Time => Some(8), + ColumnType::Timeuuid => Some(16), + ColumnType::Tuple(_) => None, + ColumnType::Uuid => Some(16), + ColumnType::Varint => None, + ColumnType::Vector(elem_type, _) => None, } } } @@ -171,6 +211,7 @@ pub enum CqlValue { Tuple(Vec>), Uuid(Uuid), Varint(CqlVarint), + Vector(Vec), } impl<'a> TableSpec<'a> { @@ -438,10 +479,18 @@ impl CqlValue { } } + pub fn as_vector(&self) -> Option<&Vec> { + match self { + Self::Vector(s) => Some(s), + _ => None, + } + } + pub fn into_vec(self) -> Option> { match self { Self::List(s) => Some(s), Self::Set(s) => Some(s), + Self::Vector(s) => Some(s), _ => None, } } @@ -864,17 +913,12 @@ fn deser_type_generic<'frame, 'result, StrT: Into>>( types::read_short(buf).map_err(|err| CqlTypeParseError::TypeIdParseError(err.into()))?; Ok(match id { 0x0000 => { - // We use types::read_string instead of read_string argument here on purpose. - // Chances are the underlying string is `...DurationType`, in which case - // we don't need to allocate it at all. Only for Custom types - // (which we don't support anyway) do we need to allocate. - // OTOH, the provided `read_string` function deserializes borrowed OR owned string; - // here we want to always deserialize borrowed string. - let type_str = - types::read_string(buf).map_err(CqlTypeParseError::CustomTypeNameParseError)?; - match type_str { - "org.apache.cassandra.db.marshal.DurationType" => Duration, - _ => Custom(type_str.to_owned().into()), + let type_str = read_string(buf).map_err(CqlTypeParseError::CustomTypeNameParseError)?; + let type_cow: Cow<'result, str> = type_str.into(); + if let Ok(typ) = type_parser::TypeParser::parse(&type_cow) { + typ + } else { + Ascii } } 0x0001 => Ascii, @@ -1405,6 +1449,16 @@ pub fn deser_cql_value( .collect::>()?; CqlValue::Tuple(t) } + Vector(elem_type, _dimensions) if elem_type.type_size().is_some() => { + let v = ConstLengthVectorIterator::<'_, '_, CqlValue>::deserialize(typ, v)?; + let v: Vec = v.collect::>()?; + CqlValue::Vector(v) + } + Vector(_, _) => { + let v = VariableLengthVectorIterator::<'_, '_, CqlValue>::deserialize(typ, v)?; + let v: Vec = v.collect::>()?; + CqlValue::Vector(v) + } }) } @@ -1525,6 +1579,7 @@ mod test_utils { Self::Set(_) => 0x0022, Self::UserDefinedType { .. } => 0x0030, Self::Tuple(_) => 0x0031, + Self::Vector(_, _) => 0x0000, } } @@ -1563,6 +1618,10 @@ mod test_utils { ColumnType::List(elem_type) | ColumnType::Set(elem_type) => { elem_type.serialize(buf)?; } + ColumnType::Vector(elem_type, dimensions) => { + elem_type.serialize(buf)?; + types::write_short_length(*dimensions as usize, buf)?; + } ColumnType::Map(key_type, value_type) => { key_type.serialize(buf)?; value_type.serialize(buf)?; diff --git a/scylla-cql/src/frame/response/type_parser.rs b/scylla-cql/src/frame/response/type_parser.rs new file mode 100644 index 0000000000..289d9e97ca --- /dev/null +++ b/scylla-cql/src/frame/response/type_parser.rs @@ -0,0 +1,267 @@ +use crate::frame::frame_errors::CqlTypeParseError; +use std::{borrow::Cow, char}; + +use super::result::ColumnType; + +pub(crate) struct TypeParser<'result> { + pos: usize, + str: &'result str, +} + +impl<'result> TypeParser<'result> { + fn new(str: &str) -> TypeParser { + TypeParser { pos: 0, str } + } + + pub(crate) fn parse(str: &'result str) -> Result, CqlTypeParseError> { + let mut parser = TypeParser::new(str); + parser.do_parse() + } + + fn is_eos(&self) -> bool { + self.pos >= self.str.len() + } + + fn is_blank(c: char) -> bool { + c == ' ' || c == '\t' || c == '\n' + } + + fn is_identifier_char(c: char) -> bool { + c.is_alphanumeric() || c == '+' || c == '-' || c == '_' || c == '.' || c == '&' + } + + fn read_next_identifier(&mut self) -> &'result str { + let start = self.pos; + while !self.is_eos() + && TypeParser::is_identifier_char(self.str.as_bytes()[self.pos] as char) + { + self.pos += 1; + } + &self.str[start..self.pos] + } + + fn skip_blank(&mut self) -> usize { + while !self.is_eos() && TypeParser::is_blank(self.str.as_bytes()[self.pos] as char) { + self.pos += 1; + } + self.pos + } + + fn skip_blank_and_comma(&mut self) -> bool { + let mut comma_found = false; + while !self.is_eos() { + let c = self.str.as_bytes()[self.pos] as char; + if c == ',' { + if comma_found { + return true; + } else { + comma_found = true; + } + } else if !TypeParser::is_blank(c) { + return true; + } + self.pos += 1; + } + return false; + } + + fn get_simple_abstract_type(name: &str) -> Result { + let string_class_name: String; + let class_name: &str; + if name.contains("org.apache.cassandra.db.marshal.") { + class_name = name + } else { + string_class_name = "org.apache.cassandra.db.marshal.".to_owned() + name; + class_name = &string_class_name; + } + + match class_name { + "org.apache.cassandra.db.marshal.AsciiType" => Ok(ColumnType::Ascii), + "org.apache.cassandra.db.marshal.BooleanType" => Ok(ColumnType::Boolean), + "org.apache.cassandra.db.marshal.BytesType" => Ok(ColumnType::Blob), + "org.apache.cassandra.db.marshal.CounterColumnType" => Ok(ColumnType::Counter), + "org.apache.cassandra.db.marshal.DateType" => Ok(ColumnType::Date), + "org.apache.cassandra.db.marshal.DecimalType" => Ok(ColumnType::Decimal), + "org.apache.cassandra.db.marshal.DoubleType" => Ok(ColumnType::Double), + "org.apache.cassandra.db.marshal.DurationType" => Ok(ColumnType::Duration), + "org.apache.cassandra.db.marshal.FloatType" => Ok(ColumnType::Float), + "org.apache.cassandra.db.marshal.InetAddressType" => Ok(ColumnType::Inet), + "org.apache.cassandra.db.marshal.Int32Type" => Ok(ColumnType::Int), + "org.apache.cassandra.db.marshal.IntegerType" => Ok(ColumnType::Varint), + "org.apache.cassandra.db.marshal.LongType" => Ok(ColumnType::BigInt), + "org.apache.cassandra.db.marshal.SimpleDateType" => Ok(ColumnType::Date), + "org.apache.cassandra.db.marshal.ShortType" => Ok(ColumnType::SmallInt), + "org.apache.cassandra.db.marshal.UTF8Type" => Ok(ColumnType::Text), + "org.apache.cassandra.db.marshal.ByteType" => Ok(ColumnType::TinyInt), + "org.apache.cassandra.db.marshal.UUIDType" => Ok(ColumnType::Uuid), + "org.apache.cassandra.db.marshal.TimeUUIDType" => Ok(ColumnType::Timeuuid), + "org.apache.cassandra.db.marshal.SmallIntType" => Ok(ColumnType::SmallInt), + "org.apache.cassandra.db.marshal.TinyIntType" => Ok(ColumnType::TinyInt), + "org.apache.cassandra.db.marshal.TimeType" => Ok(ColumnType::Time), + "org.apache.cassandra.db.marshal.TimestampType" => Ok(ColumnType::Timestamp), + _ => Err(CqlTypeParseError::AbstractTypeParseError()), + } + } + + fn get_type_parameters(&mut self) -> Result>, CqlTypeParseError> { + let mut parameters = Vec::new(); + if self.is_eos() { + return Ok(parameters); + } + if self.str.as_bytes()[self.pos] != '(' as u8 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + self.pos += 1; + loop { + if !self.skip_blank_and_comma() { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + if self.str.as_bytes()[self.pos] == ')' as u8 { + self.pos += 1; + return Ok(parameters); + } + let typ = self.do_parse()?; + parameters.push(typ); + } + } + + fn get_vector_parameters(&mut self) -> Result<(ColumnType<'result>, u32), CqlTypeParseError> { + if self.is_eos() || self.str.as_bytes()[self.pos] != '(' as u8 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + self.pos += 1; + self.skip_blank_and_comma(); + if self.str.as_bytes()[self.pos] == ')' as u8 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + + let typ = self.do_parse()?; + let start = self.pos; + while !self.is_eos() && char::is_numeric(self.str.as_bytes()[self.pos] as char) { + self.pos += 1; + } + let len = u32::from_str_radix(&self.str[start..self.pos], 10) + .map_err(|_| CqlTypeParseError::AbstractTypeParseError())?; + if self.is_eos() || self.str.as_bytes()[self.pos] != ')' as u8 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + self.pos += 1; + Ok((typ, len)) + } + + fn from_hex(s: &str) -> Result, CqlTypeParseError> { + if s.len() % 2 != 0 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + for c in s.chars() { + if !c.is_digit(16) { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + } + let mut bytes = Vec::new(); + for i in 0..s.len() / 2 { + let byte = u8::from_str_radix(&s[i * 2..i * 2 + 2], 16) + .map_err(|_| CqlTypeParseError::AbstractTypeParseError())?; + bytes.push(byte); + } + Ok(bytes) + } + + fn get_udt_parameters( + &mut self, + ) -> Result< + ( + Cow<'result, str>, + Cow<'result, str>, + Vec<(Cow<'result, str>, ColumnType<'result>)>, + ), + CqlTypeParseError, + > { + unimplemented!("get_udt_parameters"); + } + + fn get_complex_abstract_type( + &mut self, + name: &str, + ) -> Result, CqlTypeParseError> { + let string_class_name: String; + let class_name: &str; + if name.contains("org.apache.cassandra.db.marshal.") { + class_name = name + } else { + string_class_name = "org.apache.cassandra.db.marshal.".to_owned() + name; + class_name = &string_class_name; + } + match class_name { + "org.apache.cassandra.db.marshal.ListType" => { + let mut params = self.get_type_parameters()?; + if params.len() != 1 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + Ok(ColumnType::List(Box::new(params.remove(0)))) + } + "org.apache.cassandra.db.marshal.SetType" => { + let mut params = self.get_type_parameters()?; + if params.len() != 1 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + Ok(ColumnType::Set(Box::new(params.remove(0)))) + } + "org.apache.cassandra.db.marshal.MapType" => { + let mut params = self.get_type_parameters()?; + if params.len() != 2 { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + Ok(ColumnType::Map( + Box::new(params.remove(0)), + Box::new(params.remove(0)), + )) + } + "org.apache.cassandra.db.marshal.TupleType" => { + let params = self.get_type_parameters()?; + if params.is_empty() { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + Ok(ColumnType::Tuple(params)) + } + "org.apache.cassandra.db.marshal.VectorType" => { + let (typ, len) = self.get_vector_parameters()?; + Ok(ColumnType::Vector(Box::new(typ), len)) + } + "org.apache.cassandra.db.marshal.UserType" => { + let (keyspace, name, fields) = self.get_udt_parameters()?; + Ok(ColumnType::UserDefinedType { + type_name: name, + keyspace, + field_types: fields, + }) + } + _ => return Err(CqlTypeParseError::AbstractTypeParseError()), + } + } + + fn do_parse(&mut self) -> Result, CqlTypeParseError> { + self.skip_blank(); + + let mut name = self.read_next_identifier(); + if name.is_empty() { + if !self.is_eos() { + return Err(CqlTypeParseError::AbstractTypeParseError()); + } + return Ok(ColumnType::Blob); + } + + if self.str.as_bytes()[self.pos] == ':' as u8 { + self.pos += 1; + let _ = usize::from_str_radix(name, 16) + .map_err(|_| CqlTypeParseError::AbstractTypeParseError()); + name = self.read_next_identifier(); + } + self.skip_blank(); + if !self.is_eos() && self.str.as_bytes()[self.pos] == '(' as u8 { + return Ok(self.get_complex_abstract_type(name)?); + } else { + return Ok(TypeParser::get_simple_abstract_type(name)?); + } + } +} diff --git a/scylla-cql/src/frame/types.rs b/scylla-cql/src/frame/types.rs index 2ea5a8b6b3..179ba82751 100644 --- a/scylla-cql/src/frame/types.rs +++ b/scylla-cql/src/frame/types.rs @@ -145,7 +145,7 @@ impl<'a> RawValue<'a> { } } -fn read_raw_bytes<'a>( +pub fn read_raw_bytes<'a>( count: usize, buf: &mut &'a [u8], ) -> Result<&'a [u8], LowLevelDeserializationError> { diff --git a/scylla-cql/src/frame/value.rs b/scylla-cql/src/frame/value.rs index 3cf6f28cbc..c23d52210f 100644 --- a/scylla-cql/src/frame/value.rs +++ b/scylla-cql/src/frame/value.rs @@ -1534,6 +1534,10 @@ mod legacy { CqlValue::Map(m) => serialize_map(m.iter().map(|p| (&p.0, &p.1)), m.len(), buf), CqlValue::Tuple(t) => serialize_tuple(t.iter(), buf), + CqlValue::Vector(v) => { + unimplemented!("Vector serialization is not implemented yet"); + } + // A UDT value is composed of successive [bytes] values, one for each field of the UDT // value (in the order defined by the type), so they serialize in a same way tuples do. CqlValue::UserDefinedType { fields, .. } => { diff --git a/scylla-cql/src/types/deserialize/frame_slice.rs b/scylla-cql/src/types/deserialize/frame_slice.rs index 4471960a03..b628ea0225 100644 --- a/scylla-cql/src/types/deserialize/frame_slice.rs +++ b/scylla-cql/src/types/deserialize/frame_slice.rs @@ -155,6 +155,35 @@ impl<'frame> FrameSlice<'frame> { original_frame: self.original_frame, })) } + + /// Reads and consumes a fixed number of bytes item from the beginning of the frame, + /// returning a subslice that encompasses that item. + /// + /// If this slice is empty, returns `Ok(None)`. + /// Otherwise, if the slice does not contain enough data, it returns `Err`. + /// If the operation fails then the slice remains unchanged. + #[inline] + pub(super) fn read_subslice( + &mut self, + count: usize, + ) -> Result>, LowLevelDeserializationError> { + if self.is_empty() { + return Ok(None); + } + + // We copy the slice reference, not to mutate the FrameSlice in case of an error. + let mut slice = self.frame_subslice; + + let cql_bytes = types::read_raw_bytes(count, &mut slice)?; + + // `read_raw_bytes` hasn't failed, so now we must update the FrameSlice. + self.frame_subslice = slice; + + Ok(Some(Self { + frame_subslice: cql_bytes, + original_frame: self.original_frame, + })) + } } #[cfg(test)] diff --git a/scylla-cql/src/types/deserialize/value.rs b/scylla-cql/src/types/deserialize/value.rs index f1979ab63e..ff6d41cfaf 100644 --- a/scylla-cql/src/types/deserialize/value.rs +++ b/scylla-cql/src/types/deserialize/value.rs @@ -810,18 +810,52 @@ where T: DeserializeValue<'frame, 'metadata>, { fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { - // It makes sense for both Set and List to deserialize to Vec. - ListlikeIterator::<'frame, 'metadata, T>::type_check(typ) - .map_err(typck_error_replace_rust_name::) + // It makes sense for both Set, List and Vector to deserialize to Vec. + match typ { + ColumnType::List(_) | ColumnType::Set(_) => { + ListlikeIterator::<'frame, 'metadata, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::) + } + ColumnType::Vector(el_t, _) if el_t.type_size().is_some() => { + ConstLengthVectorIterator::<'frame, 'metadata, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::) + } + ColumnType::Vector(_, _) => { + VariableLengthVectorIterator::<'frame, 'metadata, T>::type_check(typ) + .map_err(typck_error_replace_rust_name::) + } + _ => Err(mk_typck_err::( + typ, + BuiltinTypeCheckErrorKind::SetOrListError( + SetOrListTypeCheckErrorKind::NotSetOrList, + ), + )), + } } fn deserialize( typ: &'metadata ColumnType<'metadata>, v: Option>, ) -> Result { - ListlikeIterator::<'frame, 'metadata, T>::deserialize(typ, v) - .and_then(|it| it.collect::>()) - .map_err(deser_error_replace_rust_name::) + match typ { + ColumnType::List(_) | ColumnType::Set(_) => { + ListlikeIterator::<'frame, 'metadata, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } + + ColumnType::Vector(el_t, _) if el_t.type_size().is_some() => { + ConstLengthVectorIterator::<'frame, 'metadata, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } + ColumnType::Vector(_, _) => { + VariableLengthVectorIterator::<'frame, 'metadata, T>::deserialize(typ, v) + .and_then(|it| it.collect::>()) + .map_err(deser_error_replace_rust_name::) + } + _ => unreachable!("Should be prevented by typecheck"), + } } } @@ -880,6 +914,197 @@ where } } +pub struct ConstLengthVectorIterator<'frame, 'metadata, T> { + coll_typ: &'metadata ColumnType<'metadata>, + elem_typ: &'metadata ColumnType<'metadata>, + count: usize, + raw_iter: VectorBytesSequenceIterator<'frame>, + phantom_data: std::marker::PhantomData, +} + +impl<'frame, 'metadata, T> ConstLengthVectorIterator<'frame, 'metadata, T> { + fn new( + coll_typ: &'metadata ColumnType<'metadata>, + elem_typ: &'metadata ColumnType<'metadata>, + count: usize, + elem_len: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + elem_typ, + count, + raw_iter: VectorBytesSequenceIterator::new(count, elem_len, slice), + phantom_data: std::marker::PhantomData, + } + } +} + +impl<'frame, 'metadata, T> DeserializeValue<'frame, 'metadata> + for ConstLengthVectorIterator<'frame, 'metadata, T> +where + T: DeserializeValue<'frame, 'metadata>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::Vector(t, _) => { + if t.type_size().is_none() { + return Err(mk_typck_err::( + typ, + VectorTypeCheckErrorKind::ElementSizeUnknown, + )); + } + >::type_check(t).map_err(|err| { + mk_typck_err::(typ, VectorTypeCheckErrorKind::ElementTypeCheckFailed(err)) + })?; + Ok(()) + } + _ => Err(mk_typck_err::( + typ, + VectorTypeCheckErrorKind::NotVector, + )), + } + } + + fn deserialize( + typ: &'metadata ColumnType<'metadata>, + v: Option>, + ) -> Result { + let (t, dim) = match typ { + ColumnType::Vector(t, dim) => (t, dim), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + + let v = ensure_not_null_frame_slice::(typ, v)?; + + Ok(Self::new(typ, t, *dim as usize, t.type_size().unwrap(), v)) + } +} + +impl<'frame, 'metadata, T> Iterator for ConstLengthVectorIterator<'frame, 'metadata, T> +where + T: DeserializeValue<'frame, 'metadata>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + let raw = self.raw_iter.next()?.map_err(|err| { + mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), + ) + }); + Some(raw.and_then(|raw| { + T::deserialize(self.elem_typ, raw).map_err(|err| { + mk_deser_err::( + self.coll_typ, + VectorDeserializationErrorKind::ElementDeserializationFailed(err), + ) + }) + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + +pub struct VariableLengthVectorIterator<'frame, 'metadata, T> { + coll_typ: &'metadata ColumnType<'metadata>, + elem_typ: &'metadata ColumnType<'metadata>, + raw_iter: FixedLengthBytesSequenceIterator<'frame>, + phantom_data: std::marker::PhantomData, +} +impl<'frame, 'metadata, T> VariableLengthVectorIterator<'frame, 'metadata, T> { + fn new( + coll_typ: &'metadata ColumnType<'metadata>, + elem_typ: &'metadata ColumnType<'metadata>, + count: usize, + slice: FrameSlice<'frame>, + ) -> Self { + Self { + coll_typ, + elem_typ, + raw_iter: FixedLengthBytesSequenceIterator::new(count, slice), + phantom_data: std::marker::PhantomData, + } + } +} + +impl<'frame, 'metadata, T> DeserializeValue<'frame, 'metadata> + for VariableLengthVectorIterator<'frame, 'metadata, T> +where + T: DeserializeValue<'frame, 'metadata>, +{ + fn type_check(typ: &ColumnType) -> Result<(), TypeCheckError> { + match typ { + ColumnType::Vector(t, _) => { + if t.type_size().is_some() { + return Err(mk_typck_err::( + typ, + VectorTypeCheckErrorKind::ElementSizeUnknown, + )); + } + >::type_check(t).map_err(|err| { + mk_typck_err::(typ, VectorTypeCheckErrorKind::ElementTypeCheckFailed(err)) + })?; + Ok(()) + } + _ => Err(mk_typck_err::( + typ, + VectorTypeCheckErrorKind::NotVector, + )), + } + } + + fn deserialize( + typ: &'metadata ColumnType<'metadata>, + v: Option>, + ) -> Result { + let (t, dim) = match typ { + ColumnType::Vector(t, dim) => (t, dim), + _ => { + unreachable!("Typecheck should have prevented this scenario!") + } + }; + + let v = ensure_not_null_frame_slice::(typ, v)?; + + Ok(Self::new(typ, t, *dim as usize, v)) + } +} +impl<'frame, 'metadata, T> Iterator for VariableLengthVectorIterator<'frame, 'metadata, T> +where + T: DeserializeValue<'frame, 'metadata>, +{ + type Item = Result; + + fn next(&mut self) -> Option { + let raw = self.raw_iter.next()?.map_err(|err| { + mk_deser_err::( + self.coll_typ, + BuiltinDeserializationErrorKind::RawCqlBytesReadError(err), + ) + }); + Some(raw.and_then(|raw| { + T::deserialize(self.elem_typ, raw).map_err(|err| { + mk_deser_err::( + self.coll_typ, + VectorDeserializationErrorKind::ElementDeserializationFailed(err), + ) + }) + })) + } + + #[inline] + fn size_hint(&self) -> (usize, Option) { + self.raw_iter.size_hint() + } +} + /// An iterator over a CQL map. pub struct MapIterator<'frame, 'metadata, K, V> { coll_typ: &'metadata ColumnType<'metadata>, @@ -1409,6 +1634,38 @@ impl<'frame> Iterator for BytesSequenceIterator<'frame> { } } +/// Iterates over a sequence of CQL vector items from a frame subslice, expecting +/// a particular number of items. +/// +/// The iterator does not consider it to be an error if there are some bytes +/// remaining in the slice after parsing requested amount of items. +#[derive(Clone, Copy, Debug)] +pub struct VectorBytesSequenceIterator<'frame> { + slice: FrameSlice<'frame>, + elem_len: usize, + remaining: usize, +} + +impl<'frame> VectorBytesSequenceIterator<'frame> { + fn new(count: usize, elem_len: usize, slice: FrameSlice<'frame>) -> Self { + Self { + slice, + elem_len, + remaining: count, + } + } +} + +impl<'frame> Iterator for VectorBytesSequenceIterator<'frame> { + type Item = Result>, LowLevelDeserializationError>; + + #[inline] + fn next(&mut self) -> Option { + self.remaining = self.remaining.checked_sub(1)?; + Some(self.slice.read_subslice(self.elem_len)) + } +} + // Error facilities /// Type checking of one of the built-in types failed. @@ -1474,6 +1731,9 @@ pub enum BuiltinTypeCheckErrorKind { /// A type check failure specific to a CQL set or list. SetOrListError(SetOrListTypeCheckErrorKind), + /// A type check failure specific to a CQL vector. + VectorError(VectorTypeCheckErrorKind), + /// A type check failure specific to a CQL map. MapError(MapTypeCheckErrorKind), @@ -1491,6 +1751,13 @@ impl From for BuiltinTypeCheckErrorKind { } } +impl From for BuiltinTypeCheckErrorKind { + #[inline] + fn from(value: VectorTypeCheckErrorKind) -> Self { + BuiltinTypeCheckErrorKind::VectorError(value) + } +} + impl From for BuiltinTypeCheckErrorKind { #[inline] fn from(value: MapTypeCheckErrorKind) -> Self { @@ -1519,6 +1786,7 @@ impl Display for BuiltinTypeCheckErrorKind { write!(f, "expected one of the CQL types: {expected:?}") } BuiltinTypeCheckErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinTypeCheckErrorKind::VectorError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::MapError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::TupleError(err) => err.fmt(f), BuiltinTypeCheckErrorKind::UdtError(err) => err.fmt(f), @@ -1554,6 +1822,38 @@ impl Display for SetOrListTypeCheckErrorKind { } } +/// Describes why type checking a vector type failed. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum VectorTypeCheckErrorKind { + NotVector, + /// Incompatible element types. + ElementTypeCheckFailed(TypeCheckError), + /// The size of the elements in the vector is unknown. + ElementSizeUnknown, + // The size of the elements in the vector is known. + ElementSizeKnown, +} + +impl Display for VectorTypeCheckErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorTypeCheckErrorKind::NotVector => { + f.write_str("the CQL type the Rust type was attempted to be type checked against was not a vector") + } + VectorTypeCheckErrorKind::ElementTypeCheckFailed(err) => { + write!(f, "the vector element types between the CQL type and the Rust type failed to type check against each other: {}", err) + } + VectorTypeCheckErrorKind::ElementSizeUnknown => { + f.write_str("the size of the elements in the vector is unknown") + } + VectorTypeCheckErrorKind::ElementSizeKnown => { + f.write_str("the size of the elements in the vector is known") + } + } + } +} + /// Describes why type checking of a map type failed. #[derive(Debug, Clone)] #[non_exhaustive] @@ -1803,6 +2103,9 @@ pub enum BuiltinDeserializationErrorKind { /// A deserialization failure specific to a CQL set or list. SetOrListError(SetOrListDeserializationErrorKind), + /// A deserialization failure specific to a CQL vector. + VectorError(VectorDeserializationErrorKind), + /// A deserialization failure specific to a CQL map. MapError(MapDeserializationErrorKind), @@ -1841,6 +2144,7 @@ impl Display for BuiltinDeserializationErrorKind { "the length of read value in bytes ({got}) is not suitable for IP address; expected 4 or 16" ), BuiltinDeserializationErrorKind::SetOrListError(err) => err.fmt(f), + BuiltinDeserializationErrorKind::VectorError(err) => err.fmt(f), BuiltinDeserializationErrorKind::MapError(err) => err.fmt(f), BuiltinDeserializationErrorKind::TupleError(err) => err.fmt(f), BuiltinDeserializationErrorKind::UdtError(err) => err.fmt(f), @@ -1880,6 +2184,31 @@ impl From for BuiltinDeserializationErrorKind } } +/// Describes why deserialization of a vector type failed. +#[derive(Debug)] +#[non_exhaustive] +pub enum VectorDeserializationErrorKind { + /// One of the elements of the vector failed to deserialize. + ElementDeserializationFailed(DeserializationError), +} + +impl Display for VectorDeserializationErrorKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VectorDeserializationErrorKind::ElementDeserializationFailed(err) => { + write!(f, "failed to deserialize one of the elements: {}", err) + } + } + } +} + +impl From for BuiltinDeserializationErrorKind { + #[inline] + fn from(err: VectorDeserializationErrorKind) -> Self { + Self::VectorError(err) + } +} + /// Describes why deserialization of a map type failed. #[derive(Debug)] #[non_exhaustive]