From eee38b0c018b3279f77d03dff796f440f40d3516 Mon Sep 17 00:00:00 2001 From: Feng Xiao Date: Sat, 22 Aug 2015 18:25:48 -0700 Subject: Down-integrate from google3. --- python/google/protobuf/internal/containers.py | 11 +- python/google/protobuf/internal/generator_test.py | 3 +- python/google/protobuf/internal/message_test.py | 115 ++++ .../protobuf/internal/packed_field_test.proto | 73 +++ python/google/protobuf/internal/python_message.py | 151 +++-- python/google/protobuf/internal/reflection_test.py | 48 +- python/google/protobuf/internal/test_util.py | 3 +- .../google/protobuf/internal/text_format_test.py | 31 + .../protobuf/internal/unknown_fields_test.py | 72 +- python/google/protobuf/pyext/cpp_message.py | 36 +- python/google/protobuf/pyext/descriptor.cc | 2 +- python/google/protobuf/pyext/descriptor_pool.cc | 7 + python/google/protobuf/pyext/descriptor_pool.h | 10 + python/google/protobuf/pyext/extension_dict.cc | 6 +- python/google/protobuf/pyext/message.cc | 724 +++++++++++++-------- python/google/protobuf/pyext/message.h | 4 - .../google/protobuf/pyext/message_map_container.cc | 1 + .../protobuf/pyext/repeated_composite_container.cc | 211 +----- .../protobuf/pyext/repeated_scalar_container.cc | 11 +- .../google/protobuf/pyext/scalar_map_container.cc | 1 + python/google/protobuf/pyext/scoped_pyobject_ptr.h | 16 +- python/google/protobuf/reflection.py | 96 +-- python/google/protobuf/text_format.py | 2 +- 23 files changed, 969 insertions(+), 665 deletions(-) (limited to 'python/google') diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 72c2fa01..9c8275eb 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -41,6 +41,7 @@ are: __author__ = 'petar@google.com (Petar Petrov)' +import collections import sys if sys.version_info[0] < 3: @@ -63,7 +64,6 @@ if sys.version_info[0] < 3: # Note: deriving from object is critical. It is the only thing that makes # this a true type, allowing us to derive from it in C++ cleanly and making # __slots__ properly disallow arbitrary element assignment. - from collections import Mapping as _Mapping class Mapping(object): __slots__ = () @@ -106,7 +106,7 @@ if sys.version_info[0] < 3: __hash__ = None def __eq__(self, other): - if not isinstance(other, _Mapping): + if not isinstance(other, collections.Mapping): return NotImplemented return dict(self.items()) == dict(other.items()) @@ -173,12 +173,13 @@ if sys.version_info[0] < 3: self[key] = default return default - _Mapping.register(Mapping) + collections.Mapping.register(Mapping) + collections.MutableMapping.register(MutableMapping) else: # In Python 3 we can just use MutableMapping directly, because it defines # __slots__. - from collections import MutableMapping + MutableMapping = collections.MutableMapping class BaseContainer(object): @@ -336,6 +337,8 @@ class RepeatedScalarFieldContainer(BaseContainer): # We are presumably comparing against some other sequence type. return other == self._values +collections.MutableSequence.register(BaseContainer) + class RepeatedCompositeFieldContainer(BaseContainer): diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 5c07cbe6..c30f633d 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -47,6 +47,7 @@ from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_mset_wire_format_pb2 from google.protobuf import unittest_no_generic_services_pb2 from google.protobuf import unittest_pb2 from google.protobuf import service @@ -142,7 +143,7 @@ class GeneratorTest(unittest.TestCase): self.assertTrue(not non_extension_descriptor.is_extension) def testOptions(self): - proto = unittest_mset_pb2.TestMessageSet() + proto = unittest_mset_wire_format_pb2.TestMessageSet() self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) def testMessageWithCustomOptions(self): diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 320ff0d2..62abf1be 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -43,6 +43,7 @@ abstract interface. __author__ = 'gps@google.com (Gregory P. Smith)' +import collections import copy import math import operator @@ -56,6 +57,7 @@ from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation +from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message @@ -421,6 +423,31 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + def testSortingRepeatedCompositeFieldsStable(self, message_module): + """Check passing a custom comparator to sort a repeated composite field.""" + message = message_module.TestAllTypes() + + message.repeated_nested_message.add().bb = 21 + message.repeated_nested_message.add().bb = 20 + message.repeated_nested_message.add().bb = 13 + message.repeated_nested_message.add().bb = 33 + message.repeated_nested_message.add().bb = 11 + message.repeated_nested_message.add().bb = 24 + message.repeated_nested_message.add().bb = 10 + message.repeated_nested_message.sort(key=lambda z: z.bb // 10) + self.assertEquals( + [13, 11, 10, 21, 20, 24, 33], + [n.bb for n in message.repeated_nested_message]) + + # Make sure that for the C++ implementation, the underlying fields + # are actually reordered. + pb = message.SerializeToString() + message.Clear() + message.MergeFromString(pb) + self.assertEquals( + [13, 11, 10, 21, 20, 24, 33], + [n.bb for n in message.repeated_nested_message]) + def testRepeatedCompositeFieldSortArguments(self, message_module): """Check sorting a repeated composite field using list.sort() arguments.""" message = message_module.TestAllTypes() @@ -514,6 +541,12 @@ class MessageTest(unittest.TestCase): # TODO(anuraag): Implement extensiondict comparison in C++ and then add test + def testRepeatedFieldsAreSequences(self, message_module): + m = message_module.TestAllTypes() + self.assertIsInstance(m.repeated_int32, collections.MutableSequence) + self.assertIsInstance(m.repeated_nested_message, + collections.MutableSequence) + def ensureNestedMessageExists(self, msg, attribute): """Make sure that a nested message object exists. @@ -556,6 +589,18 @@ class MessageTest(unittest.TestCase): self.assertFalse(m.HasField('oneof_uint32')) self.assertTrue(m.HasField('oneof_string')) + # Read nested message accessor without accessing submessage. + m.oneof_nested_message + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_nested_message')) + + # Read accessor of nested message without accessing submessage. + m.oneof_nested_message.bb + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_nested_message')) + m.oneof_nested_message.bb = 11 self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) self.assertFalse(m.HasField('oneof_string')) @@ -1583,6 +1628,21 @@ class Proto3Test(unittest.TestCase): del msg.map_int32_int32[4] self.assertEqual(0, len(msg.map_int32_int32)) + def testMapsAreMapping(self): + msg = map_unittest_pb2.TestMap() + self.assertIsInstance(msg.map_int32_int32, collections.Mapping) + self.assertIsInstance(msg.map_int32_int32, collections.MutableMapping) + self.assertIsInstance(msg.map_int32_foreign_message, collections.Mapping) + self.assertIsInstance(msg.map_int32_foreign_message, + collections.MutableMapping) + + def testMapFindInitializationErrorsSmokeTest(self): + msg = map_unittest_pb2.TestMap() + msg.map_string_string['abc'] = '123' + msg.map_int32_int32[35] = 64 + msg.map_string_foreign_message['foo'].c = 5 + self.assertEqual(0, len(msg.FindInitializationErrors())) + class ValidTypeNamesTest(unittest.TestCase): @@ -1606,6 +1666,61 @@ class ValidTypeNamesTest(unittest.TestCase): self.assertImportFromName(pb.repeated_int32, 'Scalar') self.assertImportFromName(pb.repeated_nested_message, 'Composite') +class PackedFieldTest(unittest.TestCase): + + def setMessage(self, message): + message.repeated_int32.append(1) + message.repeated_int64.append(1) + message.repeated_uint32.append(1) + message.repeated_uint64.append(1) + message.repeated_sint32.append(1) + message.repeated_sint64.append(1) + message.repeated_fixed32.append(1) + message.repeated_fixed64.append(1) + message.repeated_sfixed32.append(1) + message.repeated_sfixed64.append(1) + message.repeated_float.append(1.0) + message.repeated_double.append(1.0) + message.repeated_bool.append(True) + message.repeated_nested_enum.append(1) + + def testPackedFields(self): + message = packed_field_test_pb2.TestPackedTypes() + self.setMessage(message) + golden_data = (b'\x0A\x01\x01' + b'\x12\x01\x01' + b'\x1A\x01\x01' + b'\x22\x01\x01' + b'\x2A\x01\x02' + b'\x32\x01\x02' + b'\x3A\x04\x01\x00\x00\x00' + b'\x42\x08\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x4A\x04\x01\x00\x00\x00' + b'\x52\x08\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x5A\x04\x00\x00\x80\x3f' + b'\x62\x08\x00\x00\x00\x00\x00\x00\xf0\x3f' + b'\x6A\x01\x01' + b'\x72\x01\x01') + self.assertEqual(golden_data, message.SerializeToString()) + + def testUnpackedFields(self): + message = packed_field_test_pb2.TestUnpackedTypes() + self.setMessage(message) + golden_data = (b'\x08\x01' + b'\x10\x01' + b'\x18\x01' + b'\x20\x01' + b'\x28\x02' + b'\x30\x02' + b'\x3D\x01\x00\x00\x00' + b'\x41\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x4D\x01\x00\x00\x00' + b'\x51\x01\x00\x00\x00\x00\x00\x00\x00' + b'\x5D\x00\x00\x80\x3f' + b'\x61\x00\x00\x00\x00\x00\x00\xf0\x3f' + b'\x68\x01' + b'\x70\x01') + self.assertEqual(golden_data, message.SerializeToString()) if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/packed_field_test.proto b/python/google/protobuf/internal/packed_field_test.proto index e69de29b..0dfdc10a 100644 --- a/python/google/protobuf/internal/packed_field_test.proto +++ b/python/google/protobuf/internal/packed_field_test.proto @@ -0,0 +1,73 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto3"; + +package google.protobuf.python.internal; + +message TestPackedTypes { + enum NestedEnum { + FOO = 0; + BAR = 1; + BAZ = 2; + } + + repeated int32 repeated_int32 = 1; + repeated int64 repeated_int64 = 2; + repeated uint32 repeated_uint32 = 3; + repeated uint64 repeated_uint64 = 4; + repeated sint32 repeated_sint32 = 5; + repeated sint64 repeated_sint64 = 6; + repeated fixed32 repeated_fixed32 = 7; + repeated fixed64 repeated_fixed64 = 8; + repeated sfixed32 repeated_sfixed32 = 9; + repeated sfixed64 repeated_sfixed64 = 10; + repeated float repeated_float = 11; + repeated double repeated_double = 12; + repeated bool repeated_bool = 13; + repeated NestedEnum repeated_nested_enum = 14; +} + +message TestUnpackedTypes { + repeated int32 repeated_int32 = 1 [packed = false]; + repeated int64 repeated_int64 = 2 [packed = false]; + repeated uint32 repeated_uint32 = 3 [packed = false]; + repeated uint64 repeated_uint64 = 4 [packed = false]; + repeated sint32 repeated_sint32 = 5 [packed = false]; + repeated sint64 repeated_sint64 = 6 [packed = false]; + repeated fixed32 repeated_fixed32 = 7 [packed = false]; + repeated fixed64 repeated_fixed64 = 8 [packed = false]; + repeated sfixed32 repeated_sfixed32 = 9 [packed = false]; + repeated sfixed64 repeated_sfixed64 = 10 [packed = false]; + repeated float repeated_float = 11 [packed = false]; + repeated double repeated_double = 12 [packed = false]; + repeated bool repeated_bool = 13 [packed = false]; + repeated TestPackedTypes.NestedEnum repeated_nested_enum = 14 [packed = false]; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index ca9f7675..a3e98467 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -85,34 +85,108 @@ from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor -def NewMessage(bases, descriptor, dictionary): - _AddClassAttributesForNestedExtensions(descriptor, dictionary) - _AddSlots(descriptor, dictionary) - return bases - - -def InitMessage(descriptor, cls): - cls._decoders_by_tag = {} - cls._extensions_by_name = {} - cls._extensions_by_number = {} - if (descriptor.has_options and - descriptor.GetOptions().message_set_wire_format): - cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( - decoder.MessageSetItemDecoder(cls._extensions_by_number), None) - - # Attach stuff to each FieldDescriptor for quick lookup later on. - for field in descriptor.fields: - _AttachFieldHelpers(cls, field) +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + + superclass = super(GeneratedProtocolMessageType, cls) + new_class = superclass.__new__(cls, name, bases, dictionary) + return new_class + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. - descriptor._concrete_class = cls # pylint: disable=protected-access - _AddEnumValues(descriptor, cls) - _AddInitMethod(descriptor, cls) - _AddPropertiesForFields(descriptor, cls) - _AddPropertiesForExtensions(descriptor, cls) - _AddStaticMethods(cls) - _AddMessageMethods(descriptor, cls) - _AddPrivateHelperMethods(descriptor, cls) - copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number), None) + + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + descriptor._concrete_class = cls # pylint: disable=protected-access + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(name, bases, dictionary) # Stateless helpers for GeneratedProtocolMessageType below. @@ -362,9 +436,10 @@ def _DefaultValueConstructorForField(field): message_type = field.message_type def MakeSubMessageDefault(message): result = message_type._concrete_class() - result._SetListener(message._listener_for_children) - if field.containing_oneof: - message._UpdateOneofState(field) + result._SetListener( + _OneofListener(message, field) + if field.containing_oneof is not None + else message._listener_for_children) return result return MakeSubMessageDefault @@ -634,21 +709,11 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): proto_field_name = field.name property_name = _PropertyName(proto_field_name) - # TODO(komarek): Can anyone explain to me why we cache the message_type this - # way, instead of referring to field.message_type inside of getter(self)? - # What if someone sets message_type later on (which makes for simpler - # dyanmic proto descriptor and class creation code). - message_type = field.message_type - def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = message_type._concrete_class() # use field.message_type? - field_value._SetListener( - _OneofListener(self, field) - if field.containing_oneof is not None - else self._listener_for_children) + field_value = field._default_constructor(self) # Atomically check if another thread has preempted us and, if not, swap # in the new object we just created. If someone has preempted us, we @@ -1121,7 +1186,7 @@ def _AddIsInitializedMethod(message_descriptor, cls): if _IsMessageMapField(field): for key in value: element = value[key] - prefix = "%s[%d]." % (name, key) + prefix = "%s[%s]." % (name, key) sub_errors = element.FindInitializationErrors() errors += [prefix + error for error in sub_errors] else: @@ -1173,8 +1238,6 @@ def _AddMergeFromMethod(cls): # Construct a new object to represent this field. field_value = field._default_constructor(self) fields[field] = field_value - if field.containing_oneof: - self._UpdateOneofState(field) field_value.MergeFrom(value) else: self._fields[field] = value diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 4eca4989..ef1ced4e 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -52,6 +52,7 @@ from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import wire_format from google.protobuf.internal import test_util from google.protobuf.internal import decoder @@ -1682,8 +1683,8 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = 'abc' def testStringUTF8Serialization(self): - proto = unittest_mset_pb2.TestMessageSet() - extension_message = unittest_mset_pb2.TestMessageSetExtension2 + proto = message_set_extensions_pb2.TestMessageSet() + extension_message = message_set_extensions_pb2.TestMessageSetExtension2 extension = extension_message.message_set_extension test_utf8 = u'Тест' @@ -1703,15 +1704,14 @@ class ReflectionTest(unittest.TestCase): bytes_read = raw.MergeFromString(serialized) self.assertEqual(len(serialized), bytes_read) - message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2 = message_set_extensions_pb2.TestMessageSetExtension2() self.assertEqual(1, len(raw.item)) # Check that the type_id is the same as the tag ID in the .proto file. - self.assertEqual(raw.item[0].type_id, 1547769) + self.assertEqual(raw.item[0].type_id, 98418634) # Check the actual bytes on the wire. - self.assertTrue( - raw.item[0].message.endswith(test_utf8_bytes)) + self.assertTrue(raw.item[0].message.endswith(test_utf8_bytes)) bytes_read = message2.MergeFromString(raw.item[0].message) self.assertEqual(len(raw.item[0].message), bytes_read) @@ -2395,9 +2395,9 @@ class SerializationTest(unittest.TestCase): self.assertEqual(42, second_proto.optional_nested_message.bb) def testMessageSetWireFormat(self): - proto = unittest_mset_pb2.TestMessageSet() - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 + proto = message_set_extensions_pb2.TestMessageSet() + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + extension_message2 = message_set_extensions_pb2.TestMessageSetExtension2 extension1 = extension_message1.message_set_extension extension2 = extension_message2.message_set_extension proto.Extensions[extension1].i = 123 @@ -2415,20 +2415,20 @@ class SerializationTest(unittest.TestCase): raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) - message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1 = message_set_extensions_pb2.TestMessageSetExtension1() self.assertEqual( len(raw.item[0].message), message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) - message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2 = message_set_extensions_pb2.TestMessageSetExtension2() self.assertEqual( len(raw.item[1].message), message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. - proto2 = unittest_mset_pb2.TestMessageSet() + proto2 = message_set_extensions_pb2.TestMessageSet() self.assertEqual( len(serialized), proto2.MergeFromString(serialized)) @@ -2446,37 +2446,37 @@ class SerializationTest(unittest.TestCase): # Add an item. item = raw.item.add() - item.type_id = 1545008 - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418603 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12345 item.message = message1.SerializeToString() # Add a second, unknown extension. item = raw.item.add() - item.type_id = 1545009 - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418604 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12346 item.message = message1.SerializeToString() # Add another unknown extension. item = raw.item.add() - item.type_id = 1545010 - message1 = unittest_mset_pb2.TestMessageSetExtension2() + item.type_id = 98418605 + message1 = message_set_extensions_pb2.TestMessageSetExtension2() message1.str = 'foo' item.message = message1.SerializeToString() serialized = raw.SerializeToString() # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() self.assertEqual( len(serialized), proto.MergeFromString(serialized)) # Check that the message parsed well. - extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension_message1 = message_set_extensions_pb2.TestMessageSetExtension1 extension1 = extension_message1.message_set_extension self.assertEquals(12345, proto.Extensions[extension1].i) @@ -2805,7 +2805,7 @@ class SerializationTest(unittest.TestCase): class OptionsTest(unittest.TestCase): def testMessageOptions(self): - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() self.assertEqual(True, proto.DESCRIPTOR.GetOptions().message_set_wire_format) proto = unittest_pb2.TestAllTypes() @@ -2824,7 +2824,7 @@ class OptionsTest(unittest.TestCase): proto.packed_double.append(3.0) for field_descriptor, _ in proto.ListFields(): self.assertEqual(True, field_descriptor.GetOptions().packed) - self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, + self.assertEqual(descriptor.FieldDescriptor.LABEL_REPEATED, field_descriptor.label) diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index fec65382..ac88fa81 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -604,7 +604,8 @@ def GoldenFile(filename): # Search internally. path = '.' - full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata', filename) + full_path = os.path.join(path, 'third_party/py/google/protobuf/testdata', + filename) if os.path.exists(full_path): # Found it. Load the golden file from the testdata directory. return open(full_path, 'rb') diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 06bd1ee5..00e67654 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -35,6 +35,7 @@ __author__ = 'kenton@google.com (Kenton Varda)' import re +import string import unittest import unittest @@ -497,6 +498,36 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): ' }\n' '}\n') + def testMapOrderEnforcement(self): + message = map_unittest_pb2.TestMap() + for letter in string.ascii_uppercase[13:26]: + message.map_string_string[letter] = 'dummy' + for letter in reversed(string.ascii_uppercase[0:13]): + message.map_string_string[letter] = 'dummy' + golden = ''.join(( + 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,) + for letter in string.ascii_uppercase)) + self.CompareToGoldenText(text_format.MessageToString(message), golden) + + def testMapOrderSemantics(self): + golden_lines = self.ReadGolden('map_test_data.txt') + # The C++ implementation emits defaulted-value fields, while the Python + # implementation does not. Adjusting for this is awkward, but it is + # valuable to test against a common golden file. + line_blacklist = (' key: 0\n', + ' value: 0\n', + ' key: false\n', + ' value: false\n') + golden_lines = [line for line in golden_lines if line not in line_blacklist] + + message = map_unittest_pb2.TestMap() + text_format.ParseLines(golden_lines, message) + candidate = text_format.MessageToString(message) + # The Python implementation emits "1.0" for the double value that the C++ + # implementation emits as "1". + candidate = candidate.replace('1.0', '1', 2) + self.assertMultiLineEqual(candidate, ''.join(golden_lines)) + # Tests of proto2-only features (MessageSet, extensions, etc.). class Proto2Tests(TextFormatBase): diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 1b81ae79..0dda805b 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -41,11 +41,18 @@ from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import encoder +from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import type_checkers +def SkipIfCppImplementation(func): + return unittest.skipIf( + api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, + 'C++ implementation does not expose unknown fields to Python')(func) + + class UnknownFieldsTest(unittest.TestCase): def setUp(self): @@ -83,15 +90,15 @@ class UnknownFieldsTest(unittest.TestCase): # Add an unknown extension. item = raw.item.add() - item.type_id = 1545009 - message1 = unittest_mset_pb2.TestMessageSetExtension1() + item.type_id = 98418603 + message1 = message_set_extensions_pb2.TestMessageSetExtension1() message1.i = 12345 item.message = message1.SerializeToString() serialized = raw.SerializeToString() # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() + proto = message_set_extensions_pb2.TestMessageSet() proto.MergeFromString(serialized) # Verify that the unknown extension is serialized unchanged @@ -100,13 +107,6 @@ class UnknownFieldsTest(unittest.TestCase): new_raw.MergeFromString(reserialized) self.assertEqual(raw, new_raw) - # C++ implementation for proto2 does not currently take into account unknown - # fields when checking equality. - # - # TODO(haberman): fix this. - @unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python') def testEquals(self): message = unittest_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) @@ -117,9 +117,6 @@ class UnknownFieldsTest(unittest.TestCase): self.assertNotEqual(self.empty_message, message) -@unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python') class UnknownFieldsAccessorsTest(unittest.TestCase): def setUp(self): @@ -129,7 +126,14 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): self.all_fields_data = self.all_fields.SerializeToString() self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - self.unknown_fields = self.empty_message._unknown_fields + if api_implementation.Type() != 'cpp': + # _unknown_fields is an implementation detail. + self.unknown_fields = self.empty_message._unknown_fields + + # All the tests that use GetField() check an implementation detail of the + # Python implementation, which stores unknown fields as serialized strings. + # These tests are skipped by the C++ implementation: it's enough to check that + # the message is correctly serialized. def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] @@ -142,30 +146,37 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): decoder(value, 0, len(value), self.all_fields, result_dict) return result_dict[field_descriptor] + @SkipIfCppImplementation def testEnum(self): value = self.GetField('optional_nested_enum') self.assertEqual(self.all_fields.optional_nested_enum, value) + @SkipIfCppImplementation def testRepeatedEnum(self): value = self.GetField('repeated_nested_enum') self.assertEqual(self.all_fields.repeated_nested_enum, value) + @SkipIfCppImplementation def testVarint(self): value = self.GetField('optional_int32') self.assertEqual(self.all_fields.optional_int32, value) + @SkipIfCppImplementation def testFixed32(self): value = self.GetField('optional_fixed32') self.assertEqual(self.all_fields.optional_fixed32, value) + @SkipIfCppImplementation def testFixed64(self): value = self.GetField('optional_fixed64') self.assertEqual(self.all_fields.optional_fixed64, value) + @SkipIfCppImplementation def testLengthDelimited(self): value = self.GetField('optional_string') self.assertEqual(self.all_fields.optional_string, value) + @SkipIfCppImplementation def testGroup(self): value = self.GetField('optionalgroup') self.assertEqual(self.all_fields.optionalgroup, value) @@ -173,7 +184,7 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() message.CopyFrom(self.empty_message) - self.assertEqual(self.unknown_fields, message._unknown_fields) + self.assertEqual(message.SerializeToString(), self.all_fields_data) def testMergeFrom(self): message = unittest_pb2.TestAllTypes() @@ -187,27 +198,26 @@ class UnknownFieldsAccessorsTest(unittest.TestCase): message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() destination.ParseFromString(message.SerializeToString()) - unknown_fields = destination._unknown_fields[:] destination.MergeFrom(source) - self.assertEqual(unknown_fields + source._unknown_fields, - destination._unknown_fields) + # Check that the fields where correctly merged, even stored in the unknown + # fields set. + message.ParseFromString(destination.SerializeToString()) + self.assertEqual(message.optional_int32, 1) + self.assertEqual(message.optional_uint32, 2) + self.assertEqual(message.optional_int64, 3) def testClear(self): self.empty_message.Clear() - self.assertEqual(0, len(self.empty_message._unknown_fields)) + # All cleared, even unknown fields. + self.assertEqual(self.empty_message.SerializeToString(), b'') def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() message.ParseFromString(self.all_fields_data) - self.assertEqual(self.empty_message._unknown_fields, - message._unknown_fields) - + self.assertEqual(message.SerializeToString(), self.all_fields_data) -@unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'C++ implementation does not expose unknown fields to Python') class UnknownEnumValuesTest(unittest.TestCase): def setUp(self): @@ -227,7 +237,14 @@ class UnknownEnumValuesTest(unittest.TestCase): self.message_data = self.message.SerializeToString() self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() self.missing_message.ParseFromString(self.message_data) - self.unknown_fields = self.missing_message._unknown_fields + if api_implementation.Type() != 'cpp': + # _unknown_fields is an implementation detail. + self.unknown_fields = self.missing_message._unknown_fields + + # All the tests that use GetField() check an implementation detail of the + # Python implementation, which stores unknown fields as serialized strings. + # These tests are skipped by the C++ implementation: it's enough to check that + # the message is correctly serialized. def GetField(self, name): field_descriptor = self.descriptor.fields_by_name[name] @@ -241,15 +258,18 @@ class UnknownEnumValuesTest(unittest.TestCase): decoder(value, 0, len(value), self.message, result_dict) return result_dict[field_descriptor] + @SkipIfCppImplementation def testUnknownEnumValue(self): self.assertFalse(self.missing_message.HasField('optional_nested_enum')) value = self.GetField('optional_nested_enum') self.assertEqual(self.message.optional_nested_enum, value) + @SkipIfCppImplementation def testUnknownRepeatedEnumValue(self): value = self.GetField('repeated_nested_enum') self.assertEqual(self.message.repeated_nested_enum, value) + @SkipIfCppImplementation def testUnknownPackedEnumValue(self): value = self.GetField('packed_nested_enum') self.assertEqual(self.message.packed_nested_enum, value) diff --git a/python/google/protobuf/pyext/cpp_message.py b/python/google/protobuf/pyext/cpp_message.py index 037bb72c..b215211e 100644 --- a/python/google/protobuf/pyext/cpp_message.py +++ b/python/google/protobuf/pyext/cpp_message.py @@ -37,21 +37,29 @@ Descriptor objects at runtime backed by the protocol buffer C++ API. __author__ = 'tibell@google.com (Johan Tibell)' from google.protobuf.pyext import _message -from google.protobuf import message -def NewMessage(bases, message_descriptor, dictionary): - """Creates a new protocol message *class*.""" - new_bases = [] - for base in bases: - if base is message.Message: - # _message.Message must come before message.Message as it - # overrides methods in that class. - new_bases.append(_message.Message) - new_bases.append(base) - return tuple(new_bases) +class GeneratedProtocolMessageType(_message.MessageMeta): + """Metaclass for protocol message classes created at runtime from Descriptors. -def InitMessage(message_descriptor, cls): - """Finalizes the creation of a message class.""" - cls.AddDescriptors(message_descriptor) + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + + The above example will not work for nested types. If you wish to include them, + use reflection.MakeClass() instead of manually instantiating the class in + order to create the appropriate class structure. + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 2160757b..8581f529 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -193,7 +193,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { io::CodedInputStream input( reinterpret_cast(serialized.c_str()), serialized.size()); input.SetExtensionRegistry(GetDescriptorPool()->pool, - cmessage::GetMessageFactory()); + GetDescriptorPool()->message_factory); bool success = cmsg->message->MergePartialFromCodedStream(&input); if (!success) { PyErr_Format(PyExc_ValueError, "Error parsing Options message"); diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index ecd90847..d5ba2b6f 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -33,6 +33,7 @@ #include #include +#include #include #include #include @@ -67,6 +68,11 @@ PyDescriptorPool* NewDescriptorPool() { // as underlay. cdescriptor_pool->pool = new DescriptorPool(DescriptorPool::generated_pool()); + DynamicMessageFactory* message_factory = new DynamicMessageFactory(); + // This option might be the default some day. + message_factory->SetDelegateToGeneratedFactory(true); + cdescriptor_pool->message_factory = message_factory; + // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same // storage. cdescriptor_pool->classes_by_descriptor = @@ -93,6 +99,7 @@ static void Dealloc(PyDescriptorPool* self) { Py_DECREF(it->second); } delete self->descriptor_options; + delete self->message_factory; Py_TYPE(self)->tp_free(reinterpret_cast(self)); } diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index efb1abeb..6f6c5cdb 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -38,6 +38,8 @@ namespace google { namespace protobuf { +class MessageFactory; + namespace python { // Wraps operations to the global DescriptorPool which contains information @@ -55,6 +57,14 @@ typedef struct PyDescriptorPool { DescriptorPool* pool; + // DynamicMessageFactory used to create C++ instances of messages. + // This object cache the descriptors that were used, so the DescriptorPool + // needs to get rid of it before it can delete itself. + // + // Note: A C++ MessageFactory is different from the Python MessageFactory. + // The C++ one creates messages, when the Python one creates classes. + MessageFactory* message_factory; + // Make our own mapping to retrieve Python classes from C++ descriptors. // // Descriptor pointers stored here are owned by the DescriptorPool above. diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index b8d18f8d..8ebbb27c 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -33,6 +33,7 @@ #include +#include #include #include #include @@ -183,7 +184,8 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { return NULL; } } - if (cmessage::ClearFieldByDescriptor(self->parent, descriptor) == NULL) { + if (ScopedPyObjectPtr(cmessage::ClearFieldByDescriptor( + self->parent, descriptor)) == NULL) { return NULL; } if (PyDict_DelItem(self->values, extension) < 0) { @@ -268,7 +270,7 @@ PyTypeObject ExtensionDict_Type = { 0, // tp_as_number 0, // tp_as_sequence &extension_dict::MpMethods, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call 0, // tp_str 0, // tp_getattro diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index a4843e8d..aa3ab97a 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -49,9 +49,10 @@ #endif #include #include +#include #include +#include #include -#include #include #include #include @@ -88,12 +89,308 @@ namespace google { namespace protobuf { namespace python { +static PyObject* kDESCRIPTOR; +static PyObject* k_extensions_by_name; +static PyObject* k_extensions_by_number; +PyObject* EnumTypeWrapper_class; +static PyObject* PythonMessage_class; +static PyObject* kEmptyWeakref; + +// Defines the Metaclass of all Message classes. +// It allows us to cache some C++ pointers in the class object itself, they are +// faster to extract than from the type's dictionary. + +struct PyMessageMeta { + // This is how CPython subclasses C structures: the base structure must be + // the first member of the object. + PyHeapTypeObject super; + + // C++ descriptor of this message. + const Descriptor* message_descriptor; + // Owned reference, used to keep the pointer above alive. + PyObject* py_message_descriptor; +}; + +namespace message_meta { + +static int InsertEmptyWeakref(PyTypeObject* base); + +// Add the number of a field descriptor to the containing message class. +// Equivalent to: +// _cls._FIELD_NUMBER = +static bool AddFieldNumberToClass( + PyObject* cls, const FieldDescriptor* field_descriptor) { + string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; + UpperString(&constant_name); + ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( + constant_name.c_str(), constant_name.size())); + if (attr_name == NULL) { + return false; + } + ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); + if (number == NULL) { + return false; + } + if (PyObject_SetAttr(cls, attr_name, number) == -1) { + return false; + } + return true; +} + + +// Finalize the creation of the Message class. +// Called from its metaclass: GeneratedProtocolMessageType.__init__(). +static int AddDescriptors(PyObject* cls, PyObject* descriptor) { + const Descriptor* message_descriptor = + cdescriptor_pool::RegisterMessageClass( + GetDescriptorPool(), cls, descriptor); + if (message_descriptor == NULL) { + return -1; + } + + // If there are extension_ranges, the message is "extendable", and extension + // classes will register themselves in this class. + if (message_descriptor->extension_range_count() > 0) { + ScopedPyObjectPtr by_name(PyDict_New()); + if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { + return -1; + } + ScopedPyObjectPtr by_number(PyDict_New()); + if (PyObject_SetAttr(cls, k_extensions_by_number, by_number) < 0) { + return -1; + } + } + + // For each field set: cls._FIELD_NUMBER = + for (int i = 0; i < message_descriptor->field_count(); ++i) { + if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { + return -1; + } + } + + // For each enum set cls. = EnumTypeWrapper(). + // + // The enum descriptor we get from + // .enum_types_by_name[name] + // which was built previously. + for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); + ScopedPyObjectPtr enum_type( + PyEnumDescriptor_FromDescriptor(enum_descriptor)); + if (enum_type == NULL) { + return -1; + } + // Add wrapped enum type to message class. + ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( + EnumTypeWrapper_class, enum_type.get(), NULL)); + if (wrapped == NULL) { + return -1; + } + if (PyObject_SetAttrString( + cls, enum_descriptor->name().c_str(), wrapped) == -1) { + return -1; + } + + // For each enum value add cls. = + for (int j = 0; j < enum_descriptor->value_count(); ++j) { + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->value(j); + ScopedPyObjectPtr value_number(PyInt_FromLong( + enum_value_descriptor->number())); + if (value_number == NULL) { + return -1; + } + if (PyObject_SetAttrString( + cls, enum_value_descriptor->name().c_str(), value_number) == -1) { + return -1; + } + } + } + + // For each extension set cls. = . + // + // Extension descriptors come from + // .extensions_by_name[name] + // which was defined previously. + for (int i = 0; i < message_descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); + if (extension_field == NULL) { + return -1; + } + + // Add the extension field to the message class. + if (PyObject_SetAttrString( + cls, field->name().c_str(), extension_field) == -1) { + return -1; + } + + // For each extension set cls._FIELD_NUMBER = . + if (!AddFieldNumberToClass(cls, field)) { + return -1; + } + } + + return 0; +} + +static PyObject* New(PyTypeObject* type, + PyObject* args, PyObject* kwargs) { + static char *kwlist[] = {"name", "bases", "dict", 0}; + PyObject *bases, *dict; + const char* name; + + // Check arguments: (name, bases, dict) + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "sO!O!:type", kwlist, + &name, + &PyTuple_Type, &bases, + &PyDict_Type, &dict)) { + return NULL; + } + + // Check bases: only (), or (message.Message,) are allowed + if (!(PyTuple_GET_SIZE(bases) == 0 || + (PyTuple_GET_SIZE(bases) == 1 && + PyTuple_GET_ITEM(bases, 0) == PythonMessage_class))) { + PyErr_SetString(PyExc_TypeError, + "A Message class can only inherit from Message"); + return NULL; + } + + // Check dict['DESCRIPTOR'] + PyObject* descriptor = PyDict_GetItem(dict, kDESCRIPTOR); + if (descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); + return NULL; + } + if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", + descriptor->ob_type->tp_name); + return NULL; + } + + // Build the arguments to the base metaclass. + // We change the __bases__ classes. + ScopedPyObjectPtr new_args(Py_BuildValue( + "s(OO)O", name, &CMessage_Type, PythonMessage_class, dict)); + if (new_args == NULL) { + return NULL; + } + // Call the base metaclass. + ScopedPyObjectPtr result(PyType_Type.tp_new(type, new_args, NULL)); + if (result == NULL) { + return NULL; + } + PyMessageMeta* newtype = reinterpret_cast(result.get()); + + // Insert the empty weakref into the base classes. + if (InsertEmptyWeakref( + reinterpret_cast(PythonMessage_class)) < 0 || + InsertEmptyWeakref(&CMessage_Type) < 0) { + return NULL; + } + + // Cache the descriptor, both as Python object and as C++ pointer. + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(descriptor); + if (message_descriptor == NULL) { + return NULL; + } + Py_INCREF(descriptor); + newtype->py_message_descriptor = descriptor; + newtype->message_descriptor = message_descriptor; + + // Continue with type initialization: add other descriptors, enum values... + if (AddDescriptors(result, descriptor) < 0) { + return NULL; + } + return result.release(); +} + +static void Dealloc(PyMessageMeta *self) { + Py_DECREF(self->py_message_descriptor); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PyObject* GetDescriptor(PyMessageMeta *self, void *closure) { + Py_INCREF(self->py_message_descriptor); + return self->py_message_descriptor; +} + + +// This function inserts and empty weakref at the end of the list of +// subclasses for the main protocol buffer Message class. +// +// This eliminates a O(n^2) behaviour in the internal add_subclass +// routine. +static int InsertEmptyWeakref(PyTypeObject *base_type) { +#if PY_MAJOR_VERSION >= 3 + // Python 3.4 has already included the fix for the issue that this + // hack addresses. For further background and the fix please see + // https://bugs.python.org/issue17936. + return 0; +#else + PyObject *subclasses = base_type->tp_subclasses; + if (subclasses && PyList_CheckExact(subclasses)) { + return PyList_Append(subclasses, kEmptyWeakref); + } + return 0; +#endif // PY_MAJOR_VERSION >= 3 +} + +} // namespace message_meta + +PyTypeObject PyMessageMeta_Type { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMeta", // tp_name + sizeof(PyMessageMeta), // tp_basicsize + 0, // tp_itemsize + (destructor)message_meta::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + "The metaclass of ProtocolMessages", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + message_meta::New, // tp_new +}; + +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { + if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { + PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); + return NULL; + } + return reinterpret_cast(cls)->message_descriptor; +} + // Forward declarations namespace cmessage { -static const FieldDescriptor* GetFieldDescriptor( - CMessage* self, PyObject* name); -static const Descriptor* GetMessageDescriptor(PyTypeObject* cls); -static string GetMessageName(CMessage* self); int InternalReleaseFieldByDescriptor( CMessage* self, const FieldDescriptor* field_descriptor, @@ -180,7 +477,7 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { if (self->composite_fields) { // Never use self->message in this function, it may be already freed. const Descriptor* message_descriptor = - cmessage::GetMessageDescriptor(Py_TYPE(self)); + GetMessageDescriptor(Py_TYPE(self)); while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { Py_ssize_t key_str_size; char *key_str_data; @@ -213,8 +510,6 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- -static DynamicMessageFactory* message_factory; - // Constants used for integer type range checking. PyObject* kPythonZero; PyObject* kint32min_py; @@ -224,17 +519,13 @@ PyObject* kint64min_py; PyObject* kint64max_py; PyObject* kuint64max_py; -PyObject* EnumTypeWrapper_class; PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; // Constant PyString values used for GetAttr/GetItem. -static PyObject* kDESCRIPTOR; static PyObject* k_cdescriptor; static PyObject* kfull_name; -static PyObject* k_extensions_by_name; -static PyObject* k_extensions_by_number; /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { @@ -432,10 +723,6 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { -DynamicMessageFactory* GetMessageFactory() { - return message_factory; -} - static int MaybeReleaseOverlappingOneofField( CMessage* cmessage, const FieldDescriptor* field) { @@ -486,7 +773,7 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, message_factory); + parent_message, parent_field, GetDescriptorPool()->message_factory); } struct FixupMessageReference : public ChildVisitor { @@ -527,8 +814,9 @@ int AssureWritable(CMessage* self) { // If parent is NULL but we are trying to modify a read-only message, this // is a reference to a constant default instance that needs to be replaced // with a mutable top-level message. - const Message* prototype = message_factory->GetPrototype( - self->message->GetDescriptor()); + const Message* prototype = + GetDescriptorPool()->message_factory->GetPrototype( + self->message->GetDescriptor()); self->message = prototype->New(); self->owner.reset(self->message); // Cascade the new owner to eventual children: even if this message is @@ -567,23 +855,6 @@ int AssureWritable(CMessage* self) { // --- Globals: -// Retrieve the C++ Descriptor of a message class. -// On error, returns NULL with an exception set. -static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { - ScopedPyObjectPtr descriptor(PyObject_GetAttr( - reinterpret_cast(cls), kDESCRIPTOR)); - if (descriptor == NULL) { - PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); - return NULL; - } - if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { - PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", - descriptor->ob_type->tp_name); - return NULL; - } - return PyMessageDescriptor_AsDescriptor(descriptor); -} - // Retrieve a C++ FieldDescriptor for a message attribute. // The C++ message must be valid. // TODO(amauryfa): This function should stay internal, because exception @@ -846,9 +1117,9 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } } else { - if (repeated_scalar_container::Extend( + if (ScopedPyObjectPtr(repeated_scalar_container::Extend( reinterpret_cast(container.get()), - value) == + value)) == NULL) { return -1; } @@ -927,7 +1198,7 @@ static PyObject* New(PyTypeObject* type, return NULL; } const Message* default_message = - message_factory->GetPrototype(message_descriptor); + GetDescriptorPool()->message_factory->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); return NULL; @@ -1257,6 +1528,7 @@ int SetOwner(CMessage* self, const shared_ptr& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { + MessageFactory* message_factory = GetDescriptorPool()->message_factory; Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1492,34 +1764,35 @@ static PyObject* SerializePartialToString(CMessage* self) { // appropriate. class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter { public: - PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {} - // Python has some differences from C++ when printing floating point numbers. // // 1) Trailing .0 is always printed. - // 2) Outputted is rounded to 12 digits. + // 2) (Python2) Output is rounded to 12 digits. + // 3) (Python3) The full precision of the double is preserved (and Python uses + // David M. Gay's dtoa(), when the C++ code uses SimpleDtoa. There are some + // differences, but they rarely happen) // // We override floating point printing with the C-API function for printing // Python floats to ensure consistency. string PrintFloat(float value) const { return PrintDouble(value); } string PrintDouble(double value) const { - reinterpret_cast(float_holder_.get())->ob_fval = value; - ScopedPyObjectPtr s(PyObject_Str(float_holder_.get())); - if (s == NULL) return string(); + // Same as float.__str__() + char* buf = PyOS_double_to_string( + value, #if PY_MAJOR_VERSION < 3 - char *cstr = PyBytes_AS_STRING(static_cast(s)); + 'g', PyFloat_STR_PRECISION, // Output is rounded to 12 digits. #else - char *cstr = PyUnicode_AsUTF8(s); + 'r', 0, #endif - return string(cstr); + Py_DTSF_ADD_DOT_0, // Trailing .0 is always printed. + NULL); + if (!buf) { + return string(); + } + string result(buf); + PyMem_Free(buf); + return result; } - - private: - // Holder for a python float object which we use to allow us to use - // the Python API for printing doubles. We initialize once and then - // directly modify it for every float printed to save on allocations - // and refcounting. - ScopedPyObjectPtr float_holder_; }; static PyObject* ToStr(CMessage* self) { @@ -1590,7 +1863,7 @@ static PyObject* CopyFrom(CMessage* self, PyObject* arg) { // CopyFrom on the message will not clean up self->composite_fields, // which can leave us in an inconsistent state, so clear it out here. - Clear(self); + (void)ScopedPyObjectPtr(Clear(self)); self->message->CopyFrom(*other_message->message); @@ -1607,7 +1880,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); io::CodedInputStream input( reinterpret_cast(data), data_length); - input.SetExtensionRegistry(GetDescriptorPool()->pool, message_factory); + input.SetExtensionRegistry(GetDescriptorPool()->pool, + GetDescriptorPool()->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -1618,7 +1892,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } static PyObject* ParseFromString(CMessage* self, PyObject* arg) { - if (Clear(self) == NULL) { + if (ScopedPyObjectPtr(Clear(self)) == NULL) { return NULL; } return MergeFromString(self, arg); @@ -1790,6 +2064,7 @@ static PyObject* ListFields(CMessage* self) { // Steals reference to 'extension' PyTuple_SET_ITEM(t.get(), 1, extension); } else { + // Normal field const string& field_name = fields[i]->name(); ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( field_name.c_str(), field_name.length())); @@ -1841,28 +2116,34 @@ PyObject* FindInitializationErrors(CMessage* self) { } static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { - if (!PyObject_TypeCheck(other, &CMessage_Type)) { - if (opid == Py_EQ) { - Py_RETURN_FALSE; - } else if (opid == Py_NE) { - Py_RETURN_TRUE; - } - } - if (opid == Py_EQ || opid == Py_NE) { - ScopedPyObjectPtr self_fields(ListFields(self)); - if (!self_fields) { - return NULL; - } - ScopedPyObjectPtr other_fields(ListFields( - reinterpret_cast(other))); - if (!other_fields) { - return NULL; - } - return PyObject_RichCompare(self_fields, other_fields, opid); - } else { + // Only equality comparisons are implemented. + if (opid != Py_EQ && opid != Py_NE) { Py_INCREF(Py_NotImplemented); return Py_NotImplemented; } + bool equals = true; + // If other is not a message, it cannot be equal. + if (!PyObject_TypeCheck(other, &CMessage_Type)) { + equals = false; + } + const google::protobuf::Message* other_message = + reinterpret_cast(other)->message; + // If messages don't have the same descriptors, they are not equal. + if (equals && + self->message->GetDescriptor() != other_message->GetDescriptor()) { + equals = false; + } + // Check the message contents. + if (equals && !google::protobuf::util::MessageDifferencer::Equals( + *self->message, + *reinterpret_cast(other)->message)) { + equals = false; + } + if (equals ^ (opid == Py_EQ)) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } } PyObject* InternalGetScalar(const Message* message, @@ -1950,7 +2231,7 @@ PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, message_factory); + *self->message, field_descriptor, GetDescriptorPool()->message_factory); PyObject *message_class = cdescriptor_pool::GetMessageClass( GetDescriptorPool(), field_descriptor->message_type()); @@ -2085,125 +2366,6 @@ PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { return py_cmsg; } -// Add the number of a field descriptor to the containing message class. -// Equivalent to: -// _cls._FIELD_NUMBER = -static bool AddFieldNumberToClass( - PyObject* cls, const FieldDescriptor* field_descriptor) { - string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; - UpperString(&constant_name); - ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( - constant_name.c_str(), constant_name.size())); - if (attr_name == NULL) { - return false; - } - ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); - if (number == NULL) { - return false; - } - if (PyObject_SetAttr(cls, attr_name, number) == -1) { - return false; - } - return true; -} - - -// Finalize the creation of the Message class. -// Called from its metaclass: GeneratedProtocolMessageType.__init__(). -static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) { - const Descriptor* message_descriptor = - cdescriptor_pool::RegisterMessageClass( - GetDescriptorPool(), cls, descriptor); - if (message_descriptor == NULL) { - return NULL; - } - - // If there are extension_ranges, the message is "extendable", and extension - // classes will register themselves in this class. - if (message_descriptor->extension_range_count() > 0) { - ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { - return NULL; - } - ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(cls, k_extensions_by_number, by_number) < 0) { - return NULL; - } - } - - // For each field set: cls._FIELD_NUMBER = - for (int i = 0; i < message_descriptor->field_count(); ++i) { - if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { - return NULL; - } - } - - // For each enum set cls. = EnumTypeWrapper(). - // - // The enum descriptor we get from - // .enum_types_by_name[name] - // which was built previously. - for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { - const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); - ScopedPyObjectPtr enum_type( - PyEnumDescriptor_FromDescriptor(enum_descriptor)); - if (enum_type == NULL) { - return NULL; - } - // Add wrapped enum type to message class. - ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( - EnumTypeWrapper_class, enum_type.get(), NULL)); - if (wrapped == NULL) { - return NULL; - } - if (PyObject_SetAttrString( - cls, enum_descriptor->name().c_str(), wrapped) == -1) { - return NULL; - } - - // For each enum value add cls. = - for (int j = 0; j < enum_descriptor->value_count(); ++j) { - const EnumValueDescriptor* enum_value_descriptor = - enum_descriptor->value(j); - ScopedPyObjectPtr value_number(PyInt_FromLong( - enum_value_descriptor->number())); - if (value_number == NULL) { - return NULL; - } - if (PyObject_SetAttrString( - cls, enum_value_descriptor->name().c_str(), value_number) == -1) { - return NULL; - } - } - } - - // For each extension set cls. = . - // - // Extension descriptors come from - // .extensions_by_name[name] - // which was defined previously. - for (int i = 0; i < message_descriptor->extension_count(); ++i) { - const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); - ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); - if (extension_field == NULL) { - return NULL; - } - - // Add the extension field to the message class. - if (PyObject_SetAttrString( - cls, field->name().c_str(), extension_field) == -1) { - return NULL; - } - - // For each extension set cls._FIELD_NUMBER = . - if (!AddFieldNumberToClass(cls, field)) { - return NULL; - } - } - - Py_RETURN_NONE; -} - PyObject* DeepCopy(CMessage* self, PyObject* arg) { PyObject* clone = PyObject_CallObject( reinterpret_cast(Py_TYPE(self)), NULL); @@ -2214,8 +2376,9 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) { Py_DECREF(clone); return NULL; } - if (MergeFrom(reinterpret_cast(clone), - reinterpret_cast(self)) == NULL) { + if (ScopedPyObjectPtr(MergeFrom( + reinterpret_cast(clone), + reinterpret_cast(self))) == NULL) { Py_DECREF(clone); return NULL; } @@ -2281,7 +2444,7 @@ PyObject* SetState(CMessage* self, PyObject* state) { if (serialized == NULL) { return NULL; } - if (ParseFromString(self, serialized) == NULL) { + if (ScopedPyObjectPtr(ParseFromString(self, serialized)) == NULL) { return NULL; } Py_RETURN_NONE; @@ -2314,8 +2477,6 @@ static PyMethodDef Methods[] = { "Inputs picklable representation of the message." }, { "__unicode__", (PyCFunction)ToUnicode, METH_NOARGS, "Outputs a unicode representation of the message." }, - { "AddDescriptors", (PyCFunction)AddDescriptors, METH_O | METH_CLASS, - "Adds field descriptors to the class" }, { "ByteSize", (PyCFunction)ByteSize, METH_NOARGS, "Returns the size of the message in bytes." }, { "Clear", (PyCFunction)Clear, METH_NOARGS, @@ -2441,6 +2602,9 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* sub_message = InternalGetSubMessage(self, field_descriptor); + if (sub_message == NULL) { + return NULL; + } if (!SetCompositeField(self, name, sub_message)) { Py_DECREF(sub_message); return NULL; @@ -2484,7 +2648,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { } // namespace cmessage PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) + PyVarObject_HEAD_INIT(&PyMessageMeta_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2497,7 +2661,7 @@ PyTypeObject CMessage_Type = { 0, // tp_as_number 0, // tp_as_sequence 0, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call (reprfunc)cmessage::ToStr, // tp_str (getattrofunc)cmessage::GetAttr, // tp_getattro @@ -2580,8 +2744,9 @@ void InitGlobals() { k_extensions_by_name = PyString_FromString("_extensions_by_name"); k_extensions_by_number = PyString_FromString("_extensions_by_number"); - message_factory = new DynamicMessageFactory(); - message_factory->SetDelegateToGeneratedFactory(true); + PyObject *dummy_obj = PySet_New(NULL); + kEmptyWeakref = PyWeakref_NewRef(dummy_obj, NULL); + Py_DECREF(dummy_obj); } bool InitProto2MessageModule(PyObject *m) { @@ -2598,7 +2763,13 @@ bool InitProto2MessageModule(PyObject *m) { // Initialize constants defined in this file. InitGlobals(); - CMessage_Type.tp_hash = PyObject_HashNotImplemented; + PyMessageMeta_Type.tp_base = &PyType_Type; + if (PyType_Ready(&PyMessageMeta_Type) < 0) { + return false; + } + PyModule_AddObject(m, "MessageMeta", + reinterpret_cast(&PyMessageMeta_Type)); + if (PyType_Ready(&CMessage_Type) < 0) { return false; } @@ -2628,86 +2799,106 @@ bool InitProto2MessageModule(PyObject *m) { PyModule_AddObject(m, "Message", reinterpret_cast(&CMessage_Type)); - RepeatedScalarContainer_Type.tp_hash = - PyObject_HashNotImplemented; - if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { - return false; - } + // Initialize Repeated container types. + { + if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { + return false; + } - PyModule_AddObject(m, "RepeatedScalarContainer", - reinterpret_cast( - &RepeatedScalarContainer_Type)); + PyModule_AddObject(m, "RepeatedScalarContainer", + reinterpret_cast( + &RepeatedScalarContainer_Type)); - RepeatedCompositeContainer_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { - return false; - } + if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { + return false; + } - PyModule_AddObject( - m, "RepeatedCompositeContainer", - reinterpret_cast( - &RepeatedCompositeContainer_Type)); - - // ScalarMapContainer_Type derives from our MutableMapping type. - PyObject* containers = - PyImport_ImportModule("google.protobuf.internal.containers"); - if (containers == NULL) { - return false; + PyModule_AddObject( + m, "RepeatedCompositeContainer", + reinterpret_cast( + &RepeatedCompositeContainer_Type)); + + // Register them as collections.Sequence + ScopedPyObjectPtr collections(PyImport_ImportModule("collections")); + if (collections == NULL) { + return false; + } + ScopedPyObjectPtr mutable_sequence(PyObject_GetAttrString( + collections, "MutableSequence")); + if (mutable_sequence == NULL) { + return false; + } + if (ScopedPyObjectPtr(PyObject_CallMethod(mutable_sequence, "register", "O", + &RepeatedScalarContainer_Type)) + == NULL) { + return false; + } + if (ScopedPyObjectPtr(PyObject_CallMethod(mutable_sequence, "register", "O", + &RepeatedCompositeContainer_Type)) + == NULL) { + return false; + } } - PyObject* mutable_mapping = - PyObject_GetAttrString(containers, "MutableMapping"); - Py_DECREF(containers); + // Initialize Map container types. + { + // ScalarMapContainer_Type derives from our MutableMapping type. + ScopedPyObjectPtr containers(PyImport_ImportModule( + "google.protobuf.internal.containers")); + if (containers == NULL) { + return false; + } - if (mutable_mapping == NULL) { - return false; - } + ScopedPyObjectPtr mutable_mapping( + PyObject_GetAttrString(containers, "MutableMapping")); + if (mutable_mapping == NULL) { + return false; + } - if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) { - Py_DECREF(mutable_mapping); - return false; - } + if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) { + return false; + } - ScalarMapContainer_Type.tp_base = - reinterpret_cast(mutable_mapping); + Py_INCREF(mutable_mapping); + ScalarMapContainer_Type.tp_base = + reinterpret_cast(mutable_mapping.get()); - if (PyType_Ready(&ScalarMapContainer_Type) < 0) { - return false; - } + if (PyType_Ready(&ScalarMapContainer_Type) < 0) { + return false; + } - PyModule_AddObject(m, "ScalarMapContainer", - reinterpret_cast(&ScalarMapContainer_Type)); + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast(&ScalarMapContainer_Type)); - if (PyType_Ready(&ScalarMapIterator_Type) < 0) { - return false; - } + if (PyType_Ready(&ScalarMapIterator_Type) < 0) { + return false; + } - PyModule_AddObject(m, "ScalarMapIterator", - reinterpret_cast(&ScalarMapIterator_Type)); + PyModule_AddObject(m, "ScalarMapIterator", + reinterpret_cast(&ScalarMapIterator_Type)); - Py_INCREF(mutable_mapping); - MessageMapContainer_Type.tp_base = - reinterpret_cast(mutable_mapping); + Py_INCREF(mutable_mapping); + MessageMapContainer_Type.tp_base = + reinterpret_cast(mutable_mapping.get()); - if (PyType_Ready(&MessageMapContainer_Type) < 0) { - return false; - } + if (PyType_Ready(&MessageMapContainer_Type) < 0) { + return false; + } - PyModule_AddObject(m, "MessageMapContainer", - reinterpret_cast(&MessageMapContainer_Type)); + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast(&MessageMapContainer_Type)); - if (PyType_Ready(&MessageMapIterator_Type) < 0) { - return false; - } + if (PyType_Ready(&MessageMapIterator_Type) < 0) { + return false; + } - PyModule_AddObject(m, "MessageMapIterator", - reinterpret_cast(&MessageMapIterator_Type)); + PyModule_AddObject(m, "MessageMapIterator", + reinterpret_cast(&MessageMapIterator_Type)); + } - ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; } - PyModule_AddObject( m, "ExtensionDict", reinterpret_cast(&ExtensionDict_Type)); @@ -2751,6 +2942,7 @@ bool InitProto2MessageModule(PyObject *m) { } EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError"); DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError"); + PythonMessage_class = PyObject_GetAttrString(message_module, "Message"); Py_DECREF(message_module); PyObject* pickle_module = PyImport_ImportModule("pickle"); diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 7360b207..f147d433 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -49,7 +49,6 @@ class Message; class Reflection; class FieldDescriptor; class Descriptor; -class DynamicMessageFactory; using internal::shared_ptr; @@ -221,9 +220,6 @@ PyObject* FindInitializationErrors(CMessage* self); int SetOwner(CMessage* self, const shared_ptr& new_owner); int AssureWritable(CMessage* self); - -DynamicMessageFactory* GetMessageFactory(); - } // namespace cmessage diff --git a/python/google/protobuf/pyext/message_map_container.cc b/python/google/protobuf/pyext/message_map_container.cc index ab8d8fb9..a4a7fbfe 100644 --- a/python/google/protobuf/pyext/message_map_container.cc +++ b/python/google/protobuf/pyext/message_map_container.cc @@ -32,6 +32,7 @@ #include +#include #include #include #include diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 86b75d0f..fe2e600b 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -38,11 +38,13 @@ #include #endif +#include #include #include #include #include #include +#include #include #include @@ -74,125 +76,6 @@ namespace repeated_composite_container { GOOGLE_CHECK((self)->parent == NULL); \ } while (0); -// Returns a new reference. -static PyObject* GetKey(PyObject* x) { - // Just the identity function. - Py_INCREF(x); - return x; -} - -#define GET_KEY(keyfunc, value) \ - ((keyfunc) == NULL ? \ - GetKey((value)) : \ - PyObject_CallFunctionObjArgs((keyfunc), (value), NULL)) - -// Converts a comparison function that returns -1, 0, or 1 into a -// less-than predicate. -// -// Returns -1 on error, 1 if x < y, 0 if x >= y. -static int islt(PyObject *x, PyObject *y, PyObject *compare) { - if (compare == NULL) - return PyObject_RichCompareBool(x, y, Py_LT); - - ScopedPyObjectPtr res(PyObject_CallFunctionObjArgs(compare, x, y, NULL)); - if (res == NULL) - return -1; - if (!PyInt_Check(res)) { - PyErr_Format(PyExc_TypeError, - "comparison function must return int, not %.200s", - Py_TYPE(res)->tp_name); - return -1; - } - return PyInt_AsLong(res) < 0; -} - -// Copied from uarrsort.c but swaps memcpy swaps with protobuf/python swaps -// TODO(anuraag): Is there a better way to do this then reinventing the wheel? -static int InternalQuickSort(RepeatedCompositeContainer* self, - Py_ssize_t start, - Py_ssize_t limit, - PyObject* cmp, - PyObject* keyfunc) { - if (limit - start <= 1) - return 0; // Nothing to sort. - - GOOGLE_CHECK_ATTACHED(self); - - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - const FieldDescriptor* descriptor = self->parent_field_descriptor; - Py_ssize_t left; - Py_ssize_t right; - - PyObject* children = self->child_messages; - - do { - left = start; - right = limit; - ScopedPyObjectPtr mid( - GET_KEY(keyfunc, PyList_GET_ITEM(children, (start + limit) / 2))); - do { - ScopedPyObjectPtr key(GET_KEY(keyfunc, PyList_GET_ITEM(children, left))); - int is_lt = islt(key, mid, cmp); - if (is_lt == -1) - return -1; - /* array[left]SwapElements(message, descriptor, left, right); - PyObject* tmp = PyList_GET_ITEM(children, left); - PyList_SET_ITEM(children, left, PyList_GET_ITEM(children, right)); - PyList_SET_ITEM(children, right, tmp); - } - ++left; - } - } while (left < right); - - if ((right - start) < (limit - left)) { - /* sort [start..right[ */ - if (start < (right - 1)) { - InternalQuickSort(self, start, right, cmp, keyfunc); - } - - /* sort [left..limit[ */ - start = left; - } else { - /* sort [left..limit[ */ - if (left < (limit - 1)) { - InternalQuickSort(self, left, limit, cmp, keyfunc); - } - - /* sort [start..right[ */ - limit = right; - } - } while (start < (limit - 1)); - - return 0; -} - -#undef GET_KEY - // --------------------------------------------------------------------- // len() @@ -329,7 +212,7 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { return NULL; } CMessage* new_cmessage = reinterpret_cast(new_message.get()); - if (cmessage::MergeFrom(new_cmessage, next) == NULL) { + if (ScopedPyObjectPtr(cmessage::MergeFrom(new_cmessage, next)) == NULL) { return NULL; } } @@ -455,58 +338,39 @@ static PyObject* RichCompare(RepeatedCompositeContainer* self, // --------------------------------------------------------------------- // sort() -static PyObject* SortAttached(RepeatedCompositeContainer* self, - PyObject* args, - PyObject* kwds) { - // Sort the underlying Message array. - PyObject *compare = NULL; - int reverse = 0; - PyObject *keyfunc = NULL; - static char *kwlist[] = {"cmp", "key", "reverse", 0}; - - if (args != NULL) { - if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOi:sort", - kwlist, &compare, &keyfunc, &reverse)) - return NULL; - } - if (compare == Py_None) - compare = NULL; - if (keyfunc == Py_None) - keyfunc = NULL; - +static void ReorderAttached(RepeatedCompositeContainer* self) { + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* descriptor = self->parent_field_descriptor; const Py_ssize_t length = Length(self); - if (InternalQuickSort(self, 0, length, compare, keyfunc) < 0) - return NULL; - - // Finally reverse the result if requested. - if (reverse) { - Message* message = self->message; - const Reflection* reflection = message->GetReflection(); - const FieldDescriptor* descriptor = self->parent_field_descriptor; - // Reverse the Message array. - for (int i = 0; i < length / 2; ++i) - reflection->SwapElements(message, descriptor, i, length - i - 1); + // Since Python protobuf objects are never arena-allocated, adding and + // removing message pointers to the underlying array is just updating + // pointers. + for (Py_ssize_t i = 0; i < length; ++i) + reflection->ReleaseLast(message, descriptor); - // Reverse the Python list. - ScopedPyObjectPtr res(PyObject_CallMethod(self->child_messages, - "reverse", NULL)); - if (res == NULL) - return NULL; + for (Py_ssize_t i = 0; i < length; ++i) { + CMessage* py_cmsg = reinterpret_cast( + PyList_GET_ITEM(self->child_messages, i)); + reflection->AddAllocatedMessage(message, descriptor, py_cmsg->message); } - - Py_RETURN_NONE; } -static PyObject* SortReleased(RepeatedCompositeContainer* self, - PyObject* args, - PyObject* kwds) { +// Returns 0 if successful; returns -1 and sets an exception if +// unsuccessful. +static int SortPythonMessages(RepeatedCompositeContainer* self, + PyObject* args, + PyObject* kwds) { ScopedPyObjectPtr m(PyObject_GetAttrString(self->child_messages, "sort")); if (m == NULL) - return NULL; + return -1; if (PyObject_Call(m, args, kwds) == NULL) - return NULL; - Py_RETURN_NONE; + return -1; + if (self->message != NULL) { + ReorderAttached(self); + } + return 0; } static PyObject* Sort(RepeatedCompositeContainer* self, @@ -527,11 +391,10 @@ static PyObject* Sort(RepeatedCompositeContainer* self, if (UpdateChildMessages(self) < 0) { return NULL; } - if (self->message == NULL) { - return SortReleased(self, args, kwds); - } else { - return SortAttached(self, args, kwds); + if (SortPythonMessages(self, args, kwds) < 0) { + return NULL; } + Py_RETURN_NONE; } // --------------------------------------------------------------------- @@ -584,18 +447,6 @@ void ReleaseLastTo(CMessage* parent, parent->message->GetReflection()->ReleaseLast(parent->message, field)); // TODO(tibell): Deal with proto1. - // ReleaseMessage will return NULL which differs from - // child_cmessage->message, if the field does not exist. In this case, - // the latter points to the default instance via a const_cast<>, so we - // have to reset it to a new mutable object since we are taking ownership. - if (released_message.get() == NULL) { - const Message* prototype = - cmessage::GetMessageFactory()->GetPrototype( - target->message->GetDescriptor()); - GOOGLE_CHECK_NOTNULL(prototype); - released_message.reset(prototype->New()); - } - target->parent = NULL; target->parent_field_descriptor = NULL; target->message = released_message.get(); @@ -732,7 +583,7 @@ PyTypeObject RepeatedCompositeContainer_Type = { 0, // tp_as_number &repeated_composite_container::SqMethods, // tp_as_sequence &repeated_composite_container::MpMethods, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call 0, // tp_str 0, // tp_getattro diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index fd196836..7565c6fd 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -39,10 +39,12 @@ #endif #include +#include #include #include #include #include +#include #include #include @@ -68,7 +70,7 @@ static int InternalAssignRepeatedField( self->parent_field_descriptor); for (Py_ssize_t i = 0; i < PyList_GET_SIZE(list); ++i) { PyObject* value = PyList_GET_ITEM(list, i); - if (Append(self, value) == NULL) { + if (ScopedPyObjectPtr(Append(self, value)) == NULL) { return -1; } } @@ -510,7 +512,7 @@ PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) { } ScopedPyObjectPtr next; while ((next.reset(PyIter_Next(iter))) != NULL) { - if (Append(self, next) == NULL) { + if (ScopedPyObjectPtr(Append(self, next)) == NULL) { return NULL; } } @@ -690,8 +692,7 @@ static int InitializeAndCopyToParentContainer( if (values == NULL) { return -1; } - Message* new_message = cmessage::GetMessageFactory()->GetPrototype( - from->message->GetDescriptor())->New(); + Message* new_message = from->message->New(); to->parent = NULL; to->parent_field_descriptor = from->parent_field_descriptor; to->message = new_message; @@ -781,7 +782,7 @@ PyTypeObject RepeatedScalarContainer_Type = { 0, // tp_as_number &repeated_scalar_container::SqMethods, // tp_as_sequence &repeated_scalar_container::MpMethods, // tp_as_mapping - 0, // tp_hash + PyObject_HashNotImplemented, // tp_hash 0, // tp_call 0, // tp_str 0, // tp_getattro diff --git a/python/google/protobuf/pyext/scalar_map_container.cc b/python/google/protobuf/pyext/scalar_map_container.cc index 6f731d27..80d29425 100644 --- a/python/google/protobuf/pyext/scalar_map_container.cc +++ b/python/google/protobuf/pyext/scalar_map_container.cc @@ -32,6 +32,7 @@ #include +#include #include #include #include diff --git a/python/google/protobuf/pyext/scoped_pyobject_ptr.h b/python/google/protobuf/pyext/scoped_pyobject_ptr.h index 18ddd5cd..9979b83b 100644 --- a/python/google/protobuf/pyext/scoped_pyobject_ptr.h +++ b/python/google/protobuf/pyext/scoped_pyobject_ptr.h @@ -51,16 +51,22 @@ class ScopedPyObjectPtr { // Reset. Deletes the current owned object, if any. // Then takes ownership of a new object, if given. - // this->reset(this->get()) works. + // This function must be called with a reference that you own. + // this->reset(this->get()) is wrong! + // this->reset(this->release()) is OK. PyObject* reset(PyObject* p = NULL) { - if (p != ptr_) { - Py_XDECREF(ptr_); - ptr_ = p; - } + Py_XDECREF(ptr_); + ptr_ = p; return ptr_; } + // ScopedPyObjectPtr should not be copied. + // We explicitly list and delete this overload to avoid automatic conversion + // to PyObject*, which is wrong in this case. + PyObject* reset(const ScopedPyObjectPtr& other) = delete; + // Releases ownership of the object. + // The caller now owns the returned reference. PyObject* release() { PyObject* p = ptr_; ptr_ = NULL; diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 82fca661..0c757264 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -49,101 +49,23 @@ __author__ = 'robinson@google.com (Will Robinson)' from google.protobuf.internal import api_implementation -from google.protobuf import descriptor as descriptor_mod from google.protobuf import message -_FieldDescriptor = descriptor_mod.FieldDescriptor - if api_implementation.Type() == 'cpp': from google.protobuf.pyext import cpp_message as message_impl else: from google.protobuf.internal import python_message as message_impl -_NewMessage = message_impl.NewMessage -_InitMessage = message_impl.InitMessage - - -class GeneratedProtocolMessageType(type): - - """Metaclass for protocol message classes created at runtime from Descriptors. - - We add implementations for all methods described in the Message class. We - also create properties to allow getting/setting all fields in the protocol - message. Finally, we create slots to prevent users from accidentally - "setting" nonexistent fields in the protocol message, which then wouldn't get - serialized / deserialized properly. - - The protocol compiler currently uses this metaclass to create protocol - message classes at runtime. Clients can also manually create their own - classes at runtime, as in this example: - - mydescriptor = Descriptor(.....) - class MyProtoClass(Message): - __metaclass__ = GeneratedProtocolMessageType - DESCRIPTOR = mydescriptor - myproto_instance = MyProtoClass() - myproto.foo_field = 23 - ... - - The above example will not work for nested types. If you wish to include them, - use reflection.MakeClass() instead of manually instantiating the class in - order to create the appropriate class structure. - """ - - # Must be consistent with the protocol-compiler code in - # proto2/compiler/internal/generator.*. - _DESCRIPTOR_KEY = 'DESCRIPTOR' - - def __new__(cls, name, bases, dictionary): - """Custom allocation for runtime-generated class types. - - We override __new__ because this is apparently the only place - where we can meaningfully set __slots__ on the class we're creating(?). - (The interplay between metaclasses and slots is not very well-documented). - - Args: - name: Name of the class (ignored, but required by the - metaclass protocol). - bases: Base classes of the class we're constructing. - (Should be message.Message). We ignore this field, but - it's required by the metaclass protocol - dictionary: The class dictionary of the class we're - constructing. dictionary[_DESCRIPTOR_KEY] must contain - a Descriptor object describing this protocol message - type. - - Returns: - Newly-allocated class. - """ - descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - bases = _NewMessage(bases, descriptor, dictionary) - superclass = super(GeneratedProtocolMessageType, cls) - - new_class = superclass.__new__(cls, name, bases, dictionary) - return new_class - - def __init__(cls, name, bases, dictionary): - """Here we perform the majority of our work on the class. - We add enum getters, an __init__ method, implementations - of all Message methods, and properties for all fields - in the protocol type. - - Args: - name: Name of the class (ignored, but required by the - metaclass protocol). - bases: Base classes of the class we're constructing. - (Should be message.Message). We ignore this field, but - it's required by the metaclass protocol - dictionary: The class dictionary of the class we're - constructing. dictionary[_DESCRIPTOR_KEY] must contain - a Descriptor object describing this protocol message - type. - """ - descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - _InitMessage(descriptor, cls) - superclass = super(GeneratedProtocolMessageType, cls) - superclass.__init__(name, bases, dictionary) +# The type of all Message classes. +# Part of the public interface. +# +# Used by generated files, but clients can also use it at runtime: +# mydescriptor = pool.FindDescriptor(.....) +# class MyProtoClass(Message): +# __metaclass__ = GeneratedProtocolMessageType +# DESCRIPTOR = mydescriptor +GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType def ParseMessage(descriptor, byte_str): diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 8cbd6822..82133765 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -113,7 +113,7 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, fields.sort(key=lambda x: x[0].index) for field, value in fields: if _IsMapEntry(field): - for key in value: + for key in sorted(value): # This is slow for maps with submessage entires because it copies the # entire tree. Unfortunately this would take significant refactoring # of this file to work around. -- cgit v1.2.3