Skip to content

Commit

Permalink
Merge pull request #207 from msuozzo/master
Browse files Browse the repository at this point in the history
Add a function for detecting unrecognized fields.
  • Loading branch information
kevinli7 authored Apr 3, 2018
2 parents 286e63d + 615f9c4 commit 3abcfe1
Show file tree
Hide file tree
Showing 2 changed files with 202 additions and 0 deletions.
109 changes: 109 additions & 0 deletions apitools/base/py/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,40 @@
CodecResult = collections.namedtuple('CodecResult', ['value', 'complete'])


class EdgeType(object):
"""The type of transition made by an edge."""
SCALAR = 1
REPEATED = 2
MAP = 3


class ProtoEdge(collections.namedtuple('ProtoEdge',
['type_', 'field', 'index'])):
"""A description of a one-level transition from a message to a value.
Protobuf messages can be arbitrarily nested as fields can be defined with
any "message" type. This nesting property means that there are often many
levels of proto messages within a single message instance. This class can
unambiguously describe a single step from a message to some nested value.
Properties:
type_: EdgeType, The type of transition represented by this edge.
field: str, The name of the message-typed field.
index: Any, Additional data needed to make the transition. The semantics
of the "index" property change based on the value of "type_":
SCALAR: ignored.
REPEATED: a numeric index into "field"'s list.
MAP: a key into "field"'s mapping.
"""
__slots__ = ()

def __str__(self):
if self.type_ == EdgeType.SCALAR:
return self.field
else:
return '{}[{}]'.format(self.field, self.index)


# TODO(craigcitro): Make these non-global.
_UNRECOGNIZED_FIELD_MAPPINGS = {}
_CUSTOM_MESSAGE_CODECS = {}
Expand Down Expand Up @@ -710,3 +744,78 @@ def _IsRepeatedJsonValue(msg):
if isinstance(msg, extra_types.JsonArray):
msg = msg.entries
return msg


def _IsMap(message, field):
"""Returns whether the "field" is actually a map-type."""
value = message.get_assigned_value(field.name)
if not isinstance(value, messages.Message):
return False
try:
additional_properties = value.field_by_name('additionalProperties')
except KeyError:
return False
else:
return additional_properties.repeated


def _MapItems(message, field):
"""Yields the (key, value) pair of the map values."""
assert _IsMap(message, field)
map_message = message.get_assigned_value(field.name)
additional_properties = map_message.get_assigned_value(
'additionalProperties')
for kv_pair in additional_properties:
yield kv_pair.key, kv_pair.value


def UnrecognizedFieldIter(message, _edges=()): # pylint: disable=invalid-name
"""Yields the locations of unrecognized fields within "message".
If a sub-message is found to have unrecognized fields, that sub-message
will not be searched any further. We prune the search of the sub-message
because we assume it is malformed and further checks will not yield
productive errors.
Args:
message: The Message instance to search.
_edges: Internal arg for passing state.
Yields:
(edges_to_message, field_names):
edges_to_message: List[ProtoEdge], The edges (relative to "message")
describing the path to the sub-message where the unrecognized
fields were found.
field_names: List[Str], The names of the field(s) that were
unrecognized in the sub-message.
"""
if not isinstance(message, messages.Message):
# This is a primitive leaf, no errors found down this path.
return

field_names = message.all_unrecognized_fields()
if field_names:
# This message is malformed. Stop recursing and report it.
yield _edges, field_names
return

# Recurse through all fields in the current message.
for field in message.all_fields():
value = message.get_assigned_value(field.name)
if field.repeated:
for i, item in enumerate(value):
repeated_edge = ProtoEdge(EdgeType.REPEATED, field.name, i)
iter_ = UnrecognizedFieldIter(item, _edges + (repeated_edge,))
for (e, y) in iter_:
yield e, y
elif _IsMap(message, field):
for key, item in _MapItems(message, field):
map_edge = ProtoEdge(EdgeType.MAP, field.name, key)
iter_ = UnrecognizedFieldIter(item, _edges + (map_edge,))
for (e, y) in iter_:
yield e, y
else:
scalar_edge = ProtoEdge(EdgeType.SCALAR, field.name, None)
iter_ = UnrecognizedFieldIter(value, _edges + (scalar_edge,))
for (e, y) in iter_:
yield e, y
93 changes: 93 additions & 0 deletions apitools/base/py/encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,42 @@ class AdditionalProperty(messages.Message):
'AdditionalProperty', 1, repeated=True)


@encoding.MapUnrecognizedFields('additionalProperties')
class AdditionalPropertiesWithEnumMessage(messages.Message):

class AdditionalProperty(messages.Message):
key = messages.StringField(1)
value = messages.MessageField(MessageWithEnum, 2)

additionalProperties = messages.MessageField(
'AdditionalProperty', 1, repeated=True)


class NestedMapMessage(messages.Message):

msg_field = messages.MessageField(AdditionalPropertiesWithEnumMessage, 1)


class RepeatedNestedMapMessage(messages.Message):

map_field = messages.MessageField(NestedMapMessage, 1, repeated=True)


class NestedWithEnumMessage(messages.Message):

class ThisEnum(messages.Enum):
VALUE_ONE = 1
VALUE_TWO = 2

msg_field = messages.MessageField(MessageWithEnum, 1)
enum_field = messages.EnumField(ThisEnum, 2)


class RepeatedNestedMessage(messages.Message):

msg_field = messages.MessageField(SimpleMessage, 1, repeated=True)


@encoding.MapUnrecognizedFields('additionalProperties')
class MapToBytesValue(messages.Message):
class AdditionalProperty(messages.Message):
Expand Down Expand Up @@ -677,3 +713,60 @@ def testDictToAdditionalPropertyMessageNumeric(self):
key='key', value=1)
]
self.assertEqual(encoded_msg, expected_msg)

def testUnrecognizedFieldIter(self):
m = encoding.DictToMessage({
'nested': {
'nested': {'a': 'b'},
'nested_list': ['foo'],
'extra_field': 'foo',
}
}, ExtraNestedMessage)
results = list(encoding.UnrecognizedFieldIter(m))
self.assertEqual(1, len(results))
edges, fields = results[0]
expected_edge = encoding.ProtoEdge(
encoding.EdgeType.SCALAR, 'nested', None)
self.assertEqual((expected_edge,), edges)
self.assertEqual(['extra_field'], fields)

def testUnrecognizedFieldIterRepeated(self):
m = encoding.DictToMessage({
'msg_field': [
{'field': 'foo'},
{'not_a_field': 'bar'}
]
}, RepeatedNestedMessage)
results = list(encoding.UnrecognizedFieldIter(m))
self.assertEqual(1, len(results))
edges, fields = results[0]
expected_edge = encoding.ProtoEdge(
encoding.EdgeType.REPEATED, 'msg_field', 1)
self.assertEqual((expected_edge,), edges)
self.assertEqual(['not_a_field'], fields)

def testUnrecognizedFieldIterNestedMap(self):
m = encoding.DictToMessage({
'map_field': [{
'msg_field': {
'foo': {'field_one': 1},
'bar': {'not_a_field': 1},
}
}]
}, RepeatedNestedMapMessage)
results = list(encoding.UnrecognizedFieldIter(m))
self.assertEqual(1, len(results))
edges, fields = results[0]
expected_edges = (
encoding.ProtoEdge(encoding.EdgeType.REPEATED, 'map_field', 0),
encoding.ProtoEdge(encoding.EdgeType.MAP, 'msg_field', 'bar'),
)
self.assertEqual(expected_edges, edges)
self.assertEqual(['not_a_field'], fields)

def testUnrecognizedFieldIterAbortAfterFirstError(self):
m = encoding.DictToMessage({
'msg_field': {'field_one': 3},
'enum_field': 3,
}, NestedWithEnumMessage)
self.assertEqual(1, len(list(encoding.UnrecognizedFieldIter(m))))

0 comments on commit 3abcfe1

Please sign in to comment.