Skip to content

Commit

Permalink
Fix issue with copying unrecognized Enum values (#188)
Browse files Browse the repository at this point in the history
Before, this resulted in an error on decoding; now, it succeeds.

I removed the comment about deleting the CopyProtoMessage function since
(1) there's uses of the function both internally and by users of
apitools, and (2) it's clearly non-trivial to correctly copy a proto
message.
  • Loading branch information
znewman01 authored and craigcitro committed Nov 28, 2017
1 parent 9c954cd commit 48e438b
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
14 changes: 10 additions & 4 deletions apitools/base/py/encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,10 +88,9 @@ def Register(field_type):
return Register


# TODO(craigcitro): Delete this function with the switch to proto2.
def CopyProtoMessage(message):
codec = protojson.ProtoJson()
return codec.decode_message(type(message), codec.encode_message(message))
"""Make a deep copy of a message."""
return JsonToMessage(type(message), MessageToJson(message))


def MessageToJson(message, include_fields=None):
Expand Down Expand Up @@ -438,12 +437,19 @@ def _DecodeUnrecognizedFields(message, pair_type):
return new_values


def _CopyProtoMessageVanillaProtoJson(message):
codec = protojson.ProtoJson()
return codec.decode_message(type(message), codec.encode_message(message))


def _EncodeUnknownFields(message):
"""Remap unknown fields in message out of message.source."""
source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message))
if source is None:
return message
result = CopyProtoMessage(message)
# CopyProtoMessage uses _ProtoJsonApiTools, which uses this message. Use
# the vanilla protojson-based copy function to avoid infinite recursion.
result = _CopyProtoMessageVanillaProtoJson(message)
pairs_field = message.field_by_name(source)
if not isinstance(pairs_field, messages.MessageField):
raise exceptions.InvalidUserInputError(
Expand Down
20 changes: 20 additions & 0 deletions apitools/base/py/encoding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,26 @@ def testCopyProtoMessage(self):
msg.field = 'def'
self.assertNotEqual(msg.field, new_msg.field)

def testCopyProtoMessageInvalidEnum(self):
json_msg = '{"field_one": "BAD_VALUE"}'
orig_msg = encoding.JsonToMessage(MessageWithEnum, json_msg)
new_msg = encoding.CopyProtoMessage(orig_msg)
for msg in (orig_msg, new_msg):
self.assertEqual(msg.all_unrecognized_fields(), ['field_one'])
self.assertEqual(
msg.get_unrecognized_field_info('field_one',
value_default=None),
('BAD_VALUE', messages.Variant.ENUM))

def testCopyProtoMessageAdditionalProperties(self):
msg = AdditionalPropertiesMessage(additionalProperties=[
AdditionalPropertiesMessage.AdditionalProperty(
key='key', value='value')])
new_msg = encoding.CopyProtoMessage(msg)
self.assertEqual(len(new_msg.additionalProperties), 1)
self.assertEqual(new_msg.additionalProperties[0].key, 'key')
self.assertEqual(new_msg.additionalProperties[0].value, 'value')

def testBytesEncoding(self):
b64_str = 'AAc+'
b64_msg = '{"field": "%s"}' % b64_str
Expand Down

0 comments on commit 48e438b

Please sign in to comment.