diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index f8d341445..72dd34415 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -241,6 +241,9 @@ impl CodeGenerator<'_> { self.push_indent(); self.buf.push_str("pub struct "); self.buf.push_str(&to_upper_camel(&message_name)); + if self.message_graph.message_has_lifetime(&fq_message_name) { + self.buf.push_str("<'a>"); + } self.buf.push_str(" {\n"); self.depth += 1; @@ -406,13 +409,15 @@ impl CodeGenerator<'_> { let deprecated = self.deprecated(&field.descriptor); let optional = self.optional(&field.descriptor); let boxed = self.boxed(&field.descriptor, fq_message_name, None); - let ty = self.resolve_type(&field.descriptor, fq_message_name); + let cowed = self.cowed(&field.descriptor, fq_message_name, None); + let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed); debug!( - " field: {:?}, type: {:?}, boxed: {}", + " field: {:?}, type: {:?}, boxed: {} cowed: {}", field.descriptor.name(), ty, - boxed + boxed, + cowed ); self.append_doc(fq_message_name, Some(field.descriptor.name())); @@ -424,10 +429,10 @@ impl CodeGenerator<'_> { self.push_indent(); self.buf.push_str("#[prost("); - let type_tag = self.field_type_tag(&field.descriptor); + let type_tag = self.field_type_tag(&field.descriptor, cowed); self.buf.push_str(&type_tag); - if type_ == Type::Bytes { + if !cowed && type_ == Type::Bytes { let bytes_type = self .config .bytes_type @@ -532,8 +537,9 @@ impl CodeGenerator<'_> { key: &FieldDescriptorProto, value: &FieldDescriptorProto, ) { - let key_ty = self.resolve_type(key, fq_message_name); - let value_ty = self.resolve_type(value, fq_message_name); + let map_cowed = self.cowed(&field.descriptor, fq_message_name, None); + let key_ty = self.resolve_type(key, fq_message_name, map_cowed); + let value_ty = self.resolve_type(value, fq_message_name, map_cowed); debug!( " map field: {:?}, key type: {:?}, value type: {:?}", @@ -551,8 +557,8 @@ impl CodeGenerator<'_> { .get_first_field(fq_message_name, field.descriptor.name()) .copied() .unwrap_or_default(); - let key_tag = self.field_type_tag(key); - let value_tag = self.map_value_type_tag(value); + let key_tag = self.field_type_tag(key, map_cowed); + let value_tag = self.map_value_type_tag(value, map_cowed); self.buf.push_str(&format!( "#[prost({}=\"{}, {}\", tag=\"{}\")]\n", @@ -597,9 +603,14 @@ impl CodeGenerator<'_> { self.append_field_attributes(fq_message_name, oneof.descriptor.name()); self.push_indent(); self.buf.push_str(&format!( - "pub {}: ::core::option::Option<{}>,\n", + "pub {}: ::core::option::Option<{}{}>,\n", oneof.rust_name(), - type_name + type_name, + if self.message_graph.message_has_lifetime(fq_message_name) { + "<'a>" + } else { + "" + }, )); } @@ -628,6 +639,9 @@ impl CodeGenerator<'_> { self.push_indent(); self.buf.push_str("pub enum "); self.buf.push_str(&to_upper_camel(oneof.descriptor.name())); + if self.message_graph.message_has_lifetime(fq_message_name) { + self.buf.push_str("<'a>"); + } self.buf.push_str(" {\n"); self.path.push(2); @@ -637,8 +651,14 @@ impl CodeGenerator<'_> { self.append_doc(fq_message_name, Some(field.descriptor.name())); self.path.pop(); + let cowed = self.cowed( + &field.descriptor, + fq_message_name, + Some(oneof.descriptor.name()), + ); + self.push_indent(); - let ty_tag = self.field_type_tag(&field.descriptor); + let ty_tag = self.field_type_tag(&field.descriptor, cowed); self.buf.push_str(&format!( "#[prost({}, tag=\"{}\")]\n", ty_tag, @@ -647,7 +667,7 @@ impl CodeGenerator<'_> { self.append_field_attributes(&oneof_name, field.descriptor.name()); self.push_indent(); - let ty = self.resolve_type(&field.descriptor, fq_message_name); + let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed); let boxed = self.boxed( &field.descriptor, @@ -656,10 +676,11 @@ impl CodeGenerator<'_> { ); debug!( - " oneof: {:?}, type: {:?}, boxed: {}", + " oneof: {:?}, type: {:?}, boxed: {} cowed: {}", field.descriptor.name(), ty, - boxed + boxed, + cowed, ); if boxed { @@ -883,8 +904,8 @@ impl CodeGenerator<'_> { let name = method.name.take().unwrap(); let input_proto_type = method.input_type.take().unwrap(); let output_proto_type = method.output_type.take().unwrap(); - let input_type = self.resolve_ident(&input_proto_type); - let output_type = self.resolve_ident(&output_proto_type); + let input_type = self.resolve_ident(&input_proto_type).0; + let output_type = self.resolve_ident(&output_proto_type).0; let client_streaming = method.client_streaming(); let server_streaming = method.server_streaming(); @@ -947,7 +968,12 @@ impl CodeGenerator<'_> { self.buf.push_str("}\n"); } - fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String { + fn resolve_type( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + cowed: bool, + ) -> String { match field.r#type() { Type::Float => String::from("f32"), Type::Double => String::from("f64"), @@ -956,7 +982,13 @@ impl CodeGenerator<'_> { Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"), Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"), Type::Bool => String::from("bool"), + Type::String if cowed => { + format!("{}::alloc::borrow::Cow<'a, str>", prost_path(self.config)) + } Type::String => format!("{}::alloc::string::String", prost_path(self.config)), + Type::Bytes if cowed => { + format!("{}::alloc::borrow::Cow<'a, [u8]>", prost_path(self.config)) + } Type::Bytes => self .config .bytes_type @@ -965,16 +997,28 @@ impl CodeGenerator<'_> { .unwrap_or_default() .rust_type() .to_owned(), - Type::Group | Type::Message => self.resolve_ident(field.type_name()), + Type::Group | Type::Message => { + let (mut s, is_extern) = self.resolve_ident(field.type_name()); + if !is_extern + && cowed + && self + .message_graph + .field_has_lifetime(fq_message_name, field) + { + s.push_str("<'a>"); + } + s + } } } - fn resolve_ident(&self, pb_ident: &str) -> String { + /// Returns the identifier and a bool indicating if its an extern + fn resolve_ident(&self, pb_ident: &str) -> (String, bool) { // protoc should always give fully qualified identifiers. assert_eq!(".", &pb_ident[..1]); if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) { - return proto_ident; + return (proto_ident, true); } let mut local_path = self @@ -1000,14 +1044,15 @@ impl CodeGenerator<'_> { ident_path.next(); } - local_path + let s = local_path .map(|_| "super".to_string()) .chain(ident_path.map(to_snake)) .chain(iter::once(to_upper_camel(ident_type))) - .join("::") + .join("::"); + (s, false) } - fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn field_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> { match field.r#type() { Type::Float => Cow::Borrowed("float"), Type::Double => Cow::Borrowed("double"), @@ -1022,24 +1067,26 @@ impl CodeGenerator<'_> { Type::Sfixed32 => Cow::Borrowed("sfixed32"), Type::Sfixed64 => Cow::Borrowed("sfixed64"), Type::Bool => Cow::Borrowed("bool"), + Type::String if cowed => Cow::Borrowed("cow_str"), Type::String => Cow::Borrowed("string"), + Type::Bytes if cowed => Cow::Borrowed("cow_bytes"), Type::Bytes => Cow::Borrowed("bytes"), Type::Group => Cow::Borrowed("group"), Type::Message => Cow::Borrowed("message"), Type::Enum => Cow::Owned(format!( "enumeration={:?}", - self.resolve_ident(field.type_name()) + self.resolve_ident(field.type_name()).0 )), } } - fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> { + fn map_value_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> { match field.r#type() { Type::Enum => Cow::Owned(format!( "enumeration({})", - self.resolve_ident(field.type_name()) + self.resolve_ident(field.type_name()).0 )), - _ => self.field_type_tag(field), + _ => self.field_type_tag(field, cowed), } } @@ -1101,6 +1148,33 @@ impl CodeGenerator<'_> { false } + /// Returns whether the Rust type for this field needs to be `Cow<_>`. + fn cowed( + &self, + field: &FieldDescriptorProto, + fq_message_name: &str, + oneof: Option<&str>, + ) -> bool { + let fd_type = field.r#type(); + + // We only support Cow for Bytes and String + if !matches!( + fd_type, + Type::Message | Type::Group | Type::Bytes | Type::String + ) { + return false; + } + + let config_path = match oneof { + None => Cow::Borrowed(fq_message_name), + Some(ooname) => Cow::Owned(format!("{fq_message_name}.{ooname}")), + }; + self.config + .cowed + .get_first_field(&config_path, field.name()) + .is_some() + } + /// Returns `true` if the field options includes the `deprecated` option. fn deprecated(&self, field: &FieldDescriptorProto) -> bool { field diff --git a/prost-build/src/config.rs b/prost-build/src/config.rs index bb9f8697a..b63c2a295 100644 --- a/prost-build/src/config.rs +++ b/prost-build/src/config.rs @@ -36,6 +36,7 @@ pub struct Config { pub(crate) enum_attributes: PathMap, pub(crate) field_attributes: PathMap, pub(crate) boxed: PathMap<()>, + pub(crate) cowed: PathMap<()>, pub(crate) prost_types: bool, pub(crate) strip_enum_prefix: bool, pub(crate) out_dir: Option, @@ -373,6 +374,14 @@ impl Config { self } + pub fn cowed

(&mut self, path: P) -> &mut Self + where + P: AsRef, + { + self.cowed.insert(path.as_ref().to_string(), ()); + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -1101,7 +1110,11 @@ impl Config { let mut modules = HashMap::new(); let mut packages = HashMap::new(); - let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1), self.boxed.clone()); + let message_graph = MessageGraph::new( + requests.iter().map(|x| &x.1), + self.boxed.clone(), + self.cowed.clone(), + ); let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types) .map_err(|error| Error::new(ErrorKind::InvalidInput, error))?; @@ -1181,6 +1194,7 @@ impl default::Default for Config { enum_attributes: PathMap::default(), field_attributes: PathMap::default(), boxed: PathMap::default(), + cowed: PathMap::default(), prost_types: true, strip_enum_prefix: true, out_dir: None, diff --git a/prost-build/src/message_graph.rs b/prost-build/src/message_graph.rs index e2bcad918..b8666ad29 100644 --- a/prost-build/src/message_graph.rs +++ b/prost-build/src/message_graph.rs @@ -1,4 +1,4 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use petgraph::algo::has_path_connecting; use petgraph::graph::NodeIndex; @@ -19,18 +19,21 @@ pub struct MessageGraph { graph: Graph, messages: HashMap, boxed: PathMap<()>, + cowed: PathMap<()>, } impl MessageGraph { pub(crate) fn new<'a>( files: impl Iterator, boxed: PathMap<()>, + cowed: PathMap<()>, ) -> MessageGraph { let mut msg_graph = MessageGraph { index: HashMap::new(), graph: Graph::new(), messages: HashMap::new(), boxed, + cowed, }; for file in files { @@ -153,4 +156,50 @@ impl MessageGraph { ) } } + + fn message_has_lifetime_internal( + &self, + fq_message_name: &str, + visited: &mut HashSet, + ) -> bool { + visited.insert(fq_message_name.to_string()); + assert_eq!(".", &fq_message_name[..1]); + self.get_message(fq_message_name) + .unwrap() + .field + .iter() + .any(|field| self.field_has_lifetime_internal(fq_message_name, field, visited)) + } + + pub fn message_has_lifetime(&self, fq_message_name: &str) -> bool { + let mut visited = Default::default(); + self.message_has_lifetime_internal(fq_message_name, &mut visited) + } + + fn field_has_lifetime_internal( + &self, + fq_message_name: &str, + field: &FieldDescriptorProto, + visited: &mut HashSet, + ) -> bool { + assert_eq!(".", &fq_message_name[..1]); + + if field.r#type() == Type::Message { + if visited.contains(field.type_name()) { + return false; + } + self.message_has_lifetime_internal(field.type_name(), visited) + } else { + matches!(field.r#type(), Type::Bytes | Type::String) + && self + .cowed + .get_first_field(fq_message_name, field.name()) + .is_some() + } + } + + pub fn field_has_lifetime(&self, fq_message_name: &str, field: &FieldDescriptorProto) -> bool { + let mut visited = Default::default(); + self.field_has_lifetime_internal(fq_message_name, field, &mut visited) + } } diff --git a/prost-derive/src/field/map.rs b/prost-derive/src/field/map.rs index 4855cc5c6..a8ccf0555 100644 --- a/prost-derive/src/field/map.rs +++ b/prost-derive/src/field/map.rs @@ -367,6 +367,7 @@ fn key_ty_from_str(s: &str) -> Result { | scalar::Ty::Sfixed32 | scalar::Ty::Sfixed64 | scalar::Ty::Bool + | scalar::Ty::CowStr | scalar::Ty::String => Ok(ty), _ => bail!("invalid map key type: {}", s), } diff --git a/prost-derive/src/field/scalar.rs b/prost-derive/src/field/scalar.rs index c2e870524..d6cae075d 100644 --- a/prost-derive/src/field/scalar.rs +++ b/prost-derive/src/field/scalar.rs @@ -194,6 +194,8 @@ impl Field { Kind::Plain(ref default) | Kind::Required(ref default) => { let default = default.typed(); match self.ty { + Ty::CowStr => quote!(#ident = ::prost::alloc::borrow::Cow::Borrowed("")), + Ty::CowBytes => quote!(#ident = ::prost::alloc::borrow::Cow::Borrowed(&[])), Ty::String | Ty::Bytes(..) => quote!(#ident.clear()), _ => quote!(#ident = #default), } @@ -398,6 +400,8 @@ pub enum Ty { Sfixed64, Bool, String, + CowStr, + CowBytes, Bytes(BytesTy), Enumeration(Path), } @@ -442,6 +446,8 @@ impl Ty { Meta::Path(ref name) if name.is_ident("sfixed64") => Ty::Sfixed64, Meta::Path(ref name) if name.is_ident("bool") => Ty::Bool, Meta::Path(ref name) if name.is_ident("string") => Ty::String, + Meta::Path(ref name) if name.is_ident("cow_str") => Ty::CowStr, + Meta::Path(ref name) if name.is_ident("cow_bytes") => Ty::CowBytes, Meta::Path(ref name) if name.is_ident("bytes") => Ty::Bytes(BytesTy::Vec), Meta::NameValue(MetaNameValue { ref path, @@ -486,6 +492,8 @@ impl Ty { "sfixed32" => Ty::Sfixed32, "sfixed64" => Ty::Sfixed64, "bool" => Ty::Bool, + "cow_str" => Ty::CowStr, + "cow_bytes" => Ty::CowBytes, "string" => Ty::String, "bytes" => Ty::Bytes(BytesTy::Vec), s if s.len() > enumeration_len && &s[..enumeration_len] == "enumeration" => { @@ -523,6 +531,8 @@ impl Ty { Ty::Sfixed64 => "sfixed64", Ty::Bool => "bool", Ty::String => "string", + Ty::CowStr => "cow_str", + Ty::CowBytes => "cow_bytes", Ty::Bytes(..) => "bytes", Ty::Enumeration(..) => "enum", } @@ -531,6 +541,8 @@ impl Ty { // TODO: rename to 'owned_type'. pub fn rust_type(&self) -> TokenStream { match self { + Ty::CowStr => quote!(::prost::alloc::borrow::Cow<'a, str>), + Ty::CowBytes => quote!(::prost::alloc::borrow::Cow<'a, [u8]>), Ty::String => quote!(::prost::alloc::string::String), Ty::Bytes(ty) => ty.rust_type(), _ => self.rust_ref_type(), @@ -553,8 +565,8 @@ impl Ty { Ty::Sfixed32 => quote!(i32), Ty::Sfixed64 => quote!(i64), Ty::Bool => quote!(bool), - Ty::String => quote!(&str), - Ty::Bytes(..) => quote!(&[u8]), + Ty::CowStr | Ty::String => quote!(&str), + Ty::CowBytes | Ty::Bytes(..) => quote!(&[u8]), Ty::Enumeration(..) => quote!(i32), } } @@ -568,7 +580,7 @@ impl Ty { /// Returns false if the scalar type is length delimited (i.e., `string` or `bytes`). pub fn is_numeric(&self) -> bool { - !matches!(self, Ty::String | Ty::Bytes(..)) + !matches!(self, Ty::CowStr | Ty::String | Ty::CowBytes | Ty::Bytes(..)) } } @@ -610,6 +622,8 @@ pub enum DefaultValue { U64(u64), Bool(bool), String(String), + CowStr(std::borrow::Cow<'static, str>), + CowBytes(Vec), Bytes(Vec), Enumeration(TokenStream), Path(Path), @@ -774,6 +788,8 @@ impl DefaultValue { Ty::Bool => DefaultValue::Bool(false), Ty::String => DefaultValue::String(String::new()), + Ty::CowStr => DefaultValue::CowStr(std::borrow::Cow::Borrowed("")), + Ty::CowBytes => DefaultValue::CowBytes(Vec::new()), Ty::Bytes(..) => DefaultValue::Bytes(Vec::new()), Ty::Enumeration(ref path) => DefaultValue::Enumeration(quote!(#path::default())), } @@ -785,6 +801,17 @@ impl DefaultValue { quote!(::prost::alloc::string::String::new()) } DefaultValue::String(ref value) => quote!(#value.into()), + DefaultValue::CowStr(ref value) if value.is_empty() => { + quote!(::prost::alloc::borrow::Cow::Borrowed("")) + } + DefaultValue::CowStr(ref value) => quote!(#value.into()), + DefaultValue::CowBytes(ref value) if value.is_empty() => { + quote!(::core::default::Default::default()) + } + DefaultValue::CowBytes(ref value) => { + let lit = LitByteStr::new(value, Span::call_site()); + quote!(#lit.as_ref().into()) + } DefaultValue::Bytes(ref value) if value.is_empty() => { quote!(::core::default::Default::default()) } @@ -817,6 +844,11 @@ impl ToTokens for DefaultValue { DefaultValue::U64(value) => value.to_tokens(tokens), DefaultValue::Bool(value) => value.to_tokens(tokens), DefaultValue::String(ref value) => value.to_tokens(tokens), + DefaultValue::CowStr(ref value) => value.to_tokens(tokens), + DefaultValue::CowBytes(ref value) => { + let byte_str = LitByteStr::new(value, Span::call_site()); + tokens.append_all(quote!(#byte_str as &[u8])); + } DefaultValue::Bytes(ref value) => { let byte_str = LitByteStr::new(value, Span::call_site()); tokens.append_all(quote!(#byte_str as &[u8])); diff --git a/prost/src/encoding.rs b/prost/src/encoding.rs index 8fd7cbf04..7b4d9390a 100644 --- a/prost/src/encoding.rs +++ b/prost/src/encoding.rs @@ -4,6 +4,7 @@ #![allow(clippy::implicit_hasher, clippy::ptr_arg)] +use alloc::borrow::Cow; use alloc::collections::BTreeMap; use alloc::format; use alloc::string::String; @@ -630,12 +631,96 @@ pub mod string { } } +pub mod cow_str { + use super::*; + + pub fn encode(tag: u32, value: &Cow<'_, str>, buf: &mut impl BufMut) { + encode_key(tag, WireType::LengthDelimited, buf); + encode_varint(value.as_ref().len() as u64, buf); + buf.put_slice(value.as_ref().as_bytes()); + } + + pub fn merge( + wire_type: WireType, + value: &mut Cow<'_, str>, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + // ## Unsafety + // + // `string::merge` reuses `bytes::merge`, with an additional check of utf-8 + // well-formedness. If the utf-8 is not well-formed, or if any other error occurs, then the + // string is cleared, so as to avoid leaking a string field with invalid data. + // + // This implementation uses the unsafe `String::as_mut_vec` method instead of the safe + // alternative of temporarily swapping an empty `String` into the field, because it results + // in up to 10% better performance on the protobuf message decoding benchmarks. + // + // It's required when using `String::as_mut_vec` that invalid utf-8 data not be leaked into + // the backing `String`. To enforce this, even in the event of a panic in `bytes::merge` or + // in the buf implementation, a drop guard is used. + unsafe { + struct DropGuard<'a>(&'a mut Vec); + impl Drop for DropGuard<'_> { + #[inline] + fn drop(&mut self) { + self.0.clear(); + } + } + + *value = Cow::Owned(String::default()); + let drop_guard = DropGuard(value.to_mut().as_mut_vec()); + bytes::merge_one_copy(wire_type, drop_guard.0, buf, ctx)?; + match str::from_utf8(drop_guard.0) { + Ok(_) => { + // Success; do not clear the bytes. + mem::forget(drop_guard); + Ok(()) + } + Err(_) => Err(DecodeError::new( + "invalid string value: data is not UTF-8 encoded", + )), + } + } + } + + encode_repeated!(Cow<'_, str>); + + //length_delimited!(String); + pub fn merge_repeated( + wire_type: WireType, + values: &mut Vec>, + buf: &mut impl Buf, + ctx: DecodeContext, + ) -> Result<(), DecodeError> { + check_wire_type(WireType::LengthDelimited, wire_type)?; + let mut value = Default::default(); + merge(wire_type, &mut value, buf, ctx)?; + values.push(value); + Ok(()) + } + + #[inline] + pub fn encoded_len(tag: u32, value: &Cow<'_, str>) -> usize { + key_len(tag) + encoded_len_varint(value.as_ref().len() as u64) + value.len() + } + + #[inline] + pub fn encoded_len_repeated(tag: u32, values: &[Cow<'_, str>]) -> usize { + key_len(tag) * values.len() + + values + .iter() + .map(|value| encoded_len_varint(value.as_ref().len() as u64) + value.as_ref().len()) + .sum::() + } +} + pub trait BytesAdapter: sealed::BytesAdapter {} mod sealed { use super::{Buf, BufMut}; - pub trait BytesAdapter: Default + Sized + 'static { + pub trait BytesAdapter: Default + Sized { fn len(&self) -> usize; /// Replace contents of this buffer with the contents of another buffer. @@ -684,6 +769,26 @@ impl sealed::BytesAdapter for Vec { } } +impl<'a> BytesAdapter for Cow<'a, [u8]> {} + +impl<'a> sealed::BytesAdapter for Cow<'a, [u8]> { + fn len(&self) -> usize { + self.as_ref().len() + } + + fn replace_with(&mut self, buf: impl Buf) { + let mut v = Vec::new(); + v.clear(); + v.reserve(buf.remaining()); + v.put(buf); + *self = Cow::Owned(v); + } + + fn append_to(&self, buf: &mut impl BufMut) { + buf.put(self.as_ref()) + } +} + pub mod bytes { use super::*; @@ -781,6 +886,10 @@ pub mod bytes { } } +pub mod cow_bytes { + pub use super::bytes::*; +} + pub mod message { use super::*; diff --git a/tests/build.rs b/tests/build.rs index 6bcd68862..68943265c 100644 --- a/tests/build.rs +++ b/tests/build.rs @@ -172,6 +172,20 @@ fn main() { .compile_protos(&[src.join("boxed_field.proto")], includes) .unwrap(); + prost_build::Config::new() + .cowed(".cowed_field.Foo") + .bytes([".cowed_field.Bar.myBarBytes"]) + .cowed(".cowed_field.Bar.myBarCowBytes") + .cowed(".cowed_field.Bar.myBarCowStr") + .cowed(".cowed_field.Bar.myNormalCowMap") + .cowed(".cowed_field.Bar.myBtreeCowMap") + .btree_map([ + ".cowed_field.Bar.myBtreeMap", + ".cowed_field.Bar.myBtreeCowMap", + ]) + .compile_protos(&[src.join("cowed_field.proto")], includes) + .unwrap(); + // Check that attempting to compile a .proto without a package declaration does not result in an error. config .compile_protos(&[src.join("no_package.proto")], includes) diff --git a/tests/src/cowed_field.proto b/tests/src/cowed_field.proto new file mode 100644 index 000000000..3fa9891ae --- /dev/null +++ b/tests/src/cowed_field.proto @@ -0,0 +1,38 @@ +syntax = "proto3"; +package cowed_field; + +import "google/protobuf/wrappers.proto"; +import "google/protobuf/timestamp.proto"; + +message Bar { + string myBarStr = 1; + string myBarCowStr = 2; + bytes myBarBytes = 3; + bytes myBarVec = 4; + bytes myBarCowBytes = 5; + map myNormalMap = 6; + map myNormalCowMap = 7; + map myBtreeMap = 8; + map myBtreeCowMap = 9; +} + +message Foo { + string myStr = 1; + uint32 myInt = 2; + repeated string myRepeat = 3; + map myFirstMap = 4; + map mySecondMap = 5; + map myThirdMap = 6; + optional string myOptStr = 7; + google.protobuf.Timestamp before = 8; + bytes myBytes = 9; + google.protobuf.StringValue googleStr = 10; + map myBytesMap = 11; + bytes myVecBytes = 12; + oneof extraDetails { + uint32 oneOfInt = 100; + string oneOfStr = 101; + bytes oneOfBytes = 102; + Bar oneOfBar = 103; + }; +} diff --git a/tests/src/cowed_field.rs b/tests/src/cowed_field.rs new file mode 100644 index 000000000..e8ef84673 --- /dev/null +++ b/tests/src/cowed_field.rs @@ -0,0 +1,271 @@ +#[allow(clippy::enum_variant_names)] +mod foo { + include!(concat!(env!("OUT_DIR"), "/cowed_field.rs")); +} + +#[test] +/// Confirm `Foo` is cowed by creating an instance +fn test_foo_is_cowed() { + use crate::cowed_field::foo::*; + use prost::Message; + use std::borrow::Cow; + use std::collections::{BTreeMap, HashMap}; + + // Define a static byte array for testing + static STATIC_ARRAY: [u8; 2] = [99, 98]; + let cow: Cow<'static, [u8]> = Cow::Borrowed(&STATIC_ARRAY); + + // Validate Bar different types + let _ = Bar { + my_bar_str: "test".to_string(), + my_bar_cow_str: Cow::Owned("cow_str".to_string()), + my_bar_bytes: prost::bytes::Bytes::from("Hello world"), + my_bar_vec: vec![1, 2, 3], + my_bar_cow_bytes: Cow::Borrowed(&STATIC_ARRAY), + my_normal_map: HashMap::from([(5, "normal_map".to_string())]), + my_normal_cow_map: HashMap::from([(5, Cow::Borrowed("normal_cow_map"))]), + my_btree_map: BTreeMap::from([(5, "btree_map".to_string())]), + my_btree_cow_map: BTreeMap::from([(5, Cow::Borrowed("btree_cow_map"))]), + }; + + // Build Foo with a mix of Owned and Borrowed Cow variants + let f = Foo { + my_str: Cow::Owned("world".to_string()), + my_int: 5, + my_repeat: vec![Cow::Borrowed("hello")], + my_first_map: HashMap::from([(5, Cow::Borrowed("first_map"))]), + my_second_map: HashMap::from([(Cow::Borrowed("second_map"), 5)]), + my_third_map: HashMap::from([( + Cow::Borrowed("third_map_key"), + Cow::Borrowed("third_map_value"), + )]), + my_opt_str: None, + before: None, + my_bytes: Cow::Borrowed(&[1, 2, 3]), + google_str: Some("google".to_string()), + my_bytes_map: HashMap::from([(7, cow)]), + my_vec_bytes: Cow::Borrowed(&[4, 5, 6]), + extra_details: Some(foo::ExtraDetails::OneOfStr(Cow::Borrowed( + "ExtraDetailsStr", + ))), + }; + + // Encode the Foo instance to a byte vector + let encoded = f.encode_to_vec(); + + let g = Foo::decode(encoded.as_ref()).expect("Decoding failed"); + + // === Assertions for Equality === + + // Assert that `my_str` fields are equal + assert_eq!( + f.my_str.as_ref(), + g.my_str.as_ref(), + "my_str fields do not match" + ); + + // Assert that `my_int` fields are equal + assert_eq!(f.my_int, g.my_int, "my_int fields do not match"); + + // Assert that `my_repeat` vectors are equal + assert_eq!( + f.my_repeat.len(), + g.my_repeat.len(), + "my_repeat lengths do not match" + ); + for (i, (item_f, item_g)) in f.my_repeat.iter().zip(g.my_repeat.iter()).enumerate() { + assert_eq!( + item_f.as_ref(), + item_g.as_ref(), + "my_repeat[{}] elements do not match", + i + ); + } + + // Assert that `my_first_map` maps are equal + assert_eq!( + f.my_first_map.len(), + g.my_first_map.len(), + "my_first_map lengths do not match" + ); + for (key, val_f) in &f.my_first_map { + let val_g = g + .my_first_map + .get(key) + .expect("Key missing in g.my_first_map"); + assert_eq!( + val_f.as_ref(), + val_g.as_ref(), + "my_first_map values for key {:?} do not match", + key + ); + } + + // Assert that `my_second_map` maps are equal + assert_eq!( + f.my_second_map.len(), + g.my_second_map.len(), + "my_second_map lengths do not match" + ); + for (key_f, val_f) in &f.my_second_map { + let val_g = g + .my_second_map + .get(key_f) + .expect("Key missing in g.my_second_map"); + assert_eq!( + *val_f, *val_g, + "my_second_map values for key {:?} do not match", + key_f + ); + } + + // Assert that `my_third_map` maps are equal + assert_eq!( + f.my_third_map.len(), + g.my_third_map.len(), + "my_third_map lengths do not match" + ); + for (key_f, val_f) in &f.my_third_map { + let (key_g, val_g) = g + .my_third_map + .get_key_value(key_f) + .expect("Key missing in g.my_third_map"); + assert_eq!( + key_f.as_ref(), + key_g.as_ref(), + "my_third_map keys do not match" + ); + assert_eq!( + val_f.as_ref(), + val_g.as_ref(), + "my_third_map values do not match" + ); + } + + // Assert that `my_opt_str` fields are equal + assert_eq!(f.my_opt_str, g.my_opt_str, "my_opt_str fields do not match"); + + // Assert that `before` fields are equal + assert_eq!(f.before, g.before, "before fields do not match"); + + // Assert that `my_bytes` fields are equal + assert_eq!( + f.my_bytes.as_ref(), + g.my_bytes.as_ref(), + "my_bytes fields do not match" + ); + + // Assert that `google_str` fields are equal + assert_eq!(f.google_str, g.google_str, "google_str fields do not match"); + + // Assert that `my_bytes_map` maps are equal + assert_eq!( + f.my_bytes_map.len(), + g.my_bytes_map.len(), + "my_bytes_map lengths do not match" + ); + for (key, val_f) in &f.my_bytes_map { + let val_g = g + .my_bytes_map + .get(key) + .expect("Key missing in g.my_bytes_map"); + assert_eq!( + val_f.as_ref(), + val_g.as_ref(), + "my_bytes_map values for key {:?} do not match", + key + ); + } + + // Assert that `my_vec_bytes` fields are equal + assert_eq!( + f.my_vec_bytes.as_ref(), + g.my_vec_bytes.as_ref(), + "my_vec_bytes fields do not match" + ); + + // Assert that `extra_details` fields are equal + match (&f.extra_details, &g.extra_details) { + (Some(foo::ExtraDetails::OneOfStr(a)), Some(foo::ExtraDetails::OneOfStr(b))) => { + assert_eq!( + a.as_ref(), + b.as_ref(), + "extra_details OneOfStr variants do not match" + ); + } + (None, None) => {} // Both are None, which is fine + _ => panic!("extra_details variants do not match"), + } + + // === Additional Assertions for Cow::Owned in Decoded Instance (`g`) === + + // Assert that `g.my_repeat` elements are `Cow::Owned` + for (i, item_g) in g.my_repeat.iter().enumerate() { + assert!( + matches!(item_g, Cow::Owned(_)), + "g.my_repeat[{}] should be Cow::Owned", + i + ); + } + + // Assert that `g.my_first_map` values are `Cow::Owned` + for (key, val_g) in &g.my_first_map { + assert!( + matches!(val_g, Cow::Owned(_)), + "g.my_first_map[{}] should be Cow::Owned", + key + ); + } + + // Assert that `g.my_second_map` keys are `Cow::Owned` + for key_g in g.my_second_map.keys() { + assert!( + matches!(key_g, Cow::Owned(_)), + "g.my_second_map key {:?} should be Cow::Owned", + key_g + ); + } + + // Assert that `g.my_third_map` keys and values are `Cow::Owned` + for (key_g, val_g) in &g.my_third_map { + assert!( + matches!(key_g, Cow::Owned(_)), + "g.my_third_map key {:?} should be Cow::Owned", + key_g + ); + assert!( + matches!(val_g, Cow::Owned(_)), + "g.my_third_map value {:?} should be Cow::Owned", + val_g + ); + } + + // Assert that `g.my_bytes` is `Cow::Owned` + assert!( + matches!(g.my_bytes, Cow::Owned(_)), + "g.my_bytes should be Cow::Owned" + ); + + // Assert that `g.my_bytes_map` values are `Cow::Owned` + for (key, val_g) in &g.my_bytes_map { + assert!( + matches!(val_g, Cow::Owned(_)), + "g.my_bytes_map[{}] should be Cow::Owned", + key + ); + } + + // Assert that `g.my_vec_bytes` is `Cow::Owned` + assert!( + matches!(g.my_vec_bytes, Cow::Owned(_)), + "g.my_vec_bytes should be Cow::Owned" + ); + + // Assert that `g.extra_details` variants are `Cow::Owned` if present + if let Some(foo::ExtraDetails::OneOfStr(b)) = &g.extra_details { + assert!( + matches!(b, Cow::Owned(_)), + "g.extra_details OneOfStr should be Cow::Owned" + ); + } +} diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 11c549423..d768c5768 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -56,6 +56,10 @@ mod type_names; #[cfg(test)] mod boxed_field; +#[cfg(feature = "std")] +#[cfg(test)] +mod cowed_field; + #[cfg(test)] mod custom_debug;