From cf14183bcd5485b4a71541599ddce0b35eb71352 Mon Sep 17 00:00:00 2001 From: Jisi Liu Date: Thu, 28 Apr 2016 14:34:59 -0700 Subject: Down integrate from Google internal. --- python/google/protobuf/descriptor.py | 9 +- python/google/protobuf/descriptor_pool.py | 4 +- .../google/protobuf/internal/api_implementation.py | 6 + .../protobuf/internal/descriptor_pool_test.py | 46 +- .../protobuf/internal/descriptor_pool_test2.proto | 1 + .../protobuf/internal/message_factory_test.py | 54 ++ python/google/protobuf/internal/message_test.py | 34 +- .../google/protobuf/internal/proto_builder_test.py | 1 + python/google/protobuf/internal/python_message.py | 76 +- python/google/protobuf/internal/reflection_test.py | 12 +- .../google/protobuf/internal/text_format_test.py | 93 +++ python/google/protobuf/internal/type_checkers.py | 23 +- .../protobuf/internal/unknown_fields_test.py | 20 + .../google/protobuf/internal/well_known_types.py | 8 +- .../protobuf/internal/well_known_types_test.py | 8 + python/google/protobuf/message.py | 3 + python/google/protobuf/pyext/descriptor.cc | 11 +- python/google/protobuf/pyext/descriptor_pool.cc | 8 +- python/google/protobuf/pyext/descriptor_pool.h | 11 +- python/google/protobuf/pyext/extension_dict.cc | 19 +- python/google/protobuf/pyext/extension_dict.h | 6 + python/google/protobuf/pyext/map_container.cc | 17 +- python/google/protobuf/pyext/map_container.h | 7 +- python/google/protobuf/pyext/message.cc | 120 +-- python/google/protobuf/pyext/message.h | 33 +- .../protobuf/pyext/repeated_composite_container.cc | 14 +- .../protobuf/pyext/repeated_composite_container.h | 7 +- python/google/protobuf/text_format.py | 816 +++++++++++---------- 28 files changed, 927 insertions(+), 540 deletions(-) (limited to 'python/google') diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 5f613c88..3209b34d 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -783,6 +783,8 @@ class FileDescriptor(DescriptorBase): serialized_pb: (str) Byte string of serialized descriptor_pb2.FileDescriptorProto. dependencies: List of other FileDescriptors this FileDescriptor depends on. + public_dependencies: A list of FileDescriptors, subset of the dependencies + above, which were declared as "public". message_types_by_name: Dict of message names of their descriptors. enum_types_by_name: Dict of enum names and their descriptors. extensions_by_name: Dict of extension names and their descriptors. @@ -794,7 +796,8 @@ class FileDescriptor(DescriptorBase): _C_DESCRIPTOR_CLASS = _message.FileDescriptor def __new__(cls, name, package, options=None, serialized_pb=None, - dependencies=None, syntax=None, pool=None): + dependencies=None, public_dependencies=None, + syntax=None, pool=None): # FileDescriptor() is called from various places, not only from generated # files, to register dynamic proto files and messages. if serialized_pb: @@ -805,7 +808,8 @@ class FileDescriptor(DescriptorBase): return super(FileDescriptor, cls).__new__(cls) def __init__(self, name, package, options=None, serialized_pb=None, - dependencies=None, syntax=None, pool=None): + dependencies=None, public_dependencies=None, + syntax=None, pool=None): """Constructor.""" super(FileDescriptor, self).__init__(options, 'FileOptions') @@ -822,6 +826,7 @@ class FileDescriptor(DescriptorBase): self.enum_types_by_name = {} self.extensions_by_name = {} self.dependencies = (dependencies or []) + self.public_dependencies = (public_dependencies or []) if (api_implementation.Type() == 'cpp' and self.serialized_pb is not None): diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 3e80795c..20a33701 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -319,6 +319,7 @@ class DescriptorPool(object): if file_proto.name not in self._file_descriptors: built_deps = list(self._GetDeps(file_proto.dependency)) direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] + public_deps = [direct_deps[i] for i in file_proto.public_dependency] file_descriptor = descriptor.FileDescriptor( pool=self, @@ -327,7 +328,8 @@ class DescriptorPool(object): syntax=file_proto.syntax, options=file_proto.options, serialized_pb=file_proto.SerializeToString(), - dependencies=direct_deps) + dependencies=direct_deps, + public_dependencies=public_deps) if _USE_C_DESCRIPTORS: # When using C++ descriptors, all objects defined in the file were added # to the C++ database when the FileDescriptor was built above. diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index ffcf7511..460a4a6c 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -32,6 +32,7 @@ """ import os +import warnings import sys try: @@ -78,6 +79,11 @@ _implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', if _implementation_type != 'python': _implementation_type = 'cpp' +if 'PyPy' in sys.version and _implementation_type == 'cpp': + warnings.warn('PyPy does not work yet with cpp protocol buffers. ' + 'Falling back to the python implementation.') + _implementation_type = 'python' + # This environment variable can be used to switch between the two # 'cpp' implementations, overriding the compile-time constants in the # _api_implementation module. Right now only '2' is supported. Any other diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 4b1811d8..6a13e0bc 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -51,6 +51,7 @@ from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 +from google.protobuf.internal import more_messages_pb2 from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool @@ -60,11 +61,8 @@ from google.protobuf import symbol_database class DescriptorPoolTest(unittest.TestCase): - def CreatePool(self): - return descriptor_pool.DescriptorPool() - def setUp(self): - self.pool = self.CreatePool() + self.pool = descriptor_pool.DescriptorPool() self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( factory_test1_pb2.DESCRIPTOR.serialized_pb) self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( @@ -275,10 +273,13 @@ class DescriptorPoolTest(unittest.TestCase): self.testFindMessageTypeByName() def testComplexNesting(self): + more_messages_desc = descriptor_pb2.FileDescriptorProto.FromString( + more_messages_pb2.DESCRIPTOR.serialized_pb) test1_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) test2_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(more_messages_desc) self.pool.Add(test1_desc) self.pool.Add(test2_desc) TEST1_FILE.CheckFile(self, self.pool) @@ -350,25 +351,15 @@ class DescriptorPoolTest(unittest.TestCase): _CheckDefaultValues(message_class()) -@unittest.skipIf(api_implementation.Type() != 'cpp', - 'explicit tests of the C++ implementation') -class CppDescriptorPoolTest(DescriptorPoolTest): - # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true - # C++ descriptor pool object for C++ implementation. - - def CreatePool(self): - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - return _message.DescriptorPool() - - class ProtoFile(object): - def __init__(self, name, package, messages, dependencies=None): + def __init__(self, name, package, messages, dependencies=None, + public_dependencies=None): self.name = name self.package = package self.messages = messages self.dependencies = dependencies or [] + self.public_dependencies = public_dependencies or [] def CheckFile(self, test, pool): file_desc = pool.FindFileByName(self.name) @@ -376,6 +367,8 @@ class ProtoFile(object): test.assertEqual(self.package, file_desc.package) dependencies_names = [f.name for f in file_desc.dependencies] test.assertEqual(self.dependencies, dependencies_names) + public_dependencies_names = [f.name for f in file_desc.public_dependencies] + test.assertEqual(self.public_dependencies, public_dependencies_names) for name, msg_type in self.messages.items(): msg_type.CheckType(test, None, name, file_desc) @@ -613,18 +606,9 @@ class AddDescriptorTest(unittest.TestCase): pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes') - def _GetDescriptorPoolClass(self): - # Test with both implementations of descriptor pools. - if api_implementation.Type() == 'cpp': - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - return _message.DescriptorPool - else: - return descriptor_pool.DescriptorPool - def testEmptyDescriptorPool(self): - # Check that an empty DescriptorPool() contains no message. - pool = self._GetDescriptorPoolClass()() + # Check that an empty DescriptorPool() contains no messages. + pool = descriptor_pool.DescriptorPool() proto_file_name = descriptor_pb2.DESCRIPTOR.name self.assertRaises(KeyError, pool.FindFileByName, proto_file_name) # Add the above file to the pool @@ -636,7 +620,7 @@ class AddDescriptorTest(unittest.TestCase): def testCustomDescriptorPool(self): # Create a new pool, and add a file descriptor. - pool = self._GetDescriptorPoolClass()() + pool = descriptor_pool.DescriptorPool() file_desc = descriptor_pb2.FileDescriptorProto( name='some/file.proto', package='package') file_desc.message_type.add(name='Message') @@ -757,7 +741,9 @@ TEST2_FILE = ProtoFile( ExtensionField(1001, 'DescriptorPoolTest1')), ]), }, - dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto', + 'google/protobuf/internal/more_messages.proto'], + public_dependencies=['google/protobuf/internal/more_messages.proto']) if __name__ == '__main__': diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto index e3fa660c..a218eccb 100644 --- a/python/google/protobuf/internal/descriptor_pool_test2.proto +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -33,6 +33,7 @@ syntax = "proto2"; package google.protobuf.python.internal; import "google/protobuf/internal/descriptor_pool_test1.proto"; +import public "google/protobuf/internal/more_messages.proto"; message DescriptorPoolTest3 { diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 54b1f688..7bb7d1ac 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -131,6 +131,60 @@ class MessageFactoryTest(unittest.TestCase): self.assertEqual('test1', msg1.Extensions[ext1]) self.assertEqual('test2', msg1.Extensions[ext2]) + def testDuplicateExtensionNumber(self): + pool = descriptor_pool.DescriptorPool() + factory = message_factory.MessageFactory(pool=pool) + + # Add Container message. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/container.proto' + f.package = 'google.protobuf.python.internal' + msg = f.message_type.add() + msg.name = 'Container' + rng = msg.extension_range.add() + rng.start = 1 + rng.end = 10 + pool.Add(f) + msgs = factory.GetMessages([f.name]) + self.assertIn('google.protobuf.python.internal.Container', msgs) + + # Extend container. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/extension.proto' + f.package = 'google.protobuf.python.internal' + f.dependency.append('google/protobuf/internal/container.proto') + msg = f.message_type.add() + msg.name = 'Extension' + ext = msg.extension.add() + ext.name = 'extension_field' + ext.number = 2 + ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ext.type_name = 'Extension' + ext.extendee = 'Container' + pool.Add(f) + msgs = factory.GetMessages([f.name]) + self.assertIn('google.protobuf.python.internal.Extension', msgs) + + # Add Duplicate extending the same field number. + f = descriptor_pb2.FileDescriptorProto() + f.name = 'google/protobuf/internal/duplicate.proto' + f.package = 'google.protobuf.python.internal' + f.dependency.append('google/protobuf/internal/container.proto') + msg = f.message_type.add() + msg.name = 'Duplicate' + ext = msg.extension.add() + ext.name = 'extension_field' + ext.number = 2 + ext.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + ext.type_name = 'Duplicate' + ext.extendee = 'Container' + pool.Add(f) + + with self.assertRaises(Exception) as cm: + factory.GetMessages([f.name]) + + self.assertIsInstance(cm.exception, (AssertionError, ValueError)) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 1232ccc9..4ee31d8e 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -57,18 +57,18 @@ try: except ImportError: import unittest -from google.protobuf.internal import _parameterized +from google.protobuf import map_unittest_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pool -from google.protobuf import map_unittest_pb2 from google.protobuf import message_factory from google.protobuf import text_format -from google.protobuf import unittest_pb2 -from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message +from google.protobuf.internal import _parameterized if six.PY3: long = int @@ -1265,7 +1265,10 @@ class Proto3Test(unittest.TestCase): self.assertFalse(-2**33 in msg.map_int64_int64) self.assertFalse(123 in msg.map_uint32_uint32) self.assertFalse(2**33 in msg.map_uint64_uint64) + self.assertFalse(123 in msg.map_int32_double) + self.assertFalse(False in msg.map_bool_bool) self.assertFalse('abc' in msg.map_string_string) + self.assertFalse(111 in msg.map_int32_bytes) self.assertFalse(888 in msg.map_int32_enum) # Accessing an unset key returns the default. @@ -1273,7 +1276,12 @@ class Proto3Test(unittest.TestCase): self.assertEqual(0, msg.map_int64_int64[-2**33]) self.assertEqual(0, msg.map_uint32_uint32[123]) self.assertEqual(0, msg.map_uint64_uint64[2**33]) + self.assertEqual(0.0, msg.map_int32_double[123]) + self.assertTrue(isinstance(msg.map_int32_double[123], float)) + self.assertEqual(False, msg.map_bool_bool[False]) + self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) self.assertEqual('', msg.map_string_string['abc']) + self.assertEqual(b'', msg.map_int32_bytes[111]) self.assertEqual(0, msg.map_int32_enum[888]) # It also sets the value in the map @@ -1281,7 +1289,10 @@ class Proto3Test(unittest.TestCase): self.assertTrue(-2**33 in msg.map_int64_int64) self.assertTrue(123 in msg.map_uint32_uint32) self.assertTrue(2**33 in msg.map_uint64_uint64) + self.assertTrue(123 in msg.map_int32_double) + self.assertTrue(False in msg.map_bool_bool) self.assertTrue('abc' in msg.map_string_string) + self.assertTrue(111 in msg.map_int32_bytes) self.assertTrue(888 in msg.map_int32_enum) self.assertIsInstance(msg.map_string_string['abc'], six.text_type) @@ -1587,6 +1598,21 @@ class Proto3Test(unittest.TestCase): matching_dict = {2: 4, 3: 6, 4: 8} self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) + def testMapItems(self): + # Map items used to have strange behaviors when use c extension. Because + # [] may reorder the map and invalidate any exsting iterators. + # TODO(jieluo): Check if [] reordering the map is a bug or intended + # behavior. + msg = map_unittest_pb2.TestMap() + msg.map_string_string['local_init_op'] = '' + msg.map_string_string['trainable_variables'] = '' + msg.map_string_string['variables'] = '' + msg.map_string_string['init_op'] = '' + msg.map_string_string['summaries'] = '' + items1 = msg.map_string_string.items() + items2 = msg.map_string_string.items() + self.assertEqual(items1, items2) + def testMapIterationClearMessage(self): # Iterator needs to work even if message and map are deleted. msg = map_unittest_pb2.TestMap() diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py index 822ad895..36dfbfde 100644 --- a/python/google/protobuf/internal/proto_builder_test.py +++ b/python/google/protobuf/internal/proto_builder_test.py @@ -40,6 +40,7 @@ try: import unittest2 as unittest except ImportError: import unittest + from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pool from google.protobuf import proto_builder diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 87f60666..f8f73dd2 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -56,7 +56,14 @@ import struct import weakref import six -import six.moves.copyreg as copyreg +try: + import six.moves.copyreg as copyreg +except ImportError: + # On some platforms, for example gMac, we run native Python because there is + # nothing like hermetic Python. This means lesser control on the system and + # the six.moves package may be missing (is missing on 20150321 on gMac). Be + # extra conservative and try to load the old replacement if it fails. + import copy_reg as copyreg # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers @@ -490,6 +497,9 @@ def _AddInitMethod(message_descriptor, cls): if field is None: raise TypeError("%s() got an unexpected keyword argument '%s'" % (message_descriptor.name, field_name)) + if field_value is None: + # field=None is the same as no field at all. + continue if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite @@ -889,17 +899,6 @@ def _AddClearExtensionMethod(cls): cls.ClearExtension = ClearExtension -def _AddClearMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - def Clear(self): - # Clear fields. - self._fields = {} - self._unknown_fields = () - self._oneofs = {} - self._Modified() - cls.Clear = Clear - - def _AddHasExtensionMethod(cls): """Helper for _AddMessageMethods().""" def HasExtension(self, extension_handle): @@ -999,16 +998,6 @@ def _AddUnicodeMethod(unused_message_descriptor, cls): cls.__unicode__ = __unicode__ -def _AddSetListenerMethod(cls): - """Helper for _AddMessageMethods().""" - def SetListener(self, listener): - if listener is None: - self._listener = message_listener_mod.NullMessageListener() - else: - self._listener = listener - cls._SetListener = SetListener - - def _BytesForNonRepeatedElement(value, field_number, field_type): """Returns the number of bytes needed to serialize a non-repeated element. The returned byte count includes space for tag information and any @@ -1288,6 +1277,32 @@ def _AddWhichOneofMethod(message_descriptor, cls): cls.WhichOneof = WhichOneof +def _Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + self._oneofs = {} + self._Modified() + + +def _DiscardUnknownFields(self): + self._unknown_fields = [] + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for sub_message in value: + sub_message.DiscardUnknownFields() + else: + value.DiscardUnknownFields() + + +def _SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + + def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) @@ -1296,12 +1311,10 @@ def _AddMessageMethods(message_descriptor, cls): if message_descriptor.is_extendable: _AddClearExtensionMethod(cls) _AddHasExtensionMethod(cls) - _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) _AddReprMethod(message_descriptor, cls) _AddUnicodeMethod(message_descriptor, cls) - _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) _AddSerializeToStringMethod(message_descriptor, cls) _AddSerializePartialToStringMethod(message_descriptor, cls) @@ -1309,6 +1322,10 @@ def _AddMessageMethods(message_descriptor, cls): _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) _AddWhichOneofMethod(message_descriptor, cls) + # Adds methods which do not depend on cls. + cls.Clear = _Clear + cls.DiscardUnknownFields = _DiscardUnknownFields + cls._SetListener = _SetListener def _AddPrivateHelperMethods(message_descriptor, cls): @@ -1518,3 +1535,14 @@ class _ExtensionDict(object): Extension field descriptor. """ return self._extended_message._extensions_by_name.get(name, None) + + def _FindExtensionByNumber(self, number): + """Tries to find a known extension with the field number. + + Args: + number: Extension field number. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_number.get(number, None) diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 9e61ea0e..6dc2fffe 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -120,11 +120,13 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes( optional_int32=24, optional_double=54.321, - optional_string='optional_string') + optional_string='optional_string', + optional_float=None) self.assertEqual(24, proto.optional_int32) self.assertEqual(54.321, proto.optional_double) self.assertEqual('optional_string', proto.optional_string) + self.assertFalse(proto.HasField("optional_float")) def testRepeatedScalarConstructor(self): # Constructor with only repeated scalar types should succeed. @@ -132,12 +134,14 @@ class ReflectionTest(unittest.TestCase): repeated_int32=[1, 2, 3, 4], repeated_double=[1.23, 54.321], repeated_bool=[True, False, False], - repeated_string=["optional_string"]) + repeated_string=["optional_string"], + repeated_float=None) self.assertEqual([1, 2, 3, 4], list(proto.repeated_int32)) self.assertEqual([1.23, 54.321], list(proto.repeated_double)) self.assertEqual([True, False, False], list(proto.repeated_bool)) self.assertEqual(["optional_string"], list(proto.repeated_string)) + self.assertEqual([], list(proto.repeated_float)) def testRepeatedCompositeConstructor(self): # Constructor with only repeated composite types should succeed. @@ -188,7 +192,8 @@ class ReflectionTest(unittest.TestCase): repeated_foreign_message=[ unittest_pb2.ForeignMessage(c=-43), unittest_pb2.ForeignMessage(c=45324), - unittest_pb2.ForeignMessage(c=12)]) + unittest_pb2.ForeignMessage(c=12)], + optional_nested_message=None) self.assertEqual(24, proto.optional_int32) self.assertEqual('optional_string', proto.optional_string) @@ -205,6 +210,7 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignMessage(c=45324), unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) + self.assertFalse(proto.HasField("optional_nested_message")) def testConstructorTypeError(self): self.assertRaises( diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 8ce0a44f..ab2bf05b 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -250,6 +250,36 @@ class TextFormatTest(TextFormatBase): message.c = 123 self.assertEqual('c: 123\n', str(message)) + def testPrintField(self, message_module): + message = message_module.TestAllTypes() + field = message.DESCRIPTOR.fields_by_name['optional_float'] + value = message.optional_float + out = text_format.TextWriter(False) + text_format.PrintField(field, value, out) + self.assertEqual('optional_float: 0.0\n', out.getvalue()) + out.close() + # Test Printer + out = text_format.TextWriter(False) + printer = text_format._Printer(out) + printer.PrintField(field, value) + self.assertEqual('optional_float: 0.0\n', out.getvalue()) + out.close() + + def testPrintFieldValue(self, message_module): + message = message_module.TestAllTypes() + field = message.DESCRIPTOR.fields_by_name['optional_float'] + value = message.optional_float + out = text_format.TextWriter(False) + text_format.PrintFieldValue(field, value, out) + self.assertEqual('0.0', out.getvalue()) + out.close() + # Test Printer + out = text_format.TextWriter(False) + printer = text_format._Printer(out) + printer.PrintFieldValue(field, value) + self.assertEqual('0.0', out.getvalue()) + out.close() + def testParseAllFields(self, message_module): message = message_module.TestAllTypes() test_util.SetAllFields(message) @@ -616,6 +646,26 @@ class Proto2Tests(TextFormatBase): ' text: \"bar\"\n' '}\n') + def testPrintMessageSetByFieldNumber(self): + out = text_format.TextWriter(False) + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + text_format.PrintMessage(message, out, use_field_number=True) + self.CompareToGoldenText( + out.getvalue(), + '1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') + out.close() + def testPrintMessageSetAsOneLine(self): message = unittest_mset_pb2.TestMessageSetContainer() ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension @@ -656,6 +706,48 @@ class Proto2Tests(TextFormatBase): self.assertEqual(23, message.message_set.Extensions[ext1].i) self.assertEqual('foo', message.message_set.Extensions[ext2].str) + def testParseMessageByFieldNumber(self): + message = unittest_pb2.TestAllTypes() + text = ('34: 1\n' + 'repeated_uint64: 2\n') + text_format.Parse(text, message, allow_field_number=True) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message, allow_field_number=True) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEqual(23, message.message_set.Extensions[ext1].i) + self.assertEqual('foo', message.message_set.Extensions[ext2].str) + + # Can't parse field number without set allow_field_number=True. + message = unittest_pb2.TestAllTypes() + text = '34:1\n' + six.assertRaisesRegex( + self, + text_format.ParseError, + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"34".'), + text_format.Parse, text, message) + + # Can't parse if field number is not found. + text = '1234:1\n' + six.assertRaisesRegex( + self, + text_format.ParseError, + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"1234".'), + text_format.Parse, text, message, allow_field_number=True) + def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) @@ -696,6 +788,7 @@ class Proto2Tests(TextFormatBase): text = ('message_set {\n' ' [unknown_extension] {\n' ' i: 23\n' + ' bin: "\xe0"' ' [nested_unknown_ext]: {\n' ' i: 23\n' ' test: "test_string"\n' diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index f30ca6a8..1be3ad9a 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -109,6 +109,16 @@ class TypeChecker(object): return proposed_value +class TypeCheckerWithDefault(TypeChecker): + + def __init__(self, default_value, *acceptable_types): + TypeChecker.__init__(self, acceptable_types) + self._default_value = default_value + + def DefaultValue(self): + return self._default_value + + # IntValueChecker and its subclasses perform integer type-checks # and bounds-checks. class IntValueChecker(object): @@ -212,12 +222,13 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_INT64: Int64ValueChecker(), _FieldDescriptor.CPPTYPE_UINT32: Uint32ValueChecker(), _FieldDescriptor.CPPTYPE_UINT64: Uint64ValueChecker(), - _FieldDescriptor.CPPTYPE_DOUBLE: TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( - float, int, long), - _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), + _FieldDescriptor.CPPTYPE_DOUBLE: TypeCheckerWithDefault( + 0.0, float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: TypeCheckerWithDefault( + 0.0, float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: TypeCheckerWithDefault( + False, bool, int), + _FieldDescriptor.CPPTYPE_STRING: TypeCheckerWithDefault(b'', bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index bb2748e4..84073f1c 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -119,6 +119,26 @@ class UnknownFieldsTest(unittest.TestCase): message.ParseFromString(self.all_fields.SerializeToString()) self.assertNotEqual(self.empty_message, message) + def testDiscardUnknownFields(self): + self.empty_message.DiscardUnknownFields() + self.assertEqual(b'', self.empty_message.SerializeToString()) + # Test message field and repeated message field. + message = unittest_pb2.TestAllTypes() + other_message = unittest_pb2.TestAllTypes() + other_message.optional_string = 'discard' + message.optional_nested_message.ParseFromString( + other_message.SerializeToString()) + message.repeated_nested_message.add().ParseFromString( + other_message.SerializeToString()) + self.assertNotEqual( + b'', message.optional_nested_message.SerializeToString()) + self.assertNotEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + message.DiscardUnknownFields() + self.assertEqual(b'', message.optional_nested_message.SerializeToString()) + self.assertEqual( + b'', message.repeated_nested_message[0].SerializeToString()) + class UnknownFieldsAccessorsTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index d35fcc5f..7c5dffd0 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -82,10 +82,14 @@ class Any(object): msg.ParseFromString(self.value) return True + def TypeName(self): + """Returns the protobuf type name of the inner message.""" + # Only last part is to be used: b/25630112 + return self.type_url.split('/')[-1] + def Is(self, descriptor): """Checks if this Any represents the given protobuf type.""" - # Only last part is to be used: b/25630112 - return self.type_url.split('/')[-1] == descriptor.full_name + return self.TypeName() == descriptor.full_name class Timestamp(object): diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 18329205..2f32ac99 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -610,6 +610,14 @@ class AnyTest(unittest.TestCase): raise AttributeError('%s should not have Pack method.' % msg_descriptor.full_name) + def testMessageName(self): + # Creates and sets message. + submessage = any_test_pb2.TestAny() + submessage.int_value = 12345 + msg = any_pb2.Any() + msg.Pack(submessage) + self.assertEqual(msg.TypeName(), 'google.protobuf.internal.TestAny') + def testPackWithCustomTypeUrl(self): submessage = any_test_pb2.TestAny() submessage.int_value = 12345 diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index de2f5697..606f735f 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -255,6 +255,9 @@ class Message(object): def ClearExtension(self, extension_handle): raise NotImplementedError + def DiscardUnknownFields(self): + raise NotImplementedError + def ByteSize(self): """Returns the serialized size of this message. Recursively calls ByteSize() on all contained messages. diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 07550706..23557538 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -200,8 +200,8 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { // read-only instance. const Message& options(descriptor->options()); const Descriptor *message_type = options.GetDescriptor(); - PyObject* message_class(cdescriptor_pool::GetMessageClass( - pool, message_type)); + CMessageClass* message_class( + cdescriptor_pool::GetMessageClass(pool, message_type)); if (message_class == NULL) { // The Options message was not found in the current DescriptorPool. // In this case, there cannot be extensions to these options, and we can @@ -215,7 +215,8 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { message_type->full_name().c_str()); return NULL; } - ScopedPyObjectPtr value(PyEval_CallObject(message_class, NULL)); + ScopedPyObjectPtr value( + PyEval_CallObject(message_class->AsPyObject(), NULL)); if (value == NULL) { return NULL; } @@ -433,11 +434,11 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { // which contains this descriptor. // This might not be the one you expect! For example the returned object does // not know about extensions defined in a custom pool. - PyObject* concrete_class(cdescriptor_pool::GetMessageClass( + CMessageClass* concrete_class(cdescriptor_pool::GetMessageClass( GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()), _GetDescriptor(self))); Py_XINCREF(concrete_class); - return concrete_class; + return concrete_class->AsPyObject(); } static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) { diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index 0bc76bc9..1faff96b 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -190,8 +190,8 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) { // Add a message class to our database. int RegisterMessageClass(PyDescriptorPool* self, - const Descriptor *message_descriptor, - PyObject *message_class) { + const Descriptor* message_descriptor, + CMessageClass* message_class) { Py_INCREF(message_class); typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; std::pair ret = self->classes_by_descriptor->insert( @@ -205,8 +205,8 @@ int RegisterMessageClass(PyDescriptorPool* self, } // Retrieve the message class added to our database. -PyObject *GetMessageClass(PyDescriptorPool* self, - const Descriptor *message_descriptor) { +CMessageClass* GetMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor) { typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; iterator ret = self->classes_by_descriptor->find(message_descriptor); if (ret == self->classes_by_descriptor->end()) { diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index 16bc910c..2a42c112 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -42,6 +42,9 @@ class MessageFactory; namespace python { +// The (meta) type of all Messages classes. +struct CMessageClass; + // Wraps operations to the global DescriptorPool which contains information // about all messages and fields. // @@ -78,7 +81,7 @@ typedef struct PyDescriptorPool { // // Descriptor pointers stored here are owned by the DescriptorPool above. // Python references to classes are owned by this PyDescriptorPool. - typedef hash_map ClassesByMessageMap; + typedef hash_map ClassesByMessageMap; ClassesByMessageMap* classes_by_descriptor; // Cache the options for any kind of descriptor. @@ -101,14 +104,14 @@ const Descriptor* FindMessageTypeByName(PyDescriptorPool* self, // On error, returns -1 with a Python exception set. int RegisterMessageClass(PyDescriptorPool* self, const Descriptor* message_descriptor, - PyObject* message_class); + CMessageClass* message_class); // Retrieves the Python class registered with the given message descriptor. // // Returns a *borrowed* reference if found, otherwise returns NULL with an // exception set. -PyObject* GetMessageClass(PyDescriptorPool* self, - const Descriptor* message_descriptor); +CMessageClass* GetMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor); // The functions below are also exposed as methods of the DescriptorPool type. diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 555bd293..21bbb8c2 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -130,7 +130,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( cmessage::GetDescriptorPoolForMessage(self->parent), descriptor->message_type()); if (message_class == NULL) { @@ -239,6 +239,21 @@ PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { } } +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number) { + ScopedPyObjectPtr extensions_by_number(PyObject_GetAttrString( + reinterpret_cast(self->parent), "_extensions_by_number")); + if (extensions_by_number == NULL) { + return NULL; + } + PyObject* result = PyDict_GetItem(extensions_by_number.get(), number); + if (result == NULL) { + Py_RETURN_NONE; + } else { + Py_INCREF(result); + return result; + } +} + ExtensionDict* NewExtensionDict(CMessage *parent) { ExtensionDict* self = reinterpret_cast( PyType_GenericAlloc(&ExtensionDict_Type, 0)); @@ -271,6 +286,8 @@ static PyMethodDef Methods[] = { EDMETHOD(HasExtension, METH_O, "Checks if the object has an extension."), EDMETHOD(_FindExtensionByName, METH_O, "Finds an extension by name."), + EDMETHOD(_FindExtensionByNumber, METH_O, + "Finds an extension by field number."), { NULL, NULL } }; diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h index 1e7f6f7b..049d2e45 100644 --- a/python/google/protobuf/pyext/extension_dict.h +++ b/python/google/protobuf/pyext/extension_dict.h @@ -123,6 +123,12 @@ PyObject* ClearExtension(ExtensionDict* self, // Returns a new reference. PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name); +// Gets an extension from the dict given the extension field number as +// opposed to descriptor. +// +// Returns a new reference. +PyObject* _FindExtensionByNumber(ExtensionDict* self, PyObject* number); + } // namespace extension_dict } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index df9138a4..e022406d 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -32,6 +32,11 @@ #include +#include +#ifndef _SHARED_PTR_H +#include +#endif + #include #include #include @@ -70,7 +75,7 @@ class MapReflectionFriend { struct MapIterator { PyObject_HEAD; - scoped_ptr< ::google::protobuf::MapIterator> iter; + google::protobuf::scoped_ptr< ::google::protobuf::MapIterator> iter; // A pointer back to the container, so we can notice changes to the version. // We own a ref on this. @@ -610,8 +615,7 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); if (ret == NULL) { - CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, - message->GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(self->message_class); ret = reinterpret_cast(cmsg); if (cmsg == NULL) { @@ -634,7 +638,7 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { PyObject* NewMessageMapContainer( CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor, - PyObject* concrete_class) { + CMessageClass* message_class) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; } @@ -669,8 +673,8 @@ PyObject* NewMessageMapContainer( "Could not allocate message dict."); } - Py_INCREF(concrete_class); - self->subclass_init = concrete_class; + Py_INCREF(message_class); + self->message_class = message_class; if (self->key_field_descriptor == NULL || self->value_field_descriptor == NULL) { @@ -763,6 +767,7 @@ static void MessageMapDealloc(PyObject* _self) { MessageMapContainer* self = GetMessageMap(_self); self->owner.reset(); Py_DECREF(self->message_dict); + Py_DECREF(self->message_class); Py_TYPE(_self)->tp_free(_self); } diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h index 27ee6dbd..b11dfa34 100644 --- a/python/google/protobuf/pyext/map_container.h +++ b/python/google/protobuf/pyext/map_container.h @@ -55,6 +55,7 @@ using internal::shared_ptr; namespace python { struct CMessage; +struct CMessageClass; // This struct is used directly for ScalarMap, and is the base class of // MessageMapContainer, which is used for MessageMap. @@ -104,8 +105,8 @@ struct MapContainer { }; struct MessageMapContainer : public MapContainer { - // A callable that is used to create new child messages. - PyObject* subclass_init; + // The type used to create new child messages. + CMessageClass* message_class; // A dict mapping Message* -> CMessage. PyObject* message_dict; @@ -132,7 +133,7 @@ extern PyObject* NewScalarMapContainer( // field descriptor. extern PyObject* NewMessageMapContainer( CMessage* parent, const FieldDescriptor* parent_field_descriptor, - PyObject* concrete_class); + CMessageClass* message_class); } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 6d7b2b0f..83c151ff 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -98,31 +98,6 @@ static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; static PyObject* WKT_classes = NULL; -// Defines the Metaclass of all Message classes. -// It allows us to cache some C++ pointers in the class object itself, they are -// faster to extract than from the type's dictionary. - -struct PyMessageMeta { - // This is how CPython subclasses C structures: the base structure must be - // the first member of the object. - PyHeapTypeObject super; - - // C++ descriptor of this message. - const Descriptor* message_descriptor; - - // Owned reference, used to keep the pointer above alive. - PyObject* py_message_descriptor; - - // The Python DescriptorPool used to create the class. It is needed to resolve - // fields descriptors, including extensions fields; its C++ MessageFactory is - // used to instantiate submessages. - // This can be different from DESCRIPTOR.file.pool, in the case of a custom - // DescriptorPool which defines new extensions. - // We own the reference, because it's important to keep the descriptors and - // factory alive. - PyDescriptorPool* py_descriptor_pool; -}; - namespace message_meta { static int InsertEmptyWeakref(PyTypeObject* base); @@ -173,10 +148,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { } // For each enum set cls. = EnumTypeWrapper(). - // - // The enum descriptor we get from - // .enum_types_by_name[name] - // which was built previously. for (int i = 0; i < descriptor->enum_type_count(); ++i) { const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); ScopedPyObjectPtr enum_type( @@ -309,7 +280,7 @@ static PyObject* New(PyTypeObject* type, if (result == NULL) { return NULL; } - PyMessageMeta* newtype = reinterpret_cast(result.get()); + CMessageClass* newtype = reinterpret_cast(result.get()); // Insert the empty weakref into the base classes. if (InsertEmptyWeakref( @@ -338,7 +309,7 @@ static PyObject* New(PyTypeObject* type, // Add the message to the DescriptorPool. if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, - descriptor, result.get()) < 0) { + descriptor, newtype) < 0) { return NULL; } @@ -349,7 +320,7 @@ static PyObject* New(PyTypeObject* type, return result.release(); } -static void Dealloc(PyMessageMeta *self) { +static void Dealloc(CMessageClass *self) { Py_DECREF(self->py_message_descriptor); Py_DECREF(self->py_descriptor_pool); Py_TYPE(self)->tp_free(reinterpret_cast(self)); @@ -378,10 +349,10 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { } // namespace message_meta -PyTypeObject PyMessageMeta_Type = { +PyTypeObject CMessageClass_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMeta", // tp_name - sizeof(PyMessageMeta), // tp_basicsize + sizeof(CMessageClass), // tp_basicsize 0, // tp_itemsize (destructor)message_meta::Dealloc, // tp_dealloc 0, // tp_print @@ -419,16 +390,16 @@ PyTypeObject PyMessageMeta_Type = { message_meta::New, // tp_new }; -static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) { - if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { +static CMessageClass* CheckMessageClass(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } - return reinterpret_cast(cls); + return reinterpret_cast(cls); } static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { - PyMessageMeta* type = CheckMessageClass(cls); + CMessageClass* type = CheckMessageClass(cls); if (type == NULL) { return NULL; } @@ -783,9 +754,9 @@ namespace cmessage { PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { // No need to check the type: the type of instances of CMessage is always - // an instance of PyMessageMeta. Let's prove it with a debug-only check. + // an instance of CMessageClass. Let's prove it with a debug-only check. GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); - return reinterpret_cast(Py_TYPE(message))->py_descriptor_pool; + return reinterpret_cast(Py_TYPE(message))->py_descriptor_pool; } MessageFactory* GetFactoryForMessage(CMessage* message) { @@ -1090,6 +1061,10 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { PyString_AsString(name)); return -1; } + if (value == Py_None) { + // field=None is the same as no field at all. + continue; + } if (descriptor->is_map()) { ScopedPyObjectPtr map(GetAttr(self, name)); const FieldDescriptor* value_descriptor = @@ -1220,9 +1195,9 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { // Allocates an incomplete Python Message: the caller must fill self->message, // self->owner and eventually self->parent. -CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { +CMessage* NewEmptyMessage(CMessageClass* type) { CMessage* self = reinterpret_cast( - PyType_GenericAlloc(reinterpret_cast(type), 0)); + PyType_GenericAlloc(&type->super.ht_type, 0)); if (self == NULL) { return NULL; } @@ -1242,7 +1217,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { // Creates a new C++ message and takes ownership. static PyObject* New(PyTypeObject* cls, PyObject* unused_args, PyObject* unused_kwargs) { - PyMessageMeta* type = CheckMessageClass(cls); + CMessageClass* type = CheckMessageClass(cls); if (type == NULL) { return NULL; } @@ -1258,8 +1233,7 @@ static PyObject* New(PyTypeObject* cls, return NULL; } - CMessage* self = NewEmptyMessage(reinterpret_cast(type), - message_descriptor); + CMessage* self = NewEmptyMessage(type); if (self == NULL) { return NULL; } @@ -2023,10 +1997,34 @@ static PyObject* RegisterExtension(PyObject* cls, PyErr_SetString(PyExc_TypeError, "no extensions_by_number on class"); return NULL; } + ScopedPyObjectPtr number(PyObject_GetAttrString(extension_handle, "number")); if (number == NULL) { return NULL; } + + // If the extension was already registered by number, check that it is the + // same. + existing_extension = PyDict_GetItem(extensions_by_number.get(), number.get()); + if (existing_extension != NULL) { + const FieldDescriptor* existing_extension_descriptor = + GetExtensionDescriptor(existing_extension); + if (existing_extension_descriptor != descriptor) { + const Descriptor* msg_desc = GetMessageDescriptor( + reinterpret_cast(cls)); + PyErr_Format( + PyExc_ValueError, + "Extensions \"%s\" and \"%s\" both try to extend message type " + "\"%s\" with field number %ld.", + existing_extension_descriptor->full_name().c_str(), + descriptor->full_name().c_str(), + msg_desc->full_name().c_str(), + PyInt_AsLong(number.get())); + return NULL; + } + // Nothing else to do. + Py_RETURN_NONE; + } if (PyDict_SetItem(extensions_by_number.get(), number.get(), extension_handle) < 0) { return NULL; @@ -2166,6 +2164,12 @@ static PyObject* ListFields(CMessage* self) { return all_fields.release(); } +static PyObject* DiscardUnknownFields(CMessage* self) { + AssureWritable(self); + self->message->DiscardUnknownFields(); + Py_RETURN_NONE; +} + PyObject* FindInitializationErrors(CMessage* self) { Message* message = self->message; vector errors; @@ -2309,14 +2313,13 @@ PyObject* InternalGetSubMessage( const Message& sub_message = reflection->GetMessage( *self->message, field_descriptor, pool->message_factory); - PyObject *message_class = cdescriptor_pool::GetMessageClass( + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( pool, field_descriptor->message_type()); if (message_class == NULL) { return NULL; } - CMessage* cmsg = cmessage::NewEmptyMessage(message_class, - sub_message.GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(message_class); if (cmsg == NULL) { return NULL; } @@ -2585,6 +2588,8 @@ static PyMethodDef Methods[] = { "Clears a message field." }, { "CopyFrom", (PyCFunction)CopyFrom, METH_O, "Copies a protocol message into the current message." }, + { "DiscardUnknownFields", (PyCFunction)DiscardUnknownFields, METH_NOARGS, + "Discards the unknown fields." }, { "FindInitializationErrors", (PyCFunction)FindInitializationErrors, METH_NOARGS, "Finds unset required fields." }, @@ -2654,7 +2659,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const Descriptor* entry_type = field_descriptor->message_type(); const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject* value_class = cdescriptor_pool::GetMessageClass( + CMessageClass* value_class = cdescriptor_pool::GetMessageClass( GetDescriptorPoolForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; @@ -2677,7 +2682,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( + CMessageClass* message_class = cdescriptor_pool::GetMessageClass( GetDescriptorPoolForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; @@ -2749,7 +2754,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } // namespace cmessage PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0) + PyVarObject_HEAD_INIT(&CMessageClass_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2864,12 +2869,12 @@ bool InitProto2MessageModule(PyObject *m) { // Initialize constants defined in this file. InitGlobals(); - PyMessageMeta_Type.tp_base = &PyType_Type; - if (PyType_Ready(&PyMessageMeta_Type) < 0) { + CMessageClass_Type.tp_base = &PyType_Type; + if (PyType_Ready(&CMessageClass_Type) < 0) { return false; } PyModule_AddObject(m, "MessageMeta", - reinterpret_cast(&PyMessageMeta_Type)); + reinterpret_cast(&CMessageClass_Type)); if (PyType_Ready(&CMessage_Type) < 0) { return false; @@ -3077,9 +3082,10 @@ bool InitProto2MessageModule(PyObject *m) { } // namespace protobuf static PyMethodDef ModuleMethods[] = { - {"SetAllowOversizeProtos", - (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, - METH_O, "Enable/disable oversize proto parsing."}, + {"SetAllowOversizeProtos", + (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, + METH_O, "Enable/disable oversize proto parsing."}, + { NULL, NULL} }; #if PY_MAJOR_VERSION >= 3 diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index c2b62649..9dce198f 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -116,12 +116,43 @@ typedef struct CMessage { extern PyTypeObject CMessage_Type; + +// The (meta) type of all Messages classes. +// It allows us to cache some C++ pointers in the class object itself, they are +// faster to extract than from the type's dictionary. + +struct CMessageClass { + // This is how CPython subclasses C structures: the base structure must be + // the first member of the object. + PyHeapTypeObject super; + + // C++ descriptor of this message. + const Descriptor* message_descriptor; + + // Owned reference, used to keep the pointer above alive. + PyObject* py_message_descriptor; + + // The Python DescriptorPool used to create the class. It is needed to resolve + // fields descriptors, including extensions fields; its C++ MessageFactory is + // used to instantiate submessages. + // This can be different from DESCRIPTOR.file.pool, in the case of a custom + // DescriptorPool which defines new extensions. + // We own the reference, because it's important to keep the descriptors and + // factory alive. + PyDescriptorPool* py_descriptor_pool; + + PyObject* AsPyObject() { + return reinterpret_cast(this); + } +}; + + namespace cmessage { // Internal function to create a new empty Message Python object, but with empty // pointers to the C++ objects. // The caller must fill self->message, self->owner and eventually self->parent. -CMessage* NewEmptyMessage(PyObject* type, const Descriptor* descriptor); +CMessage* NewEmptyMessage(CMessageClass* type); // Release a submessage from its proto tree, making it a new top-level messgae. // A new message will be created if this is a read-only default instance. diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index b01123b4..4f339e77 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -107,8 +107,7 @@ static int UpdateChildMessages(RepeatedCompositeContainer* self) { for (Py_ssize_t i = child_length; i < message_length; ++i) { const Message& sub_message = reflection->GetRepeatedMessage( *(self->message), self->parent_field_descriptor, i); - CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, - sub_message.GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class); ScopedPyObjectPtr py_cmsg(reinterpret_cast(cmsg)); if (cmsg == NULL) { return -1; @@ -140,8 +139,7 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self, Message* sub_message = message->GetReflection()->AddMessage(message, self->parent_field_descriptor); - CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, - sub_message->GetDescriptor()); + CMessage* cmsg = cmessage::NewEmptyMessage(self->child_message_class); if (cmsg == NULL) return NULL; @@ -168,7 +166,7 @@ static PyObject* AddToReleased(RepeatedCompositeContainer* self, // Create a new Message detached from the rest. PyObject* py_cmsg = PyEval_CallObjectWithKeywords( - self->subclass_init, NULL, kwargs); + self->child_message_class->AsPyObject(), NULL, kwargs); if (py_cmsg == NULL) return NULL; @@ -506,7 +504,7 @@ int SetOwner(RepeatedCompositeContainer* self, PyObject *NewContainer( CMessage* parent, const FieldDescriptor* parent_field_descriptor, - PyObject *concrete_class) { + CMessageClass* concrete_class) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; } @@ -523,7 +521,7 @@ PyObject *NewContainer( self->parent_field_descriptor = parent_field_descriptor; self->owner = parent->owner; Py_INCREF(concrete_class); - self->subclass_init = concrete_class; + self->child_message_class = concrete_class; self->child_messages = PyList_New(0); return reinterpret_cast(self); @@ -531,7 +529,7 @@ PyObject *NewContainer( static void Dealloc(RepeatedCompositeContainer* self) { Py_CLEAR(self->child_messages); - Py_CLEAR(self->subclass_init); + Py_CLEAR(self->child_message_class); // TODO(tibell): Do we need to call delete on these objects to make // sure their destructors are called? self->owner.reset(); diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h index 442ce7e3..25463037 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.h +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -58,6 +58,7 @@ using internal::shared_ptr; namespace python { struct CMessage; +struct CMessageClass; // A RepeatedCompositeContainer can be in one of two states: attached // or released. @@ -94,8 +95,8 @@ typedef struct RepeatedCompositeContainer { // calling Clear() or ClearField() on the parent. Message* message; - // A callable that is used to create new child messages. - PyObject* subclass_init; + // The type used to create new child messages. + CMessageClass* child_message_class; // A list of child messages. PyObject* child_messages; @@ -110,7 +111,7 @@ namespace repeated_composite_container { PyObject *NewContainer( CMessage* parent, const FieldDescriptor* parent_field_descriptor, - PyObject *concrete_class); + CMessageClass *child_message_class); // Appends a new CMessage to the container and returns it. The // CMessage is initialized using the content of kwargs. diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index a6f41ca8..6f1e3c8b 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -99,7 +99,7 @@ class TextWriter(object): def MessageToString(message, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): + float_format=None, use_field_number=False): """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -118,15 +118,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False, field number order. float_format: If set, use this to specify floating point number formatting (per the "Format Specification Mini-Language"); otherwise, str() is used. + use_field_number: If True, print field numbers instead of names. Returns: A string of the text formatted protocol buffer message. """ out = TextWriter(as_utf8) - PrintMessage(message, out, as_utf8=as_utf8, as_one_line=as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) + printer = _Printer(out, 0, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format, + use_field_number) + printer.PrintMessage(message) result = out.getvalue() out.close() if as_one_line: @@ -142,133 +143,187 @@ def _IsMapEntry(field): def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, - float_format=None): - fields = message.ListFields() - if use_index_order: - fields.sort(key=lambda x: x[0].index) - for field, value in fields: - if _IsMapEntry(field): - for key in sorted(value): - # This is slow for maps with submessage entires because it copies the - # entire tree. Unfortunately this would take significant refactoring - # of this file to work around. - # - # TODO(haberman): refactor and optimize if this becomes an issue. - entry_submsg = field.message_type._concrete_class( - key=key, value=value[key]) - PrintField(field, entry_submsg, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, float_format=float_format) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - for element in value: - PrintField(field, element, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - else: - PrintField(field, value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) + float_format=None, use_field_number=False): + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format, + use_field_number) + printer.PrintMessage(message) def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, float_format=None): - """Print a single field name/value pair. For repeated fields, the value - should be a single element. - """ - - out.write(' ' * indent) - if field.is_extension: - out.write('[') - if (field.containing_type.GetOptions().message_set_wire_format and - field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and - field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): - out.write(field.message_type.full_name) - else: - out.write(field.full_name) - out.write(']') - elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: - # For groups, use the capitalized name. - out.write(field.message_type.name) - else: - out.write(field.name) - - if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - # The colon is optional in this case, but our cross-language golden files - # don't include it. - out.write(': ') - - PrintFieldValue(field, value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - if as_one_line: - out.write(' ') - else: - out.write('\n') + """Print a single field name/value pair.""" + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format) + printer.PrintField(field, value) def PrintFieldValue(field, value, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, float_format=None): - """Print a single field value (not including name). For repeated fields, - the value should be a single element.""" + """Print a single field value (not including name).""" + printer = _Printer(out, indent, as_utf8, as_one_line, + pointy_brackets, use_index_order, float_format) + printer.PrintFieldValue(field, value) - if pointy_brackets: - openb = '<' - closeb = '>' - else: - openb = '{' - closeb = '}' - - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - if as_one_line: - out.write(' %s ' % openb) - PrintMessage(value, out, indent, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - out.write(closeb) - else: - out.write(' %s\n' % openb) - PrintMessage(value, out, indent + 2, as_utf8, as_one_line, - pointy_brackets=pointy_brackets, - use_index_order=use_index_order, - float_format=float_format) - out.write(' ' * indent + closeb) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: - enum_value = field.enum_type.values_by_number.get(value, None) - if enum_value is not None: - out.write(enum_value.name) + +class _Printer(object): + """Text format printer for protocol message.""" + + def __init__(self, out, indent=0, as_utf8=False, as_one_line=False, + pointy_brackets=False, use_index_order=False, float_format=None, + use_field_number=False): + """Initialize the Printer. + + Floating point values can be formatted compactly with 15 digits of + precision (which is the most that IEEE 754 "double" can guarantee) + using float_format='.15g'. To ensure that converting to text and back to a + proto will result in an identical value, float_format='.17g' should be used. + + Args: + out: To record the text format result. + indent: The indent level for pretty print. + as_utf8: Produce text output in UTF8 format. + as_one_line: Don't introduce newlines between fields. + pointy_brackets: If True, use angle brackets instead of curly braces for + nesting. + use_index_order: If True, print fields of a proto message using the order + defined in source code instead of the field number. By default, use the + field number order. + float_format: If set, use this to specify floating point number formatting + (per the "Format Specification Mini-Language"); otherwise, str() is + used. + use_field_number: If True, print field numbers instead of names. + """ + self.out = out + self.indent = indent + self.as_utf8 = as_utf8 + self.as_one_line = as_one_line + self.pointy_brackets = pointy_brackets + self.use_index_order = use_index_order + self.float_format = float_format + self.use_field_number = use_field_number + + def PrintMessage(self, message): + """Convert protobuf message to text format. + + Args: + message: The protocol buffers message. + """ + fields = message.ListFields() + if self.use_index_order: + fields.sort(key=lambda x: x[0].index) + for field, value in fields: + if _IsMapEntry(field): + for key in sorted(value): + # This is slow for maps with submessage entires because it copies the + # entire tree. Unfortunately this would take significant refactoring + # of this file to work around. + # + # TODO(haberman): refactor and optimize if this becomes an issue. + entry_submsg = field.message_type._concrete_class( + key=key, value=value[key]) + self.PrintField(field, entry_submsg) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + self.PrintField(field, element) + else: + self.PrintField(field, value) + + def PrintField(self, field, value): + """Print a single field name/value pair.""" + out = self.out + out.write(' ' * self.indent) + if self.use_field_number: + out.write(str(field.number)) else: - out.write(str(value)) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: - out.write('\"') - if isinstance(value, six.text_type): - out_value = value.encode('utf-8') + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') + + self.PrintFieldValue(field, value) + if self.as_one_line: + out.write(' ') else: - out_value = value - if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # We need to escape non-UTF8 chars in TYPE_BYTES field. - out_as_utf8 = False + out.write('\n') + + def PrintFieldValue(self, field, value): + """Print a single field value (not including name). + + For repeated fields, the value should be a single element. + + Args: + field: The descriptor of the field to be printed. + value: The value of the field. + """ + out = self.out + if self.pointy_brackets: + openb = '<' + closeb = '>' else: - out_as_utf8 = as_utf8 - out.write(text_encoding.CEscape(out_value, out_as_utf8)) - out.write('\"') - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: - if value: - out.write('true') + openb = '{' + closeb = '}' + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + if self.as_one_line: + out.write(' %s ' % openb) + self.PrintMessage(value) + out.write(closeb) + else: + out.write(' %s\n' % openb) + self.indent += 2 + self.PrintMessage(value) + self.indent -= 2 + out.write(' ' * self.indent + closeb) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + out.write(enum_value.name) + else: + out.write(str(value)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + if isinstance(value, six.text_type): + out_value = value.encode('utf-8') + else: + out_value = value + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # We need to escape non-UTF8 chars in TYPE_BYTES field. + out_as_utf8 = False + else: + out_as_utf8 = self.as_utf8 + out.write(text_encoding.CEscape(out_value, out_as_utf8)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write('true') + else: + out.write('false') + elif field.cpp_type in _FLOAT_TYPES and self.float_format is not None: + out.write('{1:{0}}'.format(self.float_format, value)) else: - out.write('false') - elif field.cpp_type in _FLOAT_TYPES and float_format is not None: - out.write('{1:{0}}'.format(float_format, value)) - else: - out.write(str(value)) + out.write(str(value)) -def Parse(text, message, allow_unknown_extension=False): +def Parse(text, message, + allow_unknown_extension=False, allow_field_number=False): """Parses an text representation of a protocol message into a message. Args: @@ -276,6 +331,7 @@ def Parse(text, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -285,10 +341,12 @@ def Parse(text, message, allow_unknown_extension=False): """ if not isinstance(text, str): text = text.decode('utf-8') - return ParseLines(text.split('\n'), message, allow_unknown_extension) + return ParseLines(text.split('\n'), message, allow_unknown_extension, + allow_field_number) -def Merge(text, message, allow_unknown_extension=False): +def Merge(text, message, allow_unknown_extension=False, + allow_field_number=False): """Parses an text representation of a protocol message into a message. Like Parse(), but allows repeated values for a non-repeated field, and uses @@ -299,6 +357,7 @@ def Merge(text, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -306,10 +365,12 @@ def Merge(text, message, allow_unknown_extension=False): Raises: ParseError: On text parsing problems. """ - return MergeLines(text.split('\n'), message, allow_unknown_extension) + return MergeLines(text.split('\n'), message, allow_unknown_extension, + allow_field_number) -def ParseLines(lines, message, allow_unknown_extension=False): +def ParseLines(lines, message, allow_unknown_extension=False, + allow_field_number=False): """Parses an text representation of a protocol message into a message. Args: @@ -317,6 +378,7 @@ def ParseLines(lines, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -324,11 +386,12 @@ def ParseLines(lines, message, allow_unknown_extension=False): Raises: ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, False, allow_unknown_extension) - return message + parser = _Parser(allow_unknown_extension, allow_field_number) + return parser.ParseLines(lines, message) -def MergeLines(lines, message, allow_unknown_extension=False): +def MergeLines(lines, message, allow_unknown_extension=False, + allow_field_number=False): """Parses an text representation of a protocol message into a message. Args: @@ -336,6 +399,7 @@ def MergeLines(lines, message, allow_unknown_extension=False): message: A protocol buffer message to merge into. allow_unknown_extension: if True, skip over missing extensions and keep parsing + allow_field_number: if True, both field number and field name are allowed. Returns: The same message passed as argument. @@ -343,146 +407,272 @@ def MergeLines(lines, message, allow_unknown_extension=False): Raises: ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, True, allow_unknown_extension) - return message + parser = _Parser(allow_unknown_extension, allow_field_number) + return parser.MergeLines(lines, message) -def _ParseOrMerge(lines, - message, - allow_multiple_scalars, - allow_unknown_extension=False): - """Converts an text representation of a protocol message into a message. +class _Parser(object): + """Text format parser for protocol message.""" - Args: - lines: Lines of a message's text representation. - message: A protocol buffer message to merge into. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: if True, skip over missing extensions and keep - parsing + def __init__(self, allow_unknown_extension=False, allow_field_number=False): + self.allow_unknown_extension = allow_unknown_extension + self.allow_field_number = allow_field_number - Raises: - ParseError: On text parsing problems. - """ - tokenizer = _Tokenizer(lines) - while not tokenizer.AtEnd(): - _MergeField(tokenizer, message, allow_multiple_scalars, - allow_unknown_extension) + def ParseFromString(self, text, message): + """Parses an text representation of a protocol message into a message.""" + if not isinstance(text, str): + text = text.decode('utf-8') + return self.ParseLines(text.split('\n'), message) + def ParseLines(self, lines, message): + """Parses an text representation of a protocol message into a message.""" + self._allow_multiple_scalars = False + self._ParseOrMerge(lines, message) + return message -def _MergeField(tokenizer, - message, - allow_multiple_scalars, - allow_unknown_extension=False): - """Merges a single protocol message field into a message. + def MergeFromString(self, text, message): + """Merges an text representation of a protocol message into a message.""" + return self._MergeLines(text.split('\n'), message) - Args: - tokenizer: A tokenizer to parse the field name and values. - message: A protocol message to record the data. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: if True, skip over missing extensions and keep - parsing. + def MergeLines(self, lines, message): + """Merges an text representation of a protocol message into a message.""" + self._allow_multiple_scalars = True + self._ParseOrMerge(lines, message) + return message - Raises: - ParseError: In case of text parsing problems. - """ - message_descriptor = message.DESCRIPTOR - if (hasattr(message_descriptor, 'syntax') and - message_descriptor.syntax == 'proto3'): - # Proto3 doesn't represent presence so we can't test if multiple - # scalars have occurred. We have to allow them. - allow_multiple_scalars = True - if tokenizer.TryConsume('['): - name = [tokenizer.ConsumeIdentifier()] - while tokenizer.TryConsume('.'): - name.append(tokenizer.ConsumeIdentifier()) - name = '.'.join(name) - - if not message_descriptor.is_extendable: - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" does not have extensions.' % - message_descriptor.full_name) - # pylint: disable=protected-access - field = message.Extensions._FindExtensionByName(name) - # pylint: enable=protected-access - if not field: - if allow_unknown_extension: - field = None - else: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" not registered.' % name) - elif message_descriptor != field.containing_type: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" does not extend message type "%s".' % ( - name, message_descriptor.full_name)) + def _ParseOrMerge(self, lines, message): + """Converts an text representation of a protocol message into a message. - tokenizer.Consume(']') + Args: + lines: Lines of a message's text representation. + message: A protocol buffer message to merge into. - else: - name = tokenizer.ConsumeIdentifier() - field = message_descriptor.fields_by_name.get(name, None) - - # Group names are expected to be capitalized as they appear in the - # .proto file, which actually matches their type names, not their field - # names. - if not field: - field = message_descriptor.fields_by_name.get(name.lower(), None) - if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: - field = None - - if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and - field.message_type.name != name): - field = None - - if not field: - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" has no field named "%s".' % ( - message_descriptor.full_name, name)) - - if field: - if not allow_multiple_scalars and field.containing_oneof: - # Check if there's a different field set in this oneof. - # Note that we ignore the case if the same field was set before, and we - # apply allow_multiple_scalars to non-scalar fields as well. - which_oneof = message.WhichOneof(field.containing_oneof.name) - if which_oneof is not None and which_oneof != field.name: + Raises: + ParseError: On text parsing problems. + """ + tokenizer = _Tokenizer(lines) + while not tokenizer.AtEnd(): + self._MergeField(tokenizer, message) + + def _MergeField(self, tokenizer, message): + """Merges a single protocol message field into a message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + message: A protocol message to record the data. + + Raises: + ParseError: In case of text parsing problems. + """ + message_descriptor = message.DESCRIPTOR + if (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3'): + # Proto3 doesn't represent presence so we can't test if multiple + # scalars have occurred. We have to allow them. + self._allow_multiple_scalars = True + if tokenizer.TryConsume('['): + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + name = '.'.join(name) + + if not message_descriptor.is_extendable: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" does not have extensions.' % + message_descriptor.full_name) + # pylint: disable=protected-access + field = message.Extensions._FindExtensionByName(name) + # pylint: enable=protected-access + if not field: + if self.allow_unknown_extension: + field = None + else: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered.' % name) + elif message_descriptor != field.containing_type: raise tokenizer.ParseErrorPreviousToken( - 'Field "%s" is specified along with field "%s", another member of ' - 'oneof "%s" for message type "%s".' % ( - field.name, which_oneof, field.containing_oneof.name, - message_descriptor.full_name)) + 'Extension "%s" does not extend message type "%s".' % ( + name, message_descriptor.full_name)) + + tokenizer.Consume(']') - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - tokenizer.TryConsume(':') - merger = _MergeMessageField else: - tokenizer.Consume(':') - merger = _MergeScalarField - - if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED - and tokenizer.TryConsume('[')): - # Short repeated format, e.g. "foo: [1, 2, 3]" - while True: - merger(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension) - if tokenizer.TryConsume(']'): break - tokenizer.Consume(',') + name = tokenizer.ConsumeIdentifier() + if self.allow_field_number and name.isdigit(): + number = ParseInteger(name, True, True) + field = message_descriptor.fields_by_number.get(number, None) + if not field and message_descriptor.is_extendable: + field = message.Extensions._FindExtensionByNumber(number) + else: + field = message_descriptor.fields_by_name.get(name, None) + + # Group names are expected to be capitalized as they appear in the + # .proto file, which actually matches their type names, not their field + # names. + if not field: + field = message_descriptor.fields_by_name.get(name.lower(), None) + if field and field.type != descriptor.FieldDescriptor.TYPE_GROUP: + field = None + + if (field and field.type == descriptor.FieldDescriptor.TYPE_GROUP and + field.message_type.name != name): + field = None + + if not field: + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" has no field named "%s".' % ( + message_descriptor.full_name, name)) + + if field: + if not self._allow_multiple_scalars and field.containing_oneof: + # Check if there's a different field set in this oneof. + # Note that we ignore the case if the same field was set before, and we + # apply _allow_multiple_scalars to non-scalar fields as well. + which_oneof = message.WhichOneof(field.containing_oneof.name) + if which_oneof is not None and which_oneof != field.name: + raise tokenizer.ParseErrorPreviousToken( + 'Field "%s" is specified along with field "%s", another member ' + 'of oneof "%s" for message type "%s".' % ( + field.name, which_oneof, field.containing_oneof.name, + message_descriptor.full_name)) + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + tokenizer.TryConsume(':') + merger = self._MergeMessageField + else: + tokenizer.Consume(':') + merger = self._MergeScalarField + + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED + and tokenizer.TryConsume('[')): + # Short repeated format, e.g. "foo: [1, 2, 3]" + while True: + merger(tokenizer, message, field) + if tokenizer.TryConsume(']'): break + tokenizer.Consume(',') + + else: + merger(tokenizer, message, field) + + else: # Proto field is unknown. + assert self.allow_unknown_extension + _SkipFieldContents(tokenizer) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') + + def _MergeMessageField(self, tokenizer, message, field): + """Merges a single scalar field into a message. + + Args: + tokenizer: A tokenizer to parse the field value. + message: The message of which field is a member. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + """ + is_map_entry = _IsMapEntry(field) + if tokenizer.TryConsume('<'): + end_token = '>' else: - merger(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension) + tokenizer.Consume('{') + end_token = '}' + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + sub_message = message.Extensions[field].add() + elif is_map_entry: + # pylint: disable=protected-access + sub_message = field.message_type._concrete_class() + else: + sub_message = getattr(message, field.name).add() + else: + if field.is_extension: + sub_message = message.Extensions[field] + else: + sub_message = getattr(message, field.name) + sub_message.SetInParent() + + while not tokenizer.TryConsume(end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) + self._MergeField(tokenizer, sub_message) + + if is_map_entry: + value_cpptype = field.message_type.fields_by_name['value'].cpp_type + if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + value = getattr(message, field.name)[sub_message.key] + value.MergeFrom(sub_message.value) + else: + getattr(message, field.name)[sub_message.key] = sub_message.value - else: # Proto field is unknown. - assert allow_unknown_extension - _SkipFieldContents(tokenizer) + def _MergeScalarField(self, tokenizer, message, field): + """Merges a single scalar field into a message. - # For historical reasons, fields may optionally be separated by commas or - # semicolons. - if not tokenizer.TryConsume(','): - tokenizer.TryConsume(';') + Args: + tokenizer: A tokenizer to parse the field value. + message: A protocol message to record the data. + field: The descriptor of the field to be merged. + + Raises: + ParseError: In case of text parsing problems. + RuntimeError: On runtime errors. + """ + _ = self.allow_unknown_extension + value = None + + if field.type in (descriptor.FieldDescriptor.TYPE_INT32, + descriptor.FieldDescriptor.TYPE_SINT32, + descriptor.FieldDescriptor.TYPE_SFIXED32): + value = tokenizer.ConsumeInt32() + elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, + descriptor.FieldDescriptor.TYPE_SINT64, + descriptor.FieldDescriptor.TYPE_SFIXED64): + value = tokenizer.ConsumeInt64() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, + descriptor.FieldDescriptor.TYPE_FIXED32): + value = tokenizer.ConsumeUint32() + elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, + descriptor.FieldDescriptor.TYPE_FIXED64): + value = tokenizer.ConsumeUint64() + elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, + descriptor.FieldDescriptor.TYPE_DOUBLE): + value = tokenizer.ConsumeFloat() + elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: + value = tokenizer.ConsumeBool() + elif field.type == descriptor.FieldDescriptor.TYPE_STRING: + value = tokenizer.ConsumeString() + elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: + value = tokenizer.ConsumeByteString() + elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: + value = tokenizer.ConsumeEnum(field) + else: + raise RuntimeError('Unknown field type %d' % field.type) + + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if field.is_extension: + message.Extensions[field].append(value) + else: + getattr(message, field.name).append(value) + else: + if field.is_extension: + if not self._allow_multiple_scalars and message.HasExtension(field): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" extensions.' % + (message.DESCRIPTOR.full_name, field.full_name)) + else: + message.Extensions[field] = value + else: + if not self._allow_multiple_scalars and message.HasField(field.name): + raise tokenizer.ParseErrorPreviousToken( + 'Message type "%s" should not have multiple "%s" fields.' % + (message.DESCRIPTOR.full_name, field.name)) + else: + setattr(message, field.name, value) def _SkipFieldContents(tokenizer): @@ -555,10 +745,10 @@ def _SkipFieldValue(tokenizer): Raises: ParseError: In case an invalid field value is found. """ - # String tokens can come in multiple adjacent string literals. + # String/bytes tokens can come in multiple adjacent string literals. # If we can consume one, consume as many as we can. - if tokenizer.TryConsumeString(): - while tokenizer.TryConsumeString(): + if tokenizer.TryConsumeByteString(): + while tokenizer.TryConsumeByteString(): pass return @@ -569,132 +759,6 @@ def _SkipFieldValue(tokenizer): raise ParseError('Invalid field value: ' + tokenizer.token) -def _MergeMessageField(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension): - """Merges a single scalar field into a message. - - Args: - tokenizer: A tokenizer to parse the field value. - message: The message of which field is a member. - field: The descriptor of the field to be merged. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: if True, skip over missing extensions and keep - parsing. - - Raises: - ParseError: In case of text parsing problems. - """ - is_map_entry = _IsMapEntry(field) - - if tokenizer.TryConsume('<'): - end_token = '>' - else: - tokenizer.Consume('{') - end_token = '}' - - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if field.is_extension: - sub_message = message.Extensions[field].add() - elif is_map_entry: - # pylint: disable=protected-access - sub_message = field.message_type._concrete_class() - else: - sub_message = getattr(message, field.name).add() - else: - if field.is_extension: - sub_message = message.Extensions[field] - else: - sub_message = getattr(message, field.name) - sub_message.SetInParent() - - while not tokenizer.TryConsume(end_token): - if tokenizer.AtEnd(): - raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token,)) - _MergeField(tokenizer, sub_message, allow_multiple_scalars, - allow_unknown_extension) - - if is_map_entry: - value_cpptype = field.message_type.fields_by_name['value'].cpp_type - if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - value = getattr(message, field.name)[sub_message.key] - value.MergeFrom(sub_message.value) - else: - getattr(message, field.name)[sub_message.key] = sub_message.value - - -def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars, - allow_unknown_extension): - """Merges a single scalar field into a message. - - Args: - tokenizer: A tokenizer to parse the field value. - message: A protocol message to record the data. - field: The descriptor of the field to be merged. - allow_multiple_scalars: Determines if repeated values for a non-repeated - field are permitted, e.g., the string "foo: 1 foo: 2" for a - required/optional field named "foo". - allow_unknown_extension: Unused, just here for consistency with - _MergeMessageField. - - Raises: - ParseError: In case of text parsing problems. - RuntimeError: On runtime errors. - """ - _ = allow_unknown_extension - value = None - - if field.type in (descriptor.FieldDescriptor.TYPE_INT32, - descriptor.FieldDescriptor.TYPE_SINT32, - descriptor.FieldDescriptor.TYPE_SFIXED32): - value = tokenizer.ConsumeInt32() - elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, - descriptor.FieldDescriptor.TYPE_SINT64, - descriptor.FieldDescriptor.TYPE_SFIXED64): - value = tokenizer.ConsumeInt64() - elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, - descriptor.FieldDescriptor.TYPE_FIXED32): - value = tokenizer.ConsumeUint32() - elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, - descriptor.FieldDescriptor.TYPE_FIXED64): - value = tokenizer.ConsumeUint64() - elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, - descriptor.FieldDescriptor.TYPE_DOUBLE): - value = tokenizer.ConsumeFloat() - elif field.type == descriptor.FieldDescriptor.TYPE_BOOL: - value = tokenizer.ConsumeBool() - elif field.type == descriptor.FieldDescriptor.TYPE_STRING: - value = tokenizer.ConsumeString() - elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: - value = tokenizer.ConsumeByteString() - elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: - value = tokenizer.ConsumeEnum(field) - else: - raise RuntimeError('Unknown field type %d' % field.type) - - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - if field.is_extension: - message.Extensions[field].append(value) - else: - getattr(message, field.name).append(value) - else: - if field.is_extension: - if not allow_multiple_scalars and message.HasExtension(field): - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" should not have multiple "%s" extensions.' % - (message.DESCRIPTOR.full_name, field.full_name)) - else: - message.Extensions[field] = value - else: - if not allow_multiple_scalars and message.HasField(field.name): - raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" should not have multiple "%s" fields.' % - (message.DESCRIPTOR.full_name, field.name)) - else: - setattr(message, field.name, value) - - class _Tokenizer(object): """Protocol buffer text representation tokenizer. @@ -925,9 +989,9 @@ class _Tokenizer(object): self.NextToken() return result - def TryConsumeString(self): + def TryConsumeByteString(self): try: - self.ConsumeString() + self.ConsumeByteString() return True except ParseError: return False -- cgit v1.2.3