diff --git a/apitools/base/protorpclite/protojson.py b/apitools/base/protorpclite/protojson.py index a46d0290..8e923b5c 100644 --- a/apitools/base/protorpclite/protojson.py +++ b/apitools/base/protorpclite/protojson.py @@ -283,11 +283,19 @@ def __decode_dictionary(self, message_type, dictionary): valid_value = [self.decode_field(field, item) for item in value] setattr(message, field.name, valid_value) - else: - # This is just for consistency with the old behavior. - if value == []: - continue + continue + # This is just for consistency with the old behavior. + if value == []: + continue + try: setattr(message, field.name, self.decode_field(field, value)) + except messages.DecodeError: + # Save unknown enum values. + if not isinstance(field, messages.EnumField): + raise + variant = self.__find_variant(value) + if variant: + message.set_unrecognized_field(key, value, variant) return message diff --git a/apitools/base/protorpclite/protojson_test.py b/apitools/base/protorpclite/protojson_test.py index 4e4702a9..0018bfc3 100644 --- a/apitools/base/protorpclite/protojson_test.py +++ b/apitools/base/protorpclite/protojson_test.py @@ -198,12 +198,18 @@ def testNumericEnumeration(self): def testNumericEnumerationNegativeTest(self): """Test with an invalid number for the enum value.""" - self.assertRaisesRegexp( - messages.DecodeError, - 'Invalid enum value "89"', - protojson.decode_message, - MyMessage, - '{"an_enum": 89}') + # The message should successfully decode. + message = protojson.decode_message(MyMessage, + '{"an_enum": 89}') + + expected_message = MyMessage() + + self.assertEquals(expected_message, message) + # The roundtrip should result in equivalent encoded + # message. + self.assertEquals( + '{"an_enum": 89}', + protojson.encode_message(message)) def testAlphaEnumeration(self): """Test that alpha enum values work.""" @@ -214,23 +220,36 @@ def testAlphaEnumeration(self): self.assertEquals(expected_message, message) + def testAlphaEnumerationNegativeTest(self): """The alpha enum value is invalid.""" - self.assertRaisesRegexp( - messages.DecodeError, - 'Invalid enum value "IAMINVALID"', - protojson.decode_message, - MyMessage, - '{"an_enum": "IAMINVALID"}') + # The message should successfully decode. + message = protojson.decode_message(MyMessage, + '{"an_enum": "IAMINVALID"}') + + expected_message = MyMessage() + + self.assertEquals(expected_message, message) + # The roundtrip should result in equivalent encoded + # message. + self.assertEquals( + '{"an_enum": "IAMINVALID"}', + protojson.encode_message(message)) def testEnumerationNegativeTestWithEmptyString(self): """The enum value is an empty string.""" - self.assertRaisesRegexp( - messages.DecodeError, - 'Invalid enum value ""', - protojson.decode_message, - MyMessage, - '{"an_enum": ""}') + # The message should successfully decode. + message = protojson.decode_message(MyMessage, + '{"an_enum": ""}') + + expected_message = MyMessage() + + self.assertEquals(expected_message, message) + # The roundtrip should result in equivalent encoded + # message. + self.assertEquals( + '{"an_enum": ""}', + protojson.encode_message(message)) def testNullValues(self): """Test that null values overwrite existing values.""" diff --git a/apitools/base/protorpclite/test_util.py b/apitools/base/protorpclite/test_util.py index a86cfc72..71ca0450 100644 --- a/apitools/base/protorpclite/test_util.py +++ b/apitools/base/protorpclite/test_util.py @@ -579,11 +579,17 @@ def testContentType(self): self.assertTrue(isinstance(self.PROTOLIB.CONTENT_TYPE, str)) def testDecodeInvalidEnumType(self): - self.assertRaisesWithRegexpMatch(messages.DecodeError, - 'Invalid enum value ', - self.PROTOLIB.decode_message, - OptionalMessage, - self.encoded_invalid_enum) + # Since protos need to be able to add new enums, a message should be + # successfully decoded even if the enum value is invalid. Encoding the + # decoded message should result in equivalence with the original encoded + # message containing an invalid enum. + decoded = self.PROTOLIB.decode_message( + OptionalMessage, self.encoded_invalid_enum) + message = OptionalMessage() + self.assertEqual(message, decoded) + encoded = self.PROTOLIB.encode_message( + decoded) + self.assertEqual(self.encoded_invalid_enum, encoded) def testDateTimeNoTimeZone(self): """Test that DateTimeFields are encoded/decoded correctly.""" diff --git a/apitools/base/py/encoding_test.py b/apitools/base/py/encoding_test.py index 682d4567..182d4551 100644 --- a/apitools/base/py/encoding_test.py +++ b/apitools/base/py/encoding_test.py @@ -106,6 +106,29 @@ class AdditionalProperty(messages.Message): 'AdditionalProperty', 1, repeated=True) +@encoding.MapUnrecognizedFields('additionalProperties') +class MapToMessageWithEnum(messages.Message): + + class AdditionalProperty(messages.Message): + key = messages.StringField(1) + value = messages.MessageField(MessageWithEnum, 2) + + additionalProperties = messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + +@encoding.MapUnrecognizedFields('additionalProperties') +class NestedAdditionalPropertiesWithEnumMessage(messages.Message): + + class AdditionalProperty(messages.Message): + key = messages.StringField(1) + value = messages.MessageField( + MapToMessageWithEnum, 2) + + additionalProperties = messages.MessageField( + 'AdditionalProperty', 1, repeated=True) + + @encoding.MapUnrecognizedFields('additionalProperties') class MapToBytesValue(messages.Message): class AdditionalProperty(messages.Message): @@ -208,6 +231,21 @@ def testCopyProtoMessageAdditionalProperties(self): self.assertEqual(new_msg.additionalProperties[0].key, 'key') self.assertEqual(new_msg.additionalProperties[0].value, 'value') + def testCopyProtoMessageMappingInvalidEnum(self): + json_msg = '{"key_one": {"field_one": "BAD_VALUE"}}' + orig_msg = encoding.JsonToMessage( + MapToMessageWithEnum, + json_msg) + new_msg = encoding.CopyProtoMessage(orig_msg) + for msg in (orig_msg, new_msg): + self.assertEqual( + msg.additionalProperties[0].value.all_unrecognized_fields(), + ['field_one']) + self.assertEqual( + msg.additionalProperties[0].value.get_unrecognized_field_info( + 'field_one', value_default=None), + ('BAD_VALUE', messages.Variant.ENUM)) + def testBytesEncoding(self): b64_str = 'AAc+' b64_msg = '{"field": "%s"}' % b64_str @@ -286,6 +324,15 @@ def testDateTimeEncodingInAMap(self): ' "2nd": "2015-07-02T23:33:25.541000+00:00"}', encoding.MessageToJson(msg)) + def testInvalidEnumEncodingInAMap(self): + json_msg = '{"key_one": {"field_one": "BAD_VALUE"}}' + msg = encoding.JsonToMessage( + MapToMessageWithEnum, + json_msg) + new_msg = encoding.MessageToJson(msg) + self.assertEqual('{"key_one": {"field_one": "BAD_VALUE"}}', + encoding.MessageToJson(msg)) + def testIncludeFields(self): msg = SimpleMessage() self.assertEqual('{}', encoding.MessageToJson(msg)) @@ -434,6 +481,15 @@ def testUnknownNestedRoundtrip(self): self.assertEqual(json.loads(json_message), json.loads(encoding.MessageToJson(message))) + def testUnknownEnumNestedRoundtrip(self): + json_with_typo = ( + '{"outer_key": {"key_one": {"field_one": "VALUE_OEN",' + ' "field_two": "VALUE_OEN"}}}') + msg = encoding.JsonToMessage( + NestedAdditionalPropertiesWithEnumMessage, json_with_typo) + self.assertEqual(json.loads(json_with_typo), + json.loads(encoding.MessageToJson(msg))) + def testJsonDatetime(self): msg = TimeMessage(timefield=datetime.datetime( 2014, 7, 2, 23, 33, 25, 541000,