From 48e438bb61ec28edc928eb597e321f8a23b40465 Mon Sep 17 00:00:00 2001 From: Zack Newman Date: Tue, 28 Nov 2017 15:13:51 -0500 Subject: [PATCH] Fix issue with copying unrecognized Enum values (#188) Before, this resulted in an error on decoding; now, it succeeds. I removed the comment about deleting the CopyProtoMessage function since (1) there's uses of the function both internally and by users of apitools, and (2) it's clearly non-trivial to correctly copy a proto message. --- apitools/base/py/encoding.py | 14 ++++++++++---- apitools/base/py/encoding_test.py | 20 ++++++++++++++++++++ 2 files changed, 30 insertions(+), 4 deletions(-) diff --git a/apitools/base/py/encoding.py b/apitools/base/py/encoding.py index 26deb23c..ec7518cb 100644 --- a/apitools/base/py/encoding.py +++ b/apitools/base/py/encoding.py @@ -88,10 +88,9 @@ def Register(field_type): return Register -# TODO(craigcitro): Delete this function with the switch to proto2. def CopyProtoMessage(message): - codec = protojson.ProtoJson() - return codec.decode_message(type(message), codec.encode_message(message)) + """Make a deep copy of a message.""" + return JsonToMessage(type(message), MessageToJson(message)) def MessageToJson(message, include_fields=None): @@ -438,12 +437,19 @@ def _DecodeUnrecognizedFields(message, pair_type): return new_values +def _CopyProtoMessageVanillaProtoJson(message): + codec = protojson.ProtoJson() + return codec.decode_message(type(message), codec.encode_message(message)) + + def _EncodeUnknownFields(message): """Remap unknown fields in message out of message.source.""" source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message)) if source is None: return message - result = CopyProtoMessage(message) + # CopyProtoMessage uses _ProtoJsonApiTools, which uses this message. Use + # the vanilla protojson-based copy function to avoid infinite recursion. + result = _CopyProtoMessageVanillaProtoJson(message) pairs_field = message.field_by_name(source) if not isinstance(pairs_field, messages.MessageField): raise exceptions.InvalidUserInputError( diff --git a/apitools/base/py/encoding_test.py b/apitools/base/py/encoding_test.py index d73b1917..682d4567 100644 --- a/apitools/base/py/encoding_test.py +++ b/apitools/base/py/encoding_test.py @@ -188,6 +188,26 @@ def testCopyProtoMessage(self): msg.field = 'def' self.assertNotEqual(msg.field, new_msg.field) + def testCopyProtoMessageInvalidEnum(self): + json_msg = '{"field_one": "BAD_VALUE"}' + orig_msg = encoding.JsonToMessage(MessageWithEnum, json_msg) + new_msg = encoding.CopyProtoMessage(orig_msg) + for msg in (orig_msg, new_msg): + self.assertEqual(msg.all_unrecognized_fields(), ['field_one']) + self.assertEqual( + msg.get_unrecognized_field_info('field_one', + value_default=None), + ('BAD_VALUE', messages.Variant.ENUM)) + + def testCopyProtoMessageAdditionalProperties(self): + msg = AdditionalPropertiesMessage(additionalProperties=[ + AdditionalPropertiesMessage.AdditionalProperty( + key='key', value='value')]) + new_msg = encoding.CopyProtoMessage(msg) + self.assertEqual(len(new_msg.additionalProperties), 1) + self.assertEqual(new_msg.additionalProperties[0].key, 'key') + self.assertEqual(new_msg.additionalProperties[0].value, 'value') + def testBytesEncoding(self): b64_str = 'AAc+' b64_msg = '{"field": "%s"}' % b64_str