diff --git a/apitools/base/py/encoding.py b/apitools/base/py/encoding.py index 26deb23c..ec7518cb 100644 --- a/apitools/base/py/encoding.py +++ b/apitools/base/py/encoding.py @@ -88,10 +88,9 @@ def Register(field_type): return Register -# TODO(craigcitro): Delete this function with the switch to proto2. def CopyProtoMessage(message): - codec = protojson.ProtoJson() - return codec.decode_message(type(message), codec.encode_message(message)) + """Make a deep copy of a message.""" + return JsonToMessage(type(message), MessageToJson(message)) def MessageToJson(message, include_fields=None): @@ -438,12 +437,19 @@ def _DecodeUnrecognizedFields(message, pair_type): return new_values +def _CopyProtoMessageVanillaProtoJson(message): + codec = protojson.ProtoJson() + return codec.decode_message(type(message), codec.encode_message(message)) + + def _EncodeUnknownFields(message): """Remap unknown fields in message out of message.source.""" source = _UNRECOGNIZED_FIELD_MAPPINGS.get(type(message)) if source is None: return message - result = CopyProtoMessage(message) + # CopyProtoMessage uses _ProtoJsonApiTools, which uses this message. Use + # the vanilla protojson-based copy function to avoid infinite recursion. + result = _CopyProtoMessageVanillaProtoJson(message) pairs_field = message.field_by_name(source) if not isinstance(pairs_field, messages.MessageField): raise exceptions.InvalidUserInputError( diff --git a/apitools/base/py/encoding_test.py b/apitools/base/py/encoding_test.py index d73b1917..682d4567 100644 --- a/apitools/base/py/encoding_test.py +++ b/apitools/base/py/encoding_test.py @@ -188,6 +188,26 @@ def testCopyProtoMessage(self): msg.field = 'def' self.assertNotEqual(msg.field, new_msg.field) + def testCopyProtoMessageInvalidEnum(self): + json_msg = '{"field_one": "BAD_VALUE"}' + orig_msg = encoding.JsonToMessage(MessageWithEnum, json_msg) + new_msg = encoding.CopyProtoMessage(orig_msg) + for msg in (orig_msg, new_msg): + self.assertEqual(msg.all_unrecognized_fields(), ['field_one']) + self.assertEqual( + msg.get_unrecognized_field_info('field_one', + value_default=None), + ('BAD_VALUE', messages.Variant.ENUM)) + + def testCopyProtoMessageAdditionalProperties(self): + msg = AdditionalPropertiesMessage(additionalProperties=[ + AdditionalPropertiesMessage.AdditionalProperty( + key='key', value='value')]) + new_msg = encoding.CopyProtoMessage(msg) + self.assertEqual(len(new_msg.additionalProperties), 1) + self.assertEqual(new_msg.additionalProperties[0].key, 'key') + self.assertEqual(new_msg.additionalProperties[0].value, 'value') + def testBytesEncoding(self): b64_str = 'AAc+' b64_msg = '{"field": "%s"}' % b64_str