Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for marking fields as Cow #1202

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 102 additions & 28 deletions prost-build/src/code_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,9 @@ impl CodeGenerator<'_> {
self.push_indent();
self.buf.push_str("pub struct ");
self.buf.push_str(&to_upper_camel(&message_name));
if self.message_graph.message_has_lifetime(&fq_message_name) {
self.buf.push_str("<'a>");
}
self.buf.push_str(" {\n");

self.depth += 1;
Expand Down Expand Up @@ -406,13 +409,15 @@ impl CodeGenerator<'_> {
let deprecated = self.deprecated(&field.descriptor);
let optional = self.optional(&field.descriptor);
let boxed = self.boxed(&field.descriptor, fq_message_name, None);
let ty = self.resolve_type(&field.descriptor, fq_message_name);
let cowed = self.cowed(&field.descriptor, fq_message_name, None);
let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed);

debug!(
" field: {:?}, type: {:?}, boxed: {}",
" field: {:?}, type: {:?}, boxed: {} cowed: {}",
field.descriptor.name(),
ty,
boxed
boxed,
cowed
);

self.append_doc(fq_message_name, Some(field.descriptor.name()));
Expand All @@ -424,10 +429,10 @@ impl CodeGenerator<'_> {

self.push_indent();
self.buf.push_str("#[prost(");
let type_tag = self.field_type_tag(&field.descriptor);
let type_tag = self.field_type_tag(&field.descriptor, cowed);
self.buf.push_str(&type_tag);

if type_ == Type::Bytes {
if !cowed && type_ == Type::Bytes {
let bytes_type = self
.config
.bytes_type
Expand Down Expand Up @@ -532,8 +537,9 @@ impl CodeGenerator<'_> {
key: &FieldDescriptorProto,
value: &FieldDescriptorProto,
) {
let key_ty = self.resolve_type(key, fq_message_name);
let value_ty = self.resolve_type(value, fq_message_name);
let map_cowed = self.cowed(&field.descriptor, fq_message_name, None);
let key_ty = self.resolve_type(key, fq_message_name, map_cowed);
let value_ty = self.resolve_type(value, fq_message_name, map_cowed);

debug!(
" map field: {:?}, key type: {:?}, value type: {:?}",
Expand All @@ -551,8 +557,8 @@ impl CodeGenerator<'_> {
.get_first_field(fq_message_name, field.descriptor.name())
.copied()
.unwrap_or_default();
let key_tag = self.field_type_tag(key);
let value_tag = self.map_value_type_tag(value);
let key_tag = self.field_type_tag(key, map_cowed);
let value_tag = self.map_value_type_tag(value, map_cowed);

self.buf.push_str(&format!(
"#[prost({}=\"{}, {}\", tag=\"{}\")]\n",
Expand Down Expand Up @@ -597,9 +603,14 @@ impl CodeGenerator<'_> {
self.append_field_attributes(fq_message_name, oneof.descriptor.name());
self.push_indent();
self.buf.push_str(&format!(
"pub {}: ::core::option::Option<{}>,\n",
"pub {}: ::core::option::Option<{}{}>,\n",
oneof.rust_name(),
type_name
type_name,
if self.message_graph.message_has_lifetime(fq_message_name) {
"<'a>"
} else {
""
},
));
}

Expand Down Expand Up @@ -628,6 +639,9 @@ impl CodeGenerator<'_> {
self.push_indent();
self.buf.push_str("pub enum ");
self.buf.push_str(&to_upper_camel(oneof.descriptor.name()));
if self.message_graph.message_has_lifetime(fq_message_name) {
self.buf.push_str("<'a>");
}
self.buf.push_str(" {\n");

self.path.push(2);
Expand All @@ -637,8 +651,14 @@ impl CodeGenerator<'_> {
self.append_doc(fq_message_name, Some(field.descriptor.name()));
self.path.pop();

let cowed = self.cowed(
&field.descriptor,
fq_message_name,
Some(oneof.descriptor.name()),
);

self.push_indent();
let ty_tag = self.field_type_tag(&field.descriptor);
let ty_tag = self.field_type_tag(&field.descriptor, cowed);
self.buf.push_str(&format!(
"#[prost({}, tag=\"{}\")]\n",
ty_tag,
Expand All @@ -647,7 +667,7 @@ impl CodeGenerator<'_> {
self.append_field_attributes(&oneof_name, field.descriptor.name());

self.push_indent();
let ty = self.resolve_type(&field.descriptor, fq_message_name);
let ty = self.resolve_type(&field.descriptor, fq_message_name, cowed);

let boxed = self.boxed(
&field.descriptor,
Expand All @@ -656,10 +676,11 @@ impl CodeGenerator<'_> {
);

debug!(
" oneof: {:?}, type: {:?}, boxed: {}",
" oneof: {:?}, type: {:?}, boxed: {} cowed: {}",
field.descriptor.name(),
ty,
boxed
boxed,
cowed,
);

if boxed {
Expand Down Expand Up @@ -883,8 +904,8 @@ impl CodeGenerator<'_> {
let name = method.name.take().unwrap();
let input_proto_type = method.input_type.take().unwrap();
let output_proto_type = method.output_type.take().unwrap();
let input_type = self.resolve_ident(&input_proto_type);
let output_type = self.resolve_ident(&output_proto_type);
let input_type = self.resolve_ident(&input_proto_type).0;
let output_type = self.resolve_ident(&output_proto_type).0;
let client_streaming = method.client_streaming();
let server_streaming = method.server_streaming();

Expand Down Expand Up @@ -947,7 +968,12 @@ impl CodeGenerator<'_> {
self.buf.push_str("}\n");
}

fn resolve_type(&self, field: &FieldDescriptorProto, fq_message_name: &str) -> String {
fn resolve_type(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
cowed: bool,
) -> String {
match field.r#type() {
Type::Float => String::from("f32"),
Type::Double => String::from("f64"),
Expand All @@ -956,7 +982,13 @@ impl CodeGenerator<'_> {
Type::Int32 | Type::Sfixed32 | Type::Sint32 | Type::Enum => String::from("i32"),
Type::Int64 | Type::Sfixed64 | Type::Sint64 => String::from("i64"),
Type::Bool => String::from("bool"),
Type::String if cowed => {
format!("{}::alloc::borrow::Cow<'a, str>", prost_path(self.config))
}
Type::String => format!("{}::alloc::string::String", prost_path(self.config)),
Type::Bytes if cowed => {
format!("{}::alloc::borrow::Cow<'a, [u8]>", prost_path(self.config))
}
Type::Bytes => self
.config
.bytes_type
Expand All @@ -965,16 +997,28 @@ impl CodeGenerator<'_> {
.unwrap_or_default()
.rust_type()
.to_owned(),
Type::Group | Type::Message => self.resolve_ident(field.type_name()),
Type::Group | Type::Message => {
let (mut s, is_extern) = self.resolve_ident(field.type_name());
if !is_extern
&& cowed
&& self
.message_graph
.field_has_lifetime(fq_message_name, field)
{
s.push_str("<'a>");
}
s
}
}
}

fn resolve_ident(&self, pb_ident: &str) -> String {
/// Returns the identifier and a bool indicating if its an extern
fn resolve_ident(&self, pb_ident: &str) -> (String, bool) {
// protoc should always give fully qualified identifiers.
assert_eq!(".", &pb_ident[..1]);

if let Some(proto_ident) = self.extern_paths.resolve_ident(pb_ident) {
return proto_ident;
return (proto_ident, true);
}

let mut local_path = self
Expand All @@ -1000,14 +1044,15 @@ impl CodeGenerator<'_> {
ident_path.next();
}

local_path
let s = local_path
.map(|_| "super".to_string())
.chain(ident_path.map(to_snake))
.chain(iter::once(to_upper_camel(ident_type)))
.join("::")
.join("::");
(s, false)
}

fn field_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn field_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> {
match field.r#type() {
Type::Float => Cow::Borrowed("float"),
Type::Double => Cow::Borrowed("double"),
Expand All @@ -1022,24 +1067,26 @@ impl CodeGenerator<'_> {
Type::Sfixed32 => Cow::Borrowed("sfixed32"),
Type::Sfixed64 => Cow::Borrowed("sfixed64"),
Type::Bool => Cow::Borrowed("bool"),
Type::String if cowed => Cow::Borrowed("cow_str"),
Type::String => Cow::Borrowed("string"),
Type::Bytes if cowed => Cow::Borrowed("cow_bytes"),
Type::Bytes => Cow::Borrowed("bytes"),
Type::Group => Cow::Borrowed("group"),
Type::Message => Cow::Borrowed("message"),
Type::Enum => Cow::Owned(format!(
"enumeration={:?}",
self.resolve_ident(field.type_name())
self.resolve_ident(field.type_name()).0
)),
}
}

fn map_value_type_tag(&self, field: &FieldDescriptorProto) -> Cow<'static, str> {
fn map_value_type_tag(&self, field: &FieldDescriptorProto, cowed: bool) -> Cow<'static, str> {
match field.r#type() {
Type::Enum => Cow::Owned(format!(
"enumeration({})",
self.resolve_ident(field.type_name())
self.resolve_ident(field.type_name()).0
)),
_ => self.field_type_tag(field),
_ => self.field_type_tag(field, cowed),
}
}

Expand Down Expand Up @@ -1101,6 +1148,33 @@ impl CodeGenerator<'_> {
false
}

/// Returns whether the Rust type for this field needs to be `Cow<_>`.
fn cowed(
&self,
field: &FieldDescriptorProto,
fq_message_name: &str,
oneof: Option<&str>,
) -> bool {
let fd_type = field.r#type();

// We only support Cow for Bytes and String
if !matches!(
fd_type,
Type::Message | Type::Group | Type::Bytes | Type::String
) {
return false;
}

let config_path = match oneof {
None => Cow::Borrowed(fq_message_name),
Some(ooname) => Cow::Owned(format!("{fq_message_name}.{ooname}")),
};
self.config
.cowed
.get_first_field(&config_path, field.name())
.is_some()
}

/// Returns `true` if the field options includes the `deprecated` option.
fn deprecated(&self, field: &FieldDescriptorProto) -> bool {
field
Expand Down
16 changes: 15 additions & 1 deletion prost-build/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub struct Config {
pub(crate) enum_attributes: PathMap<String>,
pub(crate) field_attributes: PathMap<String>,
pub(crate) boxed: PathMap<()>,
pub(crate) cowed: PathMap<()>,
pub(crate) prost_types: bool,
pub(crate) strip_enum_prefix: bool,
pub(crate) out_dir: Option<PathBuf>,
Expand Down Expand Up @@ -373,6 +374,14 @@ impl Config {
self
}

pub fn cowed<P>(&mut self, path: P) -> &mut Self
where
P: AsRef<str>,
{
self.cowed.insert(path.as_ref().to_string(), ());
self
}

/// Configures the code generator to use the provided service generator.
pub fn service_generator(&mut self, service_generator: Box<dyn ServiceGenerator>) -> &mut Self {
self.service_generator = Some(service_generator);
Expand Down Expand Up @@ -1101,7 +1110,11 @@ impl Config {
let mut modules = HashMap::new();
let mut packages = HashMap::new();

let message_graph = MessageGraph::new(requests.iter().map(|x| &x.1), self.boxed.clone());
let message_graph = MessageGraph::new(
requests.iter().map(|x| &x.1),
self.boxed.clone(),
self.cowed.clone(),
);
let extern_paths = ExternPaths::new(&self.extern_paths, self.prost_types)
.map_err(|error| Error::new(ErrorKind::InvalidInput, error))?;

Expand Down Expand Up @@ -1181,6 +1194,7 @@ impl default::Default for Config {
enum_attributes: PathMap::default(),
field_attributes: PathMap::default(),
boxed: PathMap::default(),
cowed: PathMap::default(),
prost_types: true,
strip_enum_prefix: true,
out_dir: None,
Expand Down
Loading
Loading