diff options
author | Jisi Liu <liujisi@google.com> | 2017-08-21 10:39:27 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-08-21 10:39:27 -0700 |
commit | ba4e54724d2e6a1881c4fe88664d81fbacaf8c08 (patch) | |
tree | 5451ab616f035b92d006e51ec01b4c0dabb3f6d0 /python/google/protobuf/internal | |
parent | 472f700884a2cd4e9b621c5e4b706a9b6de7c870 (diff) | |
parent | 139775ccc040a07e07c5407e34834dab27928cbc (diff) | |
download | protobuf-ba4e54724d2e6a1881c4fe88664d81fbacaf8c08.tar.gz protobuf-ba4e54724d2e6a1881c4fe88664d81fbacaf8c08.tar.bz2 protobuf-ba4e54724d2e6a1881c4fe88664d81fbacaf8c08.zip |
Merge pull request #3529 from pherl/merge3.4.x
Merge 3.4.x into master
Diffstat (limited to 'python/google/protobuf/internal')
12 files changed, 449 insertions, 115 deletions
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index 460a4a6c..422af590 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -100,6 +100,27 @@ if _implementation_version_str != '2': _implementation_version = int(_implementation_version_str) +# Detect if serialization should be deterministic by default +try: + # The presence of this module in a build allows the proto implementation to + # be upgraded merely via build deps. + # + # NOTE: Merely importing this automatically enables deterministic proto + # serialization for C++ code, but we still need to export it as a boolean so + # that we can do the same for `_implementation_type == 'python'`. + # + # NOTE2: It is possible for C++ code to enable deterministic serialization by + # default _without_ affecting Python code, if the C++ implementation is not in + # use by this module. That is intended behavior, so we don't actually expose + # this boolean outside of this module. + # + # pylint: disable=g-import-not-at-top,unused-import + from google.protobuf import enable_deterministic_proto_serialization + _python_deterministic_proto_serialization = True +except ImportError: + _python_deterministic_proto_serialization = False + + # Usage of this function is discouraged. Clients shouldn't care which # implementation of the API is in use. Note that there is no guarantee # that differences between APIs will be maintained. @@ -111,3 +132,8 @@ def Type(): # See comment on 'Type' above. def Version(): return _implementation_version + + +# For internal use only +def IsPythonDefaultSerializationDeterministic(): + return _python_deterministic_proto_serialization diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index c1733a48..6015e6f8 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -131,11 +131,19 @@ class DescriptorPoolTest(unittest.TestCase): self.assertEqual('google/protobuf/internal/factory_test2.proto', file_desc4.name) + file_desc5 = self.pool.FindFileContainingSymbol( + 'protobuf_unittest.TestService') + self.assertIsInstance(file_desc5, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/unittest.proto', + file_desc5.name) + # Tests the generated pool. assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.one_more_field') assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.another_field') + assert descriptor_pool.Default().FindFileContainingSymbol( + 'protobuf_unittest.TestService') def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): @@ -506,10 +514,10 @@ class MessageType(object): subtype.CheckType(test, desc, name, file_desc) for index, (name, field) in enumerate(self.field_list): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) for index, (name, field) in enumerate(self.extensions): - field.CheckField(test, desc, name, index) + field.CheckField(test, desc, name, index, file_desc) class EnumField(object): @@ -519,7 +527,7 @@ class EnumField(object): self.type_name = type_name self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] enum_desc = msg_desc.enum_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -536,6 +544,7 @@ class EnumField(object): test.assertFalse(enum_desc.values_by_name[self.default_value].has_options) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(enum_desc, field_desc.enum_type) + test.assertEqual(file_desc, enum_desc.file) class MessageField(object): @@ -544,7 +553,7 @@ class MessageField(object): self.number = number self.type_name = type_name - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] field_type_desc = msg_desc.nested_types_by_name[self.type_name] test.assertEqual(name, field_desc.name) @@ -558,6 +567,7 @@ class MessageField(object): test.assertFalse(field_desc.has_default_value) test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) + test.assertEqual(file_desc, field_desc.file) class StringField(object): @@ -566,7 +576,7 @@ class StringField(object): self.number = number self.default_value = default_value - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.fields_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -578,6 +588,7 @@ class StringField(object): field_desc.cpp_type) test.assertTrue(field_desc.has_default_value) test.assertEqual(self.default_value, field_desc.default_value) + test.assertEqual(file_desc, field_desc.file) class ExtensionField(object): @@ -586,7 +597,7 @@ class ExtensionField(object): self.number = number self.extended_type = extended_type - def CheckField(self, test, msg_desc, name, index): + def CheckField(self, test, msg_desc, name, index, file_desc): field_desc = msg_desc.extensions_by_name[name] test.assertEqual(name, field_desc.name) expected_field_full_name = '.'.join([msg_desc.full_name, name]) @@ -601,6 +612,7 @@ class ExtensionField(object): test.assertEqual(msg_desc, field_desc.extension_scope) test.assertEqual(msg_desc, field_desc.message_type) test.assertEqual(self.extended_type, field_desc.containing_type.name) + test.assertEqual(file_desc, field_desc.file) class AddDescriptorTest(unittest.TestCase): @@ -746,15 +758,10 @@ class AddDescriptorTest(unittest.TestCase): self.assertIs(options, file_descriptor.GetOptions()) -@unittest.skipIf( - api_implementation.Type() != 'cpp', - 'default_pool is only supported by the C++ implementation') class DefaultPoolTest(unittest.TestCase): def testFindMethods(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() self.assertIs( pool.FindFileByName('google/protobuf/unittest.proto'), unittest_pb2.DESCRIPTOR) @@ -765,19 +772,22 @@ class DefaultPoolTest(unittest.TestCase): pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) self.assertIs( - pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), - unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) - self.assertIs( pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), unittest_pb2.ForeignEnum.DESCRIPTOR) + if api_implementation.Type() != 'cpp': + self.skipTest('Only the C++ implementation correctly indexes all types') + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) self.assertIs( pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + self.assertIs( + pool.FindServiceByName('protobuf_unittest.TestService'), + unittest_pb2.DESCRIPTOR.services_by_name['TestService']) def testAddFileDescriptor(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - pool = _message.default_pool + pool = descriptor_pool.Default() file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') pool.Add(file_desc) pool.AddSerializedFile(file_desc.SerializeToString()) diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 1f148ab9..c0010081 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -521,6 +521,12 @@ class GeneratedDescriptorTest(unittest.TestCase): del enum self.assertEqual('FOO', next(values_iter).name) + def testServiceDescriptor(self): + service_descriptor = unittest_pb2.DESCRIPTOR.services_by_name['TestService'] + self.assertEqual(service_descriptor.name, 'TestService') + self.assertEqual(service_descriptor.methods[0].name, 'Foo') + self.assertIs(service_descriptor.file, unittest_pb2.DESCRIPTOR) + class DescriptorCopyToProtoTest(unittest.TestCase): """Tests for CopyTo functions of Descriptor.""" diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 80e59cab..ebec42e5 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -372,7 +372,7 @@ def MapSizer(field_descriptor, is_message_map): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - def EncodeVarint(write, value): + def EncodeVarint(write, value, unused_deterministic): bits = value & 0x7f value >>= 7 while value: @@ -388,7 +388,7 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - def EncodeSignedVarint(write, value): + def EncodeSignedVarint(write, value, unused_deterministic): if value < 0: value += (1 << 64) bits = value & 0x7f @@ -411,7 +411,7 @@ def _VarintBytes(value): called at startup time so it doesn't need to be fast.""" pieces = [] - _EncodeVarint(pieces.append, value) + _EncodeVarint(pieces.append, value, True) return b"".join(pieces) @@ -440,27 +440,27 @@ def _SimpleEncoder(wire_type, encode_value, compute_value_size): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) size = 0 for element in value: size += compute_value_size(element) - local_EncodeVarint(write, size) + local_EncodeVarint(write, size, deterministic) for element in value: - encode_value(write, element) + encode_value(write, element, deterministic) return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag_bytes) - encode_value(write, element) + encode_value(write, element, deterministic) return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag_bytes) - return encode_value(write, value) + return encode_value(write, value, deterministic) return EncodeField return SpecificEncoder @@ -474,27 +474,27 @@ def _ModifiedEncoder(wire_type, encode_value, compute_value_size, modify_value): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) size = 0 for element in value: size += compute_value_size(modify_value(element)) - local_EncodeVarint(write, size) + local_EncodeVarint(write, size, deterministic) for element in value: - encode_value(write, modify_value(element)) + encode_value(write, modify_value(element), deterministic) return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag_bytes) - encode_value(write, modify_value(element)) + encode_value(write, modify_value(element), deterministic) return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag_bytes) - return encode_value(write, modify_value(value)) + return encode_value(write, modify_value(value), deterministic) return EncodeField return SpecificEncoder @@ -515,22 +515,22 @@ def _StructPackEncoder(wire_type, format): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) - local_EncodeVarint(write, len(value) * value_size) + local_EncodeVarint(write, len(value) * value_size, deterministic) for element in value: write(local_struct_pack(format, element)) return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, unused_deterministic): for element in value: write(tag_bytes) write(local_struct_pack(format, element)) return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, unused_deterministic): write(tag_bytes) return write(local_struct_pack(format, value)) return EncodeField @@ -581,9 +581,9 @@ def _FloatingPointEncoder(wire_type, format): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) - local_EncodeVarint(write, len(value) * value_size) + local_EncodeVarint(write, len(value) * value_size, deterministic) for element in value: # This try/except block is going to be faster than any code that # we could write to check whether element is finite. @@ -594,7 +594,7 @@ def _FloatingPointEncoder(wire_type, format): return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_type) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, unused_deterministic): for element in value: write(tag_bytes) try: @@ -604,7 +604,7 @@ def _FloatingPointEncoder(wire_type, format): return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_type) - def EncodeField(write, value): + def EncodeField(write, value, unused_deterministic): write(tag_bytes) try: write(local_struct_pack(format, value)) @@ -650,9 +650,9 @@ def BoolEncoder(field_number, is_repeated, is_packed): if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint - def EncodePackedField(write, value): + def EncodePackedField(write, value, deterministic): write(tag_bytes) - local_EncodeVarint(write, len(value)) + local_EncodeVarint(write, len(value), deterministic) for element in value: if element: write(true_byte) @@ -661,7 +661,7 @@ def BoolEncoder(field_number, is_repeated, is_packed): return EncodePackedField elif is_repeated: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, unused_deterministic): for element in value: write(tag_bytes) if element: @@ -671,7 +671,7 @@ def BoolEncoder(field_number, is_repeated, is_packed): return EncodeRepeatedField else: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_VARINT) - def EncodeField(write, value): + def EncodeField(write, value, unused_deterministic): write(tag_bytes) if value: return write(true_byte) @@ -687,18 +687,18 @@ def StringEncoder(field_number, is_repeated, is_packed): local_len = len assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: encoded = element.encode('utf-8') write(tag) - local_EncodeVarint(write, local_len(encoded)) + local_EncodeVarint(write, local_len(encoded), deterministic) write(encoded) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): encoded = value.encode('utf-8') write(tag) - local_EncodeVarint(write, local_len(encoded)) + local_EncodeVarint(write, local_len(encoded), deterministic) return write(encoded) return EncodeField @@ -711,16 +711,16 @@ def BytesEncoder(field_number, is_repeated, is_packed): local_len = len assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag) - local_EncodeVarint(write, local_len(element)) + local_EncodeVarint(write, local_len(element), deterministic) write(element) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag) - local_EncodeVarint(write, local_len(value)) + local_EncodeVarint(write, local_len(value), deterministic) return write(value) return EncodeField @@ -732,16 +732,16 @@ def GroupEncoder(field_number, is_repeated, is_packed): end_tag = TagBytes(field_number, wire_format.WIRETYPE_END_GROUP) assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(start_tag) - element._InternalSerialize(write) + element._InternalSerialize(write, deterministic) write(end_tag) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(start_tag) - value._InternalSerialize(write) + value._InternalSerialize(write, deterministic) return write(end_tag) return EncodeField @@ -753,17 +753,17 @@ def MessageEncoder(field_number, is_repeated, is_packed): local_EncodeVarint = _EncodeVarint assert not is_packed if is_repeated: - def EncodeRepeatedField(write, value): + def EncodeRepeatedField(write, value, deterministic): for element in value: write(tag) - local_EncodeVarint(write, element.ByteSize()) - element._InternalSerialize(write) + local_EncodeVarint(write, element.ByteSize(), deterministic) + element._InternalSerialize(write, deterministic) return EncodeRepeatedField else: - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(tag) - local_EncodeVarint(write, value.ByteSize()) - return value._InternalSerialize(write) + local_EncodeVarint(write, value.ByteSize(), deterministic) + return value._InternalSerialize(write, deterministic) return EncodeField @@ -790,10 +790,10 @@ def MessageSetItemEncoder(field_number): end_bytes = TagBytes(1, wire_format.WIRETYPE_END_GROUP) local_EncodeVarint = _EncodeVarint - def EncodeField(write, value): + def EncodeField(write, value, deterministic): write(start_bytes) - local_EncodeVarint(write, value.ByteSize()) - value._InternalSerialize(write) + local_EncodeVarint(write, value.ByteSize(), deterministic) + value._InternalSerialize(write, deterministic) return write(end_bytes) return EncodeField @@ -818,9 +818,10 @@ def MapEncoder(field_descriptor): message_type = field_descriptor.message_type encode_message = MessageEncoder(field_descriptor.number, False, False) - def EncodeField(write, value): - for key in value: + def EncodeField(write, value, deterministic): + value_keys = sorted(value.keys()) if deterministic else value.keys() + for key in value_keys: entry_msg = message_type._concrete_class(key=key, value=value[key]) - encode_message(write, entry_msg) + encode_message(write, entry_msg, deterministic) return EncodeField diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto index bb1b54ad..5fcbc5ac 100644 --- a/python/google/protobuf/internal/factory_test2.proto +++ b/python/google/protobuf/internal/factory_test2.proto @@ -97,3 +97,8 @@ message MessageWithNestedEnumOnly { extend Factory1Message { optional string another_field = 1002; } + +message MessageWithOption { + option no_standard_descriptor_accessor = true; + optional int32 field1 = 1; +} diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 5ed65622..077b64db 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -49,6 +49,7 @@ from google.protobuf import field_mask_pb2 from google.protobuf import struct_pb2 from google.protobuf import timestamp_pb2 from google.protobuf import wrappers_pb2 +from google.protobuf import unittest_mset_pb2 from google.protobuf.internal import well_known_types from google.protobuf import json_format from google.protobuf.util import json_format_proto3_pb2 @@ -158,6 +159,84 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse(text, parsed_message) self.assertEqual(message, parsed_message) + def testExtensionToJsonAndBack(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_text = json_format.MessageToJson( + message + ) + parsed_message = unittest_mset_pb2.TestMessageSetContainer() + json_format.Parse(message_text, parsed_message) + self.assertEqual(message, parsed_message) + + def testExtensionToDictAndBack(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_dict = json_format.MessageToDict( + message + ) + parsed_message = unittest_mset_pb2.TestMessageSetContainer() + json_format.ParseDict(message_dict, parsed_message) + self.assertEqual(message, parsed_message) + + def testExtensionSerializationDictMatchesProto3Spec(self): + """See go/proto3-json-spec for spec. + """ + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_dict = json_format.MessageToDict( + message + ) + golden_dict = { + 'messageSet': { + '[protobuf_unittest.' + 'TestMessageSetExtension1.messageSetExtension]': { + 'i': 23, + }, + '[protobuf_unittest.' + 'TestMessageSetExtension2.messageSetExtension]': { + 'str': u'foo', + }, + }, + } + self.assertEqual(golden_dict, message_dict) + + + def testExtensionSerializationJsonMatchesProto3Spec(self): + """See go/proto3-json-spec for spec. + """ + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + message_text = json_format.MessageToJson( + message + ) + ext1_text = ('protobuf_unittest.TestMessageSetExtension1.' + 'messageSetExtension') + ext2_text = ('protobuf_unittest.TestMessageSetExtension2.' + 'messageSetExtension') + golden_text = ('{"messageSet": {' + ' "[%s]": {' + ' "i": 23' + ' },' + ' "[%s]": {' + ' "str": "foo"' + ' }' + '}}') % (ext1_text, ext2_text) + self.assertEqual(json.loads(golden_text), json.loads(message_text)) + + def testJsonEscapeString(self): message = json_format_proto3_pb2.TestMessage() if sys.version_info[0] < 3: @@ -768,7 +847,7 @@ class JsonFormatTest(JsonFormatBase): text = '{"value": "0000-01-01T00:00:00Z"}' self.assertRaisesRegexp( json_format.ParseError, - 'Failed to parse value field: year is out of range.', + 'Failed to parse value field: year (0 )?is out of range.', json_format.Parse, text, message) # Time bigger than maxinum time. message.value.seconds = 253402300800 @@ -840,6 +919,12 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse('{"int32_value": 12345}', message) self.assertEqual(12345, message.int32_value) + def testIndent(self): + message = json_format_proto3_pb2.TestMessage() + message.int32_value = 12345 + self.assertEqual('{\n"int32Value": 12345\n}', + json_format.MessageToJson(message, indent=0)) + def testParseDict(self): expected = 12345 js_dict = {'int32Value': expected} @@ -862,6 +947,22 @@ class JsonFormatTest(JsonFormatBase): parsed_message = json_format_proto3_pb2.TestCustomJsonName() self.CheckParseBack(message, parsed_message) + def testSortKeys(self): + # Testing sort_keys is not perfectly working, as by random luck we could + # get the output sorted. We just use a selection of names. + message = json_format_proto3_pb2.TestMessage(bool_value=True, + int32_value=1, + int64_value=3, + uint32_value=4, + string_value='bla') + self.assertEqual( + json_format.MessageToJson(message, sort_keys=True), + # We use json.dumps() instead of a hardcoded string due to differences + # between Python 2 and Python 3. + json.dumps({'boolValue': True, 'int32Value': 1, 'int64Value': '3', + 'uint32Value': 4, 'stringValue': 'bla'}, + indent=2, sort_keys=True)) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 3373b21f..dda72cdd 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -140,6 +140,42 @@ class MessageTest(BaseTestCase): golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) + def testDeterminismParameters(self, message_module): + # This message is always deterministically serialized, even if determinism + # is disabled, so we can use it to verify that all the determinism + # parameters work correctly. + golden_data = (b'\xe2\x02\nOne string' + b'\xe2\x02\nTwo string' + b'\xe2\x02\nRed string' + b'\xe2\x02\x0bBlue string') + golden_message = message_module.TestAllTypes() + golden_message.repeated_string.extend([ + 'One string', + 'Two string', + 'Red string', + 'Blue string', + ]) + self.assertEqual(golden_data, + golden_message.SerializeToString(deterministic=None)) + self.assertEqual(golden_data, + golden_message.SerializeToString(deterministic=False)) + self.assertEqual(golden_data, + golden_message.SerializeToString(deterministic=True)) + + class BadArgError(Exception): + pass + + class BadArg(object): + + def __nonzero__(self): + raise BadArgError() + + def __bool__(self): + raise BadArgError() + + with self.assertRaises(BadArgError): + golden_message.SerializeToString(deterministic=BadArg()) + def testPickleSupport(self, message_module): golden_data = test_util.GoldenFileData('golden_message') golden_message = message_module.TestAllTypes() @@ -381,6 +417,7 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_int32[0], 1) self.assertEqual(message.repeated_int32[1], 2) self.assertEqual(message.repeated_int32[2], 3) + self.assertEqual(str(message.repeated_int32), str([1, 2, 3])) message.repeated_float.append(1.1) message.repeated_float.append(1.3) @@ -397,6 +434,7 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_string[0], 'a') self.assertEqual(message.repeated_string[1], 'b') self.assertEqual(message.repeated_string[2], 'c') + self.assertEqual(str(message.repeated_string), str([u'a', u'b', u'c'])) message.repeated_bytes.append(b'a') message.repeated_bytes.append(b'c') @@ -405,6 +443,7 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_bytes[0], b'a') self.assertEqual(message.repeated_bytes[1], b'b') self.assertEqual(message.repeated_bytes[2], b'c') + self.assertEqual(str(message.repeated_bytes), str([b'a', b'b', b'c'])) def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): """Check some different types with custom comparator.""" @@ -443,6 +482,8 @@ class MessageTest(BaseTestCase): self.assertEqual(message.repeated_nested_message[3].bb, 4) self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + self.assertEqual(str(message.repeated_nested_message), + '[bb: 1\n, bb: 2\n, bb: 3\n, bb: 4\n, bb: 5\n, bb: 6\n]') def testSortingRepeatedCompositeFieldsStable(self, message_module): """Check passing a custom comparator to sort a repeated composite field.""" @@ -1270,6 +1311,14 @@ class Proto3Test(BaseTestCase): self.assertEqual(1234567, m2.optional_nested_enum) self.assertEqual(7654321, m2.repeated_nested_enum[0]) + # ParseFromString in Proto2 should accept unknown enums too. + m3 = unittest_pb2.TestAllTypes() + m3.ParseFromString(serialized) + m2.Clear() + m2.ParseFromString(m3.SerializeToString()) + self.assertEqual(1234567, m2.optional_nested_enum) + self.assertEqual(7654321, m2.repeated_nested_enum[0]) + # Map isn't really a proto3-only feature. But there is no proto2 equivalent # of google/protobuf/map_unittest.proto right now, so it's not easy to # test both with the same test like we do for the other proto2/proto3 tests. @@ -1441,6 +1490,23 @@ class Proto3Test(BaseTestCase): self.assertIn(-456, msg2.map_int32_foreign_message) self.assertEqual(2, len(msg2.map_int32_foreign_message)) + def testNestedMessageMapItemDelete(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_all_types[1].optional_nested_message.bb = 1 + del msg.map_int32_all_types[1] + msg.map_int32_all_types[2].optional_nested_message.bb = 2 + self.assertEqual(1, len(msg.map_int32_all_types)) + msg.map_int32_all_types[1].optional_nested_message.bb = 1 + self.assertEqual(2, len(msg.map_int32_all_types)) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + keys = [1, 2] + # The loop triggers PyErr_Occurred() in c extension. + for key in keys: + del msg2.map_int32_all_types[key] + def testMapByteSize(self): msg = map_unittest_pb2.TestMap() msg.map_int32_int32[1] = 1 @@ -1655,6 +1721,35 @@ class Proto3Test(BaseTestCase): items2 = msg.map_string_string.items() self.assertEqual(items1, items2) + def testMapDeterministicSerialization(self): + golden_data = (b'r\x0c\n\x07init_op\x12\x01d' + b'r\n\n\x05item1\x12\x01e' + b'r\n\n\x05item2\x12\x01f' + b'r\n\n\x05item3\x12\x01g' + b'r\x0b\n\x05item4\x12\x02QQ' + b'r\x12\n\rlocal_init_op\x12\x01a' + b'r\x0e\n\tsummaries\x12\x01e' + b'r\x18\n\x13trainable_variables\x12\x01b' + b'r\x0e\n\tvariables\x12\x01c') + msg = map_unittest_pb2.TestMap() + msg.map_string_string['local_init_op'] = 'a' + msg.map_string_string['trainable_variables'] = 'b' + msg.map_string_string['variables'] = 'c' + msg.map_string_string['init_op'] = 'd' + msg.map_string_string['summaries'] = 'e' + msg.map_string_string['item1'] = 'e' + msg.map_string_string['item2'] = 'f' + msg.map_string_string['item3'] = 'g' + msg.map_string_string['item4'] = 'QQ' + + # If deterministic serialization is not working correctly, this will be + # "flaky" depending on the exact python dict hash seed. + # + # Fortunately, there are enough items in this map that it is extremely + # unlikely to ever hit the "right" in-order combination, so the test + # itself should fail reliably. + self.assertEqual(golden_data, msg.SerializeToString(deterministic=True)) + def testMapIterationClearMessage(self): # Iterator needs to work even if message and map are deleted. msg = map_unittest_pb2.TestMap() diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto index 11f85ef6..98fcbcb6 100644 --- a/python/google/protobuf/internal/more_extensions_dynamic.proto +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -47,4 +47,5 @@ message DynamicMessageType { extend ExtendedMessage { optional int32 dynamic_int32_extension = 100; optional DynamicMessageType dynamic_message_extension = 101; + repeated DynamicMessageType repeated_dynamic_message_extension = 102; } diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index cb97cb28..c363d843 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -58,6 +58,7 @@ import weakref import six # We use "as" to avoid name collisions with variables. +from google.protobuf.internal import api_implementation from google.protobuf.internal import containers from google.protobuf.internal import decoder from google.protobuf.internal import encoder @@ -1026,29 +1027,34 @@ def _AddByteSizeMethod(message_descriptor, cls): def _AddSerializeToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def SerializeToString(self): + def SerializeToString(self, **kwargs): # Check if the message has all of its required fields set. errors = [] if not self.IsInitialized(): raise message_mod.EncodeError( 'Message %s is missing required fields: %s' % ( self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) - return self.SerializePartialToString() + return self.SerializePartialToString(**kwargs) cls.SerializeToString = SerializeToString def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def SerializePartialToString(self): + def SerializePartialToString(self, **kwargs): out = BytesIO() - self._InternalSerialize(out.write) + self._InternalSerialize(out.write, **kwargs) return out.getvalue() cls.SerializePartialToString = SerializePartialToString - def InternalSerialize(self, write_bytes): + def InternalSerialize(self, write_bytes, deterministic=None): + if deterministic is None: + deterministic = ( + api_implementation.IsPythonDefaultSerializationDeterministic()) + else: + deterministic = bool(deterministic) for field_descriptor, field_value in self.ListFields(): - field_descriptor._encoder(write_bytes, field_value) + field_descriptor._encoder(write_bytes, field_value, deterministic) for tag_bytes, value_bytes in self._unknown_fields: write_bytes(tag_bytes) write_bytes(value_bytes) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 188310b2..424b29cc 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -452,16 +452,18 @@ class TextFormatTest(TextFormatBase): text_format.Parse(text_format.MessageToString(m), m2) self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + def testMergeMultipleOneof(self, message_module): + m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) + m2 = message_module.TestAllTypes() + text_format.Merge(m_string, m2) + self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) + def testParseMultipleOneof(self, message_module): m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) m2 = message_module.TestAllTypes() - if message_module is unittest_pb2: - with self.assertRaisesRegexp(text_format.ParseError, - ' is specified along with field '): - text_format.Parse(m_string, m2) - else: + with self.assertRaisesRegexp(text_format.ParseError, + ' is specified along with field '): text_format.Parse(m_string, m2) - self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) # These are tests that aren't fundamentally specific to proto2, but are at @@ -1026,8 +1028,7 @@ class Proto3Tests(unittest.TestCase): packed_message.data = 'string1' message.repeated_any_value.add().Pack(packed_message) self.assertEqual( - text_format.MessageToString(message, - descriptor_pool=descriptor_pool.Default()), + text_format.MessageToString(message), 'repeated_any_value {\n' ' [type.googleapis.com/protobuf_unittest.OneString] {\n' ' data: "string0"\n' @@ -1039,18 +1040,6 @@ class Proto3Tests(unittest.TestCase): ' }\n' '}\n') - def testPrintMessageExpandAnyNoDescriptorPool(self): - packed_message = unittest_pb2.OneString() - packed_message.data = 'string' - message = any_test_pb2.TestAny() - message.any_value.Pack(packed_message) - self.assertEqual( - text_format.MessageToString(message, descriptor_pool=None), - 'any_value {\n' - ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n' - ' value: "\\n\\006string"\n' - '}\n') - def testPrintMessageExpandAnyDescriptorPoolMissingType(self): packed_message = unittest_pb2.OneString() packed_message.data = 'string' @@ -1071,8 +1060,7 @@ class Proto3Tests(unittest.TestCase): message.any_value.Pack(packed_message) self.assertEqual( text_format.MessageToString(message, - pointy_brackets=True, - descriptor_pool=descriptor_pool.Default()), + pointy_brackets=True), 'any_value <\n' ' [type.googleapis.com/protobuf_unittest.OneString] <\n' ' data: "string"\n' @@ -1086,8 +1074,7 @@ class Proto3Tests(unittest.TestCase): message.any_value.Pack(packed_message) self.assertEqual( text_format.MessageToString(message, - as_one_line=True, - descriptor_pool=descriptor_pool.Default()), + as_one_line=True), 'any_value {' ' [type.googleapis.com/protobuf_unittest.OneString]' ' { data: "string" } ' @@ -1115,12 +1102,12 @@ class Proto3Tests(unittest.TestCase): ' data: "string"\n' ' }\n' '}\n') - text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Merge(text, message) packed_message = unittest_pb2.OneString() message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) message.Clear() - text_format.Parse(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Parse(text, message) packed_message = unittest_pb2.OneString() message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) @@ -1137,7 +1124,7 @@ class Proto3Tests(unittest.TestCase): ' data: "string1"\n' ' }\n' '}\n') - text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Merge(text, message) packed_message = unittest_pb2.OneString() message.repeated_any_value[0].Unpack(packed_message) self.assertEqual('string0', packed_message.data) @@ -1151,22 +1138,22 @@ class Proto3Tests(unittest.TestCase): ' data: "string"\n' ' >\n' '}\n') - text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + text_format.Merge(text, message) packed_message = unittest_pb2.OneString() message.any_value.Unpack(packed_message) self.assertEqual('string', packed_message.data) - def testMergeExpandedAnyNoDescriptorPool(self): + def testMergeAlternativeUrl(self): message = any_test_pb2.TestAny() text = ('any_value {\n' - ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' [type.otherapi.com/protobuf_unittest.OneString] {\n' ' data: "string"\n' ' }\n' '}\n') - with self.assertRaises(text_format.ParseError) as e: - text_format.Merge(text, message, descriptor_pool=None) - self.assertEqual(str(e.exception), - 'Descriptor pool required to parse expanded Any field') + text_format.Merge(text, message) + packed_message = unittest_pb2.OneString() + self.assertEqual('type.otherapi.com/protobuf_unittest.OneString', + message.any_value.type_url) def testMergeExpandedAnyDescriptorPoolMissingType(self): message = any_test_pb2.TestAny() @@ -1425,5 +1412,101 @@ class TokenizerTest(unittest.TestCase): tokenizer.ConsumeCommentOrTrailingComment()) self.assertTrue(tokenizer.AtEnd()) + +# Tests for pretty printer functionality. +@_parameterized.Parameters((unittest_pb2), (unittest_proto3_arena_pb2)) +class PrettyPrinterTest(TextFormatBase): + + def testPrettyPrintNoMatch(self, message_module): + + def printer(message, indent, as_one_line): + del message, indent, as_one_line + return None + + message = message_module.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'repeated_nested_message { bb: 42 }') + + def testPrettyPrintOneLine(self, message_module): + + def printer(m, indent, as_one_line): + del indent, as_one_line + if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR: + return 'My lucky number is %s' % m.bb + + message = message_module.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'repeated_nested_message { My lucky number is 42 }') + + def testPrettyPrintMultiLine(self, message_module): + + def printer(m, indent, as_one_line): + if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR: + line_deliminator = (' ' if as_one_line else '\n') + ' ' * indent + return 'My lucky number is:%s%s' % (line_deliminator, m.bb) + return None + + message = message_module.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'repeated_nested_message { My lucky number is: 42 }') + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=False, message_formatter=printer), + 'repeated_nested_message {\n My lucky number is:\n 42\n}\n') + + def testPrettyPrintEntireMessage(self, message_module): + + def printer(m, indent, as_one_line): + del indent, as_one_line + if m.DESCRIPTOR == message_module.TestAllTypes.DESCRIPTOR: + return 'The is the message!' + return None + + message = message_module.TestAllTypes() + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=False, message_formatter=printer), + 'The is the message!\n') + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + 'The is the message!') + + def testPrettyPrintMultipleParts(self, message_module): + + def printer(m, indent, as_one_line): + del indent, as_one_line + if m.DESCRIPTOR == message_module.TestAllTypes.NestedMessage.DESCRIPTOR: + return 'My lucky number is %s' % m.bb + return None + + message = message_module.TestAllTypes() + message.optional_int32 = 61 + msg = message.repeated_nested_message.add() + msg.bb = 42 + msg = message.repeated_nested_message.add() + msg.bb = 99 + msg = message.optional_nested_message + msg.bb = 1 + self.CompareToGoldenText( + text_format.MessageToString( + message, as_one_line=True, message_formatter=printer), + ('optional_int32: 61 ' + 'optional_nested_message { My lucky number is 1 } ' + 'repeated_nested_message { My lucky number is 42 } ' + 'repeated_nested_message { My lucky number is 99 }')) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d631abee..d0c7ffda 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -350,12 +350,12 @@ class Duration(object): self.nanos, _NANOS_PER_MICROSECOND)) def FromTimedelta(self, td): - """Convertd timedelta to Duration.""" + """Converts timedelta to Duration.""" self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, td.microseconds * _NANOS_PER_MICROSECOND) def _NormalizeDuration(self, seconds, nanos): - """Set Duration by seconds and nonas.""" + """Set Duration by seconds and nanos.""" # Force nanos to be negative if the duration is negative. if seconds < 0 and nanos > 0: seconds += 1 diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 077f630f..123a537c 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -284,7 +284,7 @@ class TimeUtilTest(TimeUtilTestBase): '1972-01-01T01:00:00.01+08',) self.assertRaisesRegexp( ValueError, - 'year is out of range', + 'year (0 )?is out of range', message.FromJsonString, '0000-01-01T00:00:00Z') message.seconds = 253402300800 |