Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

scylla-macros: attributes for better control over name checks #882

Merged
merged 5 commits into from
Dec 15, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions scylla-cql/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub use scylla_macros::ValueList;
/// }
/// ```
///
/// # Attributes
/// # Struct attributes
///
/// `#[scylla(flavor = "flavor_name")]`
///
Expand Down Expand Up @@ -86,6 +86,22 @@ pub use scylla_macros::ValueList;
/// It's not possible to automatically resolve those issues in the procedural
/// macro itself, so in those cases the user must provide an alternative path
/// to either the `scylla` or `scylla-cql` crate.
///
/// `#[scylla(skip_name_checks)]
///
/// _Specific only to the `enforce_order` flavor._
///
/// Skips checking Rust field names against names of the UDT fields. With this
/// annotation, the generated implementation will allow mismatch between Rust
/// struct field names and UDT field names, i.e. it's OK if i-th field has a
/// different name in Rust and in the UDT. Fields are still being type-checked.
///
Comment on lines +90 to +98
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it wouldn't be better to make this a new flavor instead of an attribute - one less place to make a mistake for a user. But seeing that it's a compile time error, maybe it doesn't matter?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely sure either; implementation-wise it was more natural to add an attribute as it introduces very slight changes to how the enforce_order flavor generates code.

One argument that I see for leaving it as a separate attribute is that it disables some safety checks, while both current flavors don't compromise on safety. It's good to make the less safe option more verbose IMO.

In the future we should consider adding support for tuple structs which very natually fit into the enforce_order + skip_name_checks pattern, we could probably remove the current skip_name_checks attribute when doing that.

/// # Field attributes
///
/// `#[scylla(rename = "name_in_the_udt")]`
///
/// Serializes the field to the UDT struct field with given name instead of
/// its Rust name.
pub use scylla_macros::SerializeCql;

/// Derive macro for the [`SerializeRow`](crate::types::serialize::row::SerializeRow) trait
Expand Down Expand Up @@ -123,7 +139,7 @@ pub use scylla_macros::SerializeCql;
/// }
/// ```
///
/// # Attributes
/// # Struct attributes
///
/// `#[scylla(flavor = "flavor_name")]`
///
Expand Down Expand Up @@ -163,6 +179,13 @@ pub use scylla_macros::SerializeCql;
/// It's not possible to automatically resolve those issues in the procedural
/// macro itself, so in those cases the user must provide an alternative path
/// to either the `scylla` or `scylla-cql` crate.
///
/// # Field attributes
///
/// `#[scylla(rename = "column_or_bind_marker_name")]`
///
/// Serializes the field to the column / bind marker with given name instead of
/// its Rust name.
pub use scylla_macros::SerializeRow;

// Reexports for derive(IntoUserType)
Expand Down
48 changes: 48 additions & 0 deletions scylla-cql/src/types/serialize/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1381,4 +1381,52 @@ mod tests {
.iter()
.all(|v| v == RawValue::Value(&[0, 0, 0, 0, 0x07, 0x5b, 0xcd, 0x15])))
}

#[derive(SerializeRow, Debug)]
#[scylla(crate = crate)]
struct TestRowWithColumnRename {
a: String,
#[scylla(rename = "x")]
b: i32,
}

#[derive(SerializeRow, Debug)]
#[scylla(crate = crate, flavor = "enforce_order")]
struct TestRowWithColumnRenameAndEnforceOrder {
a: String,
#[scylla(rename = "x")]
b: i32,
}

#[test]
fn test_row_serialization_with_column_rename() {
let spec = [col("x", ColumnType::Int), col("a", ColumnType::Text)];

let reference = do_serialize((42i32, "Ala ma kota"), &spec);
let row = do_serialize(
TestRowWithColumnRename {
a: "Ala ma kota".to_owned(),
b: 42,
},
&spec,
);

assert_eq!(reference, row);
}

#[test]
fn test_row_serialization_with_column_rename_and_enforce_order() {
let spec = [col("a", ColumnType::Text), col("x", ColumnType::Int)];

let reference = do_serialize(("Ala ma kota", 42i32), &spec);
let row = do_serialize(
TestRowWithColumnRenameAndEnforceOrder {
a: "Ala ma kota".to_owned(),
b: 42,
},
&spec,
);

assert_eq!(reference, row);
}
}
119 changes: 119 additions & 0 deletions scylla-cql/src/types/serialize/value.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2354,4 +2354,123 @@ mod tests {
)
));
}

#[derive(SerializeCql, Debug)]
#[scylla(crate = crate)]
struct TestUdtWithFieldRename {
a: String,
#[scylla(rename = "x")]
b: i32,
}

#[derive(SerializeCql, Debug)]
#[scylla(crate = crate, flavor = "enforce_order")]
struct TestUdtWithFieldRenameAndEnforceOrder {
a: String,
#[scylla(rename = "x")]
b: i32,
}

#[test]
fn test_udt_serialization_with_field_rename() {
let typ = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("x".to_string(), ColumnType::Int),
("a".to_string(), ColumnType::Text),
],
};

let mut reference = Vec::new();
// Total length of the struct is 23
reference.extend_from_slice(&23i32.to_be_bytes());
// Field 'x'
reference.extend_from_slice(&4i32.to_be_bytes());
reference.extend_from_slice(&42i32.to_be_bytes());
// Field 'a'
reference.extend_from_slice(&("Ala ma kota".len() as i32).to_be_bytes());
reference.extend_from_slice("Ala ma kota".as_bytes());

let udt = do_serialize(
TestUdtWithFieldRename {
a: "Ala ma kota".to_owned(),
b: 42,
},
&typ,
);

assert_eq!(reference, udt);
}

#[test]
fn test_udt_serialization_with_field_rename_and_enforce_order() {
let typ = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("a".to_string(), ColumnType::Text),
("x".to_string(), ColumnType::Int),
],
};

let mut reference = Vec::new();
// Total length of the struct is 23
reference.extend_from_slice(&23i32.to_be_bytes());
// Field 'a'
reference.extend_from_slice(&("Ala ma kota".len() as i32).to_be_bytes());
reference.extend_from_slice("Ala ma kota".as_bytes());
// Field 'x'
reference.extend_from_slice(&4i32.to_be_bytes());
reference.extend_from_slice(&42i32.to_be_bytes());

let udt = do_serialize(
TestUdtWithFieldRenameAndEnforceOrder {
a: "Ala ma kota".to_owned(),
b: 42,
},
&typ,
);

assert_eq!(reference, udt);
}

#[derive(SerializeCql, Debug)]
#[scylla(crate = crate, flavor = "enforce_order", skip_name_checks)]
struct TestUdtWithSkippedNameChecks {
a: String,
b: i32,
}

#[test]
fn test_udt_serialization_with_skipped_name_checks() {
let typ = ColumnType::UserDefinedType {
type_name: "typ".to_string(),
keyspace: "ks".to_string(),
field_types: vec![
("a".to_string(), ColumnType::Text),
("x".to_string(), ColumnType::Int),
],
};

let mut reference = Vec::new();
// Total length of the struct is 23
reference.extend_from_slice(&23i32.to_be_bytes());
// Field 'a'
reference.extend_from_slice(&("Ala ma kota".len() as i32).to_be_bytes());
reference.extend_from_slice("Ala ma kota".as_bytes());
// Field 'x'
reference.extend_from_slice(&4i32.to_be_bytes());
reference.extend_from_slice(&42i32.to_be_bytes());

let udt = do_serialize(
TestUdtWithFieldRenameAndEnforceOrder {
a: "Ala ma kota".to_owned(),
b: 42,
},
&typ,
);

assert_eq!(reference, udt);
}
}
108 changes: 98 additions & 10 deletions scylla-macros/src/serialize/cql.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use darling::FromAttributes;
use proc_macro::TokenStream;
use proc_macro2::Span;
Expand All @@ -11,7 +13,11 @@ struct Attributes {
#[darling(rename = "crate")]
crate_path: Option<syn::Path>,

flavor: Option<Flavor>,
#[darling(default)]
flavor: Flavor,

#[darling(default)]
skip_name_checks: bool,
}

impl Attributes {
Expand All @@ -23,9 +29,30 @@ impl Attributes {
}
}

struct Field {
ident: syn::Ident,
ty: syn::Type,
attrs: FieldAttributes,
}

impl Field {
fn field_name(&self) -> String {
match &self.attrs.rename {
Some(name) => name.clone(),
None => self.ident.to_string(),
}
}
}

#[derive(FromAttributes)]
#[darling(attributes(scylla))]
struct FieldAttributes {
rename: Option<String>,
}

struct Context {
attributes: Attributes,
fields: Vec<syn::Field>,
fields: Vec<Field>,
}

pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result<syn::ItemImpl, syn::Error> {
Expand All @@ -38,12 +65,23 @@ pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result<syn::ItemImpl,
let crate_path = attributes.crate_path();
let implemented_trait: syn::Path = parse_quote!(#crate_path::SerializeCql);

let fields = named_fields.named.iter().cloned().collect();
let fields = named_fields
.named
.iter()
.map(|f| {
FieldAttributes::from_attributes(&f.attrs).map(|attrs| Field {
ident: f.ident.clone().unwrap(),
ty: f.ty.clone(),
attrs,
})
})
.collect::<Result<_, _>>()?;
let ctx = Context { attributes, fields };
ctx.validate(&input.ident)?;

let gen: Box<dyn Generator> = match ctx.attributes.flavor {
Some(Flavor::MatchByName) | None => Box::new(FieldSortingGenerator { ctx: &ctx }),
Some(Flavor::EnforceOrder) => Box::new(FieldOrderedGenerator { ctx: &ctx }),
Flavor::MatchByName => Box::new(FieldSortingGenerator { ctx: &ctx }),
Flavor::EnforceOrder => Box::new(FieldOrderedGenerator { ctx: &ctx }),
};

let serialize_item = gen.generate_serialize();
Expand All @@ -57,6 +95,49 @@ pub fn derive_serialize_cql(tokens_input: TokenStream) -> Result<syn::ItemImpl,
}

impl Context {
fn validate(&self, struct_ident: &syn::Ident) -> Result<(), syn::Error> {
let mut errors = darling::Error::accumulator();

if self.attributes.skip_name_checks {
// Skipping name checks is only available in enforce_order mode
if self.attributes.flavor != Flavor::EnforceOrder {
let err = darling::Error::custom(
"the `skip_name_checks` attribute is only allowed with the `enforce_order` flavor",
)
.with_span(struct_ident);
errors.push(err);
}

// `rename` annotations don't make sense with skipped name checks
for field in self.fields.iter() {
if field.attrs.rename.is_some() {
let err = darling::Error::custom(
"the `rename` annotations don't make sense with `skip_name_checks` attribute",
)
.with_span(&field.ident);
errors.push(err);
}
}
}

// Check for name collisions
let mut used_names = HashMap::<String, &Field>::new();
for field in self.fields.iter() {
let field_name = field.field_name();
if let Some(other_field) = used_names.get(&field_name) {
let other_field_ident = &other_field.ident;
let msg = format!("the UDT field name `{field_name}` used by this struct field is already used by field `{other_field_ident}`");
let err = darling::Error::custom(msg).with_span(&field.ident);
errors.push(err);
} else {
used_names.insert(field_name, field);
}
}

errors.finish()?;
Ok(())
}

fn generate_udt_type_match(&self, err: syn::Expr) -> syn::Stmt {
let crate_path = self.attributes.crate_path();

Expand Down Expand Up @@ -126,9 +207,11 @@ impl<'a> Generator for FieldSortingGenerator<'a> {
.iter()
.map(|f| f.ident.clone())
.collect::<Vec<_>>();
let rust_field_names = rust_field_idents
let rust_field_names = self
.ctx
.fields
.iter()
.map(|i| i.as_ref().unwrap().to_string())
.map(|f| f.field_name())
.collect::<Vec<_>>();
let udt_field_names = rust_field_names.clone(); // For now, it's the same
let field_types = self.ctx.fields.iter().map(|f| &f.ty).collect::<Vec<_>>();
Expand Down Expand Up @@ -269,13 +352,18 @@ impl<'a> Generator for FieldOrderedGenerator<'a> {

// Serialize each field
for field in self.ctx.fields.iter() {
let rust_field_ident = field.ident.as_ref().unwrap();
let rust_field_name = rust_field_ident.to_string();
let rust_field_ident = &field.ident;
let rust_field_name = field.field_name();
let typ = &field.ty;
let name_check_expression: syn::Expr = if !self.ctx.attributes.skip_name_checks {
parse_quote! { field_name == #rust_field_name }
} else {
parse_quote! { true }
};
statements.push(parse_quote! {
match field_iter.next() {
Some((field_name, typ)) => {
if field_name == #rust_field_name {
if #name_check_expression {
let sub_builder = #crate_path::CellValueBuilder::make_sub_writer(&mut builder);
match <#typ as #crate_path::SerializeCql>::serialize(&self.#rust_field_ident, typ, sub_builder) {
Ok(_proof) => {},
Expand Down
Loading