diff options
author | Feng Xiao <xfxyjwf@gmail.com> | 2015-12-11 17:09:20 -0800 |
---|---|---|
committer | Feng Xiao <xfxyjwf@gmail.com> | 2015-12-11 17:10:28 -0800 |
commit | e841bac4fcf47f809e089a70d5f84ac37b3883df (patch) | |
tree | d25dc5fc814db182c04c5f276ff1a609c5965a5a /python | |
parent | 99a6a95c751a28a3cc33dd2384959179f83f682c (diff) | |
download | protobuf-e841bac4fcf47f809e089a70d5f84ac37b3883df.tar.gz protobuf-e841bac4fcf47f809e089a70d5f84ac37b3883df.tar.bz2 protobuf-e841bac4fcf47f809e089a70d5f84ac37b3883df.zip |
Down-integrate from internal code base.
Diffstat (limited to 'python')
35 files changed, 3564 insertions, 1807 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 2bf36532..5f613c88 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -786,25 +786,33 @@ class FileDescriptor(DescriptorBase): message_types_by_name: Dict of message names of their descriptors. enum_types_by_name: Dict of enum names and their descriptors. extensions_by_name: Dict of extension names and their descriptors. + pool: the DescriptorPool this descriptor belongs to. When not passed to the + constructor, the global default pool is used. """ if _USE_C_DESCRIPTORS: _C_DESCRIPTOR_CLASS = _message.FileDescriptor def __new__(cls, name, package, options=None, serialized_pb=None, - dependencies=None, syntax=None): + dependencies=None, syntax=None, pool=None): # FileDescriptor() is called from various places, not only from generated # files, to register dynamic proto files and messages. if serialized_pb: + # TODO(amauryfa): use the pool passed as argument. This will work only + # for C++-implemented DescriptorPools. return _message.default_pool.AddSerializedFile(serialized_pb) else: return super(FileDescriptor, cls).__new__(cls) def __init__(self, name, package, options=None, serialized_pb=None, - dependencies=None, syntax=None): + dependencies=None, syntax=None, pool=None): """Constructor.""" super(FileDescriptor, self).__init__(options, 'FileOptions') + if pool is None: + from google.protobuf import descriptor_pool + pool = descriptor_pool.Default() + self.pool = pool self.message_types_by_name = {} self.name = name self.package = package diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py index b10021e9..1333f996 100644 --- a/python/google/protobuf/descriptor_database.py +++ b/python/google/protobuf/descriptor_database.py @@ -65,6 +65,7 @@ class DescriptorDatabase(object): raise DescriptorDatabaseConflictingDefinitionError( '%s already added, but with different descriptor.' % proto_name) + # Add the top-level Message, Enum and Extension descriptors to the index. package = file_desc_proto.package for message in file_desc_proto.message_type: self._file_desc_protos_by_symbol.update( @@ -72,6 +73,9 @@ class DescriptorDatabase(object): for enum in file_desc_proto.enum_type: self._file_desc_protos_by_symbol[ '.'.join((package, enum.name))] = file_desc_proto + for extension in file_desc_proto.extension: + self._file_desc_protos_by_symbol[ + '.'.join((package, extension.name))] = file_desc_proto def FindFileByName(self, name): """Finds the file descriptor proto by file name. diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 6a1b4b5e..3e80795c 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -83,6 +83,12 @@ def _NormalizeFullyQualifiedName(name): class DescriptorPool(object): """A collection of protobufs dynamically constructed by descriptor protos.""" + if _USE_C_DESCRIPTORS: + + def __new__(cls, descriptor_db=None): + # pylint: disable=protected-access + return descriptor._message.DescriptorPool(descriptor_db) + def __init__(self, descriptor_db=None): """Initializes a Pool of proto buffs. @@ -264,6 +270,39 @@ class DescriptorPool(object): self.FindFileContainingSymbol(full_name) return self._enum_descriptors[full_name] + def FindFieldByName(self, full_name): + """Loads the named field descriptor from the pool. + + Args: + full_name: The full name of the field descriptor to load. + + Returns: + The field descriptor for the named field. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, field_name = full_name.rpartition('.') + message_descriptor = self.FindMessageTypeByName(message_name) + return message_descriptor.fields_by_name[field_name] + + def FindExtensionByName(self, full_name): + """Loads the named extension descriptor from the pool. + + Args: + full_name: The full name of the extension descriptor to load. + + Returns: + A FieldDescriptor, describing the named extension. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + message_name, _, extension_name = full_name.rpartition('.') + try: + # Most extensions are nested inside a message. + scope = self.FindMessageTypeByName(message_name) + except KeyError: + # Some extensions are defined at file scope. + scope = self.FindFileContainingSymbol(full_name) + return scope.extensions_by_name[extension_name] + def _ConvertFileProtoToFileDescriptor(self, file_proto): """Creates a FileDescriptor from a proto or returns a cached copy. @@ -282,6 +321,7 @@ class DescriptorPool(object): direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] file_descriptor = descriptor.FileDescriptor( + pool=self, name=file_proto.name, package=file_proto.package, syntax=file_proto.syntax, @@ -598,10 +638,24 @@ class DescriptorPool(object): field_desc.default_value = text_encoding.CUnescape( field_proto.default_value) else: + # All other types are of the "int" type. field_desc.default_value = int(field_proto.default_value) else: field_desc.has_default_value = False - field_desc.default_value = None + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = 0.0 + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = u'' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = False + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values[0].number + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: + field_desc.default_value = b'' + else: + # All other types are of the "int" type. + field_desc.default_value = 0 field_desc.type = field_proto.type @@ -680,3 +734,16 @@ class DescriptorPool(object): def _PrefixWithDot(name): return name if name.startswith('.') else '.%s' % name + + +if _USE_C_DESCRIPTORS: + # TODO(amauryfa): This pool could be constructed from Python code, when we + # support a flag like 'use_cpp_generated_pool=True'. + # pylint: disable=protected-access + _DEFAULT = descriptor._message.default_pool +else: + _DEFAULT = DescriptorPool() + + +def Default(): + return _DEFAULT diff --git a/python/google/protobuf/internal/any_test.proto b/python/google/protobuf/internal/any_test.proto new file mode 100644 index 00000000..cd641ca0 --- /dev/null +++ b/python/google/protobuf/internal/any_test.proto @@ -0,0 +1,42 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jieluo@google.com (Jie Luo) + +syntax = "proto3"; + +package google.protobuf.internal; + +import "google/protobuf/any.proto"; + +message TestAny { + google.protobuf.Any value = 1; + int32 int_value = 2; +} diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 9c8275eb..97cdd848 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -464,6 +464,9 @@ class ScalarMap(MutableMapping): return val def __contains__(self, item): + # We check the key's type to match the strong-typing flavor of the API. + # Also this makes it easier to match the behavior of the C++ implementation. + self._key_checker.CheckValue(item) return item in self._values # We need to override this explicitly, because our defaultdict-like behavior @@ -491,10 +494,20 @@ class ScalarMap(MutableMapping): def __iter__(self): return iter(self._values) + def __repr__(self): + return repr(self._values) + def MergeFrom(self, other): self._values.update(other._values) self._message_listener.Modified() + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + # This is defined in the abstract base, but we can do it much more cheaply. def clear(self): self._values.clear() @@ -576,12 +589,22 @@ class MessageMap(MutableMapping): def __iter__(self): return iter(self._values) + def __repr__(self): + return repr(self._values) + def MergeFrom(self, other): for key in other: self[key].MergeFrom(other[key]) # self._message_listener.Modified() not required here, because # mutations to submessages already propagate. + def InvalidateIterators(self): + # It appears that the only way to reliably invalidate iterators to + # self._values is to ensure that its size changes. + original = self._values + self._values = original.copy() + original[None] = None + # This is defined in the abstract base, but we can do it much more cheaply. def clear(self): self._values.clear() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index da9a78db..f1d6bf99 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -40,6 +40,8 @@ try: import unittest2 as unittest except ImportError: import unittest +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation @@ -51,13 +53,17 @@ from google.protobuf.internal import test_util from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool +from google.protobuf import message_factory from google.protobuf import symbol_database class DescriptorPoolTest(unittest.TestCase): + def CreatePool(self): + return descriptor_pool.DescriptorPool() + def setUp(self): - self.pool = descriptor_pool.DescriptorPool() + self.pool = self.CreatePool() self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( factory_test1_pb2.DESCRIPTOR.serialized_pb) self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( @@ -89,7 +95,7 @@ class DescriptorPoolTest(unittest.TestCase): 'google.protobuf.python.internal.Factory1Message') self.assertIsInstance(file_desc1, descriptor.FileDescriptor) self.assertEqual('google/protobuf/internal/factory_test1.proto', - file_desc1.name) + file_desc1.name) self.assertEqual('google.protobuf.python.internal', file_desc1.package) self.assertIn('Factory1Message', file_desc1.message_types_by_name) @@ -97,7 +103,7 @@ class DescriptorPoolTest(unittest.TestCase): 'google.protobuf.python.internal.Factory2Message') self.assertIsInstance(file_desc2, descriptor.FileDescriptor) self.assertEqual('google/protobuf/internal/factory_test2.proto', - file_desc2.name) + file_desc2.name) self.assertEqual('google.protobuf.python.internal', file_desc2.package) self.assertIn('Factory2Message', file_desc2.message_types_by_name) @@ -111,7 +117,7 @@ class DescriptorPoolTest(unittest.TestCase): self.assertIsInstance(msg1, descriptor.Descriptor) self.assertEqual('Factory1Message', msg1.name) self.assertEqual('google.protobuf.python.internal.Factory1Message', - msg1.full_name) + msg1.full_name) self.assertEqual(None, msg1.containing_type) nested_msg1 = msg1.nested_types[0] @@ -132,7 +138,7 @@ class DescriptorPoolTest(unittest.TestCase): self.assertIsInstance(msg2, descriptor.Descriptor) self.assertEqual('Factory2Message', msg2.name) self.assertEqual('google.protobuf.python.internal.Factory2Message', - msg2.full_name) + msg2.full_name) self.assertIsNone(msg2.containing_type) nested_msg2 = msg2.nested_types[0] @@ -223,6 +229,37 @@ class DescriptorPoolTest(unittest.TestCase): with self.assertRaises(KeyError): self.pool.FindEnumTypeByName('Does not exist') + def testFindFieldByName(self): + field = self.pool.FindFieldByName( + 'google.protobuf.python.internal.Factory1Message.list_value') + self.assertEqual(field.name, 'list_value') + self.assertEqual(field.label, field.LABEL_REPEATED) + with self.assertRaises(KeyError): + self.pool.FindFieldByName('Does not exist') + + def testFindExtensionByName(self): + # An extension defined in a message. + extension = self.pool.FindExtensionByName( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + self.assertEqual(extension.name, 'one_more_field') + # An extension defined at file scope. + extension = self.pool.FindExtensionByName( + 'google.protobuf.python.internal.another_field') + self.assertEqual(extension.name, 'another_field') + self.assertEqual(extension.number, 1002) + with self.assertRaises(KeyError): + self.pool.FindFieldByName('Does not exist') + + def testExtensionsAreNotFields(self): + with self.assertRaises(KeyError): + self.pool.FindFieldByName('google.protobuf.python.internal.another_field') + with self.assertRaises(KeyError): + self.pool.FindFieldByName( + 'google.protobuf.python.internal.Factory2Message.one_more_field') + with self.assertRaises(KeyError): + self.pool.FindExtensionByName( + 'google.protobuf.python.internal.Factory1Message.list_value') + def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() self.pool = descriptor_pool.DescriptorPool(db) @@ -231,8 +268,7 @@ class DescriptorPoolTest(unittest.TestCase): self.testFindMessageTypeByName() def testAddSerializedFile(self): - db = descriptor_database.DescriptorDatabase() - self.pool = descriptor_pool.DescriptorPool(db) + self.pool = descriptor_pool.DescriptorPool() self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString()) self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString()) self.testFindMessageTypeByName() @@ -274,6 +310,56 @@ class DescriptorPoolTest(unittest.TestCase): 'google/protobuf/internal/descriptor_pool_test1.proto') _CheckDefaultValue(file_descriptor) + def testDefaultValueForCustomMessages(self): + """Check the value returned by non-existent fields.""" + def _CheckValueAndType(value, expected_value, expected_type): + self.assertEqual(value, expected_value) + self.assertIsInstance(value, expected_type) + + def _CheckDefaultValues(msg): + try: + int64 = long + except NameError: # Python3 + int64 = int + try: + unicode_type = unicode + except NameError: # Python3 + unicode_type = str + _CheckValueAndType(msg.optional_int32, 0, int) + _CheckValueAndType(msg.optional_uint64, 0, (int64, int)) + _CheckValueAndType(msg.optional_float, 0, (float, int)) + _CheckValueAndType(msg.optional_double, 0, (float, int)) + _CheckValueAndType(msg.optional_bool, False, bool) + _CheckValueAndType(msg.optional_string, u'', unicode_type) + _CheckValueAndType(msg.optional_bytes, b'', bytes) + _CheckValueAndType(msg.optional_nested_enum, msg.FOO, int) + # First for the generated message + _CheckDefaultValues(unittest_pb2.TestAllTypes()) + # Then for a message built with from the DescriptorPool. + pool = descriptor_pool.DescriptorPool() + pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_public_pb2.DESCRIPTOR.serialized_pb)) + pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_import_pb2.DESCRIPTOR.serialized_pb)) + pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + unittest_pb2.DESCRIPTOR.serialized_pb)) + message_class = message_factory.MessageFactory(pool).GetPrototype( + pool.FindMessageTypeByName( + unittest_pb2.TestAllTypes.DESCRIPTOR.full_name)) + _CheckDefaultValues(message_class()) + + +@unittest.skipIf(api_implementation.Type() != 'cpp', + 'explicit tests of the C++ implementation') +class CppDescriptorPoolTest(DescriptorPoolTest): + # TODO(amauryfa): remove when descriptor_pool.DescriptorPool() creates true + # C++ descriptor pool object for C++ implementation. + + def CreatePool(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + return _message.DescriptorPool() + class ProtoFile(object): @@ -468,6 +554,8 @@ class AddDescriptorTest(unittest.TestCase): pool.FindFileContainingSymbol( prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testMessage(self): self._TestMessage('') self._TestMessage('.') @@ -502,10 +590,14 @@ class AddDescriptorTest(unittest.TestCase): pool.FindFileContainingSymbol( prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testEnum(self): self._TestEnum('') self._TestEnum('.') + @unittest.skipIf(api_implementation.Type() == 'cpp', + 'With the cpp implementation, Add() must be called first') def testFile(self): pool = descriptor_pool.DescriptorPool() pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) @@ -520,6 +612,76 @@ class AddDescriptorTest(unittest.TestCase): pool.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes') + def _GetDescriptorPoolClass(self): + # Test with both implementations of descriptor pools. + if api_implementation.Type() == 'cpp': + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + return _message.DescriptorPool + else: + return descriptor_pool.DescriptorPool + + def testEmptyDescriptorPool(self): + # Check that an empty DescriptorPool() contains no message. + pool = self._GetDescriptorPoolClass()() + proto_file_name = descriptor_pb2.DESCRIPTOR.name + self.assertRaises(KeyError, pool.FindFileByName, proto_file_name) + # Add the above file to the pool + file_descriptor = descriptor_pb2.FileDescriptorProto() + descriptor_pb2.DESCRIPTOR.CopyToProto(file_descriptor) + pool.Add(file_descriptor) + # Now it exists. + self.assertTrue(pool.FindFileByName(proto_file_name)) + + def testCustomDescriptorPool(self): + # Create a new pool, and add a file descriptor. + pool = self._GetDescriptorPoolClass()() + file_desc = descriptor_pb2.FileDescriptorProto( + name='some/file.proto', package='package') + file_desc.message_type.add(name='Message') + pool.Add(file_desc) + self.assertEqual(pool.FindFileByName('some/file.proto').name, + 'some/file.proto') + self.assertEqual(pool.FindMessageTypeByName('package.Message').name, + 'Message') + + +@unittest.skipIf( + api_implementation.Type() != 'cpp', + 'default_pool is only supported by the C++ implementation') +class DefaultPoolTest(unittest.TestCase): + + def testFindMethods(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + self.assertIs( + pool.FindFileByName('google/protobuf/unittest.proto'), + unittest_pb2.DESCRIPTOR) + self.assertIs( + pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertIs( + pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), + unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) + self.assertIs( + pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), + unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertIs( + pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), + unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + + def testAddFileDescriptor(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + pool.Add(file_desc) + pool.AddSerializedFile(file_desc.SerializeToString()) + TEST1_FILE = ProtoFile( 'google/protobuf/internal/descriptor_pool_test1.proto', diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 99afee63..fee09a56 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -47,6 +47,7 @@ from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import descriptor +from google.protobuf import descriptor_pool from google.protobuf import symbol_database from google.protobuf import text_format @@ -75,9 +76,9 @@ class DescriptorTest(unittest.TestCase): enum_proto.value.add(name='FOREIGN_BAR', number=5) enum_proto.value.add(name='FOREIGN_BAZ', number=6) - descriptor_pool = symbol_database.Default().pool - descriptor_pool.Add(file_proto) - self.my_file = descriptor_pool.FindFileByName(file_proto.name) + self.pool = self.GetDescriptorPool() + self.pool.Add(file_proto) + self.my_file = self.pool.FindFileByName(file_proto.name) self.my_message = self.my_file.message_types_by_name[message_proto.name] self.my_enum = self.my_message.enum_types_by_name[enum_proto.name] @@ -97,6 +98,9 @@ class DescriptorTest(unittest.TestCase): self.my_method ]) + def GetDescriptorPool(self): + return symbol_database.Default().pool + def testEnumValueName(self): self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4), 'FOREIGN_FOO') @@ -393,6 +397,9 @@ class DescriptorTest(unittest.TestCase): def testFileDescriptor(self): self.assertEqual(self.my_file.name, 'some/filename/some.proto') self.assertEqual(self.my_file.package, 'protobuf_unittest') + self.assertEqual(self.my_file.pool, self.pool) + # Generated modules also belong to the default pool. + self.assertEqual(unittest_pb2.DESCRIPTOR.pool, descriptor_pool.Default()) @unittest.skipIf( api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, @@ -407,6 +414,13 @@ class DescriptorTest(unittest.TestCase): message_descriptor.fields.append(None) +class NewDescriptorTest(DescriptorTest): + """Redo the same tests as above, but with a separate DescriptorPool.""" + + def GetDescriptorPool(self): + return descriptor_pool.DescriptorPool() + + class GeneratedDescriptorTest(unittest.TestCase): """Tests for the properties of descriptors in generated code.""" diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 69197865..be3ad11a 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -42,6 +42,7 @@ try: import unittest2 as unittest except ImportError: import unittest +from google.protobuf.internal import well_known_types from google.protobuf import json_format from google.protobuf.util import json_format_proto3_pb2 @@ -269,15 +270,15 @@ class JsonFormatTest(JsonFormatBase): '}')) parsed_message = json_format_proto3_pb2.TestTimestamp() self.CheckParseBack(message, parsed_message) - text = (r'{"value": "1972-01-01T01:00:00.01+08:00",' + text = (r'{"value": "1970-01-01T00:00:00.01+08:00",' r'"repeatedValue":[' - r' "1972-01-01T01:00:00.01+08:30",' - r' "1972-01-01T01:00:00.01-01:23"]}') + r' "1970-01-01T00:00:00.01+08:30",' + r' "1970-01-01T00:00:00.01-01:23"]}') json_format.Parse(text, parsed_message) - self.assertEqual(parsed_message.value.seconds, 63104400) + self.assertEqual(parsed_message.value.seconds, -8 * 3600) self.assertEqual(parsed_message.value.nanos, 10000000) - self.assertEqual(parsed_message.repeated_value[0].seconds, 63106200) - self.assertEqual(parsed_message.repeated_value[1].seconds, 63070620) + self.assertEqual(parsed_message.repeated_value[0].seconds, -8.5 * 3600) + self.assertEqual(parsed_message.repeated_value[1].seconds, 3600 + 23 * 60) def testDurationMessage(self): message = json_format_proto3_pb2.TestDuration() @@ -389,7 +390,7 @@ class JsonFormatTest(JsonFormatBase): def testParseEmptyText(self): self.CheckError('', - r'Failed to load JSON: (Expecting value)|(No JSON)') + r'Failed to load JSON: (Expecting value)|(No JSON).') def testParseBadEnumValue(self): self.CheckError( @@ -414,7 +415,7 @@ class JsonFormatTest(JsonFormatBase): if sys.version_info < (2, 7): return self.CheckError('{"int32Value": 1,\n"int32Value":2}', - 'Failed to load JSON: duplicate key int32Value') + 'Failed to load JSON: duplicate key int32Value.') def testInvalidBoolValue(self): self.CheckError('{"boolValue": 1}', @@ -431,39 +432,43 @@ class JsonFormatTest(JsonFormatBase): json_format.Parse, text, message) self.CheckError('{"int32Value": 012345}', (r'Failed to load JSON: Expecting \'?,\'? delimiter: ' - r'line 1')) + r'line 1.')) self.CheckError('{"int32Value": 1.0}', 'Failed to parse int32Value field: ' - 'Couldn\'t parse integer: 1.0') + 'Couldn\'t parse integer: 1.0.') self.CheckError('{"int32Value": " 1 "}', 'Failed to parse int32Value field: ' - 'Couldn\'t parse integer: " 1 "') + 'Couldn\'t parse integer: " 1 ".') + self.CheckError('{"int32Value": "1 "}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: "1 ".') self.CheckError('{"int32Value": 12345678901234567890}', 'Failed to parse int32Value field: Value out of range: ' - '12345678901234567890') + '12345678901234567890.') self.CheckError('{"int32Value": 1e5}', 'Failed to parse int32Value field: ' - 'Couldn\'t parse integer: 100000.0') + 'Couldn\'t parse integer: 100000.0.') self.CheckError('{"uint32Value": -1}', - 'Failed to parse uint32Value field: Value out of range: -1') + 'Failed to parse uint32Value field: ' + 'Value out of range: -1.') def testInvalidFloatValue(self): self.CheckError('{"floatValue": "nan"}', 'Failed to parse floatValue field: Couldn\'t ' - 'parse float "nan", use "NaN" instead') + 'parse float "nan", use "NaN" instead.') def testInvalidBytesValue(self): self.CheckError('{"bytesValue": "AQI"}', - 'Failed to parse bytesValue field: Incorrect padding') + 'Failed to parse bytesValue field: Incorrect padding.') self.CheckError('{"bytesValue": "AQI*"}', - 'Failed to parse bytesValue field: Incorrect padding') + 'Failed to parse bytesValue field: Incorrect padding.') def testInvalidMap(self): message = json_format_proto3_pb2.TestMap() text = '{"int32Map": {"null": 2, "2": 3}}' self.assertRaisesRegexp( json_format.ParseError, - 'Failed to parse int32Map field: Couldn\'t parse integer: "null"', + 'Failed to parse int32Map field: invalid literal', json_format.Parse, text, message) text = '{"int32Map": {1: 2, "2": 3}}' self.assertRaisesRegexp( @@ -474,7 +479,7 @@ class JsonFormatTest(JsonFormatBase): text = '{"boolMap": {"null": 1}}' self.assertRaisesRegexp( json_format.ParseError, - 'Failed to parse boolMap field: Expect "true" or "false", not null.', + 'Failed to parse boolMap field: Expected "true" or "false", not null.', json_format.Parse, text, message) if sys.version_info < (2, 7): return @@ -490,30 +495,29 @@ class JsonFormatTest(JsonFormatBase): self.assertRaisesRegexp( json_format.ParseError, 'time data \'10000-01-01T00:00:00\' does not match' - ' format \'%Y-%m-%dT%H:%M:%S\'', + ' format \'%Y-%m-%dT%H:%M:%S\'.', json_format.Parse, text, message) text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}' self.assertRaisesRegexp( - json_format.ParseError, - 'Failed to parse value field: Failed to parse Timestamp: ' + well_known_types.ParseError, 'nanos 0123456789012 more than 9 fractional digits.', json_format.Parse, text, message) text = '{"value": "1972-01-01T01:00:00.01+08"}' self.assertRaisesRegexp( - json_format.ParseError, - (r'Failed to parse value field: Invalid timezone offset value: \+08'), + well_known_types.ParseError, + (r'Invalid timezone offset value: \+08.'), json_format.Parse, text, message) # Time smaller than minimum time. text = '{"value": "0000-01-01T00:00:00Z"}' self.assertRaisesRegexp( json_format.ParseError, - 'Failed to parse value field: year is out of range', + 'Failed to parse value field: year is out of range.', json_format.Parse, text, message) # Time bigger than maxinum time. message.value.seconds = 253402300800 self.assertRaisesRegexp( - json_format.SerializeToJsonError, - 'Failed to serialize value field: year is out of range', + OverflowError, + 'date value out of range', json_format.MessageToJson, message) def testInvalidOneof(self): diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index d760b898..2fbe5ea7 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -45,6 +45,7 @@ from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message_factory + class MessageFactoryTest(unittest.TestCase): def setUp(self): @@ -104,8 +105,8 @@ class MessageFactoryTest(unittest.TestCase): def testGetMessages(self): # performed twice because multiple calls with the same input must be allowed for _ in range(2): - messages = message_factory.GetMessages([self.factory_test2_fd, - self.factory_test1_fd]) + messages = message_factory.GetMessages([self.factory_test1_fd, + self.factory_test2_fd]) self.assertTrue( set(['google.protobuf.python.internal.Factory2Message', 'google.protobuf.python.internal.Factory1Message'], @@ -116,7 +117,7 @@ class MessageFactoryTest(unittest.TestCase): set(['google.protobuf.python.internal.Factory2Message.one_more_field', 'google.protobuf.python.internal.another_field'], ).issubset( - set(messages['google.protobuf.python.internal.Factory1Message'] + set(messages['google.protobuf.python.internal.Factory1Message'] ._extensions_by_name.keys()))) factory_msg1 = messages['google.protobuf.python.internal.Factory1Message'] msg1 = messages['google.protobuf.python.internal.Factory1Message']() diff --git a/python/google/protobuf/internal/message_set_extensions.proto b/python/google/protobuf/internal/message_set_extensions.proto index 702c8d07..14e5f193 100644 --- a/python/google/protobuf/internal/message_set_extensions.proto +++ b/python/google/protobuf/internal/message_set_extensions.proto @@ -54,6 +54,14 @@ message TestMessageSetExtension2 { optional string str = 25; } +message TestMessageSetExtension3 { + optional string text = 35; +} + +extend TestMessageSet { + optional TestMessageSetExtension3 message_set_extension3 = 98418655; +} + // This message was used to generate // //net/proto2/python/internal/testdata/message_set_message, but is commented // out since it must not actually exist in code, to simulate an "unknown" diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 13c3caa6..d03f2d25 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -60,6 +60,7 @@ from google.protobuf.internal import _parameterized from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 +from google.protobuf.internal import any_test_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util @@ -1279,12 +1280,13 @@ class Proto3Test(unittest.TestCase): self.assertIsInstance(msg.map_string_string['abc'], six.text_type) - # Accessing an unset key still throws TypeError of the type of the key + # Accessing an unset key still throws TypeError if the type of the key # is incorrect. with self.assertRaises(TypeError): msg.map_string_string[123] - self.assertFalse(123 in msg.map_string_string) + with self.assertRaises(TypeError): + 123 in msg.map_string_string def testMapGet(self): # Need to test that get() properly returns the default, even though the dict @@ -1591,31 +1593,49 @@ class Proto3Test(unittest.TestCase): # For the C++ implementation this tests the correctness of # ScalarMapContainer::Release() msg = map_unittest_pb2.TestMap() - map = msg.map_int32_int32 + int32_map = msg.map_int32_int32 - map[2] = 4 - map[3] = 6 - map[4] = 8 + int32_map[2] = 4 + int32_map[3] = 6 + int32_map[4] = 8 msg.ClearField('map_int32_int32') + self.assertEqual(b'', msg.SerializeToString()) matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(map.items(), matching_dict) + self.assertMapIterEquals(int32_map.items(), matching_dict) - def testMapIterValidAfterFieldCleared(self): - # Map iterator needs to work even if field is cleared. + def testMessageMapValidAfterFieldCleared(self): + # Map needs to work even if field is cleared. # For the C++ implementation this tests the correctness of # ScalarMapContainer::Release() msg = map_unittest_pb2.TestMap() + int32_foreign_message = msg.map_int32_foreign_message - msg.map_int32_int32[2] = 4 - msg.map_int32_int32[3] = 6 - msg.map_int32_int32[4] = 8 + int32_foreign_message[2].c = 5 - it = msg.map_int32_int32.items() + msg.ClearField('map_int32_foreign_message') + self.assertEqual(b'', msg.SerializeToString()) + self.assertTrue(2 in int32_foreign_message.keys()) + + def testMapIterInvalidatedByClearField(self): + # Map iterator is invalidated when field is cleared. + # But this case does need to not crash the interpreter. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + + it = iter(msg.map_int32_int32) msg.ClearField('map_int32_int32') - matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(it, matching_dict) + with self.assertRaises(RuntimeError): + for _ in it: + pass + + it = iter(msg.map_int32_foreign_message) + msg.ClearField('map_int32_foreign_message') + with self.assertRaises(RuntimeError): + for _ in it: + pass def testMapDelete(self): msg = map_unittest_pb2.TestMap() @@ -1646,6 +1666,37 @@ class Proto3Test(unittest.TestCase): msg.map_string_foreign_message['foo'].c = 5 self.assertEqual(0, len(msg.FindInitializationErrors())) + def testAnyMessage(self): + # Creates and sets message. + msg = any_test_pb2.TestAny() + msg_descriptor = msg.DESCRIPTOR + all_types = unittest_pb2.TestAllTypes() + all_descriptor = all_types.DESCRIPTOR + all_types.repeated_string.append(u'\u00fc\ua71f') + # Packs to Any. + msg.value.Pack(all_types) + self.assertEqual(msg.value.type_url, + 'type.googleapis.com/%s' % all_descriptor.full_name) + self.assertEqual(msg.value.value, + all_types.SerializeToString()) + # Tests Is() method. + self.assertTrue(msg.value.Is(all_descriptor)) + self.assertFalse(msg.value.Is(msg_descriptor)) + # Unpacks Any. + unpacked_message = unittest_pb2.TestAllTypes() + self.assertTrue(msg.value.Unpack(unpacked_message)) + self.assertEqual(all_types, unpacked_message) + # Unpacks to different type. + self.assertFalse(msg.value.Unpack(msg)) + # Only Any messages have Pack method. + try: + msg.Pack(all_types) + except AttributeError: + pass + else: + raise AttributeError('%s should not have Pack method.' % + msg_descriptor.full_name) + class ValidTypeNamesTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 2b87f704..87f60666 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -65,6 +65,7 @@ from google.protobuf.internal import encoder from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import message_listener as message_listener_mod from google.protobuf.internal import type_checkers +from google.protobuf.internal import well_known_types from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod @@ -72,6 +73,7 @@ from google.protobuf import symbol_database from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor +_AnyFullTypeName = 'google.protobuf.Any' class GeneratedProtocolMessageType(type): @@ -127,6 +129,8 @@ class GeneratedProtocolMessageType(type): Newly-allocated class. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + if descriptor.full_name in well_known_types.WKTBASES: + bases += (well_known_types.WKTBASES[descriptor.full_name],) _AddClassAttributesForNestedExtensions(descriptor, dictionary) _AddSlots(descriptor, dictionary) @@ -261,7 +265,6 @@ def _IsMessageSetExtension(field): field.containing_type.has_options and field.containing_type.GetOptions().message_set_wire_format and field.type == _FieldDescriptor.TYPE_MESSAGE and - field.message_type == field.extension_scope and field.label == _FieldDescriptor.LABEL_OPTIONAL) @@ -543,7 +546,8 @@ def _GetFieldByName(message_descriptor, field_name): try: return message_descriptor.fields_by_name[field_name] except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + raise ValueError('Protocol message %s has no "%s" field.' % + (message_descriptor.name, field_name)) def _AddPropertiesForFields(descriptor, cls): @@ -848,9 +852,15 @@ def _AddClearFieldMethod(message_descriptor, cls): else: return except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + raise ValueError('Protocol message %s() has no "%s" field.' % + (message_descriptor.name, field_name)) if field in self._fields: + # To match the C++ implementation, we need to invalidate iterators + # for map fields when ClearField() happens. + if hasattr(self._fields[field], 'InvalidateIterators'): + self._fields[field].InvalidateIterators() + # Note: If the field is a sub-message, its listener will still point # at us. That's fine, because the worst than can happen is that it # will call _Modified() and invalidate our byte size. Big deal. @@ -904,7 +914,19 @@ def _AddHasExtensionMethod(cls): return extension_handle in self._fields cls.HasExtension = HasExtension -def _UnpackAny(msg): +def _InternalUnpackAny(msg): + """Unpacks Any message and returns the unpacked message. + + This internal method is differnt from public Any Unpack method which takes + the target message as argument. _InternalUnpackAny method does not have + target message type and need to find the message type in descriptor pool. + + Args: + msg: An Any message to be unpacked. + + Returns: + The unpacked message. + """ type_url = msg.type_url db = symbol_database.Default() @@ -935,9 +957,9 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - if self.DESCRIPTOR.full_name == "google.protobuf.Any": - any_a = _UnpackAny(self) - any_b = _UnpackAny(other) + if self.DESCRIPTOR.full_name == _AnyFullTypeName: + any_a = _InternalUnpackAny(self) + any_b = _InternalUnpackAny(other) if any_a and any_b: return any_a == any_b @@ -962,6 +984,13 @@ def _AddStrMethod(message_descriptor, cls): cls.__str__ = __str__ +def _AddReprMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __repr__(self): + return text_format.MessageToString(self) + cls.__repr__ = __repr__ + + def _AddUnicodeMethod(unused_message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -1270,6 +1299,7 @@ def _AddMessageMethods(message_descriptor, cls): _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) + _AddReprMethod(message_descriptor, cls) _AddUnicodeMethod(message_descriptor, cls) _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) @@ -1280,6 +1310,7 @@ def _AddMessageMethods(message_descriptor, cls): _AddMergeFromMethod(cls) _AddWhichOneofMethod(message_descriptor, cls) + def _AddPrivateHelperMethods(message_descriptor, cls): """Adds implementation of private helper methods to cls.""" diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 8621d61e..752f2f5d 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -2394,8 +2394,10 @@ class SerializationTest(unittest.TestCase): extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 extension1 = extension_message1.message_set_extension extension2 = extension_message2.message_set_extension + extension3 = message_set_extensions_pb2.message_set_extension3 proto.Extensions[extension1].i = 123 proto.Extensions[extension2].str = 'foo' + proto.Extensions[extension3].text = 'bar' # Serialize using the MessageSet wire format (this is specified in the # .proto file). @@ -2407,7 +2409,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual( len(serialized), raw.MergeFromString(serialized)) - self.assertEqual(2, len(raw.item)) + self.assertEqual(3, len(raw.item)) message1 = message_set_extensions_pb2.TestMessageSetExtension1() self.assertEqual( @@ -2421,6 +2423,12 @@ class SerializationTest(unittest.TestCase): message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) + message3 = message_set_extensions_pb2.TestMessageSetExtension3() + self.assertEqual( + len(raw.item[2].message), + message3.MergeFromString(raw.item[2].message)) + self.assertEqual('bar', message3.text) + # Deserialize using the MessageSet wire format. proto2 = message_set_extensions_pb2.TestMessageSet() self.assertEqual( @@ -2428,6 +2436,7 @@ class SerializationTest(unittest.TestCase): proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) + self.assertEqual('bar', proto2.Extensions[extension3].text) # Check byte size. self.assertEqual(proto2.ByteSize(), len(serialized)) @@ -2757,9 +2766,10 @@ class SerializationTest(unittest.TestCase): def testInitArgsUnknownFieldName(self): def InitalizeEmptyMessageWithExtraKeywordArg(): unused_proto = unittest_pb2.TestEmptyMessage(unknown='unknown') - self._CheckRaises(ValueError, - InitalizeEmptyMessageWithExtraKeywordArg, - 'Protocol message has no "unknown" field.') + self._CheckRaises( + ValueError, + InitalizeEmptyMessageWithExtraKeywordArg, + 'Protocol message TestEmptyMessage has no "unknown" field.') def testInitRequiredKwargs(self): proto = unittest_pb2.TestRequired(a=1, b=1, c=1) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index cca0ee63..0e14556c 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -51,8 +51,22 @@ from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf import text_format + +# Low-level nuts-n-bolts tests. +class SimpleTextFormatTests(unittest.TestCase): + + # The members of _QUOTES are formatted into a regexp template that + # expects single characters. Therefore it's an error (in addition to being + # non-sensical in the first place) to try to specify a "quote mark" that is + # more than one character. + def TestQuoteMarksAreSingleChars(self): + for quote in text_format._QUOTES: + self.assertEqual(1, len(quote)) + + # Base class with some common functionality. class TextFormatBase(unittest.TestCase): @@ -287,6 +301,19 @@ class TextFormatTest(TextFormatBase): self.assertEqual(u'one', message.repeated_string[0]) self.assertEqual(u'two', message.repeated_string[1]) + def testParseRepeatedScalarShortFormat(self, message_module): + message = message_module.TestAllTypes() + text = ('repeated_int64: [100, 200];\n' + 'repeated_int64: 300,\n' + 'repeated_string: ["one", "two"];\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_int64[0]) + self.assertEqual(200, message.repeated_int64[1]) + self.assertEqual(300, message.repeated_int64[2]) + self.assertEqual(u'one', message.repeated_string[0]) + self.assertEqual(u'two', message.repeated_string[1]) + def testParseEmptyText(self, message_module): message = message_module.TestAllTypes() text = '' @@ -301,7 +328,7 @@ class TextFormatTest(TextFormatBase): def testParseSingleWord(self, message_module): message = message_module.TestAllTypes() text = 'foo' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' r'"foo".'), @@ -310,7 +337,7 @@ class TextFormatTest(TextFormatBase): def testParseUnknownField(self, message_module): message = message_module.TestAllTypes() text = 'unknown_field: 8\n' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' r'"unknown_field".'), @@ -319,7 +346,7 @@ class TextFormatTest(TextFormatBase): def testParseBadEnumValue(self, message_module): message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value named BARR.'), @@ -327,7 +354,7 @@ class TextFormatTest(TextFormatBase): message = message_module.TestAllTypes() text = 'optional_nested_enum: 100' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value with number 100.'), @@ -336,7 +363,7 @@ class TextFormatTest(TextFormatBase): def testParseBadIntValue(self, message_module): message = message_module.TestAllTypes() text = 'optional_int32: bork' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, ('1:17 : Couldn\'t parse integer: bork'), text_format.Parse, text, message) @@ -553,6 +580,15 @@ class Proto2Tests(TextFormatBase): ' }\n' '}\n') + message = message_set_extensions_pb2.TestMessageSet() + ext = message_set_extensions_pb2.message_set_extension3 + message.Extensions[ext].text = 'bar' + self.CompareToGoldenText( + text_format.MessageToString(message), + '[google.protobuf.internal.TestMessageSetExtension3] {\n' + ' text: \"bar\"\n' + '}\n') + def testPrintMessageSetAsOneLine(self): message = unittest_mset_pb2.TestMessageSetContainer() ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension @@ -627,15 +663,125 @@ class Proto2Tests(TextFormatBase): text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) + def testParseAllowedUnknownExtension(self): + # Skip over unknown extension correctly. + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [unknown_extension] {\n' + ' i: 23\n' + ' [nested_unknown_ext]: {\n' + ' i: 23\n' + ' test: "test_string"\n' + ' floaty_float: -0.315\n' + ' num: -inf\n' + ' multiline_str: "abc"\n' + ' "def"\n' + ' "xyz."\n' + ' [nested_unknown_ext]: <\n' + ' i: 23\n' + ' i: 24\n' + ' pointfloat: .3\n' + ' test: "test_string"\n' + ' floaty_float: -0.315\n' + ' num: -inf\n' + ' long_string: "test" "test2" \n' + ' >\n' + ' }\n' + ' }\n' + ' [unknown_extension]: 5\n' + '}\n') + text_format.Parse(text, message, allow_unknown_extension=True) + golden = 'message_set {\n}\n' + self.CompareToGoldenText(text_format.MessageToString(message), golden) + + # Catch parse errors in unknown extension. + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [unknown_extension] {\n' + ' i:\n' # Missing value. + ' }\n' + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + 'Invalid field value: }', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [unknown_extension] {\n' + ' str: "malformed string\n' # Missing closing quote. + ' }\n' + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + 'Invalid field value: "', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [unknown_extension] {\n' + ' str: "malformed\n multiline\n string\n' + ' }\n' + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + 'Invalid field value: "', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' [malformed_extension] <\n' + ' i: -5\n' + ' \n' # Missing '>' here. + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + '5:1 : Expected ">".', + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + # Don't allow unknown fields with allow_unknown_extension=True. + message = unittest_mset_pb2.TestMessageSetContainer() + malformed = ('message_set {\n' + ' unknown_field: true\n' + ' \n' # Missing '>' here. + '}\n') + six.assertRaisesRegex(self, + text_format.ParseError, + ('2:3 : Message type ' + '"proto2_wireformat_unittest.TestMessageSet" has no' + ' field named "unknown_field".'), + text_format.Parse, malformed, message, + allow_unknown_extension=True) + + # Parse known extension correcty. + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message, allow_unknown_extension=True) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEqual(23, message.message_set.Extensions[ext1].i) + self.assertEqual('foo', message.message_set.Extensions[ext2].str) + def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' 'extensions.'), @@ -654,7 +800,7 @@ class Proto2Tests(TextFormatBase): message = unittest_pb2.TestAllExtensions() text = ('[protobuf_unittest.optional_int32_extension]: 42 ' '[protobuf_unittest.optional_int32_extension]: 67') - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' 'should not have multiple ' @@ -665,7 +811,7 @@ class Proto2Tests(TextFormatBase): message = unittest_pb2.TestAllTypes() text = ('optional_nested_message { bb: 1 } ' 'optional_nested_message { bb: 2 }') - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' 'should not have multiple "bb" fields.'), @@ -675,7 +821,7 @@ class Proto2Tests(TextFormatBase): message = unittest_pb2.TestAllTypes() text = ('optional_int32: 42 ' 'optional_int32: 67') - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' 'have multiple "optional_int32" fields.'), @@ -684,11 +830,11 @@ class Proto2Tests(TextFormatBase): def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected ">".', text_format.Parse, text, message) text = 'RepeatedGroup: {' - six.assertRaisesRegex(self, + six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected "}".', text_format.Parse, text, message) diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py new file mode 100644 index 00000000..d3de9831 --- /dev/null +++ b/python/google/protobuf/internal/well_known_types.py @@ -0,0 +1,622 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains well known classes. + +This files defines well known classes which need extra maintenance including: + - Any + - Duration + - FieldMask + - Timestamp +""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +from datetime import datetime +from datetime import timedelta + +from google.protobuf.descriptor import FieldDescriptor + +_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' +_NANOS_PER_SECOND = 1000000000 +_NANOS_PER_MILLISECOND = 1000000 +_NANOS_PER_MICROSECOND = 1000 +_MILLIS_PER_SECOND = 1000 +_MICROS_PER_SECOND = 1000000 +_SECONDS_PER_DAY = 24 * 3600 + + +class Error(Exception): + """Top-level module error.""" + + +class ParseError(Error): + """Thrown in case of parsing error.""" + + +class Any(object): + """Class for Any Message type.""" + + def Pack(self, msg): + """Packs the specified message into current Any message.""" + self.type_url = 'type.googleapis.com/%s' % msg.DESCRIPTOR.full_name + self.value = msg.SerializeToString() + + def Unpack(self, msg): + """Unpacks the current Any message into specified message.""" + descriptor = msg.DESCRIPTOR + if not self.Is(descriptor): + return False + msg.ParseFromString(self.value) + return True + + def Is(self, descriptor): + """Checks if this Any represents the given protobuf type.""" + # Only last part is to be used: b/25630112 + return self.type_url.split('/')[-1] == descriptor.full_name + + +class Timestamp(object): + """Class for Timestamp message type.""" + + def ToJsonString(self): + """Converts Timestamp to RFC 3339 date string format. + + Returns: + A string converted from timestamp. The string is always Z-normalized + and uses 3, 6 or 9 fractional digits as required to represent the + exact time. Example of the return format: '1972-01-01T10:00:20.021Z' + """ + nanos = self.nanos % _NANOS_PER_SECOND + total_sec = self.seconds + (self.nanos - nanos) // _NANOS_PER_SECOND + seconds = total_sec % _SECONDS_PER_DAY + days = (total_sec - seconds) // _SECONDS_PER_DAY + dt = datetime(1970, 1, 1) + timedelta(days, seconds) + + result = dt.isoformat() + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 'Z' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03dZ' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06dZ' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09dZ' % nanos + + def FromJsonString(self, value): + """Parse a RFC 3339 date string format to Timestamp. + + Args: + value: A date string. Any fractional digits (or none) and any offset are + accepted as long as they fit into nano-seconds precision. + Example of accepted format: '1972-01-01T10:00:20.021-05:00' + + Raises: + ParseError: On parsing problems. + """ + timezone_offset = value.find('Z') + if timezone_offset == -1: + timezone_offset = value.find('+') + if timezone_offset == -1: + timezone_offset = value.rfind('-') + if timezone_offset == -1: + raise ParseError( + 'Failed to parse timestamp: missing valid timezone offset.') + time_value = value[0:timezone_offset] + # Parse datetime and nanos. + point_position = time_value.find('.') + if point_position == -1: + second_value = time_value + nano_value = '' + else: + second_value = time_value[:point_position] + nano_value = time_value[point_position + 1:] + date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) + td = date_object - datetime(1970, 1, 1) + seconds = td.seconds + td.days * _SECONDS_PER_DAY + if len(nano_value) > 9: + raise ParseError( + 'Failed to parse Timestamp: nanos {0} more than ' + '9 fractional digits.'.format(nano_value)) + if nano_value: + nanos = round(float('0.' + nano_value) * 1e9) + else: + nanos = 0 + # Parse timezone offsets. + if value[timezone_offset] == 'Z': + if len(value) != timezone_offset + 1: + raise ParseError('Failed to parse timestamp: invalid trailing' + ' data {0}.'.format(value)) + else: + timezone = value[timezone_offset:] + pos = timezone.find(':') + if pos == -1: + raise ParseError( + 'Invalid timezone offset value: {0}.'.format(timezone)) + if timezone[0] == '+': + seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + else: + seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + # Set seconds and nanos + self.seconds = int(seconds) + self.nanos = int(nanos) + + def GetCurrentTime(self): + """Get the current UTC into Timestamp.""" + self.FromDatetime(datetime.utcnow()) + + def ToNanoseconds(self): + """Converts Timestamp to nanoseconds since epoch.""" + return self.seconds * _NANOS_PER_SECOND + self.nanos + + def ToMicroseconds(self): + """Converts Timestamp to microseconds since epoch.""" + return (self.seconds * _MICROS_PER_SECOND + + self.nanos // _NANOS_PER_MICROSECOND) + + def ToMilliseconds(self): + """Converts Timestamp to milliseconds since epoch.""" + return (self.seconds * _MILLIS_PER_SECOND + + self.nanos // _NANOS_PER_MILLISECOND) + + def ToSeconds(self): + """Converts Timestamp to seconds since epoch.""" + return self.seconds + + def FromNanoseconds(self, nanos): + """Converts nanoseconds since epoch to Timestamp.""" + self.seconds = nanos // _NANOS_PER_SECOND + self.nanos = nanos % _NANOS_PER_SECOND + + def FromMicroseconds(self, micros): + """Converts microseconds since epoch to Timestamp.""" + self.seconds = micros // _MICROS_PER_SECOND + self.nanos = (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND + + def FromMilliseconds(self, millis): + """Converts milliseconds since epoch to Timestamp.""" + self.seconds = millis // _MILLIS_PER_SECOND + self.nanos = (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND + + def FromSeconds(self, seconds): + """Converts seconds since epoch to Timestamp.""" + self.seconds = seconds + self.nanos = 0 + + def ToDatetime(self): + """Converts Timestamp to datetime.""" + return datetime.utcfromtimestamp( + self.seconds + self.nanos / float(_NANOS_PER_SECOND)) + + def FromDatetime(self, dt): + """Converts datetime to Timestamp.""" + td = dt - datetime(1970, 1, 1) + self.seconds = td.seconds + td.days * _SECONDS_PER_DAY + self.nanos = td.microseconds * _NANOS_PER_MICROSECOND + + +class Duration(object): + """Class for Duration message type.""" + + def ToJsonString(self): + """Converts Duration to string format. + + Returns: + A string converted from self. The string format will contains + 3, 6, or 9 fractional digits depending on the precision required to + represent the exact Duration value. For example: "1s", "1.010s", + "1.000000100s", "-3.100s" + """ + if self.seconds < 0 or self.nanos < 0: + result = '-' + seconds = - self.seconds + int((0 - self.nanos) // 1e9) + nanos = (0 - self.nanos) % 1e9 + else: + result = '' + seconds = self.seconds + int(self.nanos // 1e9) + nanos = self.nanos % 1e9 + result += '%d' % seconds + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 's' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03ds' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06ds' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09ds' % nanos + + def FromJsonString(self, value): + """Converts a string to Duration. + + Args: + value: A string to be converted. The string must end with 's'. Any + fractional digits (or none) are accepted as long as they fit into + precision. For example: "1s", "1.01s", "1.0000001s", "-3.100s + + Raises: + ParseError: On parsing problems. + """ + if len(value) < 1 or value[-1] != 's': + raise ParseError( + 'Duration must end with letter "s": {0}.'.format(value)) + try: + pos = value.find('.') + if pos == -1: + self.seconds = int(value[:-1]) + self.nanos = 0 + else: + self.seconds = int(value[:pos]) + if value[0] == '-': + self.nanos = int(round(float('-0{0}'.format(value[pos: -1])) *1e9)) + else: + self.nanos = int(round(float('0{0}'.format(value[pos: -1])) *1e9)) + except ValueError: + raise ParseError( + 'Couldn\'t parse duration: {0}.'.format(value)) + + def ToNanoseconds(self): + """Converts a Duration to nanoseconds.""" + return self.seconds * _NANOS_PER_SECOND + self.nanos + + def ToMicroseconds(self): + """Converts a Duration to microseconds.""" + micros = _RoundTowardZero(self.nanos, _NANOS_PER_MICROSECOND) + return self.seconds * _MICROS_PER_SECOND + micros + + def ToMilliseconds(self): + """Converts a Duration to milliseconds.""" + millis = _RoundTowardZero(self.nanos, _NANOS_PER_MILLISECOND) + return self.seconds * _MILLIS_PER_SECOND + millis + + def ToSeconds(self): + """Converts a Duration to seconds.""" + return self.seconds + + def FromNanoseconds(self, nanos): + """Converts nanoseconds to Duration.""" + self._NormalizeDuration(nanos // _NANOS_PER_SECOND, + nanos % _NANOS_PER_SECOND) + + def FromMicroseconds(self, micros): + """Converts microseconds to Duration.""" + self._NormalizeDuration( + micros // _MICROS_PER_SECOND, + (micros % _MICROS_PER_SECOND) * _NANOS_PER_MICROSECOND) + + def FromMilliseconds(self, millis): + """Converts milliseconds to Duration.""" + self._NormalizeDuration( + millis // _MILLIS_PER_SECOND, + (millis % _MILLIS_PER_SECOND) * _NANOS_PER_MILLISECOND) + + def FromSeconds(self, seconds): + """Converts seconds to Duration.""" + self.seconds = seconds + self.nanos = 0 + + def ToTimedelta(self): + """Converts Duration to timedelta.""" + return timedelta( + seconds=self.seconds, microseconds=_RoundTowardZero( + self.nanos, _NANOS_PER_MICROSECOND)) + + def FromTimedelta(self, td): + """Convertd timedelta to Duration.""" + self._NormalizeDuration(td.seconds + td.days * _SECONDS_PER_DAY, + td.microseconds * _NANOS_PER_MICROSECOND) + + def _NormalizeDuration(self, seconds, nanos): + """Set Duration by seconds and nonas.""" + # Force nanos to be negative if the duration is negative. + if seconds < 0 and nanos > 0: + seconds += 1 + nanos -= _NANOS_PER_SECOND + self.seconds = seconds + self.nanos = nanos + + +def _RoundTowardZero(value, divider): + """Truncates the remainder part after division.""" + # For some languanges, the sign of the remainder is implementation + # dependent if any of the operands is negative. Here we enforce + # "rounded toward zero" semantics. For example, for (-5) / 2 an + # implementation may give -3 as the result with the remainder being + # 1. This function ensures we always return -2 (closer to zero). + result = value // divider + remainder = value % divider + if result < 0 and remainder > 0: + return result + 1 + else: + return result + + +class FieldMask(object): + """Class for FieldMask message type.""" + + def ToJsonString(self): + """Converts FieldMask to string according to proto3 JSON spec.""" + return ','.join(self.paths) + + def FromJsonString(self, value): + """Converts string to FieldMask according to proto3 JSON spec.""" + self.Clear() + for path in value.split(','): + self.paths.append(path) + + def IsValidForDescriptor(self, message_descriptor): + """Checks whether the FieldMask is valid for Message Descriptor.""" + for path in self.paths: + if not _IsValidPath(message_descriptor, path): + return False + return True + + def AllFieldsFromDescriptor(self, message_descriptor): + """Gets all direct fields of Message Descriptor to FieldMask.""" + self.Clear() + for field in message_descriptor.fields: + self.paths.append(field.name) + + def CanonicalFormFromMask(self, mask): + """Converts a FieldMask to the canonical form. + + Removes paths that are covered by another path. For example, + "foo.bar" is covered by "foo" and will be removed if "foo" + is also in the FieldMask. Then sorts all paths in alphabetical order. + + Args: + mask: The original FieldMask to be converted. + """ + tree = _FieldMaskTree(mask) + tree.ToFieldMask(self) + + def Union(self, mask1, mask2): + """Merges mask1 and mask2 into this FieldMask.""" + _CheckFieldMaskMessage(mask1) + _CheckFieldMaskMessage(mask2) + tree = _FieldMaskTree(mask1) + tree.MergeFromFieldMask(mask2) + tree.ToFieldMask(self) + + def Intersect(self, mask1, mask2): + """Intersects mask1 and mask2 into this FieldMask.""" + _CheckFieldMaskMessage(mask1) + _CheckFieldMaskMessage(mask2) + tree = _FieldMaskTree(mask1) + intersection = _FieldMaskTree() + for path in mask2.paths: + tree.IntersectPath(path, intersection) + intersection.ToFieldMask(self) + + def MergeMessage( + self, source, destination, + replace_message_field=False, replace_repeated_field=False): + """Merges fields specified in FieldMask from source to destination. + + Args: + source: Source message. + destination: The destination message to be merged into. + replace_message_field: Replace message field if True. Merge message + field if False. + replace_repeated_field: Replace repeated field if True. Append + elements of repeated field if False. + """ + tree = _FieldMaskTree(self) + tree.MergeMessage( + source, destination, replace_message_field, replace_repeated_field) + + +def _IsValidPath(message_descriptor, path): + """Checks whether the path is valid for Message Descriptor.""" + parts = path.split('.') + last = parts.pop() + for name in parts: + field = message_descriptor.fields_by_name[name] + if (field is None or + field.label == FieldDescriptor.LABEL_REPEATED or + field.type != FieldDescriptor.TYPE_MESSAGE): + return False + message_descriptor = field.message_type + return last in message_descriptor.fields_by_name + + +def _CheckFieldMaskMessage(message): + """Raises ValueError if message is not a FieldMask.""" + message_descriptor = message.DESCRIPTOR + if (message_descriptor.name != 'FieldMask' or + message_descriptor.file.name != 'google/protobuf/field_mask.proto'): + raise ValueError('Message {0} is not a FieldMask.'.format( + message_descriptor.full_name)) + + +class _FieldMaskTree(object): + """Represents a FieldMask in a tree structure. + + For example, given a FieldMask "foo.bar,foo.baz,bar.baz", + the FieldMaskTree will be: + [_root] -+- foo -+- bar + | | + | +- baz + | + +- bar --- baz + In the tree, each leaf node represents a field path. + """ + + def __init__(self, field_mask=None): + """Initializes the tree by FieldMask.""" + self._root = {} + if field_mask: + self.MergeFromFieldMask(field_mask) + + def MergeFromFieldMask(self, field_mask): + """Merges a FieldMask to the tree.""" + for path in field_mask.paths: + self.AddPath(path) + + def AddPath(self, path): + """Adds a field path into the tree. + + If the field path to add is a sub-path of an existing field path + in the tree (i.e., a leaf node), it means the tree already matches + the given path so nothing will be added to the tree. If the path + matches an existing non-leaf node in the tree, that non-leaf node + will be turned into a leaf node with all its children removed because + the path matches all the node's children. Otherwise, a new path will + be added. + + Args: + path: The field path to add. + """ + node = self._root + for name in path.split('.'): + if name not in node: + node[name] = {} + elif not node[name]: + # Pre-existing empty node implies we already have this entire tree. + return + node = node[name] + # Remove any sub-trees we might have had. + node.clear() + + def ToFieldMask(self, field_mask): + """Converts the tree to a FieldMask.""" + field_mask.Clear() + _AddFieldPaths(self._root, '', field_mask) + + def IntersectPath(self, path, intersection): + """Calculates the intersection part of a field path with this tree. + + Args: + path: The field path to calculates. + intersection: The out tree to record the intersection part. + """ + node = self._root + for name in path.split('.'): + if name not in node: + return + elif not node[name]: + intersection.AddPath(path) + return + node = node[name] + intersection.AddLeafNodes(path, node) + + def AddLeafNodes(self, prefix, node): + """Adds leaf nodes begin with prefix to this tree.""" + if not node: + self.AddPath(prefix) + for name in node: + child_path = prefix + '.' + name + self.AddLeafNodes(child_path, node[name]) + + def MergeMessage( + self, source, destination, + replace_message, replace_repeated): + """Merge all fields specified by this tree from source to destination.""" + _MergeMessage( + self._root, source, destination, replace_message, replace_repeated) + + +def _StrConvert(value): + """Converts value to str if it is not.""" + # This file is imported by c extension and some methods like ClearField + # requires string for the field name. py2/py3 has different text + # type and may use unicode. + if not isinstance(value, str): + return value.encode('utf-8') + return value + + +def _MergeMessage( + node, source, destination, replace_message, replace_repeated): + """Merge all fields specified by a sub-tree from source to destination.""" + source_descriptor = source.DESCRIPTOR + for name in node: + child = node[name] + field = source_descriptor.fields_by_name[name] + if field is None: + raise ValueError('Error: Can\'t find field {0} in message {1}.'.format( + name, source_descriptor.full_name)) + if child: + # Sub-paths are only allowed for singular message fields. + if (field.label == FieldDescriptor.LABEL_REPEATED or + field.cpp_type != FieldDescriptor.CPPTYPE_MESSAGE): + raise ValueError('Error: Field {0} in message {1} is not a singular ' + 'message field and cannot have sub-fields.'.format( + name, source_descriptor.full_name)) + _MergeMessage( + child, getattr(source, name), getattr(destination, name), + replace_message, replace_repeated) + continue + if field.label == FieldDescriptor.LABEL_REPEATED: + if replace_repeated: + destination.ClearField(_StrConvert(name)) + repeated_source = getattr(source, name) + repeated_destination = getattr(destination, name) + if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + for item in repeated_source: + repeated_destination.add().MergeFrom(item) + else: + repeated_destination.extend(repeated_source) + else: + if field.cpp_type == FieldDescriptor.CPPTYPE_MESSAGE: + if replace_message: + destination.ClearField(_StrConvert(name)) + if source.HasField(name): + getattr(destination, name).MergeFrom(getattr(source, name)) + else: + setattr(destination, name, getattr(source, name)) + + +def _AddFieldPaths(node, prefix, field_mask): + """Adds the field paths descended from node to field_mask.""" + if not node: + field_mask.paths.append(prefix) + return + for name in sorted(node): + if prefix: + child_path = prefix + '.' + name + else: + child_path = name + _AddFieldPaths(node[name], child_path, field_mask) + + +WKTBASES = { + 'google.protobuf.Any': Any, + 'google.protobuf.Duration': Duration, + 'google.protobuf.FieldMask': FieldMask, + 'google.protobuf.Timestamp': Timestamp, +} diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py new file mode 100644 index 00000000..60b0c47d --- /dev/null +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -0,0 +1,509 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.internal.well_known_types.""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +from datetime import datetime + +from google.protobuf import duration_pb2 +from google.protobuf import field_mask_pb2 +from google.protobuf import timestamp_pb2 +import unittest +from google.protobuf import unittest_pb2 +from google.protobuf.internal import test_util +from google.protobuf.internal import well_known_types +from google.protobuf import descriptor + + +class TimeUtilTestBase(unittest.TestCase): + + def CheckTimestampConversion(self, message, text): + self.assertEqual(text, message.ToJsonString()) + parsed_message = timestamp_pb2.Timestamp() + parsed_message.FromJsonString(text) + self.assertEqual(message, parsed_message) + + def CheckDurationConversion(self, message, text): + self.assertEqual(text, message.ToJsonString()) + parsed_message = duration_pb2.Duration() + parsed_message.FromJsonString(text) + self.assertEqual(message, parsed_message) + + +class TimeUtilTest(TimeUtilTestBase): + + def testTimestampSerializeAndParse(self): + message = timestamp_pb2.Timestamp() + # Generated output should contain 3, 6, or 9 fractional digits. + message.seconds = 0 + message.nanos = 0 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00Z') + message.nanos = 10000000 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00.010Z') + message.nanos = 10000 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000010Z') + message.nanos = 10 + self.CheckTimestampConversion(message, '1970-01-01T00:00:00.000000010Z') + # Test min timestamps. + message.seconds = -62135596800 + message.nanos = 0 + self.CheckTimestampConversion(message, '0001-01-01T00:00:00Z') + # Test max timestamps. + message.seconds = 253402300799 + message.nanos = 999999999 + self.CheckTimestampConversion(message, '9999-12-31T23:59:59.999999999Z') + # Test negative timestamps. + message.seconds = -1 + self.CheckTimestampConversion(message, '1969-12-31T23:59:59.999999999Z') + + # Parsing accepts an fractional digits as long as they fit into nano + # precision. + message.FromJsonString('1970-01-01T00:00:00.1Z') + self.assertEqual(0, message.seconds) + self.assertEqual(100000000, message.nanos) + # Parsing accpets offsets. + message.FromJsonString('1970-01-01T00:00:00-08:00') + self.assertEqual(8 * 3600, message.seconds) + self.assertEqual(0, message.nanos) + + def testDurationSerializeAndParse(self): + message = duration_pb2.Duration() + # Generated output should contain 3, 6, or 9 fractional digits. + message.seconds = 0 + message.nanos = 0 + self.CheckDurationConversion(message, '0s') + message.nanos = 10000000 + self.CheckDurationConversion(message, '0.010s') + message.nanos = 10000 + self.CheckDurationConversion(message, '0.000010s') + message.nanos = 10 + self.CheckDurationConversion(message, '0.000000010s') + + # Test min and max + message.seconds = 315576000000 + message.nanos = 999999999 + self.CheckDurationConversion(message, '315576000000.999999999s') + message.seconds = -315576000000 + message.nanos = -999999999 + self.CheckDurationConversion(message, '-315576000000.999999999s') + + # Parsing accepts an fractional digits as long as they fit into nano + # precision. + message.FromJsonString('0.1s') + self.assertEqual(100000000, message.nanos) + message.FromJsonString('0.0000001s') + self.assertEqual(100, message.nanos) + + def testTimestampIntegerConversion(self): + message = timestamp_pb2.Timestamp() + message.FromNanoseconds(1) + self.assertEqual('1970-01-01T00:00:00.000000001Z', + message.ToJsonString()) + self.assertEqual(1, message.ToNanoseconds()) + + message.FromNanoseconds(-1) + self.assertEqual('1969-12-31T23:59:59.999999999Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToNanoseconds()) + + message.FromMicroseconds(1) + self.assertEqual('1970-01-01T00:00:00.000001Z', + message.ToJsonString()) + self.assertEqual(1, message.ToMicroseconds()) + + message.FromMicroseconds(-1) + self.assertEqual('1969-12-31T23:59:59.999999Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToMicroseconds()) + + message.FromMilliseconds(1) + self.assertEqual('1970-01-01T00:00:00.001Z', + message.ToJsonString()) + self.assertEqual(1, message.ToMilliseconds()) + + message.FromMilliseconds(-1) + self.assertEqual('1969-12-31T23:59:59.999Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToMilliseconds()) + + message.FromSeconds(1) + self.assertEqual('1970-01-01T00:00:01Z', + message.ToJsonString()) + self.assertEqual(1, message.ToSeconds()) + + message.FromSeconds(-1) + self.assertEqual('1969-12-31T23:59:59Z', + message.ToJsonString()) + self.assertEqual(-1, message.ToSeconds()) + + message.FromNanoseconds(1999) + self.assertEqual(1, message.ToMicroseconds()) + # For negative values, Timestamp will be rounded down. + # For example, "1969-12-31T23:59:59.5Z" (i.e., -0.5s) rounded to seconds + # will be "1969-12-31T23:59:59Z" (i.e., -1s) rather than + # "1970-01-01T00:00:00Z" (i.e., 0s). + message.FromNanoseconds(-1999) + self.assertEqual(-2, message.ToMicroseconds()) + + def testDurationIntegerConversion(self): + message = duration_pb2.Duration() + message.FromNanoseconds(1) + self.assertEqual('0.000000001s', + message.ToJsonString()) + self.assertEqual(1, message.ToNanoseconds()) + + message.FromNanoseconds(-1) + self.assertEqual('-0.000000001s', + message.ToJsonString()) + self.assertEqual(-1, message.ToNanoseconds()) + + message.FromMicroseconds(1) + self.assertEqual('0.000001s', + message.ToJsonString()) + self.assertEqual(1, message.ToMicroseconds()) + + message.FromMicroseconds(-1) + self.assertEqual('-0.000001s', + message.ToJsonString()) + self.assertEqual(-1, message.ToMicroseconds()) + + message.FromMilliseconds(1) + self.assertEqual('0.001s', + message.ToJsonString()) + self.assertEqual(1, message.ToMilliseconds()) + + message.FromMilliseconds(-1) + self.assertEqual('-0.001s', + message.ToJsonString()) + self.assertEqual(-1, message.ToMilliseconds()) + + message.FromSeconds(1) + self.assertEqual('1s', message.ToJsonString()) + self.assertEqual(1, message.ToSeconds()) + + message.FromSeconds(-1) + self.assertEqual('-1s', + message.ToJsonString()) + self.assertEqual(-1, message.ToSeconds()) + + # Test truncation behavior. + message.FromNanoseconds(1999) + self.assertEqual(1, message.ToMicroseconds()) + + # For negative values, Duration will be rounded towards 0. + message.FromNanoseconds(-1999) + self.assertEqual(-1, message.ToMicroseconds()) + + def testDatetimeConverison(self): + message = timestamp_pb2.Timestamp() + dt = datetime(1970, 1, 1) + message.FromDatetime(dt) + self.assertEqual(dt, message.ToDatetime()) + + message.FromMilliseconds(1999) + self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000), + message.ToDatetime()) + + def testTimedeltaConversion(self): + message = duration_pb2.Duration() + message.FromNanoseconds(1999999999) + td = message.ToTimedelta() + self.assertEqual(1, td.seconds) + self.assertEqual(999999, td.microseconds) + + message.FromNanoseconds(-1999999999) + td = message.ToTimedelta() + self.assertEqual(-1, td.days) + self.assertEqual(86398, td.seconds) + self.assertEqual(1, td.microseconds) + + message.FromMicroseconds(-1) + td = message.ToTimedelta() + self.assertEqual(-1, td.days) + self.assertEqual(86399, td.seconds) + self.assertEqual(999999, td.microseconds) + converted_message = duration_pb2.Duration() + converted_message.FromTimedelta(td) + self.assertEqual(message, converted_message) + + def testInvalidTimestamp(self): + message = timestamp_pb2.Timestamp() + self.assertRaisesRegexp( + ValueError, + 'time data \'10000-01-01T00:00:00\' does not match' + ' format \'%Y-%m-%dT%H:%M:%S\'', + message.FromJsonString, '10000-01-01T00:00:00.00Z') + self.assertRaisesRegexp( + well_known_types.ParseError, + 'nanos 0123456789012 more than 9 fractional digits.', + message.FromJsonString, + '1970-01-01T00:00:00.0123456789012Z') + self.assertRaisesRegexp( + well_known_types.ParseError, + (r'Invalid timezone offset value: \+08.'), + message.FromJsonString, + '1972-01-01T01:00:00.01+08',) + self.assertRaisesRegexp( + ValueError, + 'year is out of range', + message.FromJsonString, + '0000-01-01T00:00:00Z') + message.seconds = 253402300800 + self.assertRaisesRegexp( + OverflowError, + 'date value out of range', + message.ToJsonString) + + def testInvalidDuration(self): + message = duration_pb2.Duration() + self.assertRaisesRegexp( + well_known_types.ParseError, + 'Duration must end with letter "s": 1.', + message.FromJsonString, '1') + self.assertRaisesRegexp( + well_known_types.ParseError, + 'Couldn\'t parse duration: 1...2s.', + message.FromJsonString, '1...2s') + + +class FieldMaskTest(unittest.TestCase): + + def testStringFormat(self): + mask = field_mask_pb2.FieldMask() + self.assertEqual('', mask.ToJsonString()) + mask.paths.append('foo') + self.assertEqual('foo', mask.ToJsonString()) + mask.paths.append('bar') + self.assertEqual('foo,bar', mask.ToJsonString()) + + mask.FromJsonString('') + self.assertEqual('', mask.ToJsonString()) + mask.FromJsonString('foo') + self.assertEqual(['foo'], mask.paths) + mask.FromJsonString('foo,bar') + self.assertEqual(['foo', 'bar'], mask.paths) + + def testDescriptorToFieldMask(self): + mask = field_mask_pb2.FieldMask() + msg_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + mask.AllFieldsFromDescriptor(msg_descriptor) + self.assertEqual(75, len(mask.paths)) + self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + for field in msg_descriptor.fields: + self.assertTrue(field.name in mask.paths) + mask.paths.append('optional_nested_message.bb') + self.assertTrue(mask.IsValidForDescriptor(msg_descriptor)) + mask.paths.append('repeated_nested_message.bb') + self.assertFalse(mask.IsValidForDescriptor(msg_descriptor)) + + def testCanonicalFrom(self): + mask = field_mask_pb2.FieldMask() + out_mask = field_mask_pb2.FieldMask() + # Paths will be sorted. + mask.FromJsonString('baz.quz,bar,foo') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('bar,baz.quz,foo', out_mask.ToJsonString()) + # Duplicated paths will be removed. + mask.FromJsonString('foo,bar,foo') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('bar,foo', out_mask.ToJsonString()) + # Sub-paths of other paths will be removed. + mask.FromJsonString('foo.b1,bar.b1,foo.b2,bar') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('bar,foo.b1,foo.b2', out_mask.ToJsonString()) + + # Test more deeply nested cases. + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2.quz,foo.bar.baz2') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo.bar.baz1,foo.bar.baz2', + out_mask.ToJsonString()) + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo.bar.baz1,foo.bar.baz2', + out_mask.ToJsonString()) + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo.bar') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo.bar', out_mask.ToJsonString()) + mask.FromJsonString( + 'foo.bar.baz1,foo.bar.baz2,foo.bar.baz2.quz,foo') + out_mask.CanonicalFormFromMask(mask) + self.assertEqual('foo', out_mask.ToJsonString()) + + def testUnion(self): + mask1 = field_mask_pb2.FieldMask() + mask2 = field_mask_pb2.FieldMask() + out_mask = field_mask_pb2.FieldMask() + mask1.FromJsonString('foo,baz') + mask2.FromJsonString('bar,quz') + out_mask.Union(mask1, mask2) + self.assertEqual('bar,baz,foo,quz', out_mask.ToJsonString()) + # Overlap with duplicated paths. + mask1.FromJsonString('foo,baz.bb') + mask2.FromJsonString('baz.bb,quz') + out_mask.Union(mask1, mask2) + self.assertEqual('baz.bb,foo,quz', out_mask.ToJsonString()) + # Overlap with paths covering some other paths. + mask1.FromJsonString('foo.bar.baz,quz') + mask2.FromJsonString('foo.bar,bar') + out_mask.Union(mask1, mask2) + self.assertEqual('bar,foo.bar,quz', out_mask.ToJsonString()) + + def testIntersect(self): + mask1 = field_mask_pb2.FieldMask() + mask2 = field_mask_pb2.FieldMask() + out_mask = field_mask_pb2.FieldMask() + # Test cases without overlapping. + mask1.FromJsonString('foo,baz') + mask2.FromJsonString('bar,quz') + out_mask.Intersect(mask1, mask2) + self.assertEqual('', out_mask.ToJsonString()) + # Overlap with duplicated paths. + mask1.FromJsonString('foo,baz.bb') + mask2.FromJsonString('baz.bb,quz') + out_mask.Intersect(mask1, mask2) + self.assertEqual('baz.bb', out_mask.ToJsonString()) + # Overlap with paths covering some other paths. + mask1.FromJsonString('foo.bar.baz,quz') + mask2.FromJsonString('foo.bar,bar') + out_mask.Intersect(mask1, mask2) + self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) + mask1.FromJsonString('foo.bar,bar') + mask2.FromJsonString('foo.bar.baz,quz') + out_mask.Intersect(mask1, mask2) + self.assertEqual('foo.bar.baz', out_mask.ToJsonString()) + + def testMergeMessage(self): + # Test merge one field. + src = unittest_pb2.TestAllTypes() + test_util.SetAllFields(src) + for field in src.DESCRIPTOR.fields: + if field.containing_oneof: + continue + field_name = field.name + dst = unittest_pb2.TestAllTypes() + # Only set one path to mask. + mask = field_mask_pb2.FieldMask() + mask.paths.append(field_name) + mask.MergeMessage(src, dst) + # The expected result message. + msg = unittest_pb2.TestAllTypes() + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + repeated_src = getattr(src, field_name) + repeated_msg = getattr(msg, field_name) + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + for item in repeated_src: + repeated_msg.add().CopyFrom(item) + else: + repeated_msg.extend(repeated_src) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + getattr(msg, field_name).CopyFrom(getattr(src, field_name)) + else: + setattr(msg, field_name, getattr(src, field_name)) + # Only field specified in mask is merged. + self.assertEqual(msg, dst) + + # Test merge nested fields. + nested_src = unittest_pb2.NestedTestAllTypes() + nested_dst = unittest_pb2.NestedTestAllTypes() + nested_src.child.payload.optional_int32 = 1234 + nested_src.child.child.payload.optional_int32 = 5678 + mask = field_mask_pb2.FieldMask() + mask.FromJsonString('child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(0, nested_dst.child.child.payload.optional_int32) + + mask.FromJsonString('child.child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) + + nested_dst.Clear() + mask.FromJsonString('child.child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(0, nested_dst.child.payload.optional_int32) + self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) + + nested_dst.Clear() + mask.FromJsonString('child') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(5678, nested_dst.child.child.payload.optional_int32) + + # Test MergeOptions. + nested_dst.Clear() + nested_dst.child.payload.optional_int64 = 4321 + # Message fields will be merged by default. + mask.FromJsonString('child.payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(4321, nested_dst.child.payload.optional_int64) + # Change the behavior to replace message fields. + mask.FromJsonString('child.payload') + mask.MergeMessage(nested_src, nested_dst, True, False) + self.assertEqual(1234, nested_dst.child.payload.optional_int32) + self.assertEqual(0, nested_dst.child.payload.optional_int64) + + # By default, fields missing in source are not cleared in destination. + nested_dst.payload.optional_int32 = 1234 + self.assertTrue(nested_dst.HasField('payload')) + mask.FromJsonString('payload') + mask.MergeMessage(nested_src, nested_dst) + self.assertTrue(nested_dst.HasField('payload')) + # But they are cleared when replacing message fields. + nested_dst.Clear() + nested_dst.payload.optional_int32 = 1234 + mask.FromJsonString('payload') + mask.MergeMessage(nested_src, nested_dst, True, False) + self.assertFalse(nested_dst.HasField('payload')) + + nested_src.payload.repeated_int32.append(1234) + nested_dst.payload.repeated_int32.append(5678) + # Repeated fields will be appended by default. + mask.FromJsonString('payload.repeated_int32') + mask.MergeMessage(nested_src, nested_dst) + self.assertEqual(2, len(nested_dst.payload.repeated_int32)) + self.assertEqual(5678, nested_dst.payload.repeated_int32[0]) + self.assertEqual(1234, nested_dst.payload.repeated_int32[1]) + # Change the behavior to replace repeated fields. + mask.FromJsonString('payload.repeated_int32') + mask.MergeMessage(nested_src, nested_dst, False, True) + self.assertEqual(1, len(nested_dst.payload.repeated_int32)) + self.assertEqual(1234, nested_dst.payload.repeated_int32[0]) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index d95557d7..cb76e116 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -28,22 +28,29 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -"""Contains routines for printing protocol messages in JSON format.""" +"""Contains routines for printing protocol messages in JSON format. + +Simple usage example: + + # Create a proto object and serialize it to a json format string. + message = my_proto_pb2.MyMessage(foo='bar') + json_string = json_format.MessageToJson(message) + + # Parse a json format string to proto object. + message = json_format.Parse(json_string, my_proto_pb2.MyMessage()) +""" __author__ = 'jieluo@google.com (Jie Luo)' import base64 -from datetime import datetime import json import math -import re +from six import text_type import sys from google.protobuf import descriptor _TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' -_NUMBER = re.compile(u'[0-9+-][0-9e.+-]*') -_INTEGER = re.compile(u'[0-9+-]') _INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32, descriptor.FieldDescriptor.CPPTYPE_UINT32, descriptor.FieldDescriptor.CPPTYPE_INT64, @@ -52,17 +59,20 @@ _INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64, descriptor.FieldDescriptor.CPPTYPE_UINT64]) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) -if str is bytes: - _UNICODETYPE = unicode -else: - _UNICODETYPE = str +_INFINITY = 'Infinity' +_NEG_INFINITY = '-Infinity' +_NAN = 'NaN' + +class Error(Exception): + """Top-level module error for json_format.""" -class SerializeToJsonError(Exception): + +class SerializeToJsonError(Error): """Thrown if serialization to JSON fails.""" -class ParseError(Exception): +class ParseError(Error): """Thrown in case of parsing error.""" @@ -86,12 +96,8 @@ def MessageToJson(message, including_default_value_fields=False): def _MessageToJsonObject(message, including_default_value_fields): """Converts message to an object according to Proto3 JSON Specification.""" message_descriptor = message.DESCRIPTOR - if _IsTimestampMessage(message_descriptor): - return _TimestampMessageToJsonObject(message) - if _IsDurationMessage(message_descriptor): - return _DurationMessageToJsonObject(message) - if _IsFieldMaskMessage(message_descriptor): - return _FieldMaskMessageToJsonObject(message) + if hasattr(message, 'ToJsonString'): + return message.ToJsonString() if _IsWrapperMessage(message_descriptor): return _WrapperMessageToJsonObject(message) return _RegularMessageToJsonObject(message, including_default_value_fields) @@ -107,12 +113,14 @@ def _RegularMessageToJsonObject(message, including_default_value_fields): """Converts normal message according to Proto3 JSON Specification.""" js = {} fields = message.ListFields() + include_default = including_default_value_fields try: for field, value in fields: name = field.camelcase_name if _IsMapEntry(field): # Convert a map field. + v_field = field.message_type.fields_by_name['value'] js_map = {} for key in value: if isinstance(key, bool): @@ -122,20 +130,15 @@ def _RegularMessageToJsonObject(message, including_default_value_fields): recorded_key = 'false' else: recorded_key = key - js_map[recorded_key] = _ConvertFieldToJsonObject( - field.message_type.fields_by_name['value'], - value[key], including_default_value_fields) + js_map[recorded_key] = _FieldToJsonObject( + v_field, value[key], including_default_value_fields) js[name] = js_map elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: # Convert a repeated field. - repeated = [] - for element in value: - repeated.append(_ConvertFieldToJsonObject( - field, element, including_default_value_fields)) - js[name] = repeated + js[name] = [_FieldToJsonObject(field, k, include_default) + for k in value] else: - js[name] = _ConvertFieldToJsonObject( - field, value, including_default_value_fields) + js[name] = _FieldToJsonObject(field, value, include_default) # Serialize default value if including_default_value_fields is True. if including_default_value_fields: @@ -155,16 +158,16 @@ def _RegularMessageToJsonObject(message, including_default_value_fields): elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: js[name] = [] else: - js[name] = _ConvertFieldToJsonObject(field, field.default_value) + js[name] = _FieldToJsonObject(field, field.default_value) except ValueError as e: raise SerializeToJsonError( - 'Failed to serialize {0} field: {1}'.format(field.name, e)) + 'Failed to serialize {0} field: {1}.'.format(field.name, e)) return js -def _ConvertFieldToJsonObject( +def _FieldToJsonObject( field, value, including_default_value_fields=False): """Converts field value according to Proto3 JSON Specification.""" if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: @@ -183,101 +186,26 @@ def _ConvertFieldToJsonObject( else: return value elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: - if value: - return True - else: - return False + return bool(value) elif field.cpp_type in _INT64_TYPES: return str(value) elif field.cpp_type in _FLOAT_TYPES: if math.isinf(value): if value < 0.0: - return '-Infinity' + return _NEG_INFINITY else: - return 'Infinity' + return _INFINITY if math.isnan(value): - return 'NaN' + return _NAN return value -def _IsTimestampMessage(message_descriptor): - return (message_descriptor.name == 'Timestamp' and - message_descriptor.file.name == 'google/protobuf/timestamp.proto') - - -def _TimestampMessageToJsonObject(message): - """Converts Timestamp message according to Proto3 JSON Specification.""" - nanos = message.nanos % 1e9 - dt = datetime.utcfromtimestamp( - message.seconds + (message.nanos - nanos) / 1e9) - result = dt.isoformat() - if (nanos % 1e9) == 0: - # If there are 0 fractional digits, the fractional - # point '.' should be omitted when serializing. - return result + 'Z' - if (nanos % 1e6) == 0: - # Serialize 3 fractional digits. - return result + '.%03dZ' % (nanos / 1e6) - if (nanos % 1e3) == 0: - # Serialize 6 fractional digits. - return result + '.%06dZ' % (nanos / 1e3) - # Serialize 9 fractional digits. - return result + '.%09dZ' % nanos - - -def _IsDurationMessage(message_descriptor): - return (message_descriptor.name == 'Duration' and - message_descriptor.file.name == 'google/protobuf/duration.proto') - - -def _DurationMessageToJsonObject(message): - """Converts Duration message according to Proto3 JSON Specification.""" - if message.seconds < 0 or message.nanos < 0: - result = '-' - seconds = - message.seconds + int((0 - message.nanos) / 1e9) - nanos = (0 - message.nanos) % 1e9 - else: - result = '' - seconds = message.seconds + int(message.nanos / 1e9) - nanos = message.nanos % 1e9 - result += '%d' % seconds - if (nanos % 1e9) == 0: - # If there are 0 fractional digits, the fractional - # point '.' should be omitted when serializing. - return result + 's' - if (nanos % 1e6) == 0: - # Serialize 3 fractional digits. - return result + '.%03ds' % (nanos / 1e6) - if (nanos % 1e3) == 0: - # Serialize 6 fractional digits. - return result + '.%06ds' % (nanos / 1e3) - # Serialize 9 fractional digits. - return result + '.%09ds' % nanos - - -def _IsFieldMaskMessage(message_descriptor): - return (message_descriptor.name == 'FieldMask' and - message_descriptor.file.name == 'google/protobuf/field_mask.proto') - - -def _FieldMaskMessageToJsonObject(message): - """Converts FieldMask message according to Proto3 JSON Specification.""" - result = '' - first = True - for path in message.paths: - if not first: - result += ',' - result += path - first = False - return result - - def _IsWrapperMessage(message_descriptor): return message_descriptor.file.name == 'google/protobuf/wrappers.proto' def _WrapperMessageToJsonObject(message): - return _ConvertFieldToJsonObject( + return _FieldToJsonObject( message.DESCRIPTOR.fields_by_name['value'], message.value) @@ -285,7 +213,7 @@ def _DuplicateChecker(js): result = {} for name, value in js: if name in result: - raise ParseError('Failed to load JSON: duplicate key ' + name) + raise ParseError('Failed to load JSON: duplicate key {0}.'.format(name)) result[name] = value return result @@ -303,7 +231,7 @@ def Parse(text, message): Raises:: ParseError: On JSON parsing problems. """ - if not isinstance(text, _UNICODETYPE): text = text.decode('utf-8') + if not isinstance(text, text_type): text = text.decode('utf-8') try: if sys.version_info < (2, 7): # object_pair_hook is not supported before python2.7 @@ -311,7 +239,7 @@ def Parse(text, message): else: js = json.loads(text, object_pairs_hook=_DuplicateChecker) except ValueError as e: - raise ParseError('Failed to load JSON: ' + str(e)) + raise ParseError('Failed to load JSON: {0}.'.format(str(e))) _ConvertFieldValuePair(js, message) return message @@ -362,7 +290,7 @@ def _ConvertFieldValuePair(js, message): message.ClearField(field.name) if not isinstance(value, list): raise ParseError('repeated field {0} must be in [] which is ' - '{1}'.format(name, value)) + '{1}.'.format(name, value)) for item in value: if item is None: continue @@ -383,9 +311,9 @@ def _ConvertFieldValuePair(js, message): else: raise ParseError(str(e)) except ValueError as e: - raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) except TypeError as e: - raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) def _ConvertMessage(value, message): @@ -399,88 +327,13 @@ def _ConvertMessage(value, message): ParseError: In case of convert problems. """ message_descriptor = message.DESCRIPTOR - if _IsTimestampMessage(message_descriptor): - _ConvertTimestampMessage(value, message) - elif _IsDurationMessage(message_descriptor): - _ConvertDurationMessage(value, message) - elif _IsFieldMaskMessage(message_descriptor): - _ConvertFieldMaskMessage(value, message) + if hasattr(message, 'FromJsonString'): + message.FromJsonString(value) elif _IsWrapperMessage(message_descriptor): _ConvertWrapperMessage(value, message) else: _ConvertFieldValuePair(value, message) - -def _ConvertTimestampMessage(value, message): - """Convert a JSON representation into Timestamp message.""" - timezone_offset = value.find('Z') - if timezone_offset == -1: - timezone_offset = value.find('+') - if timezone_offset == -1: - timezone_offset = value.rfind('-') - if timezone_offset == -1: - raise ParseError( - 'Failed to parse timestamp: missing valid timezone offset.') - time_value = value[0:timezone_offset] - # Parse datetime and nanos - point_position = time_value.find('.') - if point_position == -1: - second_value = time_value - nano_value = '' - else: - second_value = time_value[:point_position] - nano_value = time_value[point_position + 1:] - date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) - td = date_object - datetime(1970, 1, 1) - seconds = td.seconds + td.days * 24 * 3600 - if len(nano_value) > 9: - raise ParseError( - 'Failed to parse Timestamp: nanos {0} more than ' - '9 fractional digits.'.format(nano_value)) - if nano_value: - nanos = round(float('0.' + nano_value) * 1e9) - else: - nanos = 0 - # Parse timezone offsets - if value[timezone_offset] == 'Z': - if len(value) != timezone_offset + 1: - raise ParseError( - 'Failed to parse timestamp: invalid trailing data {0}.'.format(value)) - else: - timezone = value[timezone_offset:] - pos = timezone.find(':') - if pos == -1: - raise ParseError( - 'Invalid timezone offset value: ' + timezone) - if timezone[0] == '+': - seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 - else: - seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 - # Set seconds and nanos - message.seconds = int(seconds) - message.nanos = int(nanos) - - -def _ConvertDurationMessage(value, message): - """Convert a JSON representation into Duration message.""" - if value[-1] != 's': - raise ParseError( - 'Duration must end with letter "s": ' + value) - try: - duration = float(value[:-1]) - except ValueError: - raise ParseError( - 'Couldn\'t parse duration: ' + value) - message.seconds = int(duration) - message.nanos = int(round((duration - message.seconds) * 1e9)) - - -def _ConvertFieldMaskMessage(value, message): - """Convert a JSON representation into FieldMask message.""" - for path in value.split(','): - message.paths.append(path) - - def _ConvertWrapperMessage(value, message): """Convert a JSON representation into Wrapper message.""" field = message.DESCRIPTOR.fields_by_name['value'] @@ -512,13 +365,13 @@ def _ConvertMapFieldValue(value, message, field): value[key], value_field) -def _ConvertScalarFieldValue(value, field, require_quote=False): +def _ConvertScalarFieldValue(value, field, require_str=False): """Convert a single scalar field value. Args: value: A scalar value to convert the scalar field value. field: The descriptor of the field to convert. - require_quote: If True, '"' is required for the field value. + require_str: If True, the field value must be a str. Returns: The converted scalar field value @@ -531,7 +384,7 @@ def _ConvertScalarFieldValue(value, field, require_quote=False): elif field.cpp_type in _FLOAT_TYPES: return _ConvertFloat(value) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: - return _ConvertBool(value, require_quote) + return _ConvertBool(value, require_str) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: if field.type == descriptor.FieldDescriptor.TYPE_BYTES: return base64.b64decode(value) @@ -561,10 +414,10 @@ def _ConvertInteger(value): ParseError: If an integer couldn't be consumed. """ if isinstance(value, float): - raise ParseError('Couldn\'t parse integer: {0}'.format(value)) + raise ParseError('Couldn\'t parse integer: {0}.'.format(value)) - if isinstance(value, _UNICODETYPE) and not _INTEGER.match(value): - raise ParseError('Couldn\'t parse integer: "{0}"'.format(value)) + if isinstance(value, text_type) and value.find(' ') != -1: + raise ParseError('Couldn\'t parse integer: "{0}".'.format(value)) return int(value) @@ -572,28 +425,28 @@ def _ConvertInteger(value): def _ConvertFloat(value): """Convert an floating point number.""" if value == 'nan': - raise ParseError('Couldn\'t parse float "nan", use "NaN" instead') + raise ParseError('Couldn\'t parse float "nan", use "NaN" instead.') try: # Assume Python compatible syntax. return float(value) except ValueError: # Check alternative spellings. - if value == '-Infinity': + if value == _NEG_INFINITY: return float('-inf') - elif value == 'Infinity': + elif value == _INFINITY: return float('inf') - elif value == 'NaN': + elif value == _NAN: return float('nan') else: - raise ParseError('Couldn\'t parse float: {0}'.format(value)) + raise ParseError('Couldn\'t parse float: {0}.'.format(value)) -def _ConvertBool(value, require_quote): +def _ConvertBool(value, require_str): """Convert a boolean value. Args: value: A scalar value to convert. - require_quote: If True, '"' is required for the boolean value. + require_str: If True, value must be a str. Returns: The bool parsed. @@ -601,13 +454,13 @@ def _ConvertBool(value, require_quote): Raises: ParseError: If a boolean value couldn't be consumed. """ - if require_quote: + if require_str: if value == 'true': return True elif value == 'false': return False else: - raise ParseError('Expect "true" or "false", not {0}.'.format(value)) + raise ParseError('Expected "true" or "false", not {0}.'.format(value)) if not isinstance(value, bool): raise ParseError('Expected true or false without quotes.') diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 9cd9c2a8..1b059d13 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -39,7 +39,6 @@ my_proto_instance = message_classes['some.proto.package.MessageName']() __author__ = 'matthewtoia@google.com (Matt Toia)' -from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message from google.protobuf import reflection @@ -50,8 +49,7 @@ class MessageFactory(object): def __init__(self, pool=None): """Initializes a new factory.""" - self.pool = (pool or descriptor_pool.DescriptorPool( - descriptor_database.DescriptorDatabase())) + self.pool = pool or descriptor_pool.DescriptorPool() # local cache of all classes built from protobuf descriptors self._classes = {} diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 61a3d237..a875a7be 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -203,6 +203,14 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { PyObject* message_class(cdescriptor_pool::GetMessageClass( pool, message_type)); if (message_class == NULL) { + // The Options message was not found in the current DescriptorPool. + // In this case, there cannot be extensions to these options, and we can + // try to use the basic pool instead. + PyErr_Clear(); + message_class = cdescriptor_pool::GetMessageClass( + GetDefaultDescriptorPool(), message_type); + } + if (message_class == NULL) { PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s", message_type->full_name().c_str()); return NULL; @@ -211,6 +219,12 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { if (value == NULL) { return NULL; } + if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) { + PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s", + message_type->full_name().c_str(), + Py_TYPE(value.get())->tp_name); + return NULL; + } CMessage* cmsg = reinterpret_cast<CMessage*>(value.get()); const Reflection* reflection = options.GetReflection(); @@ -327,7 +341,8 @@ PyObject* NewInternedDescriptor(PyTypeObject* type, PyDescriptorPool* pool = GetDescriptorPool_FromPool( GetFileDescriptor(descriptor)->pool()); if (pool == NULL) { - Py_DECREF(py_descriptor); + // Don't DECREF, the object is not fully initialized. + PyObject_Del(py_descriptor); return NULL; } Py_INCREF(pool); @@ -1213,6 +1228,13 @@ static void Dealloc(PyFileDescriptor* self) { descriptor::Dealloc(&self->base); } +static PyObject* GetPool(PyFileDescriptor *self, void *closure) { + PyObject* pool = reinterpret_cast<PyObject*>( + GetDescriptorPool_FromPool(_GetDescriptor(self)->pool())); + Py_XINCREF(pool); + return pool; +} + static PyObject* GetName(PyFileDescriptor *self, void *closure) { return PyString_FromCppString(_GetDescriptor(self)->name()); } @@ -1292,6 +1314,7 @@ static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) { } static PyGetSetDef Getters[] = { + { "pool", (getter)GetPool, NULL, "pool"}, { "name", (getter)GetName, NULL, "name"}, { "package", (getter)GetPackage, NULL, "package"}, { "serialized_pb", (getter)GetSerializedPb}, @@ -1354,8 +1377,8 @@ PyTypeObject PyFileDescriptor_Type = { 0, // tp_descr_set 0, // tp_dictoffset 0, // tp_init - PyType_GenericAlloc, // tp_alloc - PyType_GenericNew, // tp_new + 0, // tp_alloc + 0, // tp_new PyObject_Del, // tp_free }; diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc new file mode 100644 index 00000000..514722b4 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_database.cc @@ -0,0 +1,145 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file defines a C++ DescriptorDatabase, which wraps a Python Database +// and delegate all its operations to Python methods. + +#include <google/protobuf/pyext/descriptor_database.h> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/descriptor.pb.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +namespace google { +namespace protobuf { +namespace python { + +PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_database) + : py_database_(py_database) { + Py_INCREF(py_database_); +} + +PyDescriptorDatabase::~PyDescriptorDatabase() { Py_DECREF(py_database_); } + +// Convert a Python object to a FileDescriptorProto pointer. +// Handles all kinds of Python errors, which are simply logged. +static bool GetFileDescriptorProto(PyObject* py_descriptor, + FileDescriptorProto* output) { + if (py_descriptor == NULL) { + if (PyErr_ExceptionMatches(PyExc_KeyError)) { + // Expected error: item was simply not found. + PyErr_Clear(); + } else { + GOOGLE_LOG(ERROR) << "DescriptorDatabase method raised an error"; + PyErr_Print(); + } + return false; + } + const Descriptor* filedescriptor_descriptor = + FileDescriptorProto::default_instance().GetDescriptor(); + CMessage* message = reinterpret_cast<CMessage*>(py_descriptor); + if (PyObject_TypeCheck(py_descriptor, &CMessage_Type) && + message->message->GetDescriptor() == filedescriptor_descriptor) { + // Fast path: Just use the pointer. + FileDescriptorProto* file_proto = + static_cast<FileDescriptorProto*>(message->message); + *output = *file_proto; + return true; + } else { + // Slow path: serialize the message. This allows to use databases which + // use a different implementation of FileDescriptorProto. + ScopedPyObjectPtr serialized_pb( + PyObject_CallMethod(py_descriptor, "SerializeToString", NULL)); + if (serialized_pb == NULL) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + PyErr_Print(); + return false; + } + char* str; + Py_ssize_t len; + if (PyBytes_AsStringAndSize(serialized_pb.get(), &str, &len) < 0) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + PyErr_Print(); + return false; + } + FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(str, len)) { + GOOGLE_LOG(ERROR) + << "DescriptorDatabase method did not return a FileDescriptorProto"; + return false; + } + *output = file_proto; + return true; + } +} + +// Find a file by file name. +bool PyDescriptorDatabase::FindFileByName(const string& filename, + FileDescriptorProto* output) { + ScopedPyObjectPtr py_descriptor(PyObject_CallMethod( + py_database_, "FindFileByName", "s#", filename.c_str(), filename.size())); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Find the file that declares the given fully-qualified symbol name. +bool PyDescriptorDatabase::FindFileContainingSymbol( + const string& symbol_name, FileDescriptorProto* output) { + ScopedPyObjectPtr py_descriptor( + PyObject_CallMethod(py_database_, "FindFileContainingSymbol", "s#", + symbol_name.c_str(), symbol_name.size())); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +// Find the file which defines an extension extending the given message type +// with the given field number. +// Python DescriptorDatabases are not required to implement this method. +bool PyDescriptorDatabase::FindFileContainingExtension( + const string& containing_type, int field_number, + FileDescriptorProto* output) { + ScopedPyObjectPtr py_method( + PyObject_GetAttrString(py_database_, "FindFileContainingExtension")); + if (py_method == NULL) { + // This method is not implemented, returns without error. + PyErr_Clear(); + return false; + } + ScopedPyObjectPtr py_descriptor( + PyObject_CallFunction(py_method.get(), "s#i", containing_type.c_str(), + containing_type.size(), field_number)); + return GetFileDescriptorProto(py_descriptor.get(), output); +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_database.h b/python/google/protobuf/pyext/descriptor_database.h new file mode 100644 index 00000000..fc71c4bc --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_database.h @@ -0,0 +1,75 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ + +#include <Python.h> + +#include <google/protobuf/descriptor_database.h> + +namespace google { +namespace protobuf { +namespace python { + +class PyDescriptorDatabase : public DescriptorDatabase { + public: + explicit PyDescriptorDatabase(PyObject* py_database); + ~PyDescriptorDatabase(); + + // Implement the abstract interface. All these functions fill the output + // with a copy of FileDescriptorProto. + + // Find a file by file name. + bool FindFileByName(const string& filename, + FileDescriptorProto* output); + + // Find the file that declares the given fully-qualified symbol name. + bool FindFileContainingSymbol(const string& symbol_name, + FileDescriptorProto* output); + + // Find the file which defines an extension extending the given message type + // with the given field number. + // Containing_type must be a fully-qualified type name. + // Python objects are not required to implement this method. + bool FindFileContainingExtension(const string& containing_type, + int field_number, + FileDescriptorProto* output); + + private: + // The python object that implements the database. The reference is owned. + PyObject* py_database_; +}; + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index 0f7487fa..0bc76bc9 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -34,8 +34,9 @@ #include <google/protobuf/descriptor.pb.h> #include <google/protobuf/dynamic_message.h> -#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/descriptor_database.h> +#include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> @@ -60,38 +61,93 @@ static hash_map<const DescriptorPool*, PyDescriptorPool*> descriptor_pool_map; namespace cdescriptor_pool { -static PyDescriptorPool* NewDescriptorPool() { - PyDescriptorPool* cdescriptor_pool = PyObject_New( +// Create a Python DescriptorPool object, but does not fill the "pool" +// attribute. +static PyDescriptorPool* _CreateDescriptorPool() { + PyDescriptorPool* cpool = PyObject_New( PyDescriptorPool, &PyDescriptorPool_Type); - if (cdescriptor_pool == NULL) { + if (cpool == NULL) { return NULL; } - // Build a DescriptorPool for messages only declared in Python libraries. - // generated_pool() contains all messages linked in C++ libraries, and is used - // as underlay. - cdescriptor_pool->pool = new DescriptorPool(DescriptorPool::generated_pool()); + cpool->underlay = NULL; + cpool->database = NULL; DynamicMessageFactory* message_factory = new DynamicMessageFactory(); // This option might be the default some day. message_factory->SetDelegateToGeneratedFactory(true); - cdescriptor_pool->message_factory = message_factory; + cpool->message_factory = message_factory; // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same // storage. - cdescriptor_pool->classes_by_descriptor = + cpool->classes_by_descriptor = new PyDescriptorPool::ClassesByMessageMap(); - cdescriptor_pool->descriptor_options = + cpool->descriptor_options = new hash_map<const void*, PyObject *>(); + return cpool; +} + +// Create a Python DescriptorPool, using the given pool as an underlay: +// new messages will be added to a custom pool, not to the underlay. +// +// Ownership of the underlay is not transferred, its pointer should +// stay alive. +static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay( + const DescriptorPool* underlay) { + PyDescriptorPool* cpool = _CreateDescriptorPool(); + if (cpool == NULL) { + return NULL; + } + cpool->pool = new DescriptorPool(underlay); + cpool->underlay = underlay; + if (!descriptor_pool_map.insert( - std::make_pair(cdescriptor_pool->pool, cdescriptor_pool)).second) { + std::make_pair(cpool->pool, cpool)).second) { // Should never happen -- would indicate an internal error / bug. PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); return NULL; } - return cdescriptor_pool; + return cpool; +} + +static PyDescriptorPool* PyDescriptorPool_NewWithDatabase( + DescriptorDatabase* database) { + PyDescriptorPool* cpool = _CreateDescriptorPool(); + if (cpool == NULL) { + return NULL; + } + if (database != NULL) { + cpool->pool = new DescriptorPool(database); + cpool->database = database; + } else { + cpool->pool = new DescriptorPool(); + } + + if (!descriptor_pool_map.insert(std::make_pair(cpool->pool, cpool)).second) { + // Should never happen -- would indicate an internal error / bug. + PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); + return NULL; + } + + return cpool; +} + +// The public DescriptorPool constructor. +static PyObject* New(PyTypeObject* type, + PyObject* args, PyObject* kwargs) { + static char* kwlist[] = {"descriptor_db", 0}; + PyObject* py_database = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|O", kwlist, &py_database)) { + return NULL; + } + DescriptorDatabase* database = NULL; + if (py_database && py_database != Py_None) { + database = new PyDescriptorDatabase(py_database); + } + return reinterpret_cast<PyObject*>( + PyDescriptorPool_NewWithDatabase(database)); } static void Dealloc(PyDescriptorPool* self) { @@ -108,8 +164,9 @@ static void Dealloc(PyDescriptorPool* self) { Py_DECREF(it->second); } delete self->descriptor_options; - delete self->pool; delete self->message_factory; + delete self->database; + delete self->pool; Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self)); } @@ -354,6 +411,14 @@ PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) { char* message_type; Py_ssize_t message_len; + if (self->database != NULL) { + PyErr_SetString( + PyExc_ValueError, + "Cannot call Add on a DescriptorPool that uses a DescriptorDatabase. " + "Add your file to the underlying database."); + return NULL; + } + if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) { return NULL; } @@ -366,8 +431,10 @@ PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) { // If the file was already part of a C++ library, all its descriptors are in // the underlying pool. No need to do anything else. - const FileDescriptor* generated_file = - DescriptorPool::generated_pool()->FindFileByName(file_proto.name()); + const FileDescriptor* generated_file = NULL; + if (self->underlay) { + generated_file = self->underlay->FindFileByName(file_proto.name()); + } if (generated_file != NULL) { return PyFileDescriptor_FromDescriptorWithSerializedPb( generated_file, serialized_pb); @@ -470,7 +537,7 @@ PyTypeObject PyDescriptorPool_Type = { 0, // tp_dictoffset 0, // tp_init 0, // tp_alloc - 0, // tp_new + cdescriptor_pool::New, // tp_new PyObject_Del, // tp_free }; @@ -482,7 +549,11 @@ bool InitDescriptorPool() { if (PyType_Ready(&PyDescriptorPool_Type) < 0) return false; - python_generated_pool = cdescriptor_pool::NewDescriptorPool(); + // The Pool of messages declared in Python libraries. + // generated_pool() contains all messages already linked in C++ libraries, and + // is used as underlay. + python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay( + DescriptorPool::generated_pool()); if (python_generated_pool == NULL) { return false; } @@ -494,6 +565,10 @@ bool InitDescriptorPool() { return true; } +// The default DescriptorPool used everywhere in this module. +// Today it's the python_generated_pool. +// TODO(amauryfa): Remove all usages of this function: the pool should be +// derived from the context. PyDescriptorPool* GetDefaultDescriptorPool() { return python_generated_pool; } diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index eda73d38..16bc910c 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -55,8 +55,17 @@ namespace python { typedef struct PyDescriptorPool { PyObject_HEAD + // The C++ pool containing Descriptors. DescriptorPool* pool; + // The C++ pool acting as an underlay. Can be NULL. + // This pointer is not owned and must stay alive. + const DescriptorPool* underlay; + + // The C++ descriptor database used to fetch unknown protos. Can be NULL. + // This pointer is owned. + const DescriptorDatabase* database; + // DynamicMessageFactory used to create C++ instances of messages. // This object cache the descriptors that were used, so the DescriptorPool // needs to get rid of it before it can delete itself. @@ -138,6 +147,7 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg); // Retrieve the global descriptor pool owned by the _message module. // This is the one used by pb2.py generated modules. // Returns a *borrowed* reference. +// "Default" pool used to register messages from _pb2.py modules. PyDescriptorPool* GetDefaultDescriptorPool(); // Retrieve the python descriptor pool owning a C++ descriptor pool. diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index b361b342..555bd293 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -94,13 +94,13 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor == NULL) { return NULL; } - if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return NULL; } if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { - return cmessage::InternalGetScalar(self->parent->message, descriptor); + return cmessage::InternalGetScalar(self->message, descriptor); } PyObject* value = PyDict_GetItem(self->values, key); @@ -109,6 +109,14 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { return value; } + if (self->parent == NULL) { + // We are in "detached" state. Don't allow further modifications. + // TODO(amauryfa): Support adding non-scalars to a detached extension dict. + // This probably requires to store the type of the main message. + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* sub_message = cmessage::InternalGetSubMessage( @@ -154,7 +162,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { if (descriptor == NULL) { return -1; } - if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return -1; } @@ -164,9 +172,11 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { "type"); return -1; } - cmessage::AssureWritable(self->parent); - if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { - return -1; + if (self->parent) { + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; + } } // TODO(tibell): We shouldn't write scalars to the cache. PyDict_SetItem(self->values, key, value); @@ -180,15 +190,17 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { return NULL; } PyObject* value = PyDict_GetItem(self->values, extension); - if (value != NULL) { - if (ReleaseExtension(self, value, descriptor) < 0) { + if (self->parent) { + if (value != NULL) { + if (ReleaseExtension(self, value, descriptor) < 0) { + return NULL; + } + } + if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( + self->parent, descriptor)) == NULL) { return NULL; } } - if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( - self->parent, descriptor)) == NULL) { - return NULL; - } if (PyDict_DelItem(self->values, extension) < 0) { PyErr_Clear(); } @@ -201,8 +213,15 @@ PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { if (descriptor == NULL) { return NULL; } - PyObject* result = cmessage::HasFieldByDescriptor(self->parent, descriptor); - return result; + if (self->parent) { + return cmessage::HasFieldByDescriptor(self->parent, descriptor); + } else { + int exists = PyDict_Contains(self->values, extension); + if (exists < 0) { + return NULL; + } + return PyBool_FromLong(exists); + } } PyObject* _FindExtensionByName(ExtensionDict* self, PyObject* name) { diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h index 7a66cb23..d92cf956 100644 --- a/python/google/protobuf/pyext/extension_dict.h +++ b/python/google/protobuf/pyext/extension_dict.h @@ -117,11 +117,6 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value); PyObject* ClearExtension(ExtensionDict* self, PyObject* extension); -// Checks if the dict has an extension. -// -// Returns a new python boolean reference. -PyObject* HasExtension(ExtensionDict* self, PyObject* extension); - // Gets an extension from the dict given the extension name as opposed to // descriptor. // diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc new file mode 100644 index 00000000..c39f7b83 --- /dev/null +++ b/python/google/protobuf/pyext/map_container.cc @@ -0,0 +1,912 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: haberman@google.com (Josh Haberman) + +#include <google/protobuf/pyext/map_container.h> + +#include <google/protobuf/stubs/logging.h> +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/stubs/scoped_ptr.h> +#include <google/protobuf/map_field.h> +#include <google/protobuf/map.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +#if PY_MAJOR_VERSION >= 3 + #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t +#endif + +namespace google { +namespace protobuf { +namespace python { + +// Functions that need access to map reflection functionality. +// They need to be contained in this class because it is friended. +class MapReflectionFriend { + public: + // Methods that are in common between the map types. + static PyObject* Contains(PyObject* _self, PyObject* key); + static Py_ssize_t Length(PyObject* _self); + static PyObject* GetIterator(PyObject *_self); + static PyObject* IterNext(PyObject* _self); + + // Methods that differ between the map types. + static PyObject* ScalarMapGetItem(PyObject* _self, PyObject* key); + static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key); + static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v); + static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v); +}; + +struct MapIterator { + PyObject_HEAD; + + scoped_ptr< ::google::protobuf::MapIterator> iter; + + // A pointer back to the container, so we can notice changes to the version. + // We own a ref on this. + MapContainer* container; + + // We need to keep a ref on the Message* too, because + // MapIterator::~MapIterator() accesses it. Normally this would be ok because + // the ref on container (above) would guarantee outlive semantics. However in + // the case of ClearField(), InitializeAndCopyToParentContainer() resets the + // message pointer (and the owner) to a different message, a copy of the + // original. But our iterator still points to the original, which could now + // get deleted before us. + // + // To prevent this, we ensure that the Message will always stay alive as long + // as this iterator does. This is solely for the benefit of the MapIterator + // destructor -- we should never actually access the iterator in this state + // except to delete it. + shared_ptr<Message> owner; + + // The version of the map when we took the iterator to it. + // + // We store this so that if the map is modified during iteration we can throw + // an error. + uint64 version; + + // True if the container is empty. We signal this separately to avoid calling + // any of the iteration methods, which are non-const. + bool empty; +}; + +Message* MapContainer::GetMutableMessage() { + cmessage::AssureWritable(parent); + return const_cast<Message*>(message); +} + +// Consumes a reference on the Python string object. +static bool PyStringToSTL(PyObject* py_string, string* stl_string) { + char *value; + Py_ssize_t value_len; + + if (!py_string) { + return false; + } + if (PyBytes_AsStringAndSize(py_string, &value, &value_len) < 0) { + Py_DECREF(py_string); + return false; + } else { + stl_string->assign(value, value_len); + Py_DECREF(py_string); + return true; + } +} + +static bool PythonToMapKey(PyObject* obj, + const FieldDescriptor* field_descriptor, + MapKey* key) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + key->SetInt32Value(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(obj, value, false); + key->SetInt64Value(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(obj, value, false); + key->SetUInt32Value(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(obj, value, false); + key->SetUInt64Value(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(obj, value, false); + key->SetBoolValue(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + string str; + if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) { + return false; + } + key->SetStringValue(str); + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Type %d cannot be a map key", + field_descriptor->cpp_type()); + return false; + } + return true; +} + +static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor, + const MapKey& key) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return PyInt_FromLong(key.GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return PyLong_FromLongLong(key.GetInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return PyInt_FromSize_t(key.GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return PyLong_FromUnsignedLongLong(key.GetUInt64Value()); + case FieldDescriptor::CPPTYPE_BOOL: + return PyBool_FromLong(key.GetBoolValue()); + case FieldDescriptor::CPPTYPE_STRING: + return ToStringObject(field_descriptor, key.GetStringValue()); + default: + PyErr_Format( + PyExc_SystemError, "Couldn't convert type %d to value", + field_descriptor->cpp_type()); + return NULL; + } +} + +// This is only used for ScalarMap, so we don't need to handle the +// CPPTYPE_MESSAGE case. +PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor, + MapValueRef* value) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return PyInt_FromLong(value->GetInt32Value()); + case FieldDescriptor::CPPTYPE_INT64: + return PyLong_FromLongLong(value->GetInt64Value()); + case FieldDescriptor::CPPTYPE_UINT32: + return PyInt_FromSize_t(value->GetUInt32Value()); + case FieldDescriptor::CPPTYPE_UINT64: + return PyLong_FromUnsignedLongLong(value->GetUInt64Value()); + case FieldDescriptor::CPPTYPE_FLOAT: + return PyFloat_FromDouble(value->GetFloatValue()); + case FieldDescriptor::CPPTYPE_DOUBLE: + return PyFloat_FromDouble(value->GetDoubleValue()); + case FieldDescriptor::CPPTYPE_BOOL: + return PyBool_FromLong(value->GetBoolValue()); + case FieldDescriptor::CPPTYPE_STRING: + return ToStringObject(field_descriptor, value->GetStringValue()); + case FieldDescriptor::CPPTYPE_ENUM: + return PyInt_FromLong(value->GetEnumValue()); + default: + PyErr_Format( + PyExc_SystemError, "Couldn't convert type %d to value", + field_descriptor->cpp_type()); + return NULL; + } +} + +// This is only used for ScalarMap, so we don't need to handle the +// CPPTYPE_MESSAGE case. +static bool PythonToMapValueRef(PyObject* obj, + const FieldDescriptor* field_descriptor, + bool allow_unknown_enum_values, + MapValueRef* value_ref) { + switch (field_descriptor->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + value_ref->SetInt32Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_INT64: { + GOOGLE_CHECK_GET_INT64(obj, value, false); + value_ref->SetInt64Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_UINT32: { + GOOGLE_CHECK_GET_UINT32(obj, value, false); + value_ref->SetUInt32Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_UINT64: { + GOOGLE_CHECK_GET_UINT64(obj, value, false); + value_ref->SetUInt64Value(value); + return true; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + GOOGLE_CHECK_GET_FLOAT(obj, value, false); + value_ref->SetFloatValue(value); + return true; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + GOOGLE_CHECK_GET_DOUBLE(obj, value, false); + value_ref->SetDoubleValue(value); + return true; + } + case FieldDescriptor::CPPTYPE_BOOL: { + GOOGLE_CHECK_GET_BOOL(obj, value, false); + value_ref->SetBoolValue(value); + return true;; + } + case FieldDescriptor::CPPTYPE_STRING: { + string str; + if (!PyStringToSTL(CheckString(obj, field_descriptor), &str)) { + return false; + } + value_ref->SetStringValue(str); + return true; + } + case FieldDescriptor::CPPTYPE_ENUM: { + GOOGLE_CHECK_GET_INT32(obj, value, false); + if (allow_unknown_enum_values) { + value_ref->SetEnumValue(value); + return true; + } else { + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + value_ref->SetEnumValue(value); + return true; + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return false; + } + } + break; + } + default: + PyErr_Format( + PyExc_SystemError, "Setting value to a field of unknown type %d", + field_descriptor->cpp_type()); + return false; + } +} + +// Map methods common to ScalarMap and MessageMap ////////////////////////////// + +static MapContainer* GetMap(PyObject* obj) { + return reinterpret_cast<MapContainer*>(obj); +} + +Py_ssize_t MapReflectionFriend::Length(PyObject* _self) { + MapContainer* self = GetMap(_self); + const google::protobuf::Message* message = self->message; + return message->GetReflection()->MapSize(*message, + self->parent_field_descriptor); +} + +PyObject* Clear(PyObject* _self) { + MapContainer* self = GetMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + reflection->ClearField(message, self->parent_field_descriptor); + + Py_RETURN_NONE; +} + +PyObject* MapReflectionFriend::Contains(PyObject* _self, PyObject* key) { + MapContainer* self = GetMap(_self); + + const Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return NULL; + } + + if (reflection->ContainsMapKey(*message, self->parent_field_descriptor, + map_key)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +// Initializes the underlying Message object of "to" so it becomes a new parent +// repeated scalar, and copies all the values from "from" to it. A child scalar +// container can be released by passing it as both from and to (e.g. making it +// the recipient of the new parent message and copying the values from itself). +static int InitializeAndCopyToParentContainer(MapContainer* from, + MapContainer* to) { + // For now we require from == to, re-evaluate if we want to support deep copy + // as in repeated_scalar_container.cc. + GOOGLE_DCHECK(from == to); + Message* new_message = from->message->New(); + + if (MapReflectionFriend::Length(reinterpret_cast<PyObject*>(from)) > 0) { + // A somewhat roundabout way of copying just one field from old_message to + // new_message. This is the best we can do with what Reflection gives us. + Message* mutable_old = from->GetMutableMessage(); + vector<const FieldDescriptor*> fields; + fields.push_back(from->parent_field_descriptor); + + // Move the map field into the new message. + mutable_old->GetReflection()->SwapFields(mutable_old, new_message, fields); + + // If/when we support from != to, this will be required also to copy the + // map field back into the existing message: + // mutable_old->MergeFrom(*new_message); + } + + // If from == to this could delete old_message. + to->owner.reset(new_message); + + to->parent = NULL; + to->parent_field_descriptor = from->parent_field_descriptor; + to->message = new_message; + + // Invalidate iterators, since they point to the old copy of the field. + to->version++; + + return 0; +} + +int MapContainer::Release() { + return InitializeAndCopyToParentContainer(this, this); +} + + +// ScalarMap /////////////////////////////////////////////////////////////////// + +PyObject *NewScalarMapContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0)); + if (obj.get() == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + } + + MapContainer* self = GetMap(obj.get()); + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + self->version = 0; + + self->key_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("key"); + self->value_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("value"); + + if (self->key_field_descriptor == NULL || + self->value_field_descriptor == NULL) { + return PyErr_Format(PyExc_KeyError, + "Map entry descriptor did not have key/value fields"); + } + + return obj.release(); +} + +PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self, + PyObject* key) { + MapContainer* self = GetMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return NULL; + } + + if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value)) { + self->version++; + } + + return MapValueRefToPython(self->value_field_descriptor, &value); +} + +int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key, + PyObject* v) { + MapContainer* self = GetMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return -1; + } + + self->version++; + + if (v) { + // Set item to v. + reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value); + + return PythonToMapValueRef(v, self->value_field_descriptor, + reflection->SupportsUnknownEnumValues(), &value) + ? 0 + : -1; + } else { + // Delete key from map. + if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key)) { + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } + } +} + +static PyObject* ScalarMapGet(PyObject* self, PyObject* args) { + PyObject* key; + PyObject* default_value = NULL; + if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { + return NULL; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::ScalarMapGetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static void ScalarMapDealloc(PyObject* _self) { + MapContainer* self = GetMap(_self); + self->owner.reset(); + Py_TYPE(_self)->tp_free(_self); +} + +static PyMappingMethods ScalarMapMappingMethods = { + MapReflectionFriend::Length, // mp_length + MapReflectionFriend::ScalarMapGetItem, // mp_subscript + MapReflectionFriend::ScalarMapSetItem, // mp_ass_subscript +}; + +static PyMethodDef ScalarMapMethods[] = { + { "__contains__", MapReflectionFriend::Contains, METH_O, + "Tests whether a key is a member of the map." }, + { "clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map." }, + { "get", ScalarMapGet, METH_VARARGS, + "Gets the value for the given key if present, or otherwise a default" }, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +PyTypeObject ScalarMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ScalarMapContainer", // tp_name + sizeof(MapContainer), // tp_basicsize + 0, // tp_itemsize + ScalarMapDealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &ScalarMapMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + MapReflectionFriend::GetIterator, // tp_iter + 0, // tp_iternext + ScalarMapMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + + +// MessageMap ////////////////////////////////////////////////////////////////// + +static MessageMapContainer* GetMessageMap(PyObject* obj) { + return reinterpret_cast<MessageMapContainer*>(obj); +} + +static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { + // Get or create the CMessage object corresponding to this message. + ScopedPyObjectPtr key(PyLong_FromVoidPtr(message)); + PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); + + if (ret == NULL) { + CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, + message->GetDescriptor()); + ret = reinterpret_cast<PyObject*>(cmsg); + + if (cmsg == NULL) { + return NULL; + } + cmsg->owner = self->owner; + cmsg->message = message; + cmsg->parent = self->parent; + + if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { + Py_DECREF(ret); + return NULL; + } + } else { + Py_INCREF(ret); + } + + return ret; +} + +PyObject* NewMessageMapContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor, + PyObject* concrete_class) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0); + if (obj == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + } + + MessageMapContainer* self = GetMessageMap(obj); + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + self->version = 0; + + self->key_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("key"); + self->value_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("value"); + + self->message_dict = PyDict_New(); + if (self->message_dict == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate message dict."); + } + + Py_INCREF(concrete_class); + self->subclass_init = concrete_class; + + if (self->key_field_descriptor == NULL || + self->value_field_descriptor == NULL) { + Py_DECREF(obj); + return PyErr_Format(PyExc_KeyError, + "Map entry descriptor did not have key/value fields"); + } + + return obj; +} + +int MapReflectionFriend::MessageMapSetItem(PyObject* _self, PyObject* key, + PyObject* v) { + if (v) { + PyErr_Format(PyExc_ValueError, + "Direct assignment of submessage not allowed"); + return -1; + } + + // Now we know that this is a delete, not a set. + + MessageMapContainer* self = GetMessageMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + self->version++; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return -1; + } + + // Delete key from map. + if (reflection->DeleteMapValue(message, self->parent_field_descriptor, + map_key)) { + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } +} + +PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self, + PyObject* key) { + MessageMapContainer* self = GetMessageMap(_self); + + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + MapKey map_key; + MapValueRef value; + + if (!PythonToMapKey(key, self->key_field_descriptor, &map_key)) { + return NULL; + } + + if (reflection->InsertOrLookupMapValue(message, self->parent_field_descriptor, + map_key, &value)) { + self->version++; + } + + return GetCMessage(self, value.MutableMessageValue()); +} + +PyObject* MessageMapGet(PyObject* self, PyObject* args) { + PyObject* key; + PyObject* default_value = NULL; + if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { + return NULL; + } + + ScopedPyObjectPtr is_present(MapReflectionFriend::Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return MapReflectionFriend::MessageMapGetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static void MessageMapDealloc(PyObject* _self) { + MessageMapContainer* self = GetMessageMap(_self); + self->owner.reset(); + Py_DECREF(self->message_dict); + Py_TYPE(_self)->tp_free(_self); +} + +static PyMappingMethods MessageMapMappingMethods = { + MapReflectionFriend::Length, // mp_length + MapReflectionFriend::MessageMapGetItem, // mp_subscript + MapReflectionFriend::MessageMapSetItem, // mp_ass_subscript +}; + +static PyMethodDef MessageMapMethods[] = { + { "__contains__", (PyCFunction)MapReflectionFriend::Contains, METH_O, + "Tests whether the map contains this element."}, + { "clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map."}, + { "get", MessageMapGet, METH_VARARGS, + "Gets the value for the given key if present, or otherwise a default" }, + { "get_or_create", MapReflectionFriend::MessageMapGetItem, METH_O, + "Alias for getitem, useful to make explicit that the map is mutated." }, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +PyTypeObject MessageMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMapContainer", // tp_name + sizeof(MessageMapContainer), // tp_basicsize + 0, // tp_itemsize + MessageMapDealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &MessageMapMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A map container for message", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + MapReflectionFriend::GetIterator, // tp_iter + 0, // tp_iternext + MessageMapMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +// MapIterator ///////////////////////////////////////////////////////////////// + +static MapIterator* GetIter(PyObject* obj) { + return reinterpret_cast<MapIterator*>(obj); +} + +PyObject* MapReflectionFriend::GetIterator(PyObject *_self) { + MapContainer* self = GetMap(_self); + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&MapIterator_Type, 0)); + if (obj == NULL) { + return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); + } + + MapIterator* iter = GetIter(obj.get()); + + Py_INCREF(self); + iter->container = self; + iter->version = self->version; + iter->owner = self->owner; + + if (MapReflectionFriend::Length(_self) > 0) { + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + iter->iter.reset(new ::google::protobuf::MapIterator( + reflection->MapBegin(message, self->parent_field_descriptor))); + } + + return obj.release(); +} + +PyObject* MapReflectionFriend::IterNext(PyObject* _self) { + MapIterator* self = GetIter(_self); + + // This won't catch mutations to the map performed by MergeFrom(); no easy way + // to address that. + if (self->version != self->container->version) { + return PyErr_Format(PyExc_RuntimeError, + "Map modified during iteration."); + } + + if (self->iter.get() == NULL) { + return NULL; + } + + Message* message = self->container->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + + if (*self->iter == + reflection->MapEnd(message, self->container->parent_field_descriptor)) { + return NULL; + } + + PyObject* ret = MapKeyToPython(self->container->key_field_descriptor, + self->iter->GetKey()); + + ++(*self->iter); + + return ret; +} + +static void DeallocMapIterator(PyObject* _self) { + MapIterator* self = GetIter(_self); + self->iter.reset(); + self->owner.reset(); + Py_XDECREF(self->container); + Py_TYPE(_self)->tp_free(_self); +} + +PyTypeObject MapIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MapIterator", // tp_name + sizeof(MapIterator), // tp_basicsize + 0, // tp_itemsize + DeallocMapIterator, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map iterator", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + MapReflectionFriend::IterNext, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/message_map_container.h b/python/google/protobuf/pyext/map_container.h index 4f6cb26a..2de61187 100644 --- a/python/google/protobuf/pyext/message_map_container.h +++ b/python/google/protobuf/pyext/map_container.h @@ -28,8 +28,8 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__ -#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__ +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ #include <Python.h> @@ -39,6 +39,7 @@ #endif #include <google/protobuf/descriptor.h> +#include <google/protobuf/message.h> namespace google { namespace protobuf { @@ -55,18 +56,23 @@ namespace python { struct CMessage; -struct MessageMapContainer { +// This struct is used directly for ScalarMap, and is the base class of +// MessageMapContainer, which is used for MessageMap. +struct MapContainer { PyObject_HEAD; // This is the top-level C++ Message object that owns the whole - // proto tree. Every Python MessageMapContainer holds a + // proto tree. Every Python MapContainer holds a // reference to it in order to keep it alive as long as there's a // Python object that references any part of the tree. shared_ptr<Message> owner; // Pointer to the C++ Message that contains this container. The - // MessageMapContainer does not own this pointer. - Message* message; + // MapContainer does not own this pointer. + const Message* message; + + // Use to get a mutable message when necessary. + Message* GetMutableMessage(); // Weak reference to a parent CMessage object (i.e. may be NULL.) // @@ -82,45 +88,46 @@ struct MessageMapContainer { const FieldDescriptor* key_field_descriptor; const FieldDescriptor* value_field_descriptor; + // We bump this whenever we perform a mutation, to invalidate existing + // iterators. + uint64 version; + + // Releases the messages in the container to a new message. + // + // Returns 0 on success, -1 on failure. + int Release(); + + // Set the owner field of self and any children of self. + void SetOwner(const shared_ptr<Message>& new_owner) { + owner = new_owner; + } +}; + +struct MessageMapContainer : public MapContainer { // A callable that is used to create new child messages. PyObject* subclass_init; // A dict mapping Message* -> CMessage. PyObject* message_dict; - - // We bump this whenever we perform a mutation, to invalidate existing - // iterators. - uint64 version; }; -#if PY_MAJOR_VERSION >= 3 - extern PyObject *MessageMapContainer_Type; - extern PyType_Spec MessageMapContainer_Type_spec; -#else - extern PyTypeObject MessageMapContainer_Type; -#endif -extern PyTypeObject MessageMapIterator_Type; +extern PyTypeObject ScalarMapContainer_Type; +extern PyTypeObject MessageMapContainer_Type; +extern PyTypeObject MapIterator_Type; // Both map types use the same iterator. -namespace message_map_container { - -// Builds a MessageMapContainer object, from a parent message and a +// Builds a MapContainer object, from a parent message and a // field descriptor. -extern PyObject* NewContainer(CMessage* parent, - const FieldDescriptor* parent_field_descriptor, - PyObject* concrete_class); - -// Releases the messages in the container to a new message. -// -// Returns 0 on success, -1 on failure. -int Release(MessageMapContainer* self); +extern PyObject* NewScalarMapContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor); -// Set the owner field of self and any children of self. -void SetOwner(MessageMapContainer* self, - const shared_ptr<Message>& new_owner); +// Builds a MessageMap object, from a parent message and a +// field descriptor. +extern PyObject* NewMessageMapContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor, + PyObject* concrete_class); -} // namespace message_map_container } // namespace python } // namespace protobuf } // namespace google -#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__ +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 72f51ec1..863cde01 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -33,6 +33,7 @@ #include <google/protobuf/pyext/message.h> +#include <map> #include <memory> #ifndef _SHARED_PTR_H #include <google/protobuf/stubs/shared_ptr.h> @@ -61,8 +62,7 @@ #include <google/protobuf/pyext/extension_dict.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> -#include <google/protobuf/pyext/message_map_container.h> -#include <google/protobuf/pyext/scalar_map_container.h> +#include <google/protobuf/pyext/map_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> @@ -96,6 +96,7 @@ static PyObject* k_extensions_by_number; PyObject* EnumTypeWrapper_class; static PyObject* PythonMessage_class; static PyObject* kEmptyWeakref; +static PyObject* WKT_classes = NULL; // Defines the Metaclass of all Message classes. // It allows us to cache some C++ pointers in the class object itself, they are @@ -274,8 +275,32 @@ static PyObject* New(PyTypeObject* type, // Build the arguments to the base metaclass. // We change the __bases__ classes. - ScopedPyObjectPtr new_args(Py_BuildValue( - "s(OO)O", name, &CMessage_Type, PythonMessage_class, dict)); + ScopedPyObjectPtr new_args; + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (message_descriptor == NULL) { + return NULL; + } + + if (WKT_classes == NULL) { + ScopedPyObjectPtr well_known_types(PyImport_ImportModule( + "google.protobuf.internal.well_known_types")); + GOOGLE_DCHECK(well_known_types != NULL); + + WKT_classes = PyObject_GetAttrString(well_known_types.get(), "WKTBASES"); + GOOGLE_DCHECK(WKT_classes != NULL); + } + + PyObject* well_known_class = PyDict_GetItemString( + WKT_classes, message_descriptor->full_name().c_str()); + if (well_known_class == NULL) { + new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type, + PythonMessage_class, dict)); + } else { + new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type, + PythonMessage_class, well_known_class, dict)); + } + if (new_args == NULL) { return NULL; } @@ -448,21 +473,9 @@ static int VisitCompositeField(const FieldDescriptor* descriptor, if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (descriptor->is_map()) { - const Descriptor* entry_type = descriptor->message_type(); - const FieldDescriptor* value_type = - entry_type->FindFieldByName("value"); - if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - MessageMapContainer* container = - reinterpret_cast<MessageMapContainer*>(child); - if (visitor.VisitMessageMapContainer(container) == -1) { - return -1; - } - } else { - ScalarMapContainer* container = - reinterpret_cast<ScalarMapContainer*>(child); - if (visitor.VisitScalarMapContainer(container) == -1) { - return -1; - } + MapContainer* container = reinterpret_cast<MapContainer*>(child); + if (visitor.VisitMapContainer(container) == -1) { + return -1; } } else { RepeatedCompositeContainer* container = @@ -579,12 +592,14 @@ bool CheckAndGetInteger( if (PyObject_RichCompareBool(min, arg, Py_LE) != 1 || PyObject_RichCompareBool(max, arg, Py_GE) != 1) { #endif - PyObject *s = PyObject_Str(arg); - if (s) { - PyErr_Format(PyExc_ValueError, - "Value out of range: %s", - PyString_AsString(s)); - Py_DECREF(s); + if (!PyErr_Occurred()) { + PyObject *s = PyObject_Str(arg); + if (s) { + PyErr_Format(PyExc_ValueError, + "Value out of range: %s", + PyString_AsString(s)); + Py_DECREF(s); + } } return false; } @@ -642,38 +657,51 @@ bool CheckAndGetBool(PyObject* arg, bool* value) { return true; } -bool CheckAndSetString( - PyObject* arg, Message* message, - const FieldDescriptor* descriptor, - const Reflection* reflection, - bool append, - int index) { +// Checks whether the given object (which must be "bytes" or "unicode") contains +// valid UTF-8. +bool IsValidUTF8(PyObject* obj) { + if (PyBytes_Check(obj)) { + PyObject* unicode = PyUnicode_FromEncodedObject(obj, "utf-8", NULL); + + // Clear the error indicator; we report our own error when desired. + PyErr_Clear(); + + if (unicode) { + Py_DECREF(unicode); + return true; + } else { + return false; + } + } else { + // Unicode object, known to be valid UTF-8. + return true; + } +} + +bool AllowInvalidUTF8(const FieldDescriptor* field) { return false; } + +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor) { GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING || descriptor->type() == FieldDescriptor::TYPE_BYTES); if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { FormatTypeError(arg, "bytes, unicode"); - return false; + return NULL; } - if (PyBytes_Check(arg)) { - PyObject* unicode = PyUnicode_FromEncodedObject(arg, "utf-8", NULL); - if (unicode == NULL) { - PyObject* repr = PyObject_Repr(arg); - PyErr_Format(PyExc_ValueError, - "%s has type str, but isn't valid UTF-8 " - "encoding. Non-UTF-8 strings must be converted to " - "unicode objects before being added.", - PyString_AsString(repr)); - Py_DECREF(repr); - return false; - } else { - Py_DECREF(unicode); - } + if (!IsValidUTF8(arg) && !AllowInvalidUTF8(descriptor)) { + PyObject* repr = PyObject_Repr(arg); + PyErr_Format(PyExc_ValueError, + "%s has type str, but isn't valid UTF-8 " + "encoding. Non-UTF-8 strings must be converted to " + "unicode objects before being added.", + PyString_AsString(repr)); + Py_DECREF(repr); + return NULL; } } else if (!PyBytes_Check(arg)) { FormatTypeError(arg, "bytes"); - return false; + return NULL; } PyObject* encoded_string = NULL; @@ -691,14 +719,24 @@ bool CheckAndSetString( Py_INCREF(encoded_string); } - if (encoded_string == NULL) { + return encoded_string; +} + +bool CheckAndSetString( + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, + bool append, + int index) { + ScopedPyObjectPtr encoded_string(CheckString(arg, descriptor)); + + if (encoded_string.get() == NULL) { return false; } char* value; Py_ssize_t value_len; - if (PyBytes_AsStringAndSize(encoded_string, &value, &value_len) < 0) { - Py_DECREF(encoded_string); + if (PyBytes_AsStringAndSize(encoded_string.get(), &value, &value_len) < 0) { return false; } @@ -710,7 +748,6 @@ bool CheckAndSetString( } else { reflection->SetRepeatedString(message, descriptor, index, value_string); } - Py_DECREF(encoded_string); return true; } @@ -823,12 +860,7 @@ struct FixupMessageReference : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - container->message = message_; - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { + int VisitMapContainer(MapContainer* container) { container->message = message_; return 0; } @@ -870,9 +902,8 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // five places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, ScalarMapContainer, MessageMapContainer, - // and ExtensionDict. + // four places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, MapContainer, and ExtensionDict. if (self->extensions != NULL) self->extensions->message = self->message; if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) @@ -1054,7 +1085,8 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { } const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); if (descriptor == NULL) { - PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", + PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", + self->message->GetDescriptor()->name().c_str(), PyString_AsString(name)); return -1; } @@ -1203,18 +1235,6 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { self->composite_fields = NULL; - // If there are extension_ranges, the message is "extendable". Allocate a - // dictionary to store the extension fields. - if (descriptor->extension_range_count() > 0) { - // TODO(amauryfa): Delay the construction of this dict until extensions are - // really used on the object. - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - } - return self; } @@ -1285,12 +1305,7 @@ struct ClearWeakReferences : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - container->parent = NULL; - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { + int VisitMapContainer(MapContainer* container) { container->parent = NULL; return 0; } @@ -1305,6 +1320,9 @@ struct ClearWeakReferences : public ChildVisitor { static void Dealloc(CMessage* self) { // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); + if (self->extensions) { + self->extensions->parent = NULL; + } Py_CLEAR(self->extensions); Py_CLEAR(self->composite_fields); @@ -1466,20 +1484,27 @@ PyObject* HasField(CMessage* self, PyObject* arg) { Py_RETURN_FALSE; } -PyObject* ClearExtension(CMessage* self, PyObject* arg) { +PyObject* ClearExtension(CMessage* self, PyObject* extension) { if (self->extensions != NULL) { - return extension_dict::ClearExtension(self->extensions, arg); + return extension_dict::ClearExtension(self->extensions, extension); + } else { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; + } + if (ScopedPyObjectPtr(ClearFieldByDescriptor(self, descriptor)) == NULL) { + return NULL; + } } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + Py_RETURN_NONE; } -PyObject* HasExtension(CMessage* self, PyObject* arg) { - if (self->extensions != NULL) { - return extension_dict::HasExtension(self->extensions, arg); +PyObject* HasExtension(CMessage* self, PyObject* extension) { + const FieldDescriptor* descriptor = GetExtensionDescriptor(extension); + if (descriptor == NULL) { + return NULL; } - PyErr_SetString(PyExc_TypeError, "Message is not extendable"); - return NULL; + return HasFieldByDescriptor(self, descriptor); } // --------------------------------------------------------------------- @@ -1529,13 +1554,8 @@ struct SetOwnerVisitor : public ChildVisitor { return 0; } - int VisitScalarMapContainer(ScalarMapContainer* container) { - scalar_map_container::SetOwner(container, new_owner_); - return 0; - } - - int VisitMessageMapContainer(MessageMapContainer* container) { - message_map_container::SetOwner(container, new_owner_); + int VisitMapContainer(MapContainer* container) { + container->SetOwner(new_owner_); return 0; } @@ -1608,14 +1628,8 @@ struct ReleaseChild : public ChildVisitor { reinterpret_cast<RepeatedScalarContainer*>(container)); } - int VisitScalarMapContainer(ScalarMapContainer* container) { - return scalar_map_container::Release( - reinterpret_cast<ScalarMapContainer*>(container)); - } - - int VisitMessageMapContainer(MessageMapContainer* container) { - return message_map_container::Release( - reinterpret_cast<MessageMapContainer*>(container)); + int VisitMapContainer(MapContainer* container) { + return reinterpret_cast<MapContainer*>(container)->Release(); } int VisitCMessage(CMessage* cmessage, @@ -1707,17 +1721,7 @@ PyObject* Clear(CMessage* self) { AssureWritable(self); if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; - - // The old ExtensionDict still aliases this CMessage, but all its - // fields have been released. - if (self->extensions != NULL) { - Py_CLEAR(self->extensions); - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - } + Py_CLEAR(self->extensions); if (self->composite_fields) { PyDict_Clear(self->composite_fields); } @@ -1997,7 +2001,6 @@ static PyObject* RegisterExtension(PyObject* cls, if (descriptor->is_extension() && descriptor->containing_type()->options().message_set_wire_format() && descriptor->type() == FieldDescriptor::TYPE_MESSAGE && - descriptor->message_type() == descriptor->extension_scope() && descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { ScopedPyObjectPtr message_name(PyString_FromStringAndSize( descriptor->message_type()->full_name().c_str(), @@ -2042,6 +2045,8 @@ static PyObject* WhichOneof(CMessage* self, PyObject* arg) { } } +static PyObject* GetExtensionDict(CMessage* self, void *closure); + static PyObject* ListFields(CMessage* self) { vector<const FieldDescriptor*> fields; self->message->GetReflection()->ListFields(*self->message, &fields); @@ -2079,12 +2084,13 @@ static PyObject* ListFields(CMessage* self) { PyErr_Clear(); continue; } - PyObject* extensions = reinterpret_cast<PyObject*>(self->extensions); + ScopedPyObjectPtr extensions(GetExtensionDict(self, NULL)); if (extensions == NULL) { return NULL; } // 'extension' reference later stolen by PyTuple_SET_ITEM. - PyObject* extension = PyObject_GetItem(extensions, extension_field.get()); + PyObject* extension = PyObject_GetItem( + extensions.get(), extension_field.get()); if (extension == NULL) { return NULL; } @@ -2493,9 +2499,31 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, Py_RETURN_NONE; } -static PyMemberDef Members[] = { - {"Extensions", T_OBJECT_EX, offsetof(CMessage, extensions), 0, - "Extension dict"}, +static PyObject* GetExtensionDict(CMessage* self, void *closure) { + if (self->extensions) { + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + // If there are extension_ranges, the message is "extendable". Allocate a + // dictionary to store the extension fields. + const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); + if (descriptor->extension_range_count() > 0) { + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + if (extension_dict == NULL) { + return NULL; + } + self->extensions = extension_dict; + Py_INCREF(self->extensions); + return reinterpret_cast<PyObject*>(self->extensions); + } + + PyErr_SetNone(PyExc_AttributeError); + return NULL; +} + +static PyGetSetDef Getters[] = { + {"Extensions", (getter)GetExtensionDict, NULL, "Extension dict"}, {NULL} }; @@ -2592,10 +2620,10 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (value_class == NULL) { return NULL; } - py_container = message_map_container::NewContainer(self, field_descriptor, - value_class); + py_container = + NewMessageMapContainer(self, field_descriptor, value_class); } else { - py_container = scalar_map_container::NewContainer(self, field_descriptor); + py_container = NewScalarMapContainer(self, field_descriptor); } if (py_container == NULL) { return NULL; @@ -2672,7 +2700,10 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } } - PyErr_Format(PyExc_AttributeError, "Assignment not allowed"); + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed " + "(no field \"%s\"in protocol message object).", + PyString_AsString(name)); return -1; } @@ -2707,8 +2738,8 @@ PyTypeObject CMessage_Type = { 0, // tp_iter 0, // tp_iternext cmessage::Methods, // tp_methods - cmessage::Members, // tp_members - 0, // tp_getset + 0, // tp_members + cmessage::Getters, // tp_getset 0, // tp_base 0, // tp_dict 0, // tp_descr_get @@ -2910,12 +2941,12 @@ bool InitProto2MessageModule(PyObject *m) { reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); #endif - if (PyType_Ready(&ScalarMapIterator_Type) < 0) { + if (PyType_Ready(&MapIterator_Type) < 0) { return false; } - PyModule_AddObject(m, "ScalarMapIterator", - reinterpret_cast<PyObject*>(&ScalarMapIterator_Type)); + PyModule_AddObject(m, "MapIterator", + reinterpret_cast<PyObject*>(&MapIterator_Type)); #if PY_MAJOR_VERSION >= 3 @@ -2934,13 +2965,6 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "MessageMapContainer", reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); #endif - - if (PyType_Ready(&MessageMapIterator_Type) < 0) { - return false; - } - - PyModule_AddObject(m, "MessageMapIterator", - reinterpret_cast<PyObject*>(&MessageMapIterator_Type)); } if (PyType_Ready(&ExtensionDict_Type) < 0) { @@ -2957,6 +2981,9 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "default_pool", reinterpret_cast<PyObject*>(GetDefaultDescriptorPool())); + PyModule_AddObject(m, "DescriptorPool", reinterpret_cast<PyObject*>( + &PyDescriptorPool_Type)); + // This implementation provides full Descriptor types, we advertise it so that // descriptor.py can use them in replacement of the Python classes. PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1); diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 94de4551..cc0012e9 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -307,6 +307,7 @@ bool CheckAndGetInteger( bool CheckAndGetDouble(PyObject* arg, double* value); bool CheckAndGetFloat(PyObject* arg, float* value); bool CheckAndGetBool(PyObject* arg, bool* value); +PyObject* CheckString(PyObject* arg, const FieldDescriptor* descriptor); bool CheckAndSetString( PyObject* arg, Message* message, const FieldDescriptor* descriptor, diff --git a/python/google/protobuf/pyext/message_map_container.cc b/python/google/protobuf/pyext/message_map_container.cc deleted file mode 100644 index 8902fa00..00000000 --- a/python/google/protobuf/pyext/message_map_container.cc +++ /dev/null @@ -1,569 +0,0 @@ -// Protocol Buffers - Google's data interchange format -// Copyright 2008 Google Inc. All rights reserved. -// https://developers.google.com/protocol-buffers/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// Author: haberman@google.com (Josh Haberman) - -#include <google/protobuf/pyext/message_map_container.h> - -#include <google/protobuf/stubs/logging.h> -#include <google/protobuf/stubs/common.h> -#include <google/protobuf/message.h> -#include <google/protobuf/pyext/message.h> -#include <google/protobuf/pyext/scoped_pyobject_ptr.h> - -namespace google { -namespace protobuf { -namespace python { - -struct MessageMapIterator { - PyObject_HEAD; - - // This dict contains the full contents of what we want to iterate over. - // There's no way to avoid building this, because the list representation - // (which is canonical) can contain duplicate keys. So at the very least we - // need a set that lets us skip duplicate keys. And at the point that we're - // doing that, we might as well just build the actual dict we're iterating - // over and use dict's built-in iterator. - PyObject* dict; - - // An iterator on dict. - PyObject* iter; - - // A pointer back to the container, so we can notice changes to the version. - MessageMapContainer* container; - - // The version of the map when we took the iterator to it. - // - // We store this so that if the map is modified during iteration we can throw - // an error. - uint64 version; -}; - -static MessageMapIterator* GetIter(PyObject* obj) { - return reinterpret_cast<MessageMapIterator*>(obj); -} - -namespace message_map_container { - -static MessageMapContainer* GetMap(PyObject* obj) { - return reinterpret_cast<MessageMapContainer*>(obj); -} - -// The private constructor of MessageMapContainer objects. -PyObject* NewContainer(CMessage* parent, - const google::protobuf::FieldDescriptor* parent_field_descriptor, - PyObject* concrete_class) { - if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { - return NULL; - } - -#if PY_MAJOR_VERSION >= 3 - PyObject* obj = PyType_GenericAlloc( - reinterpret_cast<PyTypeObject *>(MessageMapContainer_Type), 0); -#else - PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0); -#endif - if (obj == NULL) { - return PyErr_Format(PyExc_RuntimeError, - "Could not allocate new container."); - } - - MessageMapContainer* self = GetMap(obj); - - self->message = parent->message; - self->parent = parent; - self->parent_field_descriptor = parent_field_descriptor; - self->owner = parent->owner; - self->version = 0; - - self->key_field_descriptor = - parent_field_descriptor->message_type()->FindFieldByName("key"); - self->value_field_descriptor = - parent_field_descriptor->message_type()->FindFieldByName("value"); - - self->message_dict = PyDict_New(); - if (self->message_dict == NULL) { - return PyErr_Format(PyExc_RuntimeError, - "Could not allocate message dict."); - } - - Py_INCREF(concrete_class); - self->subclass_init = concrete_class; - - if (self->key_field_descriptor == NULL || - self->value_field_descriptor == NULL) { - Py_DECREF(obj); - return PyErr_Format(PyExc_KeyError, - "Map entry descriptor did not have key/value fields"); - } - - return obj; -} - -// Initializes the underlying Message object of "to" so it becomes a new parent -// repeated scalar, and copies all the values from "from" to it. A child scalar -// container can be released by passing it as both from and to (e.g. making it -// the recipient of the new parent message and copying the values from itself). -static int InitializeAndCopyToParentContainer( - MessageMapContainer* from, - MessageMapContainer* to) { - // For now we require from == to, re-evaluate if we want to support deep copy - // as in repeated_composite_container.cc. - GOOGLE_DCHECK(from == to); - Message* old_message = from->message; - Message* new_message = old_message->New(); - to->parent = NULL; - to->parent_field_descriptor = from->parent_field_descriptor; - to->message = new_message; - to->owner.reset(new_message); - - vector<const FieldDescriptor*> fields; - fields.push_back(from->parent_field_descriptor); - old_message->GetReflection()->SwapFields(old_message, new_message, fields); - return 0; -} - -static PyObject* GetCMessage(MessageMapContainer* self, Message* entry) { - // Get or create the CMessage object corresponding to this message. - Message* message = entry->GetReflection()->MutableMessage( - entry, self->value_field_descriptor); - ScopedPyObjectPtr key(PyLong_FromVoidPtr(message)); - PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); - - if (ret == NULL) { - CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, - message->GetDescriptor()); - ret = reinterpret_cast<PyObject*>(cmsg); - - if (cmsg == NULL) { - return NULL; - } - cmsg->owner = self->owner; - cmsg->message = message; - cmsg->parent = self->parent; - - if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { - Py_DECREF(ret); - return NULL; - } - } else { - Py_INCREF(ret); - } - - return ret; -} - -int Release(MessageMapContainer* self) { - InitializeAndCopyToParentContainer(self, self); - return 0; -} - -void SetOwner(MessageMapContainer* self, - const shared_ptr<Message>& new_owner) { - self->owner = new_owner; -} - -Py_ssize_t Length(PyObject* _self) { - MessageMapContainer* self = GetMap(_self); - google::protobuf::Message* message = self->message; - return message->GetReflection()->FieldSize(*message, - self->parent_field_descriptor); -} - -int MapKeyMatches(MessageMapContainer* self, const Message* entry, - PyObject* key) { - // TODO(haberman): do we need more strict type checking? - ScopedPyObjectPtr entry_key( - cmessage::InternalGetScalar(entry, self->key_field_descriptor)); - int ret = PyObject_RichCompareBool(key, entry_key.get(), Py_EQ); - return ret; -} - -int SetItem(PyObject *_self, PyObject *key, PyObject *v) { - if (v) { - PyErr_Format(PyExc_ValueError, - "Direct assignment of submessage not allowed"); - return -1; - } - - // Now we know that this is a delete, not a set. - - MessageMapContainer* self = GetMap(_self); - cmessage::AssureWritable(self->parent); - - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - size_t size = - reflection->FieldSize(*message, self->parent_field_descriptor); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. We need to search from the end because the underlying - // representation can have duplicates if a user calls MergeFrom(); the last - // one needs to win. - // - // TODO(haberman): add lookup API to Reflection API. - bool found = false; - for (int i = size - 1; i >= 0; i--) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, entry, key); - if (matches < 0) return -1; - if (matches) { - found = true; - if (i != (int)size - 1) { - reflection->SwapElements(message, self->parent_field_descriptor, i, - size - 1); - } - reflection->RemoveLast(message, self->parent_field_descriptor); - - // Can't exit now, the repeated field representation of maps allows - // duplicate keys, and we have to be sure to remove all of them. - } - } - - if (!found) { - PyErr_Format(PyExc_KeyError, "Key not present in map"); - return -1; - } - - self->version++; - - return 0; -} - -PyObject* GetIterator(PyObject *_self) { - MessageMapContainer* self = GetMap(_self); - - ScopedPyObjectPtr obj(PyType_GenericAlloc(&MessageMapIterator_Type, 0)); - if (obj == NULL) { - return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); - } - - MessageMapIterator* iter = GetIter(obj.get()); - - Py_INCREF(self); - iter->container = self; - iter->version = self->version; - iter->dict = PyDict_New(); - if (iter->dict == NULL) { - return PyErr_Format(PyExc_RuntimeError, - "Could not allocate dict for iterator."); - } - - // Build the entire map into a dict right now. Start from the beginning so - // that later entries win in the case of duplicates. - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. We need to search from the end because the underlying - // representation can have duplicates if a user calls MergeFrom(); the last - // one needs to win. - // - // TODO(haberman): add lookup API to Reflection API. - size_t size = - reflection->FieldSize(*message, self->parent_field_descriptor); - for (int i = size - 1; i >= 0; i--) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - ScopedPyObjectPtr key( - cmessage::InternalGetScalar(entry, self->key_field_descriptor)); - if (PyDict_SetItem(iter->dict, key.get(), GetCMessage(self, entry)) < 0) { - return PyErr_Format(PyExc_RuntimeError, - "SetItem failed in iterator construction."); - } - } - - iter->iter = PyObject_GetIter(iter->dict); - - return obj.release(); -} - -PyObject* GetItem(PyObject* _self, PyObject* key) { - MessageMapContainer* self = GetMap(_self); - cmessage::AssureWritable(self->parent); - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. We need to search from the end because the underlying - // representation can have duplicates if a user calls MergeFrom(); the last - // one needs to win. - // - // TODO(haberman): add lookup API to Reflection API. - size_t size = - reflection->FieldSize(*message, self->parent_field_descriptor); - for (int i = size - 1; i >= 0; i--) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, entry, key); - if (matches < 0) return NULL; - if (matches) { - return GetCMessage(self, entry); - } - } - - // Key is not already present; insert a new entry. - Message* entry = - reflection->AddMessage(message, self->parent_field_descriptor); - - self->version++; - - if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor, - key) < 0) { - reflection->RemoveLast(message, self->parent_field_descriptor); - return NULL; - } - - return GetCMessage(self, entry); -} - -PyObject* Contains(PyObject* _self, PyObject* key) { - MessageMapContainer* self = GetMap(_self); - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. - // - // TODO(haberman): add lookup API to Reflection API. - int size = - reflection->FieldSize(*message, self->parent_field_descriptor); - for (int i = 0; i < size; i++) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, entry, key); - if (matches < 0) return NULL; - if (matches) { - Py_RETURN_TRUE; - } - } - - Py_RETURN_FALSE; -} - -PyObject* Clear(PyObject* _self) { - MessageMapContainer* self = GetMap(_self); - cmessage::AssureWritable(self->parent); - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - self->version++; - reflection->ClearField(message, self->parent_field_descriptor); - - Py_RETURN_NONE; -} - -PyObject* Get(PyObject* self, PyObject* args) { - PyObject* key; - PyObject* default_value = NULL; - if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { - return NULL; - } - - ScopedPyObjectPtr is_present(Contains(self, key)); - if (is_present.get() == NULL) { - return NULL; - } - - if (PyObject_IsTrue(is_present.get())) { - return GetItem(self, key); - } else { - if (default_value != NULL) { - Py_INCREF(default_value); - return default_value; - } else { - Py_RETURN_NONE; - } - } -} - -static void Dealloc(PyObject* _self) { - MessageMapContainer* self = GetMap(_self); - self->owner.reset(); - Py_DECREF(self->message_dict); - Py_TYPE(_self)->tp_free(_self); -} - -static PyMethodDef Methods[] = { - { "__contains__", (PyCFunction)Contains, METH_O, - "Tests whether the map contains this element."}, - { "clear", (PyCFunction)Clear, METH_NOARGS, - "Removes all elements from the map."}, - { "get", Get, METH_VARARGS, - "Gets the value for the given key if present, or otherwise a default" }, - { "get_or_create", GetItem, METH_O, - "Alias for getitem, useful to make explicit that the map is mutated." }, - /* - { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, - "Makes a deep copy of the class." }, - { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, - "Outputs picklable representation of the repeated field." }, - */ - {NULL, NULL}, -}; - -} // namespace message_map_container - -namespace message_map_iterator { - -static void Dealloc(PyObject* _self) { - MessageMapIterator* self = GetIter(_self); - Py_DECREF(self->dict); - Py_DECREF(self->iter); - Py_DECREF(self->container); - Py_TYPE(_self)->tp_free(_self); -} - -PyObject* IterNext(PyObject* _self) { - MessageMapIterator* self = GetIter(_self); - - // This won't catch mutations to the map performed by MergeFrom(); no easy way - // to address that. - if (self->version != self->container->version) { - return PyErr_Format(PyExc_RuntimeError, - "Map modified during iteration."); - } - - return PyIter_Next(self->iter); -} - -} // namespace message_map_iterator - -#if PY_MAJOR_VERSION >= 3 - static PyType_Slot MessageMapContainer_Type_slots[] = { - {Py_tp_dealloc, (void *)message_map_container::Dealloc}, - {Py_mp_length, (void *)message_map_container::Length}, - {Py_mp_subscript, (void *)message_map_container::GetItem}, - {Py_mp_ass_subscript, (void *)message_map_container::SetItem}, - {Py_tp_methods, (void *)message_map_container::Methods}, - {Py_tp_iter, (void *)message_map_container::GetIterator}, - {0, 0} - }; - - PyType_Spec MessageMapContainer_Type_spec = { - FULL_MODULE_NAME ".MessageMapContainer", - sizeof(MessageMapContainer), - 0, - Py_TPFLAGS_DEFAULT, - MessageMapContainer_Type_slots - }; - - PyObject *MessageMapContainer_Type; - -#else - static PyMappingMethods MpMethods = { - message_map_container::Length, // mp_length - message_map_container::GetItem, // mp_subscript - message_map_container::SetItem, // mp_ass_subscript - }; - - PyTypeObject MessageMapContainer_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - FULL_MODULE_NAME ".MessageMapContainer", // tp_name - sizeof(MessageMapContainer), // tp_basicsize - 0, // tp_itemsize - message_map_container::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - &MpMethods, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - "A map container for message", // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - message_map_container::GetIterator, // tp_iter - 0, // tp_iternext - message_map_container::Methods, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - }; -#endif - -PyTypeObject MessageMapIterator_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - FULL_MODULE_NAME ".MessageMapIterator", // tp_name - sizeof(MessageMapIterator), // tp_basicsize - 0, // tp_itemsize - message_map_iterator::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - 0, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - "A scalar map iterator", // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - PyObject_SelfIter, // tp_iter - message_map_iterator::IterNext, // tp_iternext - 0, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init -}; - -} // namespace python -} // namespace protobuf -} // namespace google diff --git a/python/google/protobuf/pyext/scalar_map_container.cc b/python/google/protobuf/pyext/scalar_map_container.cc deleted file mode 100644 index 0b0d5a3d..00000000 --- a/python/google/protobuf/pyext/scalar_map_container.cc +++ /dev/null @@ -1,542 +0,0 @@ -// Protocol Buffers - Google's data interchange format -// Copyright 2008 Google Inc. All rights reserved. -// https://developers.google.com/protocol-buffers/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -// Author: haberman@google.com (Josh Haberman) - -#include <google/protobuf/pyext/scalar_map_container.h> - -#include <google/protobuf/stubs/logging.h> -#include <google/protobuf/stubs/common.h> -#include <google/protobuf/message.h> -#include <google/protobuf/pyext/message.h> -#include <google/protobuf/pyext/scoped_pyobject_ptr.h> - -namespace google { -namespace protobuf { -namespace python { - -struct ScalarMapIterator { - PyObject_HEAD; - - // This dict contains the full contents of what we want to iterate over. - // There's no way to avoid building this, because the list representation - // (which is canonical) can contain duplicate keys. So at the very least we - // need a set that lets us skip duplicate keys. And at the point that we're - // doing that, we might as well just build the actual dict we're iterating - // over and use dict's built-in iterator. - PyObject* dict; - - // An iterator on dict. - PyObject* iter; - - // A pointer back to the container, so we can notice changes to the version. - ScalarMapContainer* container; - - // The version of the map when we took the iterator to it. - // - // We store this so that if the map is modified during iteration we can throw - // an error. - uint64 version; -}; - -static ScalarMapIterator* GetIter(PyObject* obj) { - return reinterpret_cast<ScalarMapIterator*>(obj); -} - -namespace scalar_map_container { - -static ScalarMapContainer* GetMap(PyObject* obj) { - return reinterpret_cast<ScalarMapContainer*>(obj); -} - -// The private constructor of ScalarMapContainer objects. -PyObject *NewContainer( - CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { - if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { - return NULL; - } - -#if PY_MAJOR_VERSION >= 3 - ScopedPyObjectPtr obj(PyType_GenericAlloc( - reinterpret_cast<PyTypeObject *>(ScalarMapContainer_Type), 0)); -#else - ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0)); -#endif - if (obj.get() == NULL) { - return PyErr_Format(PyExc_RuntimeError, - "Could not allocate new container."); - } - - ScalarMapContainer* self = GetMap(obj.get()); - - self->message = parent->message; - self->parent = parent; - self->parent_field_descriptor = parent_field_descriptor; - self->owner = parent->owner; - self->version = 0; - - self->key_field_descriptor = - parent_field_descriptor->message_type()->FindFieldByName("key"); - self->value_field_descriptor = - parent_field_descriptor->message_type()->FindFieldByName("value"); - - if (self->key_field_descriptor == NULL || - self->value_field_descriptor == NULL) { - return PyErr_Format(PyExc_KeyError, - "Map entry descriptor did not have key/value fields"); - } - - return obj.release(); -} - -// Initializes the underlying Message object of "to" so it becomes a new parent -// repeated scalar, and copies all the values from "from" to it. A child scalar -// container can be released by passing it as both from and to (e.g. making it -// the recipient of the new parent message and copying the values from itself). -static int InitializeAndCopyToParentContainer( - ScalarMapContainer* from, - ScalarMapContainer* to) { - // For now we require from == to, re-evaluate if we want to support deep copy - // as in repeated_scalar_container.cc. - GOOGLE_DCHECK(from == to); - Message* old_message = from->message; - Message* new_message = old_message->New(); - to->parent = NULL; - to->parent_field_descriptor = from->parent_field_descriptor; - to->message = new_message; - to->owner.reset(new_message); - - vector<const FieldDescriptor*> fields; - fields.push_back(from->parent_field_descriptor); - old_message->GetReflection()->SwapFields(old_message, new_message, fields); - return 0; -} - -int Release(ScalarMapContainer* self) { - return InitializeAndCopyToParentContainer(self, self); -} - -void SetOwner(ScalarMapContainer* self, - const shared_ptr<Message>& new_owner) { - self->owner = new_owner; -} - -Py_ssize_t Length(PyObject* _self) { - ScalarMapContainer* self = GetMap(_self); - google::protobuf::Message* message = self->message; - return message->GetReflection()->FieldSize(*message, - self->parent_field_descriptor); -} - -int MapKeyMatches(ScalarMapContainer* self, const Message* entry, - PyObject* key) { - // TODO(haberman): do we need more strict type checking? - ScopedPyObjectPtr entry_key( - cmessage::InternalGetScalar(entry, self->key_field_descriptor)); - int ret = PyObject_RichCompareBool(key, entry_key.get(), Py_EQ); - return ret; -} - -PyObject* GetItem(PyObject* _self, PyObject* key) { - ScalarMapContainer* self = GetMap(_self); - - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. - // - // TODO(haberman): add lookup API to Reflection API. - size_t size = reflection->FieldSize(*message, self->parent_field_descriptor); - for (int i = size - 1; i >= 0; i--) { - const Message& entry = reflection->GetRepeatedMessage( - *message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, &entry, key); - if (matches < 0) return NULL; - if (matches) { - return cmessage::InternalGetScalar(&entry, self->value_field_descriptor); - } - } - - // Need to add a new entry. - Message* entry = - reflection->AddMessage(message, self->parent_field_descriptor); - PyObject* ret = NULL; - - if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor, - key) >= 0) { - ret = cmessage::InternalGetScalar(entry, self->value_field_descriptor); - } - - self->version++; - - // If there was a type error above, it set the Python exception. - return ret; -} - -int SetItem(PyObject *_self, PyObject *key, PyObject *v) { - ScalarMapContainer* self = GetMap(_self); - cmessage::AssureWritable(self->parent); - - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - size_t size = - reflection->FieldSize(*message, self->parent_field_descriptor); - self->version++; - - if (v) { - // Set item. - // - // Right now the Reflection API doesn't support map lookup, so we implement - // it via linear search. - // - // TODO(haberman): add lookup API to Reflection API. - for (int i = size - 1; i >= 0; i--) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, entry, key); - if (matches < 0) return -1; - if (matches) { - return cmessage::InternalSetNonOneofScalar( - entry, self->value_field_descriptor, v); - } - } - - // Key is not already present; insert a new entry. - Message* entry = - reflection->AddMessage(message, self->parent_field_descriptor); - - if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor, - key) < 0 || - cmessage::InternalSetNonOneofScalar(entry, self->value_field_descriptor, - v) < 0) { - reflection->RemoveLast(message, self->parent_field_descriptor); - return -1; - } - - return 0; - } else { - bool found = false; - for (int i = size - 1; i >= 0; i--) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, entry, key); - if (matches < 0) return -1; - if (matches) { - found = true; - if (i != (int)size - 1) { - reflection->SwapElements(message, self->parent_field_descriptor, i, - size - 1); - } - reflection->RemoveLast(message, self->parent_field_descriptor); - - // Can't exit now, the repeated field representation of maps allows - // duplicate keys, and we have to be sure to remove all of them. - } - } - - if (found) { - return 0; - } else { - PyErr_Format(PyExc_KeyError, "Key not present in map"); - return -1; - } - } -} - -PyObject* GetIterator(PyObject *_self) { - ScalarMapContainer* self = GetMap(_self); - - ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapIterator_Type, 0)); - if (obj == NULL) { - return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); - } - - ScalarMapIterator* iter = GetIter(obj.get()); - - Py_INCREF(self); - iter->container = self; - iter->version = self->version; - iter->dict = PyDict_New(); - if (iter->dict == NULL) { - return PyErr_Format(PyExc_RuntimeError, - "Could not allocate dict for iterator."); - } - - // Build the entire map into a dict right now. Start from the beginning so - // that later entries win in the case of duplicates. - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. We need to search from the end because the underlying - // representation can have duplicates if a user calls MergeFrom(); the last - // one needs to win. - // - // TODO(haberman): add lookup API to Reflection API. - size_t size = - reflection->FieldSize(*message, self->parent_field_descriptor); - for (size_t i = 0; i < size; i++) { - Message* entry = reflection->MutableRepeatedMessage( - message, self->parent_field_descriptor, i); - ScopedPyObjectPtr key( - cmessage::InternalGetScalar(entry, self->key_field_descriptor)); - ScopedPyObjectPtr val( - cmessage::InternalGetScalar(entry, self->value_field_descriptor)); - if (PyDict_SetItem(iter->dict, key.get(), val.get()) < 0) { - return PyErr_Format(PyExc_RuntimeError, - "SetItem failed in iterator construction."); - } - } - - - iter->iter = PyObject_GetIter(iter->dict); - - - return obj.release(); -} - -PyObject* Clear(PyObject* _self) { - ScalarMapContainer* self = GetMap(_self); - cmessage::AssureWritable(self->parent); - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - reflection->ClearField(message, self->parent_field_descriptor); - - Py_RETURN_NONE; -} - -PyObject* Contains(PyObject* _self, PyObject* key) { - ScalarMapContainer* self = GetMap(_self); - - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - - // Right now the Reflection API doesn't support map lookup, so we implement it - // via linear search. - // - // TODO(haberman): add lookup API to Reflection API. - size_t size = reflection->FieldSize(*message, self->parent_field_descriptor); - for (int i = size - 1; i >= 0; i--) { - const Message& entry = reflection->GetRepeatedMessage( - *message, self->parent_field_descriptor, i); - int matches = MapKeyMatches(self, &entry, key); - if (matches < 0) return NULL; - if (matches) { - Py_RETURN_TRUE; - } - } - - Py_RETURN_FALSE; -} - -PyObject* Get(PyObject* self, PyObject* args) { - PyObject* key; - PyObject* default_value = NULL; - if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { - return NULL; - } - - ScopedPyObjectPtr is_present(Contains(self, key)); - if (is_present.get() == NULL) { - return NULL; - } - - if (PyObject_IsTrue(is_present.get())) { - return GetItem(self, key); - } else { - if (default_value != NULL) { - Py_INCREF(default_value); - return default_value; - } else { - Py_RETURN_NONE; - } - } -} - -static void Dealloc(PyObject* _self) { - ScalarMapContainer* self = GetMap(_self); - self->owner.reset(); - Py_TYPE(_self)->tp_free(_self); -} - -static PyMethodDef Methods[] = { - { "__contains__", Contains, METH_O, - "Tests whether a key is a member of the map." }, - { "clear", (PyCFunction)Clear, METH_NOARGS, - "Removes all elements from the map." }, - { "get", Get, METH_VARARGS, - "Gets the value for the given key if present, or otherwise a default" }, - /* - { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, - "Makes a deep copy of the class." }, - { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, - "Outputs picklable representation of the repeated field." }, - */ - {NULL, NULL}, -}; - -} // namespace scalar_map_container - -namespace scalar_map_iterator { - -static void Dealloc(PyObject* _self) { - ScalarMapIterator* self = GetIter(_self); - Py_DECREF(self->dict); - Py_DECREF(self->iter); - Py_DECREF(self->container); - Py_TYPE(_self)->tp_free(_self); -} - -PyObject* IterNext(PyObject* _self) { - ScalarMapIterator* self = GetIter(_self); - - // This won't catch mutations to the map performed by MergeFrom(); no easy way - // to address that. - if (self->version != self->container->version) { - return PyErr_Format(PyExc_RuntimeError, - "Map modified during iteration."); - } - - return PyIter_Next(self->iter); -} - -} // namespace scalar_map_iterator - - -#if PY_MAJOR_VERSION >= 3 - static PyType_Slot ScalarMapContainer_Type_slots[] = { - {Py_tp_dealloc, (void *)scalar_map_container::Dealloc}, - {Py_mp_length, (void *)scalar_map_container::Length}, - {Py_mp_subscript, (void *)scalar_map_container::GetItem}, - {Py_mp_ass_subscript, (void *)scalar_map_container::SetItem}, - {Py_tp_methods, (void *)scalar_map_container::Methods}, - {Py_tp_iter, (void *)scalar_map_container::GetIterator}, - {0, 0}, - }; - - PyType_Spec ScalarMapContainer_Type_spec = { - FULL_MODULE_NAME ".ScalarMapContainer", - sizeof(ScalarMapContainer), - 0, - Py_TPFLAGS_DEFAULT, - ScalarMapContainer_Type_slots - }; - PyObject *ScalarMapContainer_Type; -#else - static PyMappingMethods MpMethods = { - scalar_map_container::Length, // mp_length - scalar_map_container::GetItem, // mp_subscript - scalar_map_container::SetItem, // mp_ass_subscript - }; - - PyTypeObject ScalarMapContainer_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - FULL_MODULE_NAME ".ScalarMapContainer", // tp_name - sizeof(ScalarMapContainer), // tp_basicsize - 0, // tp_itemsize - scalar_map_container::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - &MpMethods, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - "A scalar map container", // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - scalar_map_container::GetIterator, // tp_iter - 0, // tp_iternext - scalar_map_container::Methods, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - }; -#endif - -PyTypeObject ScalarMapIterator_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - FULL_MODULE_NAME ".ScalarMapIterator", // tp_name - sizeof(ScalarMapIterator), // tp_basicsize - 0, // tp_itemsize - scalar_map_iterator::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - 0, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - "A scalar map iterator", // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - PyObject_SelfIter, // tp_iter - scalar_map_iterator::IterNext, // tp_iternext - 0, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init -}; - -} // namespace python -} // namespace protobuf -} // namespace google diff --git a/python/google/protobuf/pyext/scalar_map_container.h b/python/google/protobuf/pyext/scalar_map_container.h deleted file mode 100644 index 4d663b88..00000000 --- a/python/google/protobuf/pyext/scalar_map_container.h +++ /dev/null @@ -1,119 +0,0 @@ -// Protocol Buffers - Google's data interchange format -// Copyright 2008 Google Inc. All rights reserved. -// https://developers.google.com/protocol-buffers/ -// -// Redistribution and use in source and binary forms, with or without -// modification, are permitted provided that the following conditions are -// met: -// -// * Redistributions of source code must retain the above copyright -// notice, this list of conditions and the following disclaimer. -// * Redistributions in binary form must reproduce the above -// copyright notice, this list of conditions and the following disclaimer -// in the documentation and/or other materials provided with the -// distribution. -// * Neither the name of Google Inc. nor the names of its -// contributors may be used to endorse or promote products derived from -// this software without specific prior written permission. -// -// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__ -#define GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__ - -#include <Python.h> - -#include <memory> -#ifndef _SHARED_PTR_H -#include <google/protobuf/stubs/shared_ptr.h> -#endif - -#include <google/protobuf/descriptor.h> - -namespace google { -namespace protobuf { - -class Message; - -#ifdef _SHARED_PTR_H -using std::shared_ptr; -#else -using internal::shared_ptr; -#endif - -namespace python { - -struct CMessage; - -struct ScalarMapContainer { - PyObject_HEAD; - - // This is the top-level C++ Message object that owns the whole - // proto tree. Every Python ScalarMapContainer holds a - // reference to it in order to keep it alive as long as there's a - // Python object that references any part of the tree. - shared_ptr<Message> owner; - - // Pointer to the C++ Message that contains this container. The - // ScalarMapContainer does not own this pointer. - Message* message; - - // Weak reference to a parent CMessage object (i.e. may be NULL.) - // - // Used to make sure all ancestors are also mutable when first - // modifying the container. - CMessage* parent; - - // Pointer to the parent's descriptor that describes this - // field. Used together with the parent's message when making a - // default message instance mutable. - // The pointer is owned by the global DescriptorPool. - const FieldDescriptor* parent_field_descriptor; - const FieldDescriptor* key_field_descriptor; - const FieldDescriptor* value_field_descriptor; - - // We bump this whenever we perform a mutation, to invalidate existing - // iterators. - uint64 version; -}; - -#if PY_MAJOR_VERSION >= 3 - extern PyObject *ScalarMapContainer_Type; - extern PyType_Spec ScalarMapContainer_Type_spec; -#else - extern PyTypeObject ScalarMapContainer_Type; -#endif -extern PyTypeObject ScalarMapIterator_Type; - -namespace scalar_map_container { - -// Builds a ScalarMapContainer object, from a parent message and a -// field descriptor. -extern PyObject *NewContainer( - CMessage* parent, const FieldDescriptor* parent_field_descriptor); - -// Releases the messages in the container to a new message. -// -// Returns 0 on success, -1 on failure. -int Release(ScalarMapContainer* self); - -// Set the owner field of self and any children of self. -void SetOwner(ScalarMapContainer* self, - const shared_ptr<Message>& new_owner); - -} // namespace scalar_map_container -} // namespace python -} // namespace protobuf - -} // namespace google -#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__ diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index b81ef4d7..87760f26 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -60,7 +60,6 @@ Example usage: """ -from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool @@ -73,37 +72,12 @@ class SymbolDatabase(object): buffer types used within a program. """ - # pylint: disable=protected-access - if _descriptor._USE_C_DESCRIPTORS: - - def __new__(cls): - raise TypeError("Instances of SymbolDatabase cannot be created") - - @classmethod - def _CreateDefaultDatabase(cls): - self = object.__new__(cls) # Bypass the __new__ above. - # Don't call __init__() and initialize here. - self._symbols = {} - self._symbols_by_file = {} - # As of today all descriptors are registered and retrieved from - # _message.default_pool (see FileDescriptor.__new__), so it's not - # necessary to use another pool. - self.pool = _descriptor._message.default_pool - return self - # pylint: enable=protected-access - - else: - - @classmethod - def _CreateDefaultDatabase(cls): - return cls() - - def __init__(self): + def __init__(self, pool=None): """Constructor.""" self._symbols = {} self._symbols_by_file = {} - self.pool = descriptor_pool.DescriptorPool() + self.pool = pool or descriptor_pool.Default() def RegisterMessage(self, message): """Registers the given message type in the local database. @@ -203,7 +177,7 @@ class SymbolDatabase(object): result.update(self._symbols_by_file[f]) return result -_DEFAULT = SymbolDatabase._CreateDefaultDatabase() +_DEFAULT = SymbolDatabase(pool=descriptor_pool.Default()) def Default(): diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index e4fadf09..8d256076 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -66,6 +66,7 @@ _FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) _FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) +_QUOTES = frozenset(("'", '"')) class Error(Exception): @@ -73,7 +74,8 @@ class Error(Exception): class ParseError(Error): - """Thrown in case of ASCII parsing error.""" + """Thrown in case of text parsing error.""" + class TextWriter(object): def __init__(self, as_utf8): @@ -102,7 +104,8 @@ def MessageToString(message, as_utf8=False, as_one_line=False, Floating point values can be formatted compactly with 15 digits of precision (which is the most that IEEE 754 "double" can guarantee) - using float_format='.15g'. + using float_format='.15g'. To ensure that converting to text and back to a + proto will result in an identical value, float_format='.17g' should be used. Args: message: The protocol buffers message. @@ -130,11 +133,13 @@ def MessageToString(message, as_utf8=False, as_one_line=False, return result.rstrip() return result + def _IsMapEntry(field): return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and field.message_type.has_options and field.message_type.GetOptions().map_entry) + def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, float_format=None): @@ -166,17 +171,18 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, use_index_order=use_index_order, float_format=float_format) + def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, float_format=None): """Print a single field name/value pair. For repeated fields, the value - should be a single element.""" + should be a single element. + """ out.write(' ' * indent) if field.is_extension: out.write('[') if (field.containing_type.GetOptions().message_set_wire_format and field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and - field.message_type == field.extension_scope and field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): out.write(field.message_type.full_name) else: @@ -262,95 +268,113 @@ def PrintFieldValue(field, value, out, indent=0, as_utf8=False, out.write(str(value)) -def Parse(text, message): - """Parses an ASCII representation of a protocol message into a message. +def Parse(text, message, allow_unknown_extension=False): + """Parses an text representation of a protocol message into a message. Args: - text: Message ASCII representation. + text: Message text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - if not isinstance(text, str): text = text.decode('utf-8') - return ParseLines(text.split('\n'), message) + if not isinstance(text, str): + text = text.decode('utf-8') + return ParseLines(text.split('\n'), message, allow_unknown_extension) -def Merge(text, message): - """Parses an ASCII representation of a protocol message into a message. +def Merge(text, message, allow_unknown_extension=False): + """Parses an text representation of a protocol message into a message. Like Parse(), but allows repeated values for a non-repeated field, and uses the last one. Args: - text: Message ASCII representation. + text: Message text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - return MergeLines(text.split('\n'), message) + return MergeLines(text.split('\n'), message, allow_unknown_extension) -def ParseLines(lines, message): - """Parses an ASCII representation of a protocol message into a message. +def ParseLines(lines, message, allow_unknown_extension=False): + """Parses an text representation of a protocol message into a message. Args: - lines: An iterable of lines of a message's ASCII representation. + lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, False) + _ParseOrMerge(lines, message, False, allow_unknown_extension) return message -def MergeLines(lines, message): - """Parses an ASCII representation of a protocol message into a message. +def MergeLines(lines, message, allow_unknown_extension=False): + """Parses an text representation of a protocol message into a message. Args: - lines: An iterable of lines of a message's ASCII representation. + lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. + allow_unknown_extension: if True, skip over missing extensions and keep + parsing Returns: The same message passed as argument. Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ - _ParseOrMerge(lines, message, True) + _ParseOrMerge(lines, message, True, allow_unknown_extension) return message -def _ParseOrMerge(lines, message, allow_multiple_scalars): - """Converts an ASCII representation of a protocol message into a message. +def _ParseOrMerge(lines, + message, + allow_multiple_scalars, + allow_unknown_extension=False): + """Converts an text representation of a protocol message into a message. Args: - lines: Lines of a message's ASCII representation. + lines: Lines of a message's text representation. message: A protocol buffer message to merge into. allow_multiple_scalars: Determines if repeated values for a non-repeated field are permitted, e.g., the string "foo: 1 foo: 2" for a required/optional field named "foo". + allow_unknown_extension: if True, skip over missing extensions and keep + parsing Raises: - ParseError: On ASCII parsing problems. + ParseError: On text parsing problems. """ tokenizer = _Tokenizer(lines) while not tokenizer.AtEnd(): - _MergeField(tokenizer, message, allow_multiple_scalars) + _MergeField(tokenizer, message, allow_multiple_scalars, + allow_unknown_extension) -def _MergeField(tokenizer, message, allow_multiple_scalars): +def _MergeField(tokenizer, + message, + allow_multiple_scalars, + allow_unknown_extension=False): """Merges a single protocol message field into a message. Args: @@ -359,9 +383,11 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): allow_multiple_scalars: Determines if repeated values for a non-repeated field are permitted, e.g., the string "foo: 1 foo: 2" for a required/optional field named "foo". + allow_unknown_extension: if True, skip over missing extensions and keep + parsing Raises: - ParseError: In case of ASCII parsing problems. + ParseError: In case of text parsing problems. """ message_descriptor = message.DESCRIPTOR if (hasattr(message_descriptor, 'syntax') and @@ -383,13 +409,18 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): field = message.Extensions._FindExtensionByName(name) # pylint: enable=protected-access if not field: - raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" not registered.' % name) + if allow_unknown_extension: + field = None + else: + raise tokenizer.ParseErrorPreviousToken( + 'Extension "%s" not registered.' % name) elif message_descriptor != field.containing_type: raise tokenizer.ParseErrorPreviousToken( 'Extension "%s" does not extend message type "%s".' % ( name, message_descriptor.full_name)) + tokenizer.Consume(']') + else: name = tokenizer.ConsumeIdentifier() field = message_descriptor.fields_by_name.get(name, None) @@ -411,7 +442,7 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): 'Message type "%s" has no field named "%s".' % ( message_descriptor.full_name, name)) - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + if field and field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: is_map_entry = _IsMapEntry(field) tokenizer.TryConsume(':') @@ -438,7 +469,8 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): while not tokenizer.TryConsume(end_token): if tokenizer.AtEnd(): raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) - _MergeField(tokenizer, sub_message, allow_multiple_scalars) + _MergeField(tokenizer, sub_message, allow_multiple_scalars, + allow_unknown_extension) if is_map_entry: value_cpptype = field.message_type.fields_by_name['value'].cpp_type @@ -447,8 +479,63 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): value.MergeFrom(sub_message.value) else: getattr(message, field.name)[sub_message.key] = sub_message.value + elif field: + tokenizer.Consume(':') + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and + tokenizer.TryConsume('[')): + # Short repeated format, e.g. "foo: [1, 2, 3]" + while True: + _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + if tokenizer.TryConsume(']'): + break + tokenizer.Consume(',') + else: + _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + else: # Proto field is unknown. + assert allow_unknown_extension + _SkipFieldContents(tokenizer) + + # For historical reasons, fields may optionally be separated by commas or + # semicolons. + if not tokenizer.TryConsume(','): + tokenizer.TryConsume(';') + + +def _SkipFieldContents(tokenizer): + """Skips over contents (value or message) of a field. + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + # Try to guess the type of this field. + # If this field is not a message, there should be a ":" between the + # field name and the field value and also the field value should not + # start with "{" or "<" which indicates the beginning of a message body. + # If there is no ":" or there is a "{" or "<" after ":", this field has + # to be a message or the input is ill-formed. + if tokenizer.TryConsume(':') and not tokenizer.LookingAt( + '{') and not tokenizer.LookingAt('<'): + _SkipFieldValue(tokenizer) + else: + _SkipFieldMessage(tokenizer) + + +def _SkipField(tokenizer): + """Skips over a complete field (name and value/message). + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + if tokenizer.TryConsume('['): + # Consume extension name. + tokenizer.ConsumeIdentifier() + while tokenizer.TryConsume('.'): + tokenizer.ConsumeIdentifier() + tokenizer.Consume(']') else: - _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) + tokenizer.ConsumeIdentifier() + + _SkipFieldContents(tokenizer) # For historical reasons, fields may optionally be separated by commas or # semicolons. @@ -456,6 +543,48 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): tokenizer.TryConsume(';') +def _SkipFieldMessage(tokenizer): + """Skips over a field message. + + Args: + tokenizer: A tokenizer to parse the field name and values. + """ + + if tokenizer.TryConsume('<'): + delimiter = '>' + else: + tokenizer.Consume('{') + delimiter = '}' + + while not tokenizer.LookingAt('>') and not tokenizer.LookingAt('}'): + _SkipField(tokenizer) + + tokenizer.Consume(delimiter) + + +def _SkipFieldValue(tokenizer): + """Skips over a field value. + + Args: + tokenizer: A tokenizer to parse the field name and values. + + Raises: + ParseError: In case an invalid field value is found. + """ + # String tokens can come in multiple adjacent string literals. + # If we can consume one, consume as many as we can. + if tokenizer.TryConsumeString(): + while tokenizer.TryConsumeString(): + pass + return + + if (not tokenizer.TryConsumeIdentifier() and + not tokenizer.TryConsumeInt64() and + not tokenizer.TryConsumeUint64() and + not tokenizer.TryConsumeFloat()): + raise ParseError('Invalid field value: ' + tokenizer.token) + + def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): """Merges a single protocol message scalar field into a message. @@ -468,10 +597,9 @@ def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): required/optional field named "foo". Raises: - ParseError: In case of ASCII parsing problems. + ParseError: In case of text parsing problems. RuntimeError: On runtime errors. """ - tokenizer.Consume(':') value = None if field.type in (descriptor.FieldDescriptor.TYPE_INT32, @@ -525,7 +653,7 @@ def _MergeScalarField(tokenizer, message, field, allow_multiple_scalars): class _Tokenizer(object): - """Protocol buffer ASCII representation tokenizer. + """Protocol buffer text representation tokenizer. This class handles the lower level string parsing by splitting it into meaningful tokens. @@ -534,11 +662,13 @@ class _Tokenizer(object): """ _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE) - _TOKEN = re.compile( - '[a-zA-Z_][0-9a-zA-Z_+-]*|' # an identifier - '[0-9+-][0-9a-zA-Z_.+-]*|' # a number - '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string - '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string + _TOKEN = re.compile('|'.join([ + r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier + r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number + ] + [ # quoted str for each quote mark + r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES + ])) + _IDENTIFIER = re.compile(r'\w+') def __init__(self, lines): @@ -555,6 +685,9 @@ class _Tokenizer(object): self._SkipWhitespace() self.NextToken() + def LookingAt(self, token): + return self.token == token + def AtEnd(self): """Checks the end of the text was reached. @@ -610,6 +743,13 @@ class _Tokenizer(object): if not self.TryConsume(token): raise self._ParseError('Expected "%s".' % token) + def TryConsumeIdentifier(self): + try: + self.ConsumeIdentifier() + return True + except ParseError: + return False + def ConsumeIdentifier(self): """Consumes protocol message field identifier. @@ -657,6 +797,13 @@ class _Tokenizer(object): self.NextToken() return result + def TryConsumeInt64(self): + try: + self.ConsumeInt64() + return True + except ParseError: + return False + def ConsumeInt64(self): """Consumes a signed 64bit integer number. @@ -673,6 +820,13 @@ class _Tokenizer(object): self.NextToken() return result + def TryConsumeUint64(self): + try: + self.ConsumeUint64() + return True + except ParseError: + return False + def ConsumeUint64(self): """Consumes an unsigned 64bit integer number. @@ -689,6 +843,13 @@ class _Tokenizer(object): self.NextToken() return result + def TryConsumeFloat(self): + try: + self.ConsumeFloat() + return True + except ParseError: + return False + def ConsumeFloat(self): """Consumes an floating point number. @@ -721,6 +882,13 @@ class _Tokenizer(object): self.NextToken() return result + def TryConsumeString(self): + try: + self.ConsumeString() + return True + except ParseError: + return False + def ConsumeString(self): """Consumes a string value. @@ -746,7 +914,7 @@ class _Tokenizer(object): ParseError: If a byte array value couldn't be consumed. """ the_list = [self._ConsumeSingleByteString()] - while self.token and self.token[0] in ('\'', '"'): + while self.token and self.token[0] in _QUOTES: the_list.append(self._ConsumeSingleByteString()) return b''.join(the_list) @@ -757,11 +925,13 @@ class _Tokenizer(object): tokens which are automatically concatenated, like in C or Python. This method only consumes one token. + Returns: + The token parsed. Raises: ParseError: When the wrong format data is found. """ text = self.token - if len(text) < 1 or text[0] not in ('\'', '"'): + if len(text) < 1 or text[0] not in _QUOTES: raise self._ParseError('Expected string but found: %r' % (text,)) if len(text) < 2 or text[-1] != text[0]: diff --git a/python/setup.py b/python/setup.py index 18865e03..22f6e816 100755 --- a/python/setup.py +++ b/python/setup.py @@ -89,6 +89,7 @@ def GenerateUnittestProtos(): generate_proto("../src/google/protobuf/unittest_no_generic_services.proto", False) generate_proto("../src/google/protobuf/unittest_proto3_arena.proto", False) generate_proto("../src/google/protobuf/util/json_format_proto3.proto", False) + generate_proto("google/protobuf/internal/any_test.proto", False) generate_proto("google/protobuf/internal/descriptor_pool_test1.proto", False) generate_proto("google/protobuf/internal/descriptor_pool_test2.proto", False) generate_proto("google/protobuf/internal/factory_test1.proto", False) |