diff --git a/apitools/base/py/encoding.py b/apitools/base/py/encoding.py index ec7518cb..49f7ddb2 100644 --- a/apitools/base/py/encoding.py +++ b/apitools/base/py/encoding.py @@ -48,6 +48,40 @@ CodecResult = collections.namedtuple('CodecResult', ['value', 'complete']) +class EdgeType(object): + """The type of transition made by an edge.""" + SCALAR = 1 + REPEATED = 2 + MAP = 3 + + +class ProtoEdge(collections.namedtuple('ProtoEdge', + ['type_', 'field', 'index'])): + """A description of a one-level transition from a message to a value. + + Protobuf messages can be arbitrarily nested as fields can be defined with + any "message" type. This nesting property means that there are often many + levels of proto messages within a single message instance. This class can + unambiguously describe a single step from a message to some nested value. + + Properties: + type_: EdgeType, The type of transition represented by this edge. + field: str, The name of the message-typed field. + index: Any, Additional data needed to make the transition. The semantics + of the "index" property change based on the value of "type_": + SCALAR: ignored. + REPEATED: a numeric index into "field"'s list. + MAP: a key into "field"'s mapping. + """ + __slots__ = () + + def __str__(self): + if self.type_ == EdgeType.SCALAR: + return self.field + else: + return '{}[{}]'.format(self.field, self.index) + + # TODO(craigcitro): Make these non-global. _UNRECOGNIZED_FIELD_MAPPINGS = {} _CUSTOM_MESSAGE_CODECS = {} @@ -710,3 +744,78 @@ def _IsRepeatedJsonValue(msg): if isinstance(msg, extra_types.JsonArray): msg = msg.entries return msg + + +def _IsMap(message, field): + """Returns whether the "field" is actually a map-type.""" + value = message.get_assigned_value(field.name) + if not isinstance(value, messages.Message): + return False + try: + additional_properties = value.field_by_name('additionalProperties') + except KeyError: + return False + else: + return additional_properties.repeated + + +def _MapItems(message, field): + """Yields the (key, value) pair of the map values.""" + assert _IsMap(message, field) + map_message = message.get_assigned_value(field.name) + additional_properties = map_message.get_assigned_value( + 'additionalProperties') + for kv_pair in additional_properties: + yield kv_pair.key, kv_pair.value + + +def UnrecognizedFieldIter(message, _edges=()): # pylint: disable=invalid-name + """Yields the locations of unrecognized fields within "message". + + If a sub-message is found to have unrecognized fields, that sub-message + will not be searched any further. We prune the search of the sub-message + because we assume it is malformed and further checks will not yield + productive errors. + + Args: + message: The Message instance to search. + _edges: Internal arg for passing state. + + Yields: + (edges_to_message, field_names): + edges_to_message: List[ProtoEdge], The edges (relative to "message") + describing the path to the sub-message where the unrecognized + fields were found. + field_names: List[Str], The names of the field(s) that were + unrecognized in the sub-message. + """ + if not isinstance(message, messages.Message): + # This is a primitive leaf, no errors found down this path. + return + + field_names = message.all_unrecognized_fields() + if field_names: + # This message is malformed. Stop recursing and report it. + yield _edges, field_names + return + + # Recurse through all fields in the current message. + for field in message.all_fields(): + value = message.get_assigned_value(field.name) + if field.repeated: + for i, item in enumerate(value): + repeated_edge = ProtoEdge(EdgeType.REPEATED, field.name, i) + iter_ = UnrecognizedFieldIter(item, _edges + (repeated_edge,)) + for (e, y) in iter_: + yield e, y + elif _IsMap(message, field): + for key, item in _MapItems(message, field): + map_edge = ProtoEdge(EdgeType.MAP, field.name, key) + iter_ = UnrecognizedFieldIter(item, _edges + (map_edge,)) + for (e, y) in iter_: + yield e, y + else: + scalar_edge = ProtoEdge(EdgeType.SCALAR, field.name, None) + iter_ = UnrecognizedFieldIter(value, _edges + (scalar_edge,)) + for (e, y) in iter_: + yield e, y diff --git a/apitools/base/py/encoding_test.py b/apitools/base/py/encoding_test.py index de9ba216..d130cc51 100644 --- a/apitools/base/py/encoding_test.py +++ b/apitools/base/py/encoding_test.py @@ -129,6 +129,42 @@ class AdditionalProperty(messages.Message): 'AdditionalProperty', 1, repeated=True) +@encoding.MapUnrecognizedFields('additionalProperties') +class AdditionalPropertiesWithEnumMessage(messages.Message): + + class AdditionalProperty(messages.Message): + key = messages.StringField(1) + value = messages.MessageField(MessageWithEnum, 2) + + additionalProperties = messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + +class NestedMapMessage(messages.Message): + + msg_field = messages.MessageField(AdditionalPropertiesWithEnumMessage, 1) + + +class RepeatedNestedMapMessage(messages.Message): + + map_field = messages.MessageField(NestedMapMessage, 1, repeated=True) + + +class NestedWithEnumMessage(messages.Message): + + class ThisEnum(messages.Enum): + VALUE_ONE = 1 + VALUE_TWO = 2 + + msg_field = messages.MessageField(MessageWithEnum, 1) + enum_field = messages.EnumField(ThisEnum, 2) + + +class RepeatedNestedMessage(messages.Message): + + msg_field = messages.MessageField(SimpleMessage, 1, repeated=True) + + @encoding.MapUnrecognizedFields('additionalProperties') class MapToBytesValue(messages.Message): class AdditionalProperty(messages.Message): @@ -677,3 +713,60 @@ def testDictToAdditionalPropertyMessageNumeric(self): key='key', value=1) ] self.assertEqual(encoded_msg, expected_msg) + + def testUnrecognizedFieldIter(self): + m = encoding.DictToMessage({ + 'nested': { + 'nested': {'a': 'b'}, + 'nested_list': ['foo'], + 'extra_field': 'foo', + } + }, ExtraNestedMessage) + results = list(encoding.UnrecognizedFieldIter(m)) + self.assertEqual(1, len(results)) + edges, fields = results[0] + expected_edge = encoding.ProtoEdge( + encoding.EdgeType.SCALAR, 'nested', None) + self.assertEqual((expected_edge,), edges) + self.assertEqual(['extra_field'], fields) + + def testUnrecognizedFieldIterRepeated(self): + m = encoding.DictToMessage({ + 'msg_field': [ + {'field': 'foo'}, + {'not_a_field': 'bar'} + ] + }, RepeatedNestedMessage) + results = list(encoding.UnrecognizedFieldIter(m)) + self.assertEqual(1, len(results)) + edges, fields = results[0] + expected_edge = encoding.ProtoEdge( + encoding.EdgeType.REPEATED, 'msg_field', 1) + self.assertEqual((expected_edge,), edges) + self.assertEqual(['not_a_field'], fields) + + def testUnrecognizedFieldIterNestedMap(self): + m = encoding.DictToMessage({ + 'map_field': [{ + 'msg_field': { + 'foo': {'field_one': 1}, + 'bar': {'not_a_field': 1}, + } + }] + }, RepeatedNestedMapMessage) + results = list(encoding.UnrecognizedFieldIter(m)) + self.assertEqual(1, len(results)) + edges, fields = results[0] + expected_edges = ( + encoding.ProtoEdge(encoding.EdgeType.REPEATED, 'map_field', 0), + encoding.ProtoEdge(encoding.EdgeType.MAP, 'msg_field', 'bar'), + ) + self.assertEqual(expected_edges, edges) + self.assertEqual(['not_a_field'], fields) + + def testUnrecognizedFieldIterAbortAfterFirstError(self): + m = encoding.DictToMessage({ + 'msg_field': {'field_one': 3}, + 'enum_field': 3, + }, NestedWithEnumMessage) + self.assertEqual(1, len(list(encoding.UnrecognizedFieldIter(m))))