diff options
author | Bo Yang <teboring@google.com> | 2015-05-21 14:28:59 -0700 |
---|---|---|
committer | Bo Yang <teboring@google.com> | 2015-05-21 19:32:02 -0700 |
commit | 5db217305f37a79eeccd70f000088a06ec82fcec (patch) | |
tree | be53dcf0c0b47ef9178ab8a6fa5c1946ee84a28f /python | |
parent | 56095026ccc2f755a6fdb296e30c3ddec8f556a2 (diff) | |
download | protobuf-5db217305f37a79eeccd70f000088a06ec82fcec.tar.gz protobuf-5db217305f37a79eeccd70f000088a06ec82fcec.tar.bz2 protobuf-5db217305f37a79eeccd70f000088a06ec82fcec.zip |
down-integrate internal changes
Diffstat (limited to 'python')
42 files changed, 3282 insertions, 544 deletions
diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index f7a58ca0..970b1a88 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -245,9 +245,6 @@ class Descriptor(_NestedDescriptorBase): is_extendable: Does this type define any extension ranges? - options: (descriptor_pb2.MessageOptions) Protocol message options or None - to use default message options. - oneofs: (list of OneofDescriptor) The list of descriptors for oneof fields in this message. oneofs_by_name: (dict str -> OneofDescriptor) Same objects as in |oneofs|, @@ -265,7 +262,7 @@ class Descriptor(_NestedDescriptorBase): file=None, serialized_start=None, serialized_end=None, syntax=None): _message.Message._CheckCalledFromGeneratedFile() - return _message.Message._GetMessageDescriptor(full_name) + return _message.default_pool.FindMessageTypeByName(full_name) # NOTE(tmarek): The file argument redefining a builtin is nothing we can # fix right now since we don't know how many clients already rely on the @@ -495,9 +492,9 @@ class FieldDescriptor(DescriptorBase): has_default_value=True, containing_oneof=None): _message.Message._CheckCalledFromGeneratedFile() if is_extension: - return _message.Message._GetExtensionDescriptor(full_name) + return _message.default_pool.FindExtensionByName(full_name) else: - return _message.Message._GetFieldDescriptor(full_name) + return _message.default_pool.FindFieldByName(full_name) def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, @@ -528,14 +525,9 @@ class FieldDescriptor(DescriptorBase): self.containing_oneof = containing_oneof if api_implementation.Type() == 'cpp': if is_extension: - # pylint: disable=protected-access - self._cdescriptor = ( - _message.Message._GetExtensionDescriptor(full_name)) - # pylint: enable=protected-access + self._cdescriptor = _message.default_pool.FindExtensionByName(full_name) else: - # pylint: disable=protected-access - self._cdescriptor = _message.Message._GetFieldDescriptor(full_name) - # pylint: enable=protected-access + self._cdescriptor = _message.default_pool.FindFieldByName(full_name) else: self._cdescriptor = None @@ -592,7 +584,7 @@ class EnumDescriptor(_NestedDescriptorBase): containing_type=None, options=None, file=None, serialized_start=None, serialized_end=None): _message.Message._CheckCalledFromGeneratedFile() - return _message.Message._GetEnumDescriptor(full_name) + return _message.default_pool.FindEnumTypeByName(full_name) def __init__(self, name, full_name, filename, values, containing_type=None, options=None, file=None, @@ -677,7 +669,7 @@ class OneofDescriptor(object): def __new__(cls, name, full_name, index, containing_type, fields): _message.Message._CheckCalledFromGeneratedFile() - return _message.Message._GetOneofDescriptor(full_name) + return _message.default_pool.FindOneofByName(full_name) def __init__(self, name, full_name, index, containing_type, fields): """Arguments are as described in the attribute description above.""" @@ -788,12 +780,8 @@ class FileDescriptor(DescriptorBase): dependencies=None, syntax=None): # FileDescriptor() is called from various places, not only from generated # files, to register dynamic proto files and messages. - # TODO(amauryfa): Expose BuildFile() as a public function and make this - # constructor an implementation detail. if serialized_pb: - # pylint: disable=protected-access2 - return _message.Message._BuildFile(serialized_pb) - # pylint: enable=protected-access + return _message.default_pool.AddSerializedFile(serialized_pb) else: return super(FileDescriptor, cls).__new__(cls) @@ -814,9 +802,7 @@ class FileDescriptor(DescriptorBase): if (api_implementation.Type() == 'cpp' and self.serialized_pb is not None): - # pylint: disable=protected-access - _message.Message._BuildFile(self.serialized_pb) - # pylint: enable=protected-access + _message.default_pool.AddSerializedFile(self.serialized_pb) def CopyToProto(self, proto): """Copies this to a descriptor_pb2.FileDescriptorProto. @@ -864,10 +850,10 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, file_descriptor_proto = descriptor_pb2.FileDescriptorProto() file_descriptor_proto.message_type.add().MergeFrom(desc_proto) - # Generate a random name for this proto file to prevent conflicts with - # any imported ones. We need to specify a file name so BuildFile accepts - # our FileDescriptorProto, but it is not important what that file name - # is actually set to. + # Generate a random name for this proto file to prevent conflicts with any + # imported ones. We need to specify a file name so the descriptor pool + # accepts our FileDescriptorProto, but it is not important what that file + # name is actually set to. proto_name = str(uuid.uuid4()) if package: @@ -877,10 +863,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, else: file_descriptor_proto.name = proto_name + '.proto' - # pylint: disable=protected-access - result = _message.Message._BuildFile( - file_descriptor_proto.SerializeToString()) - # pylint: enable=protected-access + _message.default_pool.Add(file_descriptor_proto) + result = _message.default_pool.FindFileByName(file_descriptor_proto.name) if _USE_C_DESCRIPTORS: return result.message_types_by_name[desc_proto.name] diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 7e7701f8..1244ba7c 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -113,6 +113,20 @@ class DescriptorPool(object): self._internal_db.Add(file_desc_proto) + def AddSerializedFile(self, serialized_file_desc_proto): + """Adds the FileDescriptorProto and its types to this pool. + + Args: + serialized_file_desc_proto: A bytes string, serialization of the + FileDescriptorProto to add. + """ + + # pylint: disable=g-import-not-at-top + from google.protobuf import descriptor_pb2 + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + serialized_file_desc_proto) + self.Add(file_desc_proto) + def AddDescriptor(self, desc): """Adds a Descriptor to the pool, non-recursively. @@ -320,17 +334,17 @@ class DescriptorPool(object): file_descriptor, None, scope)) for index, extension_proto in enumerate(file_proto.extension): - extension_desc = self.MakeFieldDescriptor( + extension_desc = self._MakeFieldDescriptor( extension_proto, file_proto.package, index, is_extension=True) extension_desc.containing_type = self._GetTypeFromScope( file_descriptor.package, extension_proto.extendee, scope) - self.SetFieldType(extension_proto, extension_desc, + self._SetFieldType(extension_proto, extension_desc, file_descriptor.package, scope) file_descriptor.extensions_by_name[extension_desc.name] = ( extension_desc) for desc_proto in file_proto.message_type: - self.SetAllFieldTypes(file_proto.package, desc_proto, scope) + self._SetAllFieldTypes(file_proto.package, desc_proto, scope) if file_proto.package: desc_proto_prefix = _PrefixWithDot(file_proto.package) @@ -381,10 +395,11 @@ class DescriptorPool(object): enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) for enum in desc_proto.enum_type] - fields = [self.MakeFieldDescriptor(field, desc_name, index) + fields = [self._MakeFieldDescriptor(field, desc_name, index) for index, field in enumerate(desc_proto.field)] extensions = [ - self.MakeFieldDescriptor(extension, desc_name, index, is_extension=True) + self._MakeFieldDescriptor(extension, desc_name, index, + is_extension=True) for index, extension in enumerate(desc_proto.extension)] oneofs = [ descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), @@ -464,8 +479,8 @@ class DescriptorPool(object): self._enum_descriptors[enum_name] = desc return desc - def MakeFieldDescriptor(self, field_proto, message_name, index, - is_extension=False): + def _MakeFieldDescriptor(self, field_proto, message_name, index, + is_extension=False): """Creates a field descriptor from a FieldDescriptorProto. For message and enum type fields, this method will do a look up @@ -506,7 +521,7 @@ class DescriptorPool(object): extension_scope=None, options=field_proto.options) - def SetAllFieldTypes(self, package, desc_proto, scope): + def _SetAllFieldTypes(self, package, desc_proto, scope): """Sets all the descriptor's fields's types. This method also sets the containing types on any extensions. @@ -527,18 +542,18 @@ class DescriptorPool(object): nested_package = '.'.join([package, desc_proto.name]) for field_proto, field_desc in zip(desc_proto.field, main_desc.fields): - self.SetFieldType(field_proto, field_desc, nested_package, scope) + self._SetFieldType(field_proto, field_desc, nested_package, scope) for extension_proto, extension_desc in ( zip(desc_proto.extension, main_desc.extensions)): extension_desc.containing_type = self._GetTypeFromScope( nested_package, extension_proto.extendee, scope) - self.SetFieldType(extension_proto, extension_desc, nested_package, scope) + self._SetFieldType(extension_proto, extension_desc, nested_package, scope) for nested_type in desc_proto.nested_type: - self.SetAllFieldTypes(nested_package, nested_type, scope) + self._SetAllFieldTypes(nested_package, nested_type, scope) - def SetFieldType(self, field_proto, field_desc, package, scope): + def _SetFieldType(self, field_proto, field_desc, package, scope): """Sets the field's type, cpp_type, message_type and enum_type. Args: diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index d976f9e1..72c2fa01 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -41,6 +41,146 @@ are: __author__ = 'petar@google.com (Petar Petrov)' +import sys + +if sys.version_info[0] < 3: + # We would use collections.MutableMapping all the time, but in Python 2 it + # doesn't define __slots__. This causes two significant problems: + # + # 1. we can't disallow arbitrary attribute assignment, even if our derived + # classes *do* define __slots__. + # + # 2. we can't safely derive a C type from it without __slots__ defined (the + # interpreter expects to find a dict at tp_dictoffset, which we can't + # robustly provide. And we don't want an instance dict anyway. + # + # So this is the Python 2.7 definition of Mapping/MutableMapping functions + # verbatim, except that: + # 1. We declare __slots__. + # 2. We don't declare this as a virtual base class. The classes defined + # in collections are the interesting base classes, not us. + # + # Note: deriving from object is critical. It is the only thing that makes + # this a true type, allowing us to derive from it in C++ cleanly and making + # __slots__ properly disallow arbitrary element assignment. + from collections import Mapping as _Mapping + + class Mapping(object): + __slots__ = () + + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + + def __contains__(self, key): + try: + self[key] + except KeyError: + return False + else: + return True + + def iterkeys(self): + return iter(self) + + def itervalues(self): + for key in self: + yield self[key] + + def iteritems(self): + for key in self: + yield (key, self[key]) + + def keys(self): + return list(self) + + def items(self): + return [(key, self[key]) for key in self] + + def values(self): + return [self[key] for key in self] + + # Mappings are not hashable by default, but subclasses can change this + __hash__ = None + + def __eq__(self, other): + if not isinstance(other, _Mapping): + return NotImplemented + return dict(self.items()) == dict(other.items()) + + def __ne__(self, other): + return not (self == other) + + class MutableMapping(Mapping): + __slots__ = () + + __marker = object() + + def pop(self, key, default=__marker): + try: + value = self[key] + except KeyError: + if default is self.__marker: + raise + return default + else: + del self[key] + return value + + def popitem(self): + try: + key = next(iter(self)) + except StopIteration: + raise KeyError + value = self[key] + del self[key] + return key, value + + def clear(self): + try: + while True: + self.popitem() + except KeyError: + pass + + def update(*args, **kwds): + if len(args) > 2: + raise TypeError("update() takes at most 2 positional " + "arguments ({} given)".format(len(args))) + elif not args: + raise TypeError("update() takes at least 1 argument (0 given)") + self = args[0] + other = args[1] if len(args) >= 2 else () + + if isinstance(other, Mapping): + for key in other: + self[key] = other[key] + elif hasattr(other, "keys"): + for key in other.keys(): + self[key] = other[key] + else: + for key, value in other: + self[key] = value + for key, value in kwds.items(): + self[key] = value + + def setdefault(self, key, default=None): + try: + return self[key] + except KeyError: + self[key] = default + return default + + _Mapping.register(Mapping) + +else: + # In Python 3 we can just use MutableMapping directly, because it defines + # __slots__. + from collections import MutableMapping + + class BaseContainer(object): """Base container class.""" @@ -286,3 +426,160 @@ class RepeatedCompositeFieldContainer(BaseContainer): raise TypeError('Can only compare repeated composite fields against ' 'other repeated composite fields.') return self._values == other._values + + +class ScalarMap(MutableMapping): + + """Simple, type-checked, dict-like container for holding repeated scalars.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_value_checker', '_values', '_message_listener'] + + def __init__(self, message_listener, key_checker, value_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._key_checker = key_checker + self._value_checker = value_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + val = self._value_checker.DefaultValue() + self._values[key] = val + return val + + def __contains__(self, item): + return item in self._values + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __setitem__(self, key, value): + checked_key = self._key_checker.CheckValue(key) + checked_value = self._value_checker.CheckValue(value) + self._values[checked_key] = checked_value + self._message_listener.Modified() + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def MergeFrom(self, other): + self._values.update(other._values) + self._message_listener.Modified() + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() + + +class MessageMap(MutableMapping): + + """Simple, type-checked, dict-like container for with submessage values.""" + + # Disallows assignment to other attributes. + __slots__ = ['_key_checker', '_values', '_message_listener', + '_message_descriptor'] + + def __init__(self, message_listener, message_descriptor, key_checker): + """ + Args: + message_listener: A MessageListener implementation. + The ScalarMap will call this object's Modified() method when it + is modified. + key_checker: A type_checkers.ValueChecker instance to run on keys + inserted into this container. + value_checker: A type_checkers.ValueChecker instance to run on values + inserted into this container. + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._key_checker = key_checker + self._values = {} + + def __getitem__(self, key): + try: + return self._values[key] + except KeyError: + key = self._key_checker.CheckValue(key) + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values[key] = new_element + self._message_listener.Modified() + + return new_element + + def get_or_create(self, key): + """get_or_create() is an alias for getitem (ie. map[key]). + + Args: + key: The key to get or create in the map. + + This is useful in cases where you want to be explicit that the call is + mutating the map. This can avoid lint errors for statements like this + that otherwise would appear to be pointless statements: + + msg.my_map[key] + """ + return self[key] + + # We need to override this explicitly, because our defaultdict-like behavior + # will make the default implementation (from our base class) always insert + # the key. + def get(self, key, default=None): + if key in self: + return self[key] + else: + return default + + def __contains__(self, item): + return item in self._values + + def __setitem__(self, key, value): + raise ValueError('May not set values directly, call my_map[key].foo = 5') + + def __delitem__(self, key): + del self._values[key] + self._message_listener.Modified() + + def __len__(self): + return len(self._values) + + def __iter__(self): + return iter(self._values) + + def MergeFrom(self, other): + for key in other: + self[key].MergeFrom(other[key]) + # self._message_listener.Modified() not required here, because + # mutations to submessages already propagate. + + # This is defined in the abstract base, but we can do it much more cheaply. + def clear(self): + self._values.clear() + self._message_listener.Modified() diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 0f500606..3837eaea 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -733,6 +733,50 @@ def MessageSetItemDecoder(extensions_by_number): return DecodeItem # -------------------------------------------------------------------- + +def MapDecoder(field_descriptor, new_default, is_message_map): + """Returns a decoder for a map field.""" + + key = field_descriptor + tag_bytes = encoder.TagBytes(field_descriptor.number, + wire_format.WIRETYPE_LENGTH_DELIMITED) + tag_len = len(tag_bytes) + local_DecodeVarint = _DecodeVarint + # Can't read _concrete_class yet; might not be initialized. + message_type = field_descriptor.message_type + + def DecodeMap(buffer, pos, end, message, field_dict): + submsg = message_type._concrete_class() + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + # Read length. + (size, pos) = local_DecodeVarint(buffer, pos) + new_pos = pos + size + if new_pos > end: + raise _DecodeError('Truncated message.') + # Read sub-message. + submsg.Clear() + if submsg._InternalParse(buffer, pos, new_pos) != new_pos: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise _DecodeError('Unexpected end-group tag.') + + if is_message_map: + value[submsg.key].MergeFrom(submsg.value) + else: + value[submsg.key] = submsg.value + + # Predict that the next tag is another copy of the same repeated field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos == end: + # Prediction failed. Return. + return new_pos + + return DecodeMap + +# -------------------------------------------------------------------- # Optimization is not as heavy here because calls to SkipField() are rare, # except for handling end-group tags. diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index 56fe14e9..8416e157 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -35,7 +35,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import unittest - from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 7d145f42..d159cc62 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -37,6 +37,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import os import unittest +import unittest from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation @@ -226,6 +227,13 @@ class DescriptorPoolTest(unittest.TestCase): db.Add(self.factory_test2_fd) self.testFindMessageTypeByName() + def testAddSerializedFile(self): + db = descriptor_database.DescriptorDatabase() + self.pool = descriptor_pool.DescriptorPool(db) + self.pool.AddSerializedFile(self.factory_test1_fd.SerializeToString()) + self.pool.AddSerializedFile(self.factory_test2_fd.SerializeToString()) + self.testFindMessageTypeByName() + def testComplexNesting(self): test1_desc = descriptor_pb2.FileDescriptorProto.FromString( descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) @@ -510,6 +518,43 @@ class AddDescriptorTest(unittest.TestCase): 'protobuf_unittest.TestAllTypes') +@unittest.skipIf( + api_implementation.Type() != 'cpp', + 'default_pool is only supported by the C++ implementation') +class DefaultPoolTest(unittest.TestCase): + + def testFindMethods(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + self.assertIs( + pool.FindFileByName('google/protobuf/unittest.proto'), + unittest_pb2.DESCRIPTOR) + self.assertIs( + pool.FindMessageTypeByName('protobuf_unittest.TestAllTypes'), + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertIs( + pool.FindFieldByName('protobuf_unittest.TestAllTypes.optional_int32'), + unittest_pb2.TestAllTypes.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertIs( + pool.FindExtensionByName('protobuf_unittest.optional_int32_extension'), + unittest_pb2.DESCRIPTOR.extensions_by_name['optional_int32_extension']) + self.assertIs( + pool.FindEnumTypeByName('protobuf_unittest.ForeignEnum'), + unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertIs( + pool.FindOneofByName('protobuf_unittest.TestAllTypes.oneof_field'), + unittest_pb2.TestAllTypes.DESCRIPTOR.oneofs_by_name['oneof_field']) + + def testAddFileDescriptor(self): + # pylint: disable=g-import-not-at-top + from google.protobuf.pyext import _message + pool = _message.default_pool + file_desc = descriptor_pb2.FileDescriptorProto(name='some/file.proto') + pool.Add(file_desc) + pool.AddSerializedFile(file_desc.SerializeToString()) + + TEST1_FILE = ProtoFile( 'google/protobuf/internal/descriptor_pool_test1.proto', 'google.protobuf.python.internal', diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 335caee6..26866f3a 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -35,8 +35,8 @@ __author__ = 'robinson@google.com (Will Robinson)' import sys -import unittest +import unittest from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 38a5138a..752f4eab 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -314,7 +314,7 @@ def MessageSizer(field_number, is_repeated, is_packed): # -------------------------------------------------------------------- -# MessageSet is special. +# MessageSet is special: it needs custom logic to compute its size properly. def MessageSetItemSizer(field_number): @@ -339,6 +339,32 @@ def MessageSetItemSizer(field_number): return FieldSize +# -------------------------------------------------------------------- +# Map is special: it needs custom logic to compute its size properly. + + +def MapSizer(field_descriptor): + """Returns a sizer for a map field.""" + + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + message_sizer = MessageSizer(field_descriptor.number, False, False) + + def FieldSize(map_value): + total = 0 + for key in map_value: + value = map_value[key] + # It's wasteful to create the messages and throw them away one second + # later since we'll do the same for the actual encode. But there's not an + # obvious way to avoid this within the current design without tons of code + # duplication. + entry_msg = message_type._concrete_class(key=key, value=value) + total += message_sizer(entry_msg) + return total + + return FieldSize + # ==================================================================== # Encoders! @@ -786,3 +812,30 @@ def MessageSetItemEncoder(field_number): return write(end_bytes) return EncodeField + + +# -------------------------------------------------------------------- +# As before, Map is special. + + +def MapEncoder(field_descriptor): + """Encoder for extensions of MessageSet. + + Maps always have a wire format like this: + message MapEntry { + key_type key = 1; + value_type value = 2; + } + repeated MapEntry map = N; + """ + # Can't look at field_descriptor.message_type._concrete_class because it may + # not have been initialized yet. + message_type = field_descriptor.message_type + encode_message = MessageEncoder(field_descriptor.number, False, False) + + def EncodeField(write, value): + for key in value: + entry_msg = message_type._concrete_class(key=key, value=value[key]) + encode_message(write, entry_msg) + + return EncodeField diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index ccc5860b..5c07cbe6 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -42,7 +42,6 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' import unittest - from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 626c3fc9..b8694f96 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -35,7 +35,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import unittest - from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 4ecaa1c7..320ff0d2 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -50,7 +50,9 @@ import pickle import sys import unittest +import unittest from google.protobuf.internal import _parameterized +from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation @@ -125,10 +127,17 @@ class MessageTest(unittest.TestCase): self.assertEqual(unpickled_message, golden_message) def testPositiveInfinity(self, message_module): - golden_data = (b'\x5D\x00\x00\x80\x7F' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' - b'\xCD\x02\x00\x00\x80\x7F' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + if message_module is unittest_pb2: + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + else: + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCA\x02\x04\x00\x00\x80\x7F' + b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.optional_float)) @@ -138,10 +147,17 @@ class MessageTest(unittest.TestCase): self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self, message_module): - golden_data = (b'\x5D\x00\x00\x80\xFF' - b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' - b'\xCD\x02\x00\x00\x80\xFF' - b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + if message_module is unittest_pb2: + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + else: + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCA\x02\x04\x00\x00\x80\xFF' + b'\xD2\x02\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) @@ -1034,64 +1050,132 @@ class Proto2Test(unittest.TestCase): self.assertEqual(len(parsing_merge.Extensions[ unittest_pb2.TestParsingMerge.repeated_ext]), 3) + def testPythonicInit(self): + message = unittest_pb2.TestAllTypes( + optional_int32=100, + optional_fixed32=200, + optional_float=300.5, + optional_bytes=b'x', + optionalgroup={'a': 400}, + optional_nested_message={'bb': 500}, + optional_nested_enum='BAZ', + repeatedgroup=[{'a': 600}, + {'a': 700}], + repeated_nested_enum=['FOO', unittest_pb2.TestAllTypes.BAR], + default_int32=800, + oneof_string='y') + self.assertTrue(isinstance(message, unittest_pb2.TestAllTypes)) + self.assertEqual(100, message.optional_int32) + self.assertEqual(200, message.optional_fixed32) + self.assertEqual(300.5, message.optional_float) + self.assertEqual(b'x', message.optional_bytes) + self.assertEqual(400, message.optionalgroup.a) + self.assertTrue(isinstance(message.optional_nested_message, + unittest_pb2.TestAllTypes.NestedMessage)) + self.assertEqual(500, message.optional_nested_message.bb) + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + self.assertEqual(2, len(message.repeatedgroup)) + self.assertEqual(600, message.repeatedgroup[0].a) + self.assertEqual(700, message.repeatedgroup[1].a) + self.assertEqual(2, len(message.repeated_nested_enum)) + self.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.repeated_nested_enum[0]) + self.assertEqual(unittest_pb2.TestAllTypes.BAR, + message.repeated_nested_enum[1]) + self.assertEqual(800, message.default_int32) + self.assertEqual('y', message.oneof_string) + self.assertFalse(message.HasField('optional_int64')) + self.assertEqual(0, len(message.repeated_float)) + self.assertEqual(42, message.default_int64) + + message = unittest_pb2.TestAllTypes(optional_nested_enum=u'BAZ') + self.assertEqual(unittest_pb2.TestAllTypes.BAZ, + message.optional_nested_enum) + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes( + optional_nested_message={'INVALID_NESTED_FIELD': 17}) + + with self.assertRaises(TypeError): + unittest_pb2.TestAllTypes( + optional_nested_message={'bb': 'INVALID_VALUE_TYPE'}) + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes(optional_nested_enum='INVALID_LABEL') + + with self.assertRaises(ValueError): + unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') + # Class to test proto3-only features/behavior (updated field presence & enums) class Proto3Test(unittest.TestCase): + # Utility method for comparing equality with a map. + def assertMapIterEquals(self, map_iter, dict_value): + # Avoid mutating caller's copy. + dict_value = dict(dict_value) + + for k, v in map_iter: + self.assertEqual(v, dict_value[k]) + del dict_value[k] + + self.assertEqual({}, dict_value) + def testFieldPresence(self): message = unittest_proto3_arena_pb2.TestAllTypes() # We can't test presence of non-repeated, non-submessage fields. with self.assertRaises(ValueError): - message.HasField("optional_int32") + message.HasField('optional_int32') with self.assertRaises(ValueError): - message.HasField("optional_float") + message.HasField('optional_float') with self.assertRaises(ValueError): - message.HasField("optional_string") + message.HasField('optional_string') with self.assertRaises(ValueError): - message.HasField("optional_bool") + message.HasField('optional_bool') # But we can still test presence of submessage fields. - self.assertFalse(message.HasField("optional_nested_message")) + self.assertFalse(message.HasField('optional_nested_message')) # As with proto2, we can't test presence of fields that don't exist, or # repeated fields. with self.assertRaises(ValueError): - message.HasField("field_doesnt_exist") + message.HasField('field_doesnt_exist') with self.assertRaises(ValueError): - message.HasField("repeated_int32") + message.HasField('repeated_int32') with self.assertRaises(ValueError): - message.HasField("repeated_nested_message") + message.HasField('repeated_nested_message') # Fields should default to their type-specific default. self.assertEqual(0, message.optional_int32) self.assertEqual(0, message.optional_float) - self.assertEqual("", message.optional_string) + self.assertEqual('', message.optional_string) self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb) # Setting a submessage should still return proper presence information. message.optional_nested_message.bb = 0 - self.assertTrue(message.HasField("optional_nested_message")) + self.assertTrue(message.HasField('optional_nested_message')) # Set the fields to non-default values. message.optional_int32 = 5 message.optional_float = 1.1 - message.optional_string = "abc" + message.optional_string = 'abc' message.optional_bool = True message.optional_nested_message.bb = 15 # Clearing the fields unsets them and resets their value to default. - message.ClearField("optional_int32") - message.ClearField("optional_float") - message.ClearField("optional_string") - message.ClearField("optional_bool") - message.ClearField("optional_nested_message") + message.ClearField('optional_int32') + message.ClearField('optional_float') + message.ClearField('optional_string') + message.ClearField('optional_bool') + message.ClearField('optional_nested_message') self.assertEqual(0, message.optional_int32) self.assertEqual(0, message.optional_float) - self.assertEqual("", message.optional_string) + self.assertEqual('', message.optional_string) self.assertEqual(False, message.optional_bool) self.assertEqual(0, message.optional_nested_message.bb) @@ -1113,6 +1197,393 @@ class Proto3Test(unittest.TestCase): self.assertEqual(1234567, m2.optional_nested_enum) self.assertEqual(7654321, m2.repeated_nested_enum[0]) + # Map isn't really a proto3-only feature. But there is no proto2 equivalent + # of google/protobuf/map_unittest.proto right now, so it's not easy to + # test both with the same test like we do for the other proto2/proto3 tests. + # (google/protobuf/map_protobuf_unittest.proto is very different in the set + # of messages and fields it contains). + def testScalarMapDefaults(self): + msg = map_unittest_pb2.TestMap() + + # Scalars start out unset. + self.assertFalse(-123 in msg.map_int32_int32) + self.assertFalse(-2**33 in msg.map_int64_int64) + self.assertFalse(123 in msg.map_uint32_uint32) + self.assertFalse(2**33 in msg.map_uint64_uint64) + self.assertFalse('abc' in msg.map_string_string) + self.assertFalse(888 in msg.map_int32_enum) + + # Accessing an unset key returns the default. + self.assertEqual(0, msg.map_int32_int32[-123]) + self.assertEqual(0, msg.map_int64_int64[-2**33]) + self.assertEqual(0, msg.map_uint32_uint32[123]) + self.assertEqual(0, msg.map_uint64_uint64[2**33]) + self.assertEqual('', msg.map_string_string['abc']) + self.assertEqual(0, msg.map_int32_enum[888]) + + # It also sets the value in the map + self.assertTrue(-123 in msg.map_int32_int32) + self.assertTrue(-2**33 in msg.map_int64_int64) + self.assertTrue(123 in msg.map_uint32_uint32) + self.assertTrue(2**33 in msg.map_uint64_uint64) + self.assertTrue('abc' in msg.map_string_string) + self.assertTrue(888 in msg.map_int32_enum) + + self.assertTrue(isinstance(msg.map_string_string['abc'], unicode)) + + # Accessing an unset key still throws TypeError of the type of the key + # is incorrect. + with self.assertRaises(TypeError): + msg.map_string_string[123] + + self.assertFalse(123 in msg.map_string_string) + + def testMapGet(self): + # Need to test that get() properly returns the default, even though the dict + # has defaultdict-like semantics. + msg = map_unittest_pb2.TestMap() + + self.assertIsNone(msg.map_int32_int32.get(5)) + self.assertEquals(10, msg.map_int32_int32.get(5, 10)) + self.assertIsNone(msg.map_int32_int32.get(5)) + + msg.map_int32_int32[5] = 15 + self.assertEquals(15, msg.map_int32_int32.get(5)) + + self.assertIsNone(msg.map_int32_foreign_message.get(5)) + self.assertEquals(10, msg.map_int32_foreign_message.get(5, 10)) + + submsg = msg.map_int32_foreign_message[5] + self.assertIs(submsg, msg.map_int32_foreign_message.get(5)) + + def testScalarMap(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_int32)) + self.assertFalse(5 in msg.map_int32_int32) + + msg.map_int32_int32[-123] = -456 + msg.map_int64_int64[-2**33] = -2**34 + msg.map_uint32_uint32[123] = 456 + msg.map_uint64_uint64[2**33] = 2**34 + msg.map_string_string['abc'] = '123' + msg.map_int32_enum[888] = 2 + + self.assertEqual([], msg.FindInitializationErrors()) + + self.assertEqual(1, len(msg.map_string_string)) + + # Bad key. + with self.assertRaises(TypeError): + msg.map_string_string[123] = '123' + + # Verify that trying to assign a bad key doesn't actually add a member to + # the map. + self.assertEqual(1, len(msg.map_string_string)) + + # Bad value. + with self.assertRaises(TypeError): + msg.map_string_string['123'] = 123 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + # Bad key. + with self.assertRaises(TypeError): + msg2.map_string_string[123] = '123' + + # Bad value. + with self.assertRaises(TypeError): + msg2.map_string_string['123'] = 123 + + self.assertEqual(-456, msg2.map_int32_int32[-123]) + self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) + self.assertEqual(456, msg2.map_uint32_uint32[123]) + self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + self.assertEqual('123', msg2.map_string_string['abc']) + self.assertEqual(2, msg2.map_int32_enum[888]) + + def testStringUnicodeConversionInMap(self): + msg = map_unittest_pb2.TestMap() + + unicode_obj = u'\u1234' + bytes_obj = unicode_obj.encode('utf8') + + msg.map_string_string[bytes_obj] = bytes_obj + + (key, value) = msg.map_string_string.items()[0] + + self.assertEqual(key, unicode_obj) + self.assertEqual(value, unicode_obj) + + self.assertTrue(isinstance(key, unicode)) + self.assertTrue(isinstance(value, unicode)) + + def testMessageMap(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_foreign_message)) + self.assertFalse(5 in msg.map_int32_foreign_message) + + msg.map_int32_foreign_message[123] + # get_or_create() is an alias for getitem. + msg.map_int32_foreign_message.get_or_create(-456) + + self.assertEqual(2, len(msg.map_int32_foreign_message)) + self.assertIn(123, msg.map_int32_foreign_message) + self.assertIn(-456, msg.map_int32_foreign_message) + self.assertEqual(2, len(msg.map_int32_foreign_message)) + + # Bad key. + with self.assertRaises(TypeError): + msg.map_int32_foreign_message['123'] + + # Can't assign directly to submessage. + with self.assertRaises(ValueError): + msg.map_int32_foreign_message[999] = msg.map_int32_foreign_message[123] + + # Verify that trying to assign a bad key doesn't actually add a member to + # the map. + self.assertEqual(2, len(msg.map_int32_foreign_message)) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(2, len(msg2.map_int32_foreign_message)) + self.assertIn(123, msg2.map_int32_foreign_message) + self.assertIn(-456, msg2.map_int32_foreign_message) + self.assertEqual(2, len(msg2.map_int32_foreign_message)) + + def testMergeFrom(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[12] = 34 + msg.map_int32_int32[56] = 78 + msg.map_int64_int64[22] = 33 + msg.map_int32_foreign_message[111].c = 5 + msg.map_int32_foreign_message[222].c = 10 + + msg2 = map_unittest_pb2.TestMap() + msg2.map_int32_int32[12] = 55 + msg2.map_int64_int64[88] = 99 + msg2.map_int32_foreign_message[222].c = 15 + + msg2.MergeFrom(msg) + + self.assertEqual(34, msg2.map_int32_int32[12]) + self.assertEqual(78, msg2.map_int32_int32[56]) + self.assertEqual(33, msg2.map_int64_int64[22]) + self.assertEqual(99, msg2.map_int64_int64[88]) + self.assertEqual(5, msg2.map_int32_foreign_message[111].c) + self.assertEqual(10, msg2.map_int32_foreign_message[222].c) + + # Verify that there is only one entry per key, even though the MergeFrom + # may have internally created multiple entries for a single key in the + # list representation. + as_dict = {} + for key in msg2.map_int32_foreign_message: + self.assertFalse(key in as_dict) + as_dict[key] = msg2.map_int32_foreign_message[key].c + + self.assertEqual({111: 5, 222: 10}, as_dict) + + # Special case: test that delete of item really removes the item, even if + # there might have physically been duplicate keys due to the previous merge. + # This is only a special case for the C++ implementation which stores the + # map as an array. + del msg2.map_int32_int32[12] + self.assertFalse(12 in msg2.map_int32_int32) + + del msg2.map_int32_foreign_message[222] + self.assertFalse(222 in msg2.map_int32_foreign_message) + + def testIntegerMapWithLongs(self): + msg = map_unittest_pb2.TestMap() + msg.map_int32_int32[long(-123)] = long(-456) + msg.map_int64_int64[long(-2**33)] = long(-2**34) + msg.map_uint32_uint32[long(123)] = long(456) + msg.map_uint64_uint64[long(2**33)] = long(2**34) + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(-456, msg2.map_int32_int32[-123]) + self.assertEqual(-2**34, msg2.map_int64_int64[-2**33]) + self.assertEqual(456, msg2.map_uint32_uint32[123]) + self.assertEqual(2**34, msg2.map_uint64_uint64[2**33]) + + def testMapAssignmentCausesPresence(self): + msg = map_unittest_pb2.TestMapSubmessage() + msg.test_map.map_int32_int32[123] = 456 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMapSubmessage() + msg2.ParseFromString(serialized) + + self.assertEqual(msg, msg2) + + # Now test that various mutations of the map properly invalidate the + # cached size of the submessage. + msg.test_map.map_int32_int32[888] = 999 + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_int32.clear() + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + def testMapAssignmentCausesPresenceForSubmessages(self): + msg = map_unittest_pb2.TestMapSubmessage() + msg.test_map.map_int32_foreign_message[123].c = 5 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMapSubmessage() + msg2.ParseFromString(serialized) + + self.assertEqual(msg, msg2) + + # Now test that various mutations of the map properly invalidate the + # cached size of the submessage. + msg.test_map.map_int32_foreign_message[888].c = 7 + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_foreign_message[888].MergeFrom( + msg.test_map.map_int32_foreign_message[123]) + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + msg.test_map.map_int32_foreign_message.clear() + serialized = msg.SerializeToString() + msg2.ParseFromString(serialized) + self.assertEqual(msg, msg2) + + def testModifyMapWhileIterating(self): + msg = map_unittest_pb2.TestMap() + + string_string_iter = iter(msg.map_string_string) + int32_foreign_iter = iter(msg.map_int32_foreign_message) + + msg.map_string_string['abc'] = '123' + msg.map_int32_foreign_message[5].c = 5 + + with self.assertRaises(RuntimeError): + for key in string_string_iter: + pass + + with self.assertRaises(RuntimeError): + for key in int32_foreign_iter: + pass + + def testSubmessageMap(self): + msg = map_unittest_pb2.TestMap() + + submsg = msg.map_int32_foreign_message[111] + self.assertIs(submsg, msg.map_int32_foreign_message[111]) + self.assertTrue(isinstance(submsg, unittest_pb2.ForeignMessage)) + + submsg.c = 5 + + serialized = msg.SerializeToString() + msg2 = map_unittest_pb2.TestMap() + msg2.ParseFromString(serialized) + + self.assertEqual(5, msg2.map_int32_foreign_message[111].c) + + # Doesn't allow direct submessage assignment. + with self.assertRaises(ValueError): + msg.map_int32_foreign_message[88] = unittest_pb2.ForeignMessage() + + def testMapIteration(self): + msg = map_unittest_pb2.TestMap() + + for k, v in msg.map_int32_int32.iteritems(): + # Should not be reached. + self.assertTrue(False) + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + self.assertEqual(3, len(msg.map_int32_int32)) + + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(msg.map_int32_int32.iteritems(), matching_dict) + + def testMapIterationClearMessage(self): + # Iterator needs to work even if message and map are deleted. + msg = map_unittest_pb2.TestMap() + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + + it = msg.map_int32_int32.iteritems() + del msg + + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(it, matching_dict) + + def testMapConstruction(self): + msg = map_unittest_pb2.TestMap(map_int32_int32={1: 2, 3: 4}) + self.assertEqual(2, msg.map_int32_int32[1]) + self.assertEqual(4, msg.map_int32_int32[3]) + + msg = map_unittest_pb2.TestMap( + map_int32_foreign_message={3: unittest_pb2.ForeignMessage(c=5)}) + self.assertEqual(5, msg.map_int32_foreign_message[3].c) + + def testMapValidAfterFieldCleared(self): + # Map needs to work even if field is cleared. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + map = msg.map_int32_int32 + + map[2] = 4 + map[3] = 6 + map[4] = 8 + + msg.ClearField('map_int32_int32') + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(map.iteritems(), matching_dict) + + def testMapIterValidAfterFieldCleared(self): + # Map iterator needs to work even if field is cleared. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + + msg.map_int32_int32[2] = 4 + msg.map_int32_int32[3] = 6 + msg.map_int32_int32[4] = 8 + + it = msg.map_int32_int32.iteritems() + + msg.ClearField('map_int32_int32') + matching_dict = {2: 4, 3: 6, 4: 8} + self.assertMapIterEquals(it, matching_dict) + + def testMapDelete(self): + msg = map_unittest_pb2.TestMap() + + self.assertEqual(0, len(msg.map_int32_int32)) + + msg.map_int32_int32[4] = 6 + self.assertEqual(1, len(msg.map_int32_int32)) + + with self.assertRaises(KeyError): + del msg.map_int32_int32[88] + + del msg.map_int32_int32[4] + self.assertEqual(0, len(msg.map_int32_int32)) + + class ValidTypeNamesTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py index b1e57f35..edaf3fa3 100644 --- a/python/google/protobuf/internal/proto_builder_test.py +++ b/python/google/protobuf/internal/proto_builder_test.py @@ -32,6 +32,7 @@ """Tests for google.protobuf.proto_builder.""" +import collections import unittest from google.protobuf import descriptor_pb2 @@ -43,10 +44,11 @@ from google.protobuf import text_format class ProtoBuilderTest(unittest.TestCase): def setUp(self): - self._fields = { - 'foo': descriptor_pb2.FieldDescriptorProto.TYPE_INT64, - 'bar': descriptor_pb2.FieldDescriptorProto.TYPE_STRING, - } + self.ordered_fields = collections.OrderedDict([ + ('foo', descriptor_pb2.FieldDescriptorProto.TYPE_INT64), + ('bar', descriptor_pb2.FieldDescriptorProto.TYPE_STRING), + ]) + self._fields = dict(self.ordered_fields) def testMakeSimpleProtoClass(self): """Test that we can create a proto class.""" @@ -59,6 +61,17 @@ class ProtoBuilderTest(unittest.TestCase): self.assertMultiLineEqual( 'bar: "asdf"\nfoo: 12345\n', text_format.MessageToString(proto)) + def testOrderedFields(self): + """Test that the field order is maintained when given an OrderedDict.""" + proto_cls = proto_builder.MakeSimpleProtoClass( + self.ordered_fields, + full_name='net.proto2.python.public.proto_builder_test.OrderedTest') + proto = proto_cls() + proto.foo = 12345 + proto.bar = 'asdf' + self.assertMultiLineEqual( + 'foo: 12345\nbar: "asdf"\n', text_format.MessageToString(proto)) + def testMakeSameProtoClassTwice(self): """Test that the DescriptorPool is used.""" pool = descriptor_pool.DescriptorPool() diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 54f584ae..ca9f7675 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -61,9 +61,11 @@ if sys.version_info[0] < 3: except ImportError: from StringIO import StringIO as BytesIO import copy_reg as copyreg + _basestring = basestring else: from io import BytesIO import copyreg + _basestring = str import struct import weakref @@ -77,6 +79,7 @@ from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod +from google.protobuf import symbol_database from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor @@ -101,6 +104,7 @@ def InitMessage(descriptor, cls): for field in descriptor.fields: _AttachFieldHelpers(cls, field) + descriptor._concrete_class = cls # pylint: disable=protected-access _AddEnumValues(descriptor, cls) _AddInitMethod(descriptor, cls) _AddPropertiesForFields(descriptor, cls) @@ -198,12 +202,37 @@ def _IsMessageSetExtension(field): field.label == _FieldDescriptor.LABEL_OPTIONAL) +def _IsMapField(field): + return (field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _IsMessageMapField(field): + value_type = field.message_type.fields_by_name["value"] + return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + + def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) - is_packed = (field_descriptor.has_options and - field_descriptor.GetOptions().packed) - - if _IsMessageSetExtension(field_descriptor): + is_packable = (is_repeated and + wire_format.IsTypePackable(field_descriptor.type)) + if not is_packable: + is_packed = False + elif field_descriptor.containing_type.syntax == "proto2": + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + else: + has_packed_false = (field_descriptor.has_options and + field_descriptor.GetOptions().HasField("packed") and + field_descriptor.GetOptions().packed == False) + is_packed = not has_packed_false + is_map_entry = _IsMapField(field_descriptor) + + if is_map_entry: + field_encoder = encoder.MapEncoder(field_descriptor) + sizer = encoder.MapSizer(field_descriptor) + elif _IsMessageSetExtension(field_descriptor): field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) sizer = encoder.MessageSetItemSizer(field_descriptor.number) else: @@ -228,9 +257,16 @@ def _AttachFieldHelpers(cls, field_descriptor): if field_descriptor.containing_oneof is not None: oneof_descriptor = field_descriptor - field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor) + if is_map_entry: + is_message_map = _IsMessageMapField(field_descriptor) + + field_decoder = decoder.MapDecoder( + field_descriptor, _GetInitializeDefaultForMap(field_descriptor), + is_message_map) + else: + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) @@ -265,6 +301,26 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) +def _GetInitializeDefaultForMap(field): + if field.label != _FieldDescriptor.LABEL_REPEATED: + raise ValueError('map_entry set on non-repeated field %s' % ( + field.name)) + fields_by_name = field.message_type.fields_by_name + key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) + + value_field = fields_by_name['value'] + if _IsMessageMapField(field): + def MakeMessageMapDefault(message): + return containers.MessageMap( + message._listener_for_children, value_field.message_type, key_checker) + return MakeMessageMapDefault + else: + value_checker = type_checkers.GetTypeChecker(value_field) + def MakePrimitiveMapDefault(message): + return containers.ScalarMap( + message._listener_for_children, key_checker, value_checker) + return MakePrimitiveMapDefault + def _DefaultValueConstructorForField(field): """Returns a function which returns a default value for a field. @@ -279,6 +335,9 @@ def _DefaultValueConstructorForField(field): value may refer back to |message| via a weak reference. """ + if _IsMapField(field): + return _GetInitializeDefaultForMap(field) + if field.label == _FieldDescriptor.LABEL_REPEATED: if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( @@ -329,7 +388,22 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name): def _AddInitMethod(message_descriptor, cls): """Adds an __init__ method to cls.""" - fields = message_descriptor.fields + + def _GetIntegerEnumValue(enum_type, value): + """Convert a string or integer enum value to an integer. + + If the value is a string, it is converted to the enum value in + enum_type with the same name. If the value is not a string, it's + returned as-is. (No conversion or bounds-checking is done.) + """ + if isinstance(value, _basestring): + try: + return enum_type.values_by_name[value].number + except KeyError: + raise ValueError('Enum type %s: unknown label "%s"' % ( + enum_type.full_name, value)) + return value + def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 @@ -352,19 +426,37 @@ def _AddInitMethod(message_descriptor, cls): if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite - for val in field_value: - copy.add().MergeFrom(val) + if _IsMapField(field): + if _IsMessageMapField(field): + for key in field_value: + copy[key].MergeFrom(field_value[key]) + else: + copy.update(field_value) + else: + for val in field_value: + if isinstance(val, dict): + copy.add(**val) + else: + copy.add().MergeFrom(val) else: # Scalar + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = [_GetIntegerEnumValue(field.enum_type, val) + for val in field_value] copy.extend(field_value) self._fields[field] = copy elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: copy = field._default_constructor(self) + new_val = field_value + if isinstance(field_value, dict): + new_val = field.message_type._concrete_class(**field_value) try: - copy.MergeFrom(field_value) + copy.MergeFrom(new_val) except TypeError: _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) self._fields[field] = copy else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = _GetIntegerEnumValue(field.enum_type, field_value) try: setattr(self, field_name, field_value) except TypeError: @@ -758,6 +850,26 @@ def _AddHasExtensionMethod(cls): return extension_handle in self._fields cls.HasExtension = HasExtension +def _UnpackAny(msg): + type_url = msg.type_url + db = symbol_database.Default() + + if not type_url: + return None + + # TODO(haberman): For now we just strip the hostname. Better logic will be + # required. + type_name = type_url.split("/")[-1] + descriptor = db.pool.FindMessageTypeByName(type_name) + + if descriptor is None: + return None + + message_class = db.GetPrototype(descriptor) + message = message_class() + + message.ParseFromString(msg.value) + return message def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -769,6 +881,12 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True + if self.DESCRIPTOR.full_name == "google.protobuf.Any": + any_a = _UnpackAny(self) + any_b = _UnpackAny(other) + if any_a and any_b: + return any_a == any_b + if not self.ListFields() == other.ListFields(): return False @@ -961,6 +1079,9 @@ def _AddIsInitializedMethod(message_descriptor, cls): for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: + if (field.message_type.has_options and + field.message_type.GetOptions().map_entry): + continue for element in value: if not element.IsInitialized(): if errors is not None: @@ -996,16 +1117,26 @@ def _AddIsInitializedMethod(message_descriptor, cls): else: name = field.name - if field.label == _FieldDescriptor.LABEL_REPEATED: + if _IsMapField(field): + if _IsMessageMapField(field): + for key in value: + element = value[key] + prefix = "%s[%d]." % (name, key) + sub_errors = element.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + else: + # ScalarMaps can't have any initialization errors. + pass + elif field.label == _FieldDescriptor.LABEL_REPEATED: for i in xrange(len(value)): element = value[i] prefix = "%s[%d]." % (name, i) sub_errors = element.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] else: prefix = name + "." sub_errors = value.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] return errors diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index ae79c78b..4eca4989 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -39,8 +39,8 @@ import copy import gc import operator import struct -import unittest +import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -1798,8 +1798,8 @@ class ReflectionTest(unittest.TestCase): def testBadArguments(self): # Some of these assertions used to segfault. from google.protobuf.pyext import _message - self.assertRaises(TypeError, _message.Message._GetFieldDescriptor, 3) - self.assertRaises(TypeError, _message.Message._GetExtensionDescriptor, 42) + self.assertRaises(TypeError, _message.default_pool.FindFieldByName, 3) + self.assertRaises(TypeError, _message.default_pool.FindExtensionByName, 42) self.assertRaises(TypeError, unittest_pb2.TestAllTypes().__getattribute__, 42) diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index de462124..9967255a 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -35,7 +35,6 @@ __author__ = 'petar@google.com (Petar Petrov)' import unittest - from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service @@ -81,7 +80,7 @@ class FooUnitTest(unittest.TestCase): self.assertEqual('Method Bar not implemented.', rpc_controller.failure_message) self.assertEqual(None, self.callback_response) - + class MyServiceImpl(unittest_pb2.TestService): def Foo(self, rpc_controller, request, done): self.foo_called = True diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py index c888aff7..80b83bc2 100644 --- a/python/google/protobuf/internal/symbol_database_test.py +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -33,7 +33,6 @@ """Tests for google.protobuf.symbol_database.""" import unittest - from google.protobuf import unittest_pb2 from google.protobuf import symbol_database diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index d84e3836..0cbdbad9 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -75,7 +75,8 @@ def SetAllNonLazyFields(message): message.optional_string = u'115' message.optional_bytes = b'116' - message.optionalgroup.a = 117 + if IsProto2(message): + message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 @@ -109,7 +110,8 @@ def SetAllNonLazyFields(message): message.repeated_string.append(u'215') message.repeated_bytes.append(b'216') - message.repeatedgroup.add().a = 217 + if IsProto2(message): + message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 @@ -140,7 +142,8 @@ def SetAllNonLazyFields(message): message.repeated_string.append(u'315') message.repeated_bytes.append(b'316') - message.repeatedgroup.add().a = 317 + if IsProto2(message): + message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 @@ -396,7 +399,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertTrue(message.HasField('optional_string')) test_case.assertTrue(message.HasField('optional_bytes')) - test_case.assertTrue(message.HasField('optionalgroup')) + if IsProto2(message): + test_case.assertTrue(message.HasField('optionalgroup')) test_case.assertTrue(message.HasField('optional_nested_message')) test_case.assertTrue(message.HasField('optional_foreign_message')) test_case.assertTrue(message.HasField('optional_import_message')) @@ -430,7 +434,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('115', message.optional_string) test_case.assertEqual(b'116', message.optional_bytes) - test_case.assertEqual(117, message.optionalgroup.a) + if IsProto2(message): + test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) @@ -463,7 +468,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(2, len(message.repeated_string)) test_case.assertEqual(2, len(message.repeated_bytes)) - test_case.assertEqual(2, len(message.repeatedgroup)) + if IsProto2(message): + test_case.assertEqual(2, len(message.repeatedgroup)) test_case.assertEqual(2, len(message.repeated_nested_message)) test_case.assertEqual(2, len(message.repeated_foreign_message)) test_case.assertEqual(2, len(message.repeated_import_message)) @@ -491,7 +497,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('215', message.repeated_string[0]) test_case.assertEqual(b'216', message.repeated_bytes[0]) - test_case.assertEqual(217, message.repeatedgroup[0].a) + if IsProto2(message): + test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) @@ -521,7 +528,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual('315', message.repeated_string[1]) test_case.assertEqual(b'316', message.repeated_bytes[1]) - test_case.assertEqual(317, message.repeatedgroup[1].a) + if IsProto2(message): + test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py index 27896b94..5df13b78 100755 --- a/python/google/protobuf/internal/text_encoding_test.py +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -33,7 +33,6 @@ """Tests for google.protobuf.text_encoding.""" import unittest - from google.protobuf import text_encoding TEST_VALUES = [ diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index bf7e06ee..06bd1ee5 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -37,8 +37,10 @@ __author__ = 'kenton@google.com (Kenton Varda)' import re import unittest +import unittest from google.protobuf.internal import _parameterized +from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 @@ -309,31 +311,6 @@ class TextFormatTest(TextFormatBase): r'"unknown_field".'), text_format.Parse, text, message) - def testParseGroupNotClosed(self, message_module): - message = message_module.TestAllTypes() - text = 'RepeatedGroup: <' - self.assertRaisesRegexp( - text_format.ParseError, '1:16 : Expected ">".', - text_format.Parse, text, message) - - text = 'RepeatedGroup: {' - self.assertRaisesRegexp( - text_format.ParseError, '1:16 : Expected "}".', - text_format.Parse, text, message) - - def testParseEmptyGroup(self, message_module): - message = message_module.TestAllTypes() - text = 'OptionalGroup: {}' - text_format.Parse(text, message) - self.assertTrue(message.HasField('optionalgroup')) - - message.Clear() - - message = message_module.TestAllTypes() - text = 'OptionalGroup: <>' - text_format.Parse(text, message) - self.assertTrue(message.HasField('optionalgroup')) - def testParseBadEnumValue(self, message_module): message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' @@ -408,6 +385,14 @@ class TextFormatTest(TextFormatBase): # Ideally the schemas would be made more similar so these tests could pass. class OnlyWorksWithProto2RightNowTests(TextFormatBase): + def testPrintAllFieldsPointy(self, message_module): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + def testParseGolden(self): golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) parsed_message = unittest_pb2.TestAllTypes() @@ -471,8 +456,49 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): test_util.SetAllFields(message) self.assertEquals(message, parsed_message) + def testPrintMap(self): + message = map_unittest_pb2.TestMap() -# Tests of proto2-only features (MessageSet and extensions). + message.map_int32_int32[-123] = -456 + message.map_int64_int64[-2**33] = -2**34 + message.map_uint32_uint32[123] = 456 + message.map_uint64_uint64[2**33] = 2**34 + message.map_string_string["abc"] = "123" + message.map_int32_foreign_message[111].c = 5 + + # Maps are serialized to text format using their underlying repeated + # representation. + self.CompareToGoldenText( + text_format.MessageToString(message), + 'map_int32_int32 {\n' + ' key: -123\n' + ' value: -456\n' + '}\n' + 'map_int64_int64 {\n' + ' key: -8589934592\n' + ' value: -17179869184\n' + '}\n' + 'map_uint32_uint32 {\n' + ' key: 123\n' + ' value: 456\n' + '}\n' + 'map_uint64_uint64 {\n' + ' key: 8589934592\n' + ' value: 17179869184\n' + '}\n' + 'map_string_string {\n' + ' key: "abc"\n' + ' value: "123"\n' + '}\n' + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 5\n' + ' }\n' + '}\n') + + +# Tests of proto2-only features (MessageSet, extensions, etc.). class Proto2Tests(TextFormatBase): def testPrintMessageSet(self): @@ -620,6 +646,69 @@ class Proto2Tests(TextFormatBase): 'have multiple "optional_int32" fields.'), text_format.Parse, text, message) + def testParseGroupNotClosed(self): + message = unittest_pb2.TestAllTypes() + text = 'RepeatedGroup: <' + self.assertRaisesRegexp( + text_format.ParseError, '1:16 : Expected ">".', + text_format.Parse, text, message) + text = 'RepeatedGroup: {' + self.assertRaisesRegexp( + text_format.ParseError, '1:16 : Expected "}".', + text_format.Parse, text, message) + + def testParseEmptyGroup(self): + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: {}' + text_format.Parse(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + message.Clear() + + message = unittest_pb2.TestAllTypes() + text = 'OptionalGroup: <>' + text_format.Parse(text, message) + self.assertTrue(message.HasField('optionalgroup')) + + # Maps aren't really proto2-only, but our test schema only has maps for + # proto2. + def testParseMap(self): + text = ('map_int32_int32 {\n' + ' key: -123\n' + ' value: -456\n' + '}\n' + 'map_int64_int64 {\n' + ' key: -8589934592\n' + ' value: -17179869184\n' + '}\n' + 'map_uint32_uint32 {\n' + ' key: 123\n' + ' value: 456\n' + '}\n' + 'map_uint64_uint64 {\n' + ' key: 8589934592\n' + ' value: 17179869184\n' + '}\n' + 'map_string_string {\n' + ' key: "abc"\n' + ' value: "123"\n' + '}\n' + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 5\n' + ' }\n' + '}\n') + message = map_unittest_pb2.TestMap() + text_format.Parse(text, message) + + self.assertEqual(-456, message.map_int32_int32[-123]) + self.assertEqual(-2**34, message.map_int64_int64[-2**33]) + self.assertEqual(456, message.map_uint32_uint32[123]) + self.assertEqual(2**34, message.map_uint64_uint64[2**33]) + self.assertEqual("123", message.map_string_string["abc"]) + self.assertEqual(5, message.map_int32_foreign_message[111].c) + class TokenizerTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 76c056c4..f20e526a 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -129,6 +129,9 @@ class IntValueChecker(object): proposed_value = self._TYPE(proposed_value) return proposed_value + def DefaultValue(self): + return 0 + class EnumValueChecker(object): @@ -146,6 +149,9 @@ class EnumValueChecker(object): raise ValueError('Unknown enum value: %d' % proposed_value) return proposed_value + def DefaultValue(self): + return self._enum_type.values[0].number + class UnicodeValueChecker(object): @@ -171,6 +177,9 @@ class UnicodeValueChecker(object): (proposed_value)) return proposed_value + def DefaultValue(self): + return u"" + class Int32ValueChecker(IntValueChecker): # We're sure to use ints instead of longs here since comparison may be more diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 9337ae8a..1b81ae79 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -36,7 +36,6 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' import unittest - from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index 5cd7fcb9..78dc1167 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -35,7 +35,6 @@ __author__ = 'robinson@google.com (Will Robinson)' import unittest - from google.protobuf import message from google.protobuf.internal import wire_format diff --git a/python/google/protobuf/proto_builder.py b/python/google/protobuf/proto_builder.py index 1fa28f1a..7489cf63 100644 --- a/python/google/protobuf/proto_builder.py +++ b/python/google/protobuf/proto_builder.py @@ -30,6 +30,7 @@ """Dynamic Protobuf class creator.""" +import collections import hashlib import os @@ -59,7 +60,9 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): Note: this doesn't validate field names! Args: - fields: dict of {name: field_type} mappings for each field in the proto. + fields: dict of {name: field_type} mappings for each field in the proto. If + this is an OrderedDict the order will be maintained, otherwise the + fields will be sorted by name. full_name: str, the fully-qualified name of the proto type. pool: optional DescriptorPool instance. Returns: @@ -73,12 +76,19 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): # The factory's DescriptorPool doesn't know about this class yet. pass + # Get a list of (name, field_type) tuples from the fields dict. If fields was + # an OrderedDict we keep the order, but otherwise we sort the field to ensure + # consistent ordering. + field_items = fields.items() + if not isinstance(fields, collections.OrderedDict): + field_items = sorted(field_items) + # Use a consistent file name that is unlikely to conflict with any imported # proto files. fields_hash = hashlib.sha1() - for f_name, f_type in sorted(fields.items()): - fields_hash.update(f_name.encode('utf8')) - fields_hash.update(str(f_type).encode('utf8')) + for f_name, f_type in field_items: + fields_hash.update(f_name.encode('utf-8')) + fields_hash.update(str(f_type).encode('utf-8')) proto_file_name = fields_hash.hexdigest() + '.proto' package, name = full_name.rsplit('.', 1) @@ -87,7 +97,7 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): file_proto.package = package desc_proto = file_proto.message_type.add() desc_proto.name = name - for f_number, (f_name, f_type) in enumerate(sorted(fields.items()), 1): + for f_number, (f_name, f_type) in enumerate(field_items, 1): field_proto = desc_proto.field.add() field_proto.name = f_name field_proto.number = f_number diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index e77d0bb9..2160757b 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -43,8 +43,6 @@ #include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> -#define C(str) const_cast<char*>(str) - #if PY_MAJOR_VERSION >= 3 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize #define PyString_Check PyUnicode_Check @@ -257,8 +255,14 @@ namespace descriptor { // Creates or retrieve a Python descriptor of the specified type. // Objects are interned: the same descriptor will return the same object if it // was kept alive. +// 'was_created' is an optional pointer to a bool, and is set to true if a new +// object was allocated. // Always return a new reference. -PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor) { +PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor, + bool* was_created) { + if (was_created) { + *was_created = false; + } if (descriptor == NULL) { PyErr_BadInternalCall(); return NULL; @@ -283,6 +287,9 @@ PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor) { GetDescriptorPool()->interned_descriptors->insert( std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor))); + if (was_created) { + *was_created = true; + } return reinterpret_cast<PyObject*>(py_descriptor); } @@ -298,9 +305,7 @@ static PyGetSetDef Getters[] = { PyTypeObject PyBaseDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - // Keep the fully qualified _message symbol in a line for opensource. - "google.protobuf.internal._message." - "DescriptorBase", // tp_name + FULL_MODULE_NAME ".DescriptorBase", // tp_name sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize (destructor)Dealloc, // tp_dealloc @@ -357,7 +362,7 @@ static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { } static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { - return PyFileDescriptor_New(_GetDescriptor(self)->file()); + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); } static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { @@ -367,17 +372,6 @@ static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { return concrete_class; } -static int SetConcreteClass(PyBaseDescriptor *self, PyObject *value, - void *closure) { - // This attribute is also set from reflection.py. Check that it's actually a - // no-op. - if (value != cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), _GetDescriptor(self))) { - PyErr_SetString(PyExc_AttributeError, "Cannot change _concrete_class"); - } - return 0; -} - static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) { return NewMessageFieldsByName(_GetDescriptor(self)); } @@ -452,7 +446,7 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { const Descriptor* containing_type = _GetDescriptor(self)->containing_type(); if (containing_type) { - return PyMessageDescriptor_New(containing_type); + return PyMessageDescriptor_FromDescriptor(containing_type); } else { Py_RETURN_NONE; } @@ -515,29 +509,34 @@ static PyObject* GetSyntax(PyBaseDescriptor *self, void *closure) { } static PyGetSetDef Getters[] = { - { C("name"), (getter)GetName, NULL, "Last name", NULL}, - { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("_concrete_class"), (getter)GetConcreteClass, (setter)SetConcreteClass, "concrete class", NULL}, - { C("file"), (getter)GetFile, NULL, "File descriptor", NULL}, - - { C("fields"), (getter)GetFieldsSeq, NULL, "Fields sequence", NULL}, - { C("fields_by_name"), (getter)GetFieldsByName, NULL, "Fields by name", NULL}, - { C("fields_by_number"), (getter)GetFieldsByNumber, NULL, "Fields by number", NULL}, - { C("nested_types"), (getter)GetNestedTypesSeq, NULL, "Nested types sequence", NULL}, - { C("nested_types_by_name"), (getter)GetNestedTypesByName, NULL, "Nested types by name", NULL}, - { C("extensions"), (getter)GetExtensions, NULL, "Extensions Sequence", NULL}, - { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL}, - { C("extension_ranges"), (getter)GetExtensionRanges, NULL, "Extension ranges", NULL}, - { C("enum_types"), (getter)GetEnumsSeq, NULL, "Enum sequence", NULL}, - { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enum types by name", NULL}, - { C("enum_values_by_name"), (getter)GetEnumValuesByName, NULL, "Enum values by name", NULL}, - { C("oneofs_by_name"), (getter)GetOneofsByName, NULL, "Oneofs by name", NULL}, - { C("oneofs"), (getter)GetOneofsSeq, NULL, "Oneofs by name", NULL}, - { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL}, - { C("is_extendable"), (getter)IsExtendable, (setter)NULL, NULL, NULL}, - { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, - { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, - { C("syntax"), (getter)GetSyntax, (setter)NULL, "Syntax", NULL}, + { "name", (getter)GetName, NULL, "Last name"}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "_concrete_class", (getter)GetConcreteClass, NULL, "concrete class"}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + + { "fields", (getter)GetFieldsSeq, NULL, "Fields sequence"}, + { "fields_by_name", (getter)GetFieldsByName, NULL, "Fields by name"}, + { "fields_by_number", (getter)GetFieldsByNumber, NULL, "Fields by number"}, + { "nested_types", (getter)GetNestedTypesSeq, NULL, "Nested types sequence"}, + { "nested_types_by_name", (getter)GetNestedTypesByName, NULL, + "Nested types by name"}, + { "extensions", (getter)GetExtensions, NULL, "Extensions Sequence"}, + { "extensions_by_name", (getter)GetExtensionsByName, NULL, + "Extensions by name"}, + { "extension_ranges", (getter)GetExtensionRanges, NULL, "Extension ranges"}, + { "enum_types", (getter)GetEnumsSeq, NULL, "Enum sequence"}, + { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, + "Enum types by name"}, + { "enum_values_by_name", (getter)GetEnumValuesByName, NULL, + "Enum values by name"}, + { "oneofs_by_name", (getter)GetOneofsByName, NULL, "Oneofs by name"}, + { "oneofs", (getter)GetOneofsSeq, NULL, "Oneofs by name"}, + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "is_extendable", (getter)IsExtendable, (setter)NULL}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"}, {NULL} }; @@ -552,9 +551,7 @@ static PyMethodDef Methods[] = { PyTypeObject PyMessageDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - // Keep the fully qualified _message symbol in a line for opensource. - C("google.protobuf.internal._message." - "MessageDescriptor"), // tp_name + FULL_MODULE_NAME ".MessageDescriptor", // tp_name sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize 0, // tp_dealloc @@ -573,7 +570,7 @@ PyTypeObject PyMessageDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Message Descriptor"), // tp_doc + "A Message Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare @@ -586,10 +583,10 @@ PyTypeObject PyMessageDescriptor_Type = { &descriptor::PyBaseDescriptor_Type, // tp_base }; -PyObject* PyMessageDescriptor_New( +PyObject* PyMessageDescriptor_FromDescriptor( const Descriptor* message_descriptor) { return descriptor::NewInternedDescriptor( - &PyMessageDescriptor_Type, message_descriptor); + &PyMessageDescriptor_Type, message_descriptor, NULL); } const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj) { @@ -715,7 +712,7 @@ static PyObject* GetCDescriptor(PyObject *self, void *closure) { static PyObject *GetEnumType(PyBaseDescriptor *self, void *closure) { const EnumDescriptor* enum_type = _GetDescriptor(self)->enum_type(); if (enum_type) { - return PyEnumDescriptor_New(enum_type); + return PyEnumDescriptor_FromDescriptor(enum_type); } else { Py_RETURN_NONE; } @@ -728,7 +725,7 @@ static int SetEnumType(PyBaseDescriptor *self, PyObject *value, void *closure) { static PyObject *GetMessageType(PyBaseDescriptor *self, void *closure) { const Descriptor* message_type = _GetDescriptor(self)->message_type(); if (message_type) { - return PyMessageDescriptor_New(message_type); + return PyMessageDescriptor_FromDescriptor(message_type); } else { Py_RETURN_NONE; } @@ -743,7 +740,7 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { const Descriptor* containing_type = _GetDescriptor(self)->containing_type(); if (containing_type) { - return PyMessageDescriptor_New(containing_type); + return PyMessageDescriptor_FromDescriptor(containing_type); } else { Py_RETURN_NONE; } @@ -758,7 +755,7 @@ static PyObject* GetExtensionScope(PyBaseDescriptor *self, void *closure) { const Descriptor* extension_scope = _GetDescriptor(self)->extension_scope(); if (extension_scope) { - return PyMessageDescriptor_New(extension_scope); + return PyMessageDescriptor_FromDescriptor(extension_scope); } else { Py_RETURN_NONE; } @@ -768,7 +765,7 @@ static PyObject* GetContainingOneof(PyBaseDescriptor *self, void *closure) { const OneofDescriptor* containing_oneof = _GetDescriptor(self)->containing_oneof(); if (containing_oneof) { - return PyOneofDescriptor_New(containing_oneof); + return PyOneofDescriptor_FromDescriptor(containing_oneof); } else { Py_RETURN_NONE; } @@ -803,26 +800,30 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value, static PyGetSetDef Getters[] = { - { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("name"), (getter)GetName, NULL, "Unqualified name", NULL}, - { C("type"), (getter)GetType, NULL, "C++ Type", NULL}, - { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL}, - { C("label"), (getter)GetLabel, NULL, "Label", NULL}, - { C("number"), (getter)GetNumber, NULL, "Number", NULL}, - { C("index"), (getter)GetIndex, NULL, "Index", NULL}, - { C("default_value"), (getter)GetDefaultValue, NULL, "Default Value", NULL}, - { C("has_default_value"), (getter)HasDefaultValue, NULL, NULL, NULL}, - { C("is_extension"), (getter)IsExtension, NULL, "ID", NULL}, - { C("id"), (getter)GetID, NULL, "ID", NULL}, - { C("_cdescriptor"), (getter)GetCDescriptor, NULL, "HAACK REMOVE ME", NULL}, - - { C("message_type"), (getter)GetMessageType, (setter)SetMessageType, "Message type", NULL}, - { C("enum_type"), (getter)GetEnumType, (setter)SetEnumType, "Enum type", NULL}, - { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL}, - { C("extension_scope"), (getter)GetExtensionScope, (setter)NULL, "Extension scope", NULL}, - { C("containing_oneof"), (getter)GetContainingOneof, (setter)SetContainingOneof, "Containing oneof", NULL}, - { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, - { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "name", (getter)GetName, NULL, "Unqualified name"}, + { "type", (getter)GetType, NULL, "C++ Type"}, + { "cpp_type", (getter)GetCppType, NULL, "C++ Type"}, + { "label", (getter)GetLabel, NULL, "Label"}, + { "number", (getter)GetNumber, NULL, "Number"}, + { "index", (getter)GetIndex, NULL, "Index"}, + { "default_value", (getter)GetDefaultValue, NULL, "Default Value"}, + { "has_default_value", (getter)HasDefaultValue}, + { "is_extension", (getter)IsExtension, NULL, "ID"}, + { "id", (getter)GetID, NULL, "ID"}, + { "_cdescriptor", (getter)GetCDescriptor, NULL, "HAACK REMOVE ME"}, + + { "message_type", (getter)GetMessageType, (setter)SetMessageType, + "Message type"}, + { "enum_type", (getter)GetEnumType, (setter)SetEnumType, "Enum type"}, + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "extension_scope", (getter)GetExtensionScope, (setter)NULL, + "Extension scope"}, + { "containing_oneof", (getter)GetContainingOneof, (setter)SetContainingOneof, + "Containing oneof"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, {NULL} }; @@ -835,8 +836,7 @@ static PyMethodDef Methods[] = { PyTypeObject PyFieldDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_message.FieldDescriptor"), // tp_name + FULL_MODULE_NAME ".FieldDescriptor", // tp_name sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize 0, // tp_dealloc @@ -855,7 +855,7 @@ PyTypeObject PyFieldDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Field Descriptor"), // tp_doc + "A Field Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare @@ -868,10 +868,10 @@ PyTypeObject PyFieldDescriptor_Type = { &descriptor::PyBaseDescriptor_Type, // tp_base }; -PyObject* PyFieldDescriptor_New( +PyObject* PyFieldDescriptor_FromDescriptor( const FieldDescriptor* field_descriptor) { return descriptor::NewInternedDescriptor( - &PyFieldDescriptor_Type, field_descriptor); + &PyFieldDescriptor_Type, field_descriptor, NULL); } const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj) { @@ -900,7 +900,7 @@ static PyObject* GetName(PyBaseDescriptor *self, void *closure) { } static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { - return PyFileDescriptor_New(_GetDescriptor(self)->file()); + return PyFileDescriptor_FromDescriptor(_GetDescriptor(self)->file()); } static PyObject* GetEnumvaluesByName(PyBaseDescriptor* self, void *closure) { @@ -919,7 +919,7 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { const Descriptor* containing_type = _GetDescriptor(self)->containing_type(); if (containing_type) { - return PyMessageDescriptor_New(containing_type); + return PyMessageDescriptor_FromDescriptor(containing_type); } else { Py_RETURN_NONE; } @@ -964,16 +964,19 @@ static PyMethodDef Methods[] = { }; static PyGetSetDef Getters[] = { - { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("name"), (getter)GetName, NULL, "last name", NULL}, - { C("file"), (getter)GetFile, NULL, "File descriptor", NULL}, - { C("values"), (getter)GetEnumvaluesSeq, NULL, "values", NULL}, - { C("values_by_name"), (getter)GetEnumvaluesByName, NULL, "Enumvalues by name", NULL}, - { C("values_by_number"), (getter)GetEnumvaluesByNumber, NULL, "Enumvalues by number", NULL}, - - { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL}, - { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, - { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "name", (getter)GetName, NULL, "last name"}, + { "file", (getter)GetFile, NULL, "File descriptor"}, + { "values", (getter)GetEnumvaluesSeq, NULL, "values"}, + { "values_by_name", (getter)GetEnumvaluesByName, NULL, + "Enum values by name"}, + { "values_by_number", (getter)GetEnumvaluesByNumber, NULL, + "Enum values by number"}, + + { "containing_type", (getter)GetContainingType, (setter)SetContainingType, + "Containing type"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, {NULL} }; @@ -981,9 +984,7 @@ static PyGetSetDef Getters[] = { PyTypeObject PyEnumDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - // Keep the fully qualified _message symbol in a line for opensource. - C("google.protobuf.internal._message." - "EnumDescriptor"), // tp_name + FULL_MODULE_NAME ".EnumDescriptor", // tp_name sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize 0, // tp_dealloc @@ -1002,7 +1003,7 @@ PyTypeObject PyEnumDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Enum Descriptor"), // tp_doc + "A Enum Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare @@ -1015,10 +1016,10 @@ PyTypeObject PyEnumDescriptor_Type = { &descriptor::PyBaseDescriptor_Type, // tp_base }; -PyObject* PyEnumDescriptor_New( +PyObject* PyEnumDescriptor_FromDescriptor( const EnumDescriptor* enum_descriptor) { return descriptor::NewInternedDescriptor( - &PyEnumDescriptor_Type, enum_descriptor); + &PyEnumDescriptor_Type, enum_descriptor, NULL); } namespace enumvalue_descriptor { @@ -1042,7 +1043,7 @@ static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { } static PyObject* GetType(PyBaseDescriptor *self, void *closure) { - return PyEnumDescriptor_New(_GetDescriptor(self)->type()); + return PyEnumDescriptor_FromDescriptor(_GetDescriptor(self)->type()); } static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { @@ -1069,13 +1070,13 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value, static PyGetSetDef Getters[] = { - { C("name"), (getter)GetName, NULL, "name", NULL}, - { C("number"), (getter)GetNumber, NULL, "number", NULL}, - { C("index"), (getter)GetIndex, NULL, "index", NULL}, - { C("type"), (getter)GetType, NULL, "index", NULL}, + { "name", (getter)GetName, NULL, "name"}, + { "number", (getter)GetNumber, NULL, "number"}, + { "index", (getter)GetIndex, NULL, "index"}, + { "type", (getter)GetType, NULL, "index"}, - { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, - { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, {NULL} }; @@ -1088,8 +1089,7 @@ static PyMethodDef Methods[] = { PyTypeObject PyEnumValueDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_message.EnumValueDescriptor"), // tp_name + FULL_MODULE_NAME ".EnumValueDescriptor", // tp_name sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize 0, // tp_dealloc @@ -1108,7 +1108,7 @@ PyTypeObject PyEnumValueDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A EnumValue Descriptor"), // tp_doc + "A EnumValue Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare @@ -1121,10 +1121,10 @@ PyTypeObject PyEnumValueDescriptor_Type = { &descriptor::PyBaseDescriptor_Type, // tp_base }; -PyObject* PyEnumValueDescriptor_New( +PyObject* PyEnumValueDescriptor_FromDescriptor( const EnumValueDescriptor* enumvalue_descriptor) { return descriptor::NewInternedDescriptor( - &PyEnumValueDescriptor_Type, enumvalue_descriptor); + &PyEnumValueDescriptor_Type, enumvalue_descriptor, NULL); } namespace file_descriptor { @@ -1218,18 +1218,20 @@ static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) { } static PyGetSetDef Getters[] = { - { C("name"), (getter)GetName, NULL, "name", NULL}, - { C("package"), (getter)GetPackage, NULL, "package", NULL}, - { C("serialized_pb"), (getter)GetSerializedPb, NULL, NULL, NULL}, - { C("message_types_by_name"), (getter)GetMessageTypesByName, NULL, "Messages by name", NULL}, - { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enums by name", NULL}, - { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL}, - { C("dependencies"), (getter)GetDependencies, NULL, "Dependencies", NULL}, - { C("public_dependencies"), (getter)GetPublicDependencies, NULL, "Dependencies", NULL}, - - { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, - { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, - { C("syntax"), (getter)GetSyntax, (setter)NULL, "Syntax", NULL}, + { "name", (getter)GetName, NULL, "name"}, + { "package", (getter)GetPackage, NULL, "package"}, + { "serialized_pb", (getter)GetSerializedPb}, + { "message_types_by_name", (getter)GetMessageTypesByName, NULL, + "Messages by name"}, + { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"}, + { "extensions_by_name", (getter)GetExtensionsByName, NULL, + "Extensions by name"}, + { "dependencies", (getter)GetDependencies, NULL, "Dependencies"}, + { "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"}, + + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, + { "syntax", (getter)GetSyntax, (setter)NULL, "Syntax"}, {NULL} }; @@ -1243,11 +1245,10 @@ static PyMethodDef Methods[] = { PyTypeObject PyFileDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_message.FileDescriptor"), // tp_name + FULL_MODULE_NAME ".FileDescriptor", // tp_name sizeof(PyFileDescriptor), // tp_basicsize 0, // tp_itemsize - (destructor)file_descriptor::Dealloc, // tp_dealloc + (destructor)file_descriptor::Dealloc, // tp_dealloc 0, // tp_print 0, // tp_getattr 0, // tp_setattr @@ -1263,7 +1264,7 @@ PyTypeObject PyFileDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A File Descriptor"), // tp_doc + "A File Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare @@ -1284,23 +1285,28 @@ PyTypeObject PyFileDescriptor_Type = { PyObject_Del, // tp_free }; -PyObject* PyFileDescriptor_New(const FileDescriptor* file_descriptor) { - return descriptor::NewInternedDescriptor( - &PyFileDescriptor_Type, file_descriptor); +PyObject* PyFileDescriptor_FromDescriptor( + const FileDescriptor* file_descriptor) { + return PyFileDescriptor_FromDescriptorWithSerializedPb(file_descriptor, + NULL); } -PyObject* PyFileDescriptor_NewWithPb( +PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( const FileDescriptor* file_descriptor, PyObject *serialized_pb) { - PyObject* py_descriptor = PyFileDescriptor_New(file_descriptor); + bool was_created; + PyObject* py_descriptor = descriptor::NewInternedDescriptor( + &PyFileDescriptor_Type, file_descriptor, &was_created); if (py_descriptor == NULL) { return NULL; } - if (serialized_pb != NULL) { + if (was_created) { PyFileDescriptor* cfile_descriptor = reinterpret_cast<PyFileDescriptor*>(py_descriptor); Py_XINCREF(serialized_pb); cfile_descriptor->serialized_pb = serialized_pb; } + // TODO(amauryfa): In the case of a cached object, check that serialized_pb + // is the same as before. return py_descriptor; } @@ -1333,19 +1339,19 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { const Descriptor* containing_type = _GetDescriptor(self)->containing_type(); if (containing_type) { - return PyMessageDescriptor_New(containing_type); + return PyMessageDescriptor_FromDescriptor(containing_type); } else { Py_RETURN_NONE; } } static PyGetSetDef Getters[] = { - { C("name"), (getter)GetName, NULL, "Name", NULL}, - { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("index"), (getter)GetIndex, NULL, "Index", NULL}, + { "name", (getter)GetName, NULL, "Name"}, + { "full_name", (getter)GetFullName, NULL, "Full name"}, + { "index", (getter)GetIndex, NULL, "Index"}, - { C("containing_type"), (getter)GetContainingType, NULL, "Containing type", NULL}, - { C("fields"), (getter)GetFields, NULL, "Fields", NULL}, + { "containing_type", (getter)GetContainingType, NULL, "Containing type"}, + { "fields", (getter)GetFields, NULL, "Fields"}, {NULL} }; @@ -1353,8 +1359,7 @@ static PyGetSetDef Getters[] = { PyTypeObject PyOneofDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_message.OneofDescriptor"), // tp_name + FULL_MODULE_NAME ".OneofDescriptor", // tp_name sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize 0, // tp_dealloc @@ -1373,7 +1378,7 @@ PyTypeObject PyOneofDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Oneof Descriptor"), // tp_doc + "A Oneof Descriptor", // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare @@ -1386,10 +1391,10 @@ PyTypeObject PyOneofDescriptor_Type = { &descriptor::PyBaseDescriptor_Type, // tp_base }; -PyObject* PyOneofDescriptor_New( +PyObject* PyOneofDescriptor_FromDescriptor( const OneofDescriptor* oneof_descriptor) { return descriptor::NewInternedDescriptor( - &PyOneofDescriptor_Type, oneof_descriptor); + &PyOneofDescriptor_Type, oneof_descriptor, NULL); } // Add a enum values to a type dictionary. diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h index ba6e7298..b2550406 100644 --- a/python/google/protobuf/pyext/descriptor.h +++ b/python/google/protobuf/pyext/descriptor.h @@ -48,21 +48,24 @@ extern PyTypeObject PyEnumValueDescriptor_Type; extern PyTypeObject PyFileDescriptor_Type; extern PyTypeObject PyOneofDescriptor_Type; -// Return a new reference to a Descriptor object. +// Wraps a Descriptor in a Python object. // The C++ pointer is usually borrowed from the global DescriptorPool. // In any case, it must stay alive as long as the Python object. -PyObject* PyMessageDescriptor_New(const Descriptor* descriptor); -PyObject* PyFieldDescriptor_New(const FieldDescriptor* descriptor); -PyObject* PyEnumDescriptor_New(const EnumDescriptor* descriptor); -PyObject* PyEnumValueDescriptor_New(const EnumValueDescriptor* descriptor); -PyObject* PyOneofDescriptor_New(const OneofDescriptor* descriptor); -PyObject* PyFileDescriptor_New(const FileDescriptor* file_descriptor); +// Returns a new reference. +PyObject* PyMessageDescriptor_FromDescriptor(const Descriptor* descriptor); +PyObject* PyFieldDescriptor_FromDescriptor(const FieldDescriptor* descriptor); +PyObject* PyEnumDescriptor_FromDescriptor(const EnumDescriptor* descriptor); +PyObject* PyEnumValueDescriptor_FromDescriptor( + const EnumValueDescriptor* descriptor); +PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor); +PyObject* PyFileDescriptor_FromDescriptor( + const FileDescriptor* file_descriptor); // Alternate constructor of PyFileDescriptor, used when we already have a // serialized FileDescriptorProto that can be cached. // Returns a new reference. -PyObject* PyFileDescriptor_NewWithPb(const FileDescriptor* file_descriptor, - PyObject* serialized_pb); +PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( + const FileDescriptor* file_descriptor, PyObject* serialized_pb); // Return the C++ descriptor pointer. // This function checks the parameter type; on error, return NULL with a Python diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc index 06edebf8..92e11e31 100644 --- a/python/google/protobuf/pyext/descriptor_containers.cc +++ b/python/google/protobuf/pyext/descriptor_containers.cc @@ -898,7 +898,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyFieldDescriptor_New(item); + return PyFieldDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -956,7 +956,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyMessageDescriptor_New(item); + return PyMessageDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1006,7 +1006,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyEnumDescriptor_New(item); + return PyEnumDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1082,7 +1082,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyEnumValueDescriptor_New(item); + return PyEnumValueDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1124,7 +1124,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyFieldDescriptor_New(item); + return PyFieldDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1174,7 +1174,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyOneofDescriptor_New(item); + return PyOneofDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1238,7 +1238,7 @@ static ItemDescriptor GetByNumber(PyContainer* self, int number) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyEnumValueDescriptor_New(item); + return PyEnumValueDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1302,7 +1302,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyFieldDescriptor_New(item); + return PyFieldDescriptor_FromDescriptor(item); } static int GetItemIndex(ItemDescriptor item) { @@ -1354,7 +1354,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyMessageDescriptor_New(item); + return PyMessageDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1400,7 +1400,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyEnumDescriptor_New(item); + return PyEnumDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1446,7 +1446,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyFieldDescriptor_New(item); + return PyFieldDescriptor_FromDescriptor(item); } static const string& GetItemName(ItemDescriptor item) { @@ -1488,7 +1488,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyFileDescriptor_New(item); + return PyFileDescriptor_FromDescriptor(item); } static DescriptorContainerDef ContainerDef = { @@ -1522,7 +1522,7 @@ static ItemDescriptor GetByIndex(PyContainer* self, int index) { } static PyObject* NewObjectFromItem(ItemDescriptor item) { - return PyFileDescriptor_New(item); + return PyFileDescriptor_FromDescriptor(item); } static DescriptorContainerDef ContainerDef = { diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h index d81537de..8fbdaff9 100644 --- a/python/google/protobuf/pyext/descriptor_containers.h +++ b/python/google/protobuf/pyext/descriptor_containers.h @@ -28,6 +28,9 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ + // Mappings and Sequences of descriptors. // They implement containers like fields_by_name, EnumDescriptor.values... // See descriptor_containers.cc for more description. @@ -92,4 +95,6 @@ PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor); } // namespace python } // namespace protobuf + } // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index bc3077bc..ecd90847 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -35,10 +35,9 @@ #include <google/protobuf/descriptor.pb.h> #include <google/protobuf/pyext/descriptor_pool.h> #include <google/protobuf/pyext/descriptor.h> +#include <google/protobuf/pyext/message.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> -#define C(str) const_cast<char*>(str) - #if PY_MAJOR_VERSION >= 3 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize #if PY_VERSION_HEX < 0x03030000 @@ -108,11 +107,11 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) { self->pool->FindMessageTypeByName(string(name, name_size)); if (message_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find message %.200s", name); + PyErr_Format(PyExc_KeyError, "Couldn't find message %.200s", name); return NULL; } - return PyMessageDescriptor_New(message_descriptor); + return PyMessageDescriptor_FromDescriptor(message_descriptor); } // Add a message class to our database. @@ -158,6 +157,24 @@ PyObject *GetMessageClass(PyDescriptorPool* self, } } +PyObject* FindFileByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FileDescriptor* file_descriptor = + self->pool->FindFileByName(string(name, name_size)); + if (file_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find file %.200s", + name); + return NULL; + } + + return PyFileDescriptor_FromDescriptor(file_descriptor); +} + PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) { Py_ssize_t name_size; char* name; @@ -168,12 +185,12 @@ PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) { const FieldDescriptor* field_descriptor = self->pool->FindFieldByName(string(name, name_size)); if (field_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", + PyErr_Format(PyExc_KeyError, "Couldn't find field %.200s", name); return NULL; } - return PyFieldDescriptor_New(field_descriptor); + return PyFieldDescriptor_FromDescriptor(field_descriptor); } PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { @@ -186,11 +203,11 @@ PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { const FieldDescriptor* field_descriptor = self->pool->FindExtensionByName(string(name, name_size)); if (field_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", name); + PyErr_Format(PyExc_KeyError, "Couldn't find extension field %.200s", name); return NULL; } - return PyFieldDescriptor_New(field_descriptor); + return PyFieldDescriptor_FromDescriptor(field_descriptor); } PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) { @@ -203,11 +220,11 @@ PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) { const EnumDescriptor* enum_descriptor = self->pool->FindEnumTypeByName(string(name, name_size)); if (enum_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find enum %.200s", name); + PyErr_Format(PyExc_KeyError, "Couldn't find enum %.200s", name); return NULL; } - return PyEnumDescriptor_New(enum_descriptor); + return PyEnumDescriptor_FromDescriptor(enum_descriptor); } PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { @@ -220,70 +237,13 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { const OneofDescriptor* oneof_descriptor = self->pool->FindOneofByName(string(name, name_size)); if (oneof_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find oneof %.200s", name); + PyErr_Format(PyExc_KeyError, "Couldn't find oneof %.200s", name); return NULL; } - return PyOneofDescriptor_New(oneof_descriptor); + return PyOneofDescriptor_FromDescriptor(oneof_descriptor); } -static PyMethodDef Methods[] = { - { C("FindFieldByName"), - (PyCFunction)FindFieldByName, - METH_O, - C("Searches for a field descriptor by full name.") }, - { C("FindExtensionByName"), - (PyCFunction)FindExtensionByName, - METH_O, - C("Searches for extension descriptor by full name.") }, - {NULL} -}; - -} // namespace cdescriptor_pool - -PyTypeObject PyDescriptorPool_Type = { - PyVarObject_HEAD_INIT(&PyType_Type, 0) - C("google.protobuf.internal." - "_message.DescriptorPool"), // tp_name - sizeof(PyDescriptorPool), // tp_basicsize - 0, // tp_itemsize - (destructor)cdescriptor_pool::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - 0, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - C("A Descriptor Pool"), // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - 0, // tp_iter - 0, // tp_iternext - cdescriptor_pool::Methods, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - 0, // tp_alloc - 0, // tp_new - PyObject_Del, // tp_free -}; - // The code below loads new Descriptors from a serialized FileDescriptorProto. @@ -301,6 +261,7 @@ class BuildFileErrorCollector : public DescriptorPool::ErrorCollector { if (!had_errors) { error_message += ("Invalid proto descriptor for file \"" + filename + "\":\n"); + had_errors = true; } // As this only happens on failure and will result in the program not // running at all, no effort is made to optimize this string manipulation. @@ -311,7 +272,7 @@ class BuildFileErrorCollector : public DescriptorPool::ErrorCollector { bool had_errors; }; -PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) { +PyObject* AddSerializedFile(PyDescriptorPool* self, PyObject* serialized_pb) { char* message_type; Py_ssize_t message_len; @@ -330,13 +291,14 @@ PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) { const FileDescriptor* generated_file = DescriptorPool::generated_pool()->FindFileByName(file_proto.name()); if (generated_file != NULL) { - return PyFileDescriptor_NewWithPb(generated_file, serialized_pb); + return PyFileDescriptor_FromDescriptorWithSerializedPb( + generated_file, serialized_pb); } BuildFileErrorCollector error_collector; const FileDescriptor* descriptor = - GetDescriptorPool()->pool->BuildFileCollectingErrors(file_proto, - &error_collector); + self->pool->BuildFileCollectingErrors(file_proto, + &error_collector); if (descriptor == NULL) { PyErr_Format(PyExc_TypeError, "Couldn't build proto file into descriptor pool!\n%s", @@ -344,9 +306,84 @@ PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) { return NULL; } - return PyFileDescriptor_NewWithPb(descriptor, serialized_pb); + return PyFileDescriptor_FromDescriptorWithSerializedPb( + descriptor, serialized_pb); +} + +PyObject* Add(PyDescriptorPool* self, PyObject* file_descriptor_proto) { + ScopedPyObjectPtr serialized_pb( + PyObject_CallMethod(file_descriptor_proto, "SerializeToString", NULL)); + if (serialized_pb == NULL) { + return NULL; + } + return AddSerializedFile(self, serialized_pb); } +static PyMethodDef Methods[] = { + { "Add", (PyCFunction)Add, METH_O, + "Adds the FileDescriptorProto and its types to this pool." }, + { "AddSerializedFile", (PyCFunction)AddSerializedFile, METH_O, + "Adds a serialized FileDescriptorProto to this pool." }, + + { "FindFileByName", (PyCFunction)FindFileByName, METH_O, + "Searches for a file descriptor by its .proto name." }, + { "FindMessageTypeByName", (PyCFunction)FindMessageByName, METH_O, + "Searches for a message descriptor by full name." }, + { "FindFieldByName", (PyCFunction)FindFieldByName, METH_O, + "Searches for a field descriptor by full name." }, + { "FindExtensionByName", (PyCFunction)FindExtensionByName, METH_O, + "Searches for extension descriptor by full name." }, + { "FindEnumTypeByName", (PyCFunction)FindEnumTypeByName, METH_O, + "Searches for enum type descriptor by full name." }, + { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O, + "Searches for oneof descriptor by full name." }, + {NULL} +}; + +} // namespace cdescriptor_pool + +PyTypeObject PyDescriptorPool_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".DescriptorPool", // tp_name + sizeof(PyDescriptorPool), // tp_basicsize + 0, // tp_itemsize + (destructor)cdescriptor_pool::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Descriptor Pool", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cdescriptor_pool::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + PyObject_Del, // tp_free +}; + static PyDescriptorPool* global_cdescriptor_pool = NULL; bool InitDescriptorPool() { diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index 4e494b89..efb1abeb 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -95,6 +95,8 @@ const Descriptor* FindMessageTypeByName(PyDescriptorPool* self, const Descriptor* RegisterMessageClass( PyDescriptorPool* self, PyObject* message_class, PyObject* descriptor); +// The function below are also exposed as methods of the DescriptorPool type. + // Retrieves the Python class registered with the given message descriptor. // // Returns a *borrowed* reference if found, otherwise returns NULL with an @@ -134,12 +136,8 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg); } // namespace cdescriptor_pool -// Implement the Python "_BuildFile" method, it takes a serialized -// FileDescriptorProto, and adds it to the C++ DescriptorPool. -// It returns a new FileDescriptor object, or NULL when an exception is raised. -PyObject* Python_BuildFile(PyObject* ignored, PyObject* args); - // Retrieve the global descriptor pool owned by the _message module. +// Returns a *borrowed* reference. PyDescriptorPool* GetDescriptorPool(); // Initialize objects used by this module. diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 8e38fc42..b8d18f8d 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -51,16 +51,6 @@ namespace python { namespace extension_dict { -// TODO(tibell): Always use self->message for clarity, just like in -// RepeatedCompositeContainer. -static Message* GetMessage(ExtensionDict* self) { - if (self->parent != NULL) { - return self->parent->message; - } else { - return self->message; - } -} - PyObject* len(ExtensionDict* self) { #if PY_MAJOR_VERSION >= 3 return PyLong_FromLong(PyDict_Size(self->values)); @@ -89,7 +79,7 @@ int ReleaseExtension(ExtensionDict* self, } } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (cmessage::ReleaseSubMessage( - GetMessage(self), descriptor, + self->parent, descriptor, reinterpret_cast<CMessage*>(extension)) < 0) { return -1; } @@ -109,7 +99,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { - return cmessage::InternalGetScalar(self->parent, descriptor); + return cmessage::InternalGetScalar(self->parent->message, descriptor); } PyObject* value = PyDict_GetItem(self->values, key); @@ -266,8 +256,7 @@ static PyMethodDef Methods[] = { PyTypeObject ExtensionDict_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.ExtensionDict", // tp_name + FULL_MODULE_NAME ".ExtensionDict", // tp_name sizeof(ExtensionDict), // tp_basicsize 0, // tp_itemsize (destructor)extension_dict::dealloc, // tp_dealloc diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index a2b357b2..a4843e8d 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -59,6 +59,8 @@ #include <google/protobuf/pyext/extension_dict.h> #include <google/protobuf/pyext/repeated_composite_container.h> #include <google/protobuf/pyext/repeated_scalar_container.h> +#include <google/protobuf/pyext/message_map_container.h> +#include <google/protobuf/pyext/scalar_map_container.h> #include <google/protobuf/pyext/scoped_pyobject_ptr.h> #include <google/protobuf/stubs/strutil.h> @@ -93,9 +95,9 @@ static const FieldDescriptor* GetFieldDescriptor( static const Descriptor* GetMessageDescriptor(PyTypeObject* cls); static string GetMessageName(CMessage* self); int InternalReleaseFieldByDescriptor( + CMessage* self, const FieldDescriptor* field_descriptor, - PyObject* composite_field, - Message* parent_message); + PyObject* composite_field); } // namespace cmessage // --------------------------------------------------------------------- @@ -127,10 +129,29 @@ static int VisitCompositeField(const FieldDescriptor* descriptor, Visitor visitor) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - RepeatedCompositeContainer* container = - reinterpret_cast<RepeatedCompositeContainer*>(child); - if (visitor.VisitRepeatedCompositeContainer(container) == -1) - return -1; + if (descriptor->is_map()) { + const Descriptor* entry_type = descriptor->message_type(); + const FieldDescriptor* value_type = + entry_type->FindFieldByName("value"); + if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + MessageMapContainer* container = + reinterpret_cast<MessageMapContainer*>(child); + if (visitor.VisitMessageMapContainer(container) == -1) { + return -1; + } + } else { + ScalarMapContainer* container = + reinterpret_cast<ScalarMapContainer*>(child); + if (visitor.VisitScalarMapContainer(container) == -1) { + return -1; + } + } + } else { + RepeatedCompositeContainer* container = + reinterpret_cast<RepeatedCompositeContainer*>(child); + if (visitor.VisitRepeatedCompositeContainer(container) == -1) + return -1; + } } else { RepeatedScalarContainer* container = reinterpret_cast<RepeatedScalarContainer*>(child); @@ -444,7 +465,7 @@ static int MaybeReleaseOverlappingOneofField( } if (InternalReleaseFieldByDescriptor( - existing_field, child_message, message) < 0) { + cmessage, existing_field, child_message) < 0) { return -1; } return PyDict_DelItemString(cmessage->composite_fields, field_name); @@ -483,6 +504,16 @@ struct FixupMessageReference : public ChildVisitor { return 0; } + int VisitScalarMapContainer(ScalarMapContainer* container) { + container->message = message_; + return 0; + } + + int VisitMessageMapContainer(MessageMapContainer* container) { + container->message = message_; + return 0; + } + private: Message* message_; }; @@ -500,6 +531,9 @@ int AssureWritable(CMessage* self) { self->message->GetDescriptor()); self->message = prototype->New(); self->owner.reset(self->message); + // Cascade the new owner to eventual children: even if this message is + // empty, some submessages or repeated containers might exist already. + SetOwner(self, self->owner); } else { // Otherwise, we need a mutable child message. if (AssureWritable(self->parent) == -1) @@ -520,8 +554,9 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // three places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, and ExtensionDict. + // five places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, ScalarMapContainer, MessageMapContainer, + // and ExtensionDict. if (self->extensions != NULL) self->extensions->message = self->message; if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) @@ -583,15 +618,43 @@ const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { return PyFieldDescriptor_AsDescriptor(extension); } +// If value is a string, convert it into an enum value based on the labels in +// descriptor, otherwise simply return value. Always returns a new reference. +static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor, + PyObject* value) { + if (PyString_Check(value) || PyUnicode_Check(value)) { + const EnumDescriptor* enum_descriptor = descriptor.enum_type(); + if (enum_descriptor == NULL) { + PyErr_SetString(PyExc_TypeError, "not an enum field"); + return NULL; + } + char* enum_label; + Py_ssize_t size; + if (PyString_AsStringAndSize(value, &enum_label, &size) < 0) { + return NULL; + } + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->FindValueByName(string(enum_label, size)); + if (enum_value_descriptor == NULL) { + PyErr_SetString(PyExc_ValueError, "unknown enum label"); + return NULL; + } + return PyInt_FromLong(enum_value_descriptor->number()); + } + Py_INCREF(value); + return value; +} + // If cmessage_list is not NULL, this function releases values into the // container CMessages instead of just removing. Repeated composite container // needs to do this to make sure CMessages stay alive if they're still // referenced after deletion. Repeated scalar container doesn't need to worry. int InternalDeleteRepeatedField( - Message* message, + CMessage* self, const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list) { + Message* message = self->message; Py_ssize_t length, from, to, step, slice_length; const Reflection* reflection = message->GetReflection(); int min, max; @@ -665,7 +728,7 @@ int InternalDeleteRepeatedField( CMessage* last_cmessage = reinterpret_cast<CMessage*>( PyList_GET_ITEM(cmessage_list, PyList_GET_SIZE(cmessage_list) - 1)); repeated_composite_container::ReleaseLastTo( - field_descriptor, message, last_cmessage); + self, field_descriptor, last_cmessage); if (PySequence_DelItem(cmessage_list, -1) < 0) { return -1; } @@ -696,16 +759,90 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { PyString_AsString(name)); return -1; } - if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->is_map()) { + ScopedPyObjectPtr map(GetAttr(self, name)); + const FieldDescriptor* value_descriptor = + descriptor->message_type()->FindFieldByName("value"); + if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + Py_ssize_t map_pos = 0; + PyObject* map_key; + PyObject* map_value; + while (PyDict_Next(value, &map_pos, &map_key, &map_value)) { + ScopedPyObjectPtr function_return; + function_return.reset(PyObject_GetItem(map.get(), map_key)); + if (function_return.get() == NULL) { + return -1; + } + ScopedPyObjectPtr ok(PyObject_CallMethod( + function_return.get(), "MergeFrom", "O", map_value)); + if (ok.get() == NULL) { + return -1; + } + } + } else { + ScopedPyObjectPtr function_return; + function_return.reset( + PyObject_CallMethod(map.get(), "update", "O", value)); + if (function_return.get() == NULL) { + return -1; + } + } + } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { ScopedPyObjectPtr container(GetAttr(self, name)); if (container == NULL) { return -1; } if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - if (repeated_composite_container::Extend( - reinterpret_cast<RepeatedCompositeContainer*>(container.get()), - value) - == NULL) { + RepeatedCompositeContainer* rc_container = + reinterpret_cast<RepeatedCompositeContainer*>(container.get()); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter))) != NULL) { + PyObject* kwargs = (PyDict_Check(next) ? next.get() : NULL); + ScopedPyObjectPtr new_msg( + repeated_composite_container::Add(rc_container, NULL, kwargs)); + if (new_msg == NULL) { + return -1; + } + if (kwargs == NULL) { + // next was not a dict, it's a message we need to merge + ScopedPyObjectPtr merged( + MergeFrom(reinterpret_cast<CMessage*>(new_msg.get()), next)); + if (merged == NULL) { + return -1; + } + } + } + if (PyErr_Occurred()) { + // Check to see how PyIter_Next() exited. + return -1; + } + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + RepeatedScalarContainer* rs_container = + reinterpret_cast<RepeatedScalarContainer*>(container.get()); + ScopedPyObjectPtr iter(PyObject_GetIter(value)); + if (iter == NULL) { + PyErr_SetString(PyExc_TypeError, "Value must be iterable"); + return -1; + } + ScopedPyObjectPtr next; + while ((next.reset(PyIter_Next(iter))) != NULL) { + ScopedPyObjectPtr enum_value(GetIntegerEnumValue(*descriptor, next)); + if (enum_value == NULL) { + return -1; + } + ScopedPyObjectPtr new_msg( + repeated_scalar_container::Append(rs_container, enum_value)); + if (new_msg == NULL) { + return -1; + } + } + if (PyErr_Occurred()) { + // Check to see how PyIter_Next() exited. return -1; } } else { @@ -721,12 +858,26 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { if (message == NULL) { return -1; } - if (MergeFrom(reinterpret_cast<CMessage*>(message.get()), - value) == NULL) { - return -1; + CMessage* cmessage = reinterpret_cast<CMessage*>(message.get()); + if (PyDict_Check(value)) { + if (InitAttributes(cmessage, value) < 0) { + return -1; + } + } else { + ScopedPyObjectPtr merged(MergeFrom(cmessage, value)); + if (merged == NULL) { + return -1; + } } } else { - if (SetAttr(self, name, value) < 0) { + ScopedPyObjectPtr new_val; + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + new_val.reset(GetIntegerEnumValue(*descriptor, value)); + if (new_val == NULL) { + return -1; + } + } + if (SetAttr(self, name, (new_val == NULL) ? value : new_val) < 0) { return -1; } } @@ -789,7 +940,6 @@ static PyObject* New(PyTypeObject* type, } self->message = default_message->New(); self->owner.reset(self->message); - return reinterpret_cast<PyObject*>(self); } @@ -830,6 +980,16 @@ struct ClearWeakReferences : public ChildVisitor { return 0; } + int VisitScalarMapContainer(ScalarMapContainer* container) { + container->parent = NULL; + return 0; + } + + int VisitMessageMapContainer(MessageMapContainer* container) { + container->parent = NULL; + return 0; + } + int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { cmessage->parent = NULL; @@ -1064,6 +1224,16 @@ struct SetOwnerVisitor : public ChildVisitor { return 0; } + int VisitScalarMapContainer(ScalarMapContainer* container) { + scalar_map_container::SetOwner(container, new_owner_); + return 0; + } + + int VisitMessageMapContainer(MessageMapContainer* container) { + message_map_container::SetOwner(container, new_owner_); + return 0; + } + int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { return SetOwner(cmessage, new_owner_); @@ -1084,11 +1254,11 @@ int SetOwner(CMessage* self, const shared_ptr<Message>& new_owner) { // Releases the message specified by 'field' and returns the // pointer. If the field does not exist a new message is created using // 'descriptor'. The caller takes ownership of the returned pointer. -Message* ReleaseMessage(Message* message, +Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { - Message* released_message = message->GetReflection()->ReleaseMessage( - message, field_descriptor, message_factory); + Message* released_message = self->message->GetReflection()->ReleaseMessage( + self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from // child_cmessage->message, if the field does not exist. In this case, // the latter points to the default instance via a const_cast<>, so we @@ -1102,12 +1272,12 @@ Message* ReleaseMessage(Message* message, return released_message; } -int ReleaseSubMessage(Message* message, +int ReleaseSubMessage(CMessage* self, const FieldDescriptor* field_descriptor, CMessage* child_cmessage) { // Release the Message shared_ptr<Message> released_message(ReleaseMessage( - message, child_cmessage->message->GetDescriptor(), field_descriptor)); + self, child_cmessage->message->GetDescriptor(), field_descriptor)); child_cmessage->message = released_message.get(); child_cmessage->owner.swap(released_message); child_cmessage->parent = NULL; @@ -1119,8 +1289,8 @@ int ReleaseSubMessage(Message* message, struct ReleaseChild : public ChildVisitor { // message must outlive this object. - explicit ReleaseChild(Message* parent_message) : - parent_message_(parent_message) {} + explicit ReleaseChild(CMessage* parent) : + parent_(parent) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { return repeated_composite_container::Release( @@ -1132,23 +1302,33 @@ struct ReleaseChild : public ChildVisitor { reinterpret_cast<RepeatedScalarContainer*>(container)); } + int VisitScalarMapContainer(ScalarMapContainer* container) { + return scalar_map_container::Release( + reinterpret_cast<ScalarMapContainer*>(container)); + } + + int VisitMessageMapContainer(MessageMapContainer* container) { + return message_map_container::Release( + reinterpret_cast<MessageMapContainer*>(container)); + } + int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { - return ReleaseSubMessage(parent_message_, field_descriptor, + return ReleaseSubMessage(parent_, field_descriptor, reinterpret_cast<CMessage*>(cmessage)); } - Message* parent_message_; + CMessage* parent_; }; int InternalReleaseFieldByDescriptor( + CMessage* self, const FieldDescriptor* field_descriptor, - PyObject* composite_field, - Message* parent_message) { + PyObject* composite_field) { return VisitCompositeField( field_descriptor, composite_field, - ReleaseChild(parent_message)); + ReleaseChild(self)); } PyObject* ClearFieldByDescriptor( @@ -1200,8 +1380,8 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { // Only release the field if there's a possibility that there are // references to it. if (composite_field != NULL) { - if (InternalReleaseFieldByDescriptor(field_descriptor, - composite_field, message) < 0) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, + composite_field) < 0) { return NULL; } PyDict_DelItem(self->composite_fields, arg); @@ -1219,7 +1399,7 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { PyObject* Clear(CMessage* self) { AssureWritable(self); - if (ForEachCompositeField(self, ReleaseChild(self->message)) == -1) + if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; // The old ExtensionDict still aliases this CMessage, but all its @@ -1582,7 +1762,8 @@ static PyObject* ListFields(CMessage* self) { } if (fields[i]->is_extension()) { - ScopedPyObjectPtr extension_field(PyFieldDescriptor_New(fields[i])); + ScopedPyObjectPtr extension_field( + PyFieldDescriptor_FromDescriptor(fields[i])); if (extension_field == NULL) { return NULL; } @@ -1616,7 +1797,8 @@ static PyObject* ListFields(CMessage* self) { PyErr_SetString(PyExc_ValueError, "bad string"); return NULL; } - ScopedPyObjectPtr field_descriptor(PyFieldDescriptor_New(fields[i])); + ScopedPyObjectPtr field_descriptor( + PyFieldDescriptor_FromDescriptor(fields[i])); if (field_descriptor == NULL) { return NULL; } @@ -1683,10 +1865,8 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { } } -PyObject* InternalGetScalar( - CMessage* self, - const FieldDescriptor* field_descriptor) { - Message* message = self->message; +PyObject* InternalGetScalar(const Message* message, + const FieldDescriptor* field_descriptor) { const Reflection* reflection = message->GetReflection(); if (!CheckFieldBelongsToMessage(field_descriptor, message)) { @@ -1739,12 +1919,12 @@ PyObject* InternalGetScalar( if (!message->GetReflection()->SupportsUnknownEnumValues() && !message->GetReflection()->HasField(*message, field_descriptor)) { // Look for the value in the unknown fields. - UnknownFieldSet* unknown_field_set = - message->GetReflection()->MutableUnknownFields(message); - for (int i = 0; i < unknown_field_set->field_count(); ++i) { - if (unknown_field_set->field(i).number() == + const UnknownFieldSet& unknown_field_set = + message->GetReflection()->GetUnknownFields(*message); + for (int i = 0; i < unknown_field_set.field_count(); ++i) { + if (unknown_field_set.field(i).number() == field_descriptor->number()) { - result = PyInt_FromLong(unknown_field_set->field(i).varint()); + result = PyInt_FromLong(unknown_field_set.field(i).varint()); break; } } @@ -1793,21 +1973,16 @@ PyObject* InternalGetSubMessage( return reinterpret_cast<PyObject*>(cmsg); } -int InternalSetScalar( - CMessage* self, +int InternalSetNonOneofScalar( + Message* message, const FieldDescriptor* field_descriptor, PyObject* arg) { - Message* message = self->message; const Reflection* reflection = message->GetReflection(); if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return -1; } - if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { - return -1; - } - switch (field_descriptor->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(arg, value, -1); @@ -1878,6 +2053,21 @@ int InternalSetScalar( return 0; } +int InternalSetScalar( + CMessage* self, + const FieldDescriptor* field_descriptor, + PyObject* arg) { + if (!CheckFieldBelongsToMessage(field_descriptor, self->message)) { + return -1; + } + + if (MaybeReleaseOverlappingOneofField(self, field_descriptor) < 0) { + return -1; + } + + return InternalSetNonOneofScalar(self->message, field_descriptor, arg); +} + PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { PyObject* py_cmsg = PyObject_CallObject( reinterpret_cast<PyObject*>(cls), NULL); @@ -1955,7 +2145,8 @@ static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) { // which was built previously. for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); - ScopedPyObjectPtr enum_type(PyEnumDescriptor_New(enum_descriptor)); + ScopedPyObjectPtr enum_type( + PyEnumDescriptor_FromDescriptor(enum_descriptor)); if (enum_type == NULL) { return NULL; } @@ -1993,7 +2184,7 @@ static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) { // which was defined previously. for (int i = 0; i < message_descriptor->extension_count(); ++i) { const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); - ScopedPyObjectPtr extension_field(PyFieldDescriptor_New(field)); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); if (extension_field == NULL) { return NULL; } @@ -2097,26 +2288,6 @@ PyObject* SetState(CMessage* self, PyObject* state) { } // CMessage static methods: -PyObject* _GetMessageDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindMessageByName(GetDescriptorPool(), arg); -} - -PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindFieldByName(GetDescriptorPool(), arg); -} - -PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindExtensionByName(GetDescriptorPool(), arg); -} - -PyObject* _GetEnumDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindEnumTypeByName(GetDescriptorPool(), arg); -} - -PyObject* _GetOneofDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindOneofByName(GetDescriptorPool(), arg); -} - PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, PyObject* unused_arg) { if (!_CalledFromGeneratedFile(1)) { @@ -2188,21 +2359,6 @@ static PyMethodDef Methods[] = { "or None if no field is set." }, // Static Methods. - { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC, - "Registers a new protocol buffer file in the global C++ descriptor pool." }, - { "_GetMessageDescriptor", (PyCFunction)_GetMessageDescriptor, - METH_O | METH_STATIC, "Finds a message descriptor in the message pool." }, - { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor, - METH_O | METH_STATIC, "Finds a field descriptor in the message pool." }, - { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor, - METH_O | METH_STATIC, - "Finds a extension descriptor in the message pool." }, - { "_GetEnumDescriptor", (PyCFunction)_GetEnumDescriptor, - METH_O | METH_STATIC, - "Finds an enum descriptor in the message pool." }, - { "_GetOneofDescriptor", (PyCFunction)_GetOneofDescriptor, - METH_O | METH_STATIC, - "Finds an oneof descriptor in the message pool." }, { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile, METH_NOARGS | METH_STATIC, "Raises TypeError if the caller is not in a _pb2.py file."}, @@ -2234,6 +2390,31 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { reinterpret_cast<PyObject*>(self), name); } + if (field_descriptor->is_map()) { + PyObject* py_container = NULL; + const Descriptor* entry_type = field_descriptor->message_type(); + const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); + if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* value_class = cdescriptor_pool::GetMessageClass( + GetDescriptorPool(), value_type->message_type()); + if (value_class == NULL) { + return NULL; + } + py_container = message_map_container::NewContainer(self, field_descriptor, + value_class); + } else { + py_container = scalar_map_container::NewContainer(self, field_descriptor); + } + if (py_container == NULL) { + return NULL; + } + if (!SetCompositeField(self, name, py_container)) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { @@ -2267,7 +2448,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { return sub_message; } - return InternalGetScalar(self, field_descriptor); + return InternalGetScalar(self->message, field_descriptor); } int SetAttr(CMessage* self, PyObject* name, PyObject* value) { @@ -2304,9 +2485,7 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { PyTypeObject CMessage_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - // Keep the fully qualified _message symbol in a line for opensource. - "google.protobuf.pyext._message." - "CMessage", // tp_name + FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize (destructor)cmessage::Dealloc, // tp_dealloc @@ -2401,7 +2580,7 @@ void InitGlobals() { k_extensions_by_name = PyString_FromString("_extensions_by_name"); k_extensions_by_number = PyString_FromString("_extensions_by_number"); - message_factory = new DynamicMessageFactory(GetDescriptorPool()->pool); + message_factory = new DynamicMessageFactory(); message_factory->SetDelegateToGeneratedFactory(true); } @@ -2469,6 +2648,61 @@ bool InitProto2MessageModule(PyObject *m) { reinterpret_cast<PyObject*>( &RepeatedCompositeContainer_Type)); + // ScalarMapContainer_Type derives from our MutableMapping type. + PyObject* containers = + PyImport_ImportModule("google.protobuf.internal.containers"); + if (containers == NULL) { + return false; + } + + PyObject* mutable_mapping = + PyObject_GetAttrString(containers, "MutableMapping"); + Py_DECREF(containers); + + if (mutable_mapping == NULL) { + return false; + } + + if (!PyObject_TypeCheck(mutable_mapping, &PyType_Type)) { + Py_DECREF(mutable_mapping); + return false; + } + + ScalarMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping); + + if (PyType_Ready(&ScalarMapContainer_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "ScalarMapContainer", + reinterpret_cast<PyObject*>(&ScalarMapContainer_Type)); + + if (PyType_Ready(&ScalarMapIterator_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "ScalarMapIterator", + reinterpret_cast<PyObject*>(&ScalarMapIterator_Type)); + + Py_INCREF(mutable_mapping); + MessageMapContainer_Type.tp_base = + reinterpret_cast<PyTypeObject*>(mutable_mapping); + + if (PyType_Ready(&MessageMapContainer_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "MessageMapContainer", + reinterpret_cast<PyObject*>(&MessageMapContainer_Type)); + + if (PyType_Ready(&MessageMapIterator_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "MessageMapIterator", + reinterpret_cast<PyObject*>(&MessageMapIterator_Type)); + ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; @@ -2478,6 +2712,12 @@ bool InitProto2MessageModule(PyObject *m) { m, "ExtensionDict", reinterpret_cast<PyObject*>(&ExtensionDict_Type)); + // Expose the DescriptorPool used to hold all descriptors added from generated + // pb2.py files. + Py_INCREF(GetDescriptorPool()); // PyModule_AddObject steals a reference. + PyModule_AddObject( + m, "default_pool", reinterpret_cast<PyObject*>(GetDescriptorPool())); + // This implementation provides full Descriptor types, we advertise it so that // descriptor.py can use them in replacement of the Python classes. PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1); diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 2f2da795..7360b207 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -120,7 +120,7 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor* descriptor); // A new message will be created if this is a read-only default instance. // // Corresponds to reflection api method ReleaseMessage. -int ReleaseSubMessage(Message* message, +int ReleaseSubMessage(CMessage* self, const FieldDescriptor* field_descriptor, CMessage* child_cmessage); @@ -144,7 +144,7 @@ PyObject* InternalGetSubMessage( // by slice will be removed from cmessage_list by this function. // // Corresponds to reflection api method RemoveLast. -int InternalDeleteRepeatedField(Message* message, +int InternalDeleteRepeatedField(CMessage* self, const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list); @@ -153,10 +153,15 @@ int InternalSetScalar(CMessage* self, const FieldDescriptor* field_descriptor, PyObject* value); +// Sets the specified scalar value to the message. Requires it is not a Oneof. +int InternalSetNonOneofScalar(Message* message, + const FieldDescriptor* field_descriptor, + PyObject* arg); + // Retrieves the specified scalar value from the message. // // Returns a new python reference. -PyObject* InternalGetScalar(CMessage* self, +PyObject* InternalGetScalar(const Message* message, const FieldDescriptor* field_descriptor); // Clears the message, removing all contained data. Extension dictionary and @@ -279,7 +284,7 @@ extern PyObject* kint64min_py; extern PyObject* kint64max_py; extern PyObject* kuint64max_py; -#define C(str) const_cast<char*>(str) +#define FULL_MODULE_NAME "google.protobuf.pyext._message" void FormatTypeError(PyObject* arg, char* expected_types); template<class T> diff --git a/python/google/protobuf/pyext/message_map_container.cc b/python/google/protobuf/pyext/message_map_container.cc new file mode 100644 index 00000000..ab8d8fb9 --- /dev/null +++ b/python/google/protobuf/pyext/message_map_container.cc @@ -0,0 +1,540 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: haberman@google.com (Josh Haberman) + +#include <google/protobuf/pyext/message_map_container.h> + +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +namespace google { +namespace protobuf { +namespace python { + +struct MessageMapIterator { + PyObject_HEAD; + + // This dict contains the full contents of what we want to iterate over. + // There's no way to avoid building this, because the list representation + // (which is canonical) can contain duplicate keys. So at the very least we + // need a set that lets us skip duplicate keys. And at the point that we're + // doing that, we might as well just build the actual dict we're iterating + // over and use dict's built-in iterator. + PyObject* dict; + + // An iterator on dict. + PyObject* iter; + + // A pointer back to the container, so we can notice changes to the version. + MessageMapContainer* container; + + // The version of the map when we took the iterator to it. + // + // We store this so that if the map is modified during iteration we can throw + // an error. + uint64 version; +}; + +static MessageMapIterator* GetIter(PyObject* obj) { + return reinterpret_cast<MessageMapIterator*>(obj); +} + +namespace message_map_container { + +static MessageMapContainer* GetMap(PyObject* obj) { + return reinterpret_cast<MessageMapContainer*>(obj); +} + +// The private constructor of MessageMapContainer objects. +PyObject* NewContainer(CMessage* parent, + const google::protobuf::FieldDescriptor* parent_field_descriptor, + PyObject* concrete_class) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + PyObject* obj = PyType_GenericAlloc(&MessageMapContainer_Type, 0); + if (obj == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + } + + MessageMapContainer* self = GetMap(obj); + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + self->version = 0; + + self->key_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("key"); + self->value_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("value"); + + self->message_dict = PyDict_New(); + if (self->message_dict == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate message dict."); + } + + Py_INCREF(concrete_class); + self->subclass_init = concrete_class; + + if (self->key_field_descriptor == NULL || + self->value_field_descriptor == NULL) { + Py_DECREF(obj); + return PyErr_Format(PyExc_KeyError, + "Map entry descriptor did not have key/value fields"); + } + + return obj; +} + +// Initializes the underlying Message object of "to" so it becomes a new parent +// repeated scalar, and copies all the values from "from" to it. A child scalar +// container can be released by passing it as both from and to (e.g. making it +// the recipient of the new parent message and copying the values from itself). +static int InitializeAndCopyToParentContainer( + MessageMapContainer* from, + MessageMapContainer* to) { + // For now we require from == to, re-evaluate if we want to support deep copy + // as in repeated_composite_container.cc. + GOOGLE_DCHECK(from == to); + Message* old_message = from->message; + Message* new_message = old_message->New(); + to->parent = NULL; + to->parent_field_descriptor = from->parent_field_descriptor; + to->message = new_message; + to->owner.reset(new_message); + + vector<const FieldDescriptor*> fields; + fields.push_back(from->parent_field_descriptor); + old_message->GetReflection()->SwapFields(old_message, new_message, fields); + return 0; +} + +static PyObject* GetCMessage(MessageMapContainer* self, Message* entry) { + // Get or create the CMessage object corresponding to this message. + Message* message = entry->GetReflection()->MutableMessage( + entry, self->value_field_descriptor); + ScopedPyObjectPtr key(PyLong_FromVoidPtr(message)); + PyObject* ret = PyDict_GetItem(self->message_dict, key); + + if (ret == NULL) { + CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, + message->GetDescriptor()); + ret = reinterpret_cast<PyObject*>(cmsg); + + if (cmsg == NULL) { + return NULL; + } + cmsg->owner = self->owner; + cmsg->message = message; + cmsg->parent = self->parent; + + if (PyDict_SetItem(self->message_dict, key, ret) < 0) { + Py_DECREF(ret); + return NULL; + } + } else { + Py_INCREF(ret); + } + + return ret; +} + +int Release(MessageMapContainer* self) { + InitializeAndCopyToParentContainer(self, self); + return 0; +} + +void SetOwner(MessageMapContainer* self, + const shared_ptr<Message>& new_owner) { + self->owner = new_owner; +} + +Py_ssize_t Length(PyObject* _self) { + MessageMapContainer* self = GetMap(_self); + google::protobuf::Message* message = self->message; + return message->GetReflection()->FieldSize(*message, + self->parent_field_descriptor); +} + +int MapKeyMatches(MessageMapContainer* self, const Message* entry, + PyObject* key) { + // TODO(haberman): do we need more strict type checking? + ScopedPyObjectPtr entry_key( + cmessage::InternalGetScalar(entry, self->key_field_descriptor)); + int ret = PyObject_RichCompareBool(key, entry_key, Py_EQ); + return ret; +} + +int SetItem(PyObject *_self, PyObject *key, PyObject *v) { + if (v) { + PyErr_Format(PyExc_ValueError, + "Direct assignment of submessage not allowed"); + return -1; + } + + // Now we know that this is a delete, not a set. + + MessageMapContainer* self = GetMap(_self); + cmessage::AssureWritable(self->parent); + + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + size_t size = + reflection->FieldSize(*message, self->parent_field_descriptor); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. We need to search from the end because the underlying + // representation can have duplicates if a user calls MergeFrom(); the last + // one needs to win. + // + // TODO(haberman): add lookup API to Reflection API. + bool found = false; + for (int i = size - 1; i >= 0; i--) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, entry, key); + if (matches < 0) return -1; + if (matches) { + found = true; + if (i != size - 1) { + reflection->SwapElements(message, self->parent_field_descriptor, i, + size - 1); + } + reflection->RemoveLast(message, self->parent_field_descriptor); + + // Can't exit now, the repeated field representation of maps allows + // duplicate keys, and we have to be sure to remove all of them. + } + } + + if (!found) { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } + + self->version++; + + return 0; +} + +PyObject* GetIterator(PyObject *_self) { + MessageMapContainer* self = GetMap(_self); + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&MessageMapIterator_Type, 0)); + if (obj == NULL) { + return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); + } + + MessageMapIterator* iter = GetIter(obj); + + Py_INCREF(self); + iter->container = self; + iter->version = self->version; + iter->dict = PyDict_New(); + if (iter->dict == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate dict for iterator."); + } + + // Build the entire map into a dict right now. Start from the beginning so + // that later entries win in the case of duplicates. + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. We need to search from the end because the underlying + // representation can have duplicates if a user calls MergeFrom(); the last + // one needs to win. + // + // TODO(haberman): add lookup API to Reflection API. + size_t size = + reflection->FieldSize(*message, self->parent_field_descriptor); + for (int i = size - 1; i >= 0; i--) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + ScopedPyObjectPtr key( + cmessage::InternalGetScalar(entry, self->key_field_descriptor)); + if (PyDict_SetItem(iter->dict, key.get(), GetCMessage(self, entry)) < 0) { + return PyErr_Format(PyExc_RuntimeError, + "SetItem failed in iterator construction."); + } + } + + iter->iter = PyObject_GetIter(iter->dict); + + return obj.release(); +} + +PyObject* GetItem(PyObject* _self, PyObject* key) { + MessageMapContainer* self = GetMap(_self); + cmessage::AssureWritable(self->parent); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. We need to search from the end because the underlying + // representation can have duplicates if a user calls MergeFrom(); the last + // one needs to win. + // + // TODO(haberman): add lookup API to Reflection API. + size_t size = + reflection->FieldSize(*message, self->parent_field_descriptor); + for (int i = size - 1; i >= 0; i--) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, entry, key); + if (matches < 0) return NULL; + if (matches) { + return GetCMessage(self, entry); + } + } + + // Key is not already present; insert a new entry. + Message* entry = + reflection->AddMessage(message, self->parent_field_descriptor); + + self->version++; + + if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor, + key) < 0) { + reflection->RemoveLast(message, self->parent_field_descriptor); + return NULL; + } + + return GetCMessage(self, entry); +} + +PyObject* Contains(PyObject* _self, PyObject* key) { + MessageMapContainer* self = GetMap(_self); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. + // + // TODO(haberman): add lookup API to Reflection API. + size_t size = + reflection->FieldSize(*message, self->parent_field_descriptor); + for (int i = 0; i < size; i++) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, entry, key); + if (matches < 0) return NULL; + if (matches) { + Py_RETURN_TRUE; + } + } + + Py_RETURN_FALSE; +} + +PyObject* Clear(PyObject* _self) { + MessageMapContainer* self = GetMap(_self); + cmessage::AssureWritable(self->parent); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + self->version++; + reflection->ClearField(message, self->parent_field_descriptor); + + Py_RETURN_NONE; +} + +PyObject* Get(PyObject* self, PyObject* args) { + PyObject* key; + PyObject* default_value = NULL; + if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { + return NULL; + } + + ScopedPyObjectPtr is_present(Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return GetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static PyMappingMethods MpMethods = { + Length, // mp_length + GetItem, // mp_subscript + SetItem, // mp_ass_subscript +}; + +static void Dealloc(PyObject* _self) { + MessageMapContainer* self = GetMap(_self); + self->owner.reset(); + Py_DECREF(self->message_dict); + Py_TYPE(_self)->tp_free(_self); +} + +static PyMethodDef Methods[] = { + { "__contains__", (PyCFunction)Contains, METH_O, + "Tests whether the map contains this element."}, + { "clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map."}, + { "get", Get, METH_VARARGS, + "Gets the value for the given key if present, or otherwise a default" }, + { "get_or_create", GetItem, METH_O, + "Alias for getitem, useful to make explicit that the map is mutated." }, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +} // namespace message_map_container + +namespace message_map_iterator { + +static void Dealloc(PyObject* _self) { + MessageMapIterator* self = GetIter(_self); + Py_DECREF(self->dict); + Py_DECREF(self->iter); + Py_DECREF(self->container); + Py_TYPE(_self)->tp_free(_self); +} + +PyObject* IterNext(PyObject* _self) { + MessageMapIterator* self = GetIter(_self); + + // This won't catch mutations to the map performed by MergeFrom(); no easy way + // to address that. + if (self->version != self->container->version) { + return PyErr_Format(PyExc_RuntimeError, + "Map modified during iteration."); + } + + return PyIter_Next(self->iter); +} + +} // namespace message_map_iterator + +PyTypeObject MessageMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMapContainer", // tp_name + sizeof(MessageMapContainer), // tp_basicsize + 0, // tp_itemsize + message_map_container::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &message_map_container::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A map container for message", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + message_map_container::GetIterator, // tp_iter + 0, // tp_iternext + message_map_container::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +PyTypeObject MessageMapIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MessageMapIterator", // tp_name + sizeof(MessageMapIterator), // tp_basicsize + 0, // tp_itemsize + message_map_iterator::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map iterator", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + message_map_iterator::IterNext, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/message_map_container.h b/python/google/protobuf/pyext/message_map_container.h new file mode 100644 index 00000000..4ca0aecc --- /dev/null +++ b/python/google/protobuf/pyext/message_map_container.h @@ -0,0 +1,117 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + +#include <google/protobuf/descriptor.h> + +namespace google { +namespace protobuf { + +class Message; + +using internal::shared_ptr; + +namespace python { + +struct CMessage; + +struct MessageMapContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python MessageMapContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr<Message> owner; + + // Pointer to the C++ Message that contains this container. The + // MessageMapContainer does not own this pointer. + Message* message; + + // Weak reference to a parent CMessage object (i.e. may be NULL.) + // + // Used to make sure all ancestors are also mutable when first + // modifying the container. + CMessage* parent; + + // Pointer to the parent's descriptor that describes this + // field. Used together with the parent's message when making a + // default message instance mutable. + // The pointer is owned by the global DescriptorPool. + const FieldDescriptor* parent_field_descriptor; + const FieldDescriptor* key_field_descriptor; + const FieldDescriptor* value_field_descriptor; + + // A callable that is used to create new child messages. + PyObject* subclass_init; + + // A dict mapping Message* -> CMessage. + PyObject* message_dict; + + // We bump this whenever we perform a mutation, to invalidate existing + // iterators. + uint64 version; +}; + +extern PyTypeObject MessageMapContainer_Type; +extern PyTypeObject MessageMapIterator_Type; + +namespace message_map_container { + +// Builds a MessageMapContainer object, from a parent message and a +// field descriptor. +extern PyObject* NewContainer(CMessage* parent, + const FieldDescriptor* parent_field_descriptor, + PyObject* concrete_class); + +// Releases the messages in the container to a new message. +// +// Returns 0 on success, -1 on failure. +int Release(MessageMapContainer* self); + +// Set the owner field of self and any children of self. +void SetOwner(MessageMapContainer* self, + const shared_ptr<Message>& new_owner); + +} // namespace message_map_container +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_MAP_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 0fe98e73..86b75d0f 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -367,8 +367,8 @@ int AssignSubscript(RepeatedCompositeContainer* self, } // Delete from the underlying Message, if any. - if (self->message != NULL) { - if (cmessage::InternalDeleteRepeatedField(self->message, + if (self->parent != NULL) { + if (cmessage::InternalDeleteRepeatedField(self->parent, self->parent_field_descriptor, slice, self->child_messages) < 0) { @@ -572,47 +572,35 @@ static PyObject* Pop(RepeatedCompositeContainer* self, return item; } -// The caller takes ownership of the returned Message. -Message* ReleaseLast(const FieldDescriptor* field, - const Descriptor* type, - Message* message) { +// Release field of parent message and transfer the ownership to target. +void ReleaseLastTo(CMessage* parent, + const FieldDescriptor* field, + CMessage* target) { + GOOGLE_CHECK_NOTNULL(parent); GOOGLE_CHECK_NOTNULL(field); - GOOGLE_CHECK_NOTNULL(type); - GOOGLE_CHECK_NOTNULL(message); + GOOGLE_CHECK_NOTNULL(target); - Message* released_message = message->GetReflection()->ReleaseLast( - message, field); + shared_ptr<Message> released_message( + parent->message->GetReflection()->ReleaseLast(parent->message, field)); // TODO(tibell): Deal with proto1. // ReleaseMessage will return NULL which differs from // child_cmessage->message, if the field does not exist. In this case, // the latter points to the default instance via a const_cast<>, so we // have to reset it to a new mutable object since we are taking ownership. - if (released_message == NULL) { + if (released_message.get() == NULL) { const Message* prototype = - cmessage::GetMessageFactory()->GetPrototype(type); + cmessage::GetMessageFactory()->GetPrototype( + target->message->GetDescriptor()); GOOGLE_CHECK_NOTNULL(prototype); - return prototype->New(); - } else { - return released_message; + released_message.reset(prototype->New()); } -} -// Release field of message and transfer the ownership to cmessage. -void ReleaseLastTo(const FieldDescriptor* field, - Message* message, - CMessage* cmessage) { - GOOGLE_CHECK_NOTNULL(field); - GOOGLE_CHECK_NOTNULL(message); - GOOGLE_CHECK_NOTNULL(cmessage); - - shared_ptr<Message> released_message( - ReleaseLast(field, cmessage->message->GetDescriptor(), message)); - cmessage->parent = NULL; - cmessage->parent_field_descriptor = NULL; - cmessage->message = released_message.get(); - cmessage->read_only = false; - cmessage::SetOwner(cmessage, released_message); + target->parent = NULL; + target->parent_field_descriptor = NULL; + target->message = released_message.get(); + target->read_only = false; + cmessage::SetOwner(target, released_message); } // Called to release a container using @@ -635,7 +623,7 @@ int Release(RepeatedCompositeContainer* self) { for (Py_ssize_t i = size - 1; i >= 0; --i) { CMessage* child_cmessage = reinterpret_cast<CMessage*>( PyList_GET_ITEM(self->child_messages, i)); - ReleaseLastTo(field, message, child_cmessage); + ReleaseLastTo(self->parent, field, child_cmessage); } // Detach from containing message. @@ -732,9 +720,7 @@ static PyMethodDef Methods[] = { PyTypeObject RepeatedCompositeContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - // Keep the fully qualified _message symbol in a line for opensource. - "google.protobuf.pyext._message." - "RepeatedCompositeContainer", // tp_name + FULL_MODULE_NAME ".RepeatedCompositeContainer", // tp_name sizeof(RepeatedCompositeContainer), // tp_basicsize 0, // tp_itemsize (destructor)repeated_composite_container::Dealloc, // tp_dealloc diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h index ce7cee0f..e0f21360 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.h +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -161,13 +161,13 @@ int SetOwner(RepeatedCompositeContainer* self, const shared_ptr<Message>& new_owner); // Removes the last element of the repeated message field 'field' on -// the Message 'message', and transfers the ownership of the released -// Message to 'cmessage'. +// the Message 'parent', and transfers the ownership of the released +// Message to 'target'. // // Corresponds to reflection api method ReleaseMessage. -void ReleaseLastTo(const FieldDescriptor* field, - Message* message, - CMessage* cmessage); +void ReleaseLastTo(CMessage* parent, + const FieldDescriptor* field, + CMessage* target); } // namespace repeated_composite_container } // namespace python diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index 110a4c85..fd196836 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -102,7 +102,7 @@ static int AssignItem(RepeatedScalarContainer* self, if (arg == NULL) { ScopedPyObjectPtr py_index(PyLong_FromLong(index)); - return cmessage::InternalDeleteRepeatedField(message, field_descriptor, + return cmessage::InternalDeleteRepeatedField(self->parent, field_descriptor, py_index, NULL); } @@ -470,7 +470,7 @@ static int AssSubscript(RepeatedScalarContainer* self, if (value == NULL) { return cmessage::InternalDeleteRepeatedField( - message, field_descriptor, slice, NULL); + self->parent, field_descriptor, slice, NULL); } if (!create_list) { @@ -769,9 +769,7 @@ static PyMethodDef Methods[] = { PyTypeObject RepeatedScalarContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - // Keep the fully qualified _message symbol in a line for opensource. - "google.protobuf.pyext._message." - "RepeatedScalarContainer", // tp_name + FULL_MODULE_NAME ".RepeatedScalarContainer", // tp_name sizeof(RepeatedScalarContainer), // tp_basicsize 0, // tp_itemsize (destructor)repeated_scalar_container::Dealloc, // tp_dealloc diff --git a/python/google/protobuf/pyext/scalar_map_container.cc b/python/google/protobuf/pyext/scalar_map_container.cc new file mode 100644 index 00000000..6f731d27 --- /dev/null +++ b/python/google/protobuf/pyext/scalar_map_container.cc @@ -0,0 +1,514 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: haberman@google.com (Josh Haberman) + +#include <google/protobuf/pyext/scalar_map_container.h> + +#include <google/protobuf/stubs/common.h> +#include <google/protobuf/message.h> +#include <google/protobuf/pyext/message.h> +#include <google/protobuf/pyext/scoped_pyobject_ptr.h> + +namespace google { +namespace protobuf { +namespace python { + +struct ScalarMapIterator { + PyObject_HEAD; + + // This dict contains the full contents of what we want to iterate over. + // There's no way to avoid building this, because the list representation + // (which is canonical) can contain duplicate keys. So at the very least we + // need a set that lets us skip duplicate keys. And at the point that we're + // doing that, we might as well just build the actual dict we're iterating + // over and use dict's built-in iterator. + PyObject* dict; + + // An iterator on dict. + PyObject* iter; + + // A pointer back to the container, so we can notice changes to the version. + ScalarMapContainer* container; + + // The version of the map when we took the iterator to it. + // + // We store this so that if the map is modified during iteration we can throw + // an error. + uint64 version; +}; + +static ScalarMapIterator* GetIter(PyObject* obj) { + return reinterpret_cast<ScalarMapIterator*>(obj); +} + +namespace scalar_map_container { + +static ScalarMapContainer* GetMap(PyObject* obj) { + return reinterpret_cast<ScalarMapContainer*>(obj); +} + +// The private constructor of ScalarMapContainer objects. +PyObject *NewContainer( + CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { + if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { + return NULL; + } + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapContainer_Type, 0)); + if (obj.get() == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate new container."); + } + + ScalarMapContainer* self = GetMap(obj); + + self->message = parent->message; + self->parent = parent; + self->parent_field_descriptor = parent_field_descriptor; + self->owner = parent->owner; + self->version = 0; + + self->key_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("key"); + self->value_field_descriptor = + parent_field_descriptor->message_type()->FindFieldByName("value"); + + if (self->key_field_descriptor == NULL || + self->value_field_descriptor == NULL) { + return PyErr_Format(PyExc_KeyError, + "Map entry descriptor did not have key/value fields"); + } + + return obj.release(); +} + +// Initializes the underlying Message object of "to" so it becomes a new parent +// repeated scalar, and copies all the values from "from" to it. A child scalar +// container can be released by passing it as both from and to (e.g. making it +// the recipient of the new parent message and copying the values from itself). +static int InitializeAndCopyToParentContainer( + ScalarMapContainer* from, + ScalarMapContainer* to) { + // For now we require from == to, re-evaluate if we want to support deep copy + // as in repeated_scalar_container.cc. + GOOGLE_DCHECK(from == to); + Message* old_message = from->message; + Message* new_message = old_message->New(); + to->parent = NULL; + to->parent_field_descriptor = from->parent_field_descriptor; + to->message = new_message; + to->owner.reset(new_message); + + vector<const FieldDescriptor*> fields; + fields.push_back(from->parent_field_descriptor); + old_message->GetReflection()->SwapFields(old_message, new_message, fields); + return 0; +} + +int Release(ScalarMapContainer* self) { + return InitializeAndCopyToParentContainer(self, self); +} + +void SetOwner(ScalarMapContainer* self, + const shared_ptr<Message>& new_owner) { + self->owner = new_owner; +} + +Py_ssize_t Length(PyObject* _self) { + ScalarMapContainer* self = GetMap(_self); + google::protobuf::Message* message = self->message; + return message->GetReflection()->FieldSize(*message, + self->parent_field_descriptor); +} + +int MapKeyMatches(ScalarMapContainer* self, const Message* entry, + PyObject* key) { + // TODO(haberman): do we need more strict type checking? + ScopedPyObjectPtr entry_key( + cmessage::InternalGetScalar(entry, self->key_field_descriptor)); + int ret = PyObject_RichCompareBool(key, entry_key, Py_EQ); + return ret; +} + +PyObject* GetItem(PyObject* _self, PyObject* key) { + ScalarMapContainer* self = GetMap(_self); + + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. + // + // TODO(haberman): add lookup API to Reflection API. + size_t size = reflection->FieldSize(*message, self->parent_field_descriptor); + for (int i = size - 1; i >= 0; i--) { + const Message& entry = reflection->GetRepeatedMessage( + *message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, &entry, key); + if (matches < 0) return NULL; + if (matches) { + return cmessage::InternalGetScalar(&entry, self->value_field_descriptor); + } + } + + // Need to add a new entry. + Message* entry = + reflection->AddMessage(message, self->parent_field_descriptor); + PyObject* ret = NULL; + + if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor, + key) >= 0) { + ret = cmessage::InternalGetScalar(entry, self->value_field_descriptor); + } + + self->version++; + + // If there was a type error above, it set the Python exception. + return ret; +} + +int SetItem(PyObject *_self, PyObject *key, PyObject *v) { + ScalarMapContainer* self = GetMap(_self); + cmessage::AssureWritable(self->parent); + + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + size_t size = + reflection->FieldSize(*message, self->parent_field_descriptor); + self->version++; + + if (v) { + // Set item. + // + // Right now the Reflection API doesn't support map lookup, so we implement + // it via linear search. + // + // TODO(haberman): add lookup API to Reflection API. + for (int i = size - 1; i >= 0; i--) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, entry, key); + if (matches < 0) return -1; + if (matches) { + return cmessage::InternalSetNonOneofScalar( + entry, self->value_field_descriptor, v); + } + } + + // Key is not already present; insert a new entry. + Message* entry = + reflection->AddMessage(message, self->parent_field_descriptor); + + if (cmessage::InternalSetNonOneofScalar(entry, self->key_field_descriptor, + key) < 0 || + cmessage::InternalSetNonOneofScalar(entry, self->value_field_descriptor, + v) < 0) { + reflection->RemoveLast(message, self->parent_field_descriptor); + return -1; + } + + return 0; + } else { + bool found = false; + for (int i = size - 1; i >= 0; i--) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, entry, key); + if (matches < 0) return -1; + if (matches) { + found = true; + if (i != size - 1) { + reflection->SwapElements(message, self->parent_field_descriptor, i, + size - 1); + } + reflection->RemoveLast(message, self->parent_field_descriptor); + + // Can't exit now, the repeated field representation of maps allows + // duplicate keys, and we have to be sure to remove all of them. + } + } + + if (found) { + return 0; + } else { + PyErr_Format(PyExc_KeyError, "Key not present in map"); + return -1; + } + } +} + +PyObject* GetIterator(PyObject *_self) { + ScalarMapContainer* self = GetMap(_self); + + ScopedPyObjectPtr obj(PyType_GenericAlloc(&ScalarMapIterator_Type, 0)); + if (obj == NULL) { + return PyErr_Format(PyExc_KeyError, "Could not allocate iterator"); + } + + ScalarMapIterator* iter = GetIter(obj.get()); + + Py_INCREF(self); + iter->container = self; + iter->version = self->version; + iter->dict = PyDict_New(); + if (iter->dict == NULL) { + return PyErr_Format(PyExc_RuntimeError, + "Could not allocate dict for iterator."); + } + + // Build the entire map into a dict right now. Start from the beginning so + // that later entries win in the case of duplicates. + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. We need to search from the end because the underlying + // representation can have duplicates if a user calls MergeFrom(); the last + // one needs to win. + // + // TODO(haberman): add lookup API to Reflection API. + size_t size = + reflection->FieldSize(*message, self->parent_field_descriptor); + for (int i = 0; i < size; i++) { + Message* entry = reflection->MutableRepeatedMessage( + message, self->parent_field_descriptor, i); + ScopedPyObjectPtr key( + cmessage::InternalGetScalar(entry, self->key_field_descriptor)); + ScopedPyObjectPtr val( + cmessage::InternalGetScalar(entry, self->value_field_descriptor)); + if (PyDict_SetItem(iter->dict, key.get(), val.get()) < 0) { + return PyErr_Format(PyExc_RuntimeError, + "SetItem failed in iterator construction."); + } + } + + + iter->iter = PyObject_GetIter(iter->dict); + + + return obj.release(); +} + +PyObject* Clear(PyObject* _self) { + ScalarMapContainer* self = GetMap(_self); + cmessage::AssureWritable(self->parent); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + reflection->ClearField(message, self->parent_field_descriptor); + + Py_RETURN_NONE; +} + +PyObject* Contains(PyObject* _self, PyObject* key) { + ScalarMapContainer* self = GetMap(_self); + + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + + // Right now the Reflection API doesn't support map lookup, so we implement it + // via linear search. + // + // TODO(haberman): add lookup API to Reflection API. + size_t size = reflection->FieldSize(*message, self->parent_field_descriptor); + for (int i = size - 1; i >= 0; i--) { + const Message& entry = reflection->GetRepeatedMessage( + *message, self->parent_field_descriptor, i); + int matches = MapKeyMatches(self, &entry, key); + if (matches < 0) return NULL; + if (matches) { + Py_RETURN_TRUE; + } + } + + Py_RETURN_FALSE; +} + +PyObject* Get(PyObject* self, PyObject* args) { + PyObject* key; + PyObject* default_value = NULL; + if (PyArg_ParseTuple(args, "O|O", &key, &default_value) < 0) { + return NULL; + } + + ScopedPyObjectPtr is_present(Contains(self, key)); + if (is_present.get() == NULL) { + return NULL; + } + + if (PyObject_IsTrue(is_present.get())) { + return GetItem(self, key); + } else { + if (default_value != NULL) { + Py_INCREF(default_value); + return default_value; + } else { + Py_RETURN_NONE; + } + } +} + +static PyMappingMethods MpMethods = { + Length, // mp_length + GetItem, // mp_subscript + SetItem, // mp_ass_subscript +}; + +static void Dealloc(PyObject* _self) { + ScalarMapContainer* self = GetMap(_self); + self->owner.reset(); + Py_TYPE(_self)->tp_free(_self); +} + +static PyMethodDef Methods[] = { + { "__contains__", Contains, METH_O, + "Tests whether a key is a member of the map." }, + { "clear", (PyCFunction)Clear, METH_NOARGS, + "Removes all elements from the map." }, + { "get", Get, METH_VARARGS, + "Gets the value for the given key if present, or otherwise a default" }, + /* + { "__deepcopy__", (PyCFunction)DeepCopy, METH_VARARGS, + "Makes a deep copy of the class." }, + { "__reduce__", (PyCFunction)Reduce, METH_NOARGS, + "Outputs picklable representation of the repeated field." }, + */ + {NULL, NULL}, +}; + +} // namespace scalar_map_container + +namespace scalar_map_iterator { + +static void Dealloc(PyObject* _self) { + ScalarMapIterator* self = GetIter(_self); + Py_DECREF(self->dict); + Py_DECREF(self->iter); + Py_DECREF(self->container); + Py_TYPE(_self)->tp_free(_self); +} + +PyObject* IterNext(PyObject* _self) { + ScalarMapIterator* self = GetIter(_self); + + // This won't catch mutations to the map performed by MergeFrom(); no easy way + // to address that. + if (self->version != self->container->version) { + return PyErr_Format(PyExc_RuntimeError, + "Map modified during iteration."); + } + + return PyIter_Next(self->iter); +} + +} // namespace scalar_map_iterator + +PyTypeObject ScalarMapContainer_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ScalarMapContainer", // tp_name + sizeof(ScalarMapContainer), // tp_basicsize + 0, // tp_itemsize + scalar_map_container::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + &scalar_map_container::MpMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map container", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + scalar_map_container::GetIterator, // tp_iter + 0, // tp_iternext + scalar_map_container::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +PyTypeObject ScalarMapIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ScalarMapIterator", // tp_name + sizeof(ScalarMapIterator), // tp_basicsize + 0, // tp_itemsize + scalar_map_iterator::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A scalar map iterator", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + scalar_map_iterator::IterNext, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init +}; + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/scalar_map_container.h b/python/google/protobuf/pyext/scalar_map_container.h new file mode 100644 index 00000000..254e6e98 --- /dev/null +++ b/python/google/protobuf/pyext/scalar_map_container.h @@ -0,0 +1,110 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__ + +#include <Python.h> + +#include <memory> +#ifndef _SHARED_PTR_H +#include <google/protobuf/stubs/shared_ptr.h> +#endif + +#include <google/protobuf/descriptor.h> + +namespace google { +namespace protobuf { + +class Message; + +using internal::shared_ptr; + +namespace python { + +struct CMessage; + +struct ScalarMapContainer { + PyObject_HEAD; + + // This is the top-level C++ Message object that owns the whole + // proto tree. Every Python ScalarMapContainer holds a + // reference to it in order to keep it alive as long as there's a + // Python object that references any part of the tree. + shared_ptr<Message> owner; + + // Pointer to the C++ Message that contains this container. The + // ScalarMapContainer does not own this pointer. + Message* message; + + // Weak reference to a parent CMessage object (i.e. may be NULL.) + // + // Used to make sure all ancestors are also mutable when first + // modifying the container. + CMessage* parent; + + // Pointer to the parent's descriptor that describes this + // field. Used together with the parent's message when making a + // default message instance mutable. + // The pointer is owned by the global DescriptorPool. + const FieldDescriptor* parent_field_descriptor; + const FieldDescriptor* key_field_descriptor; + const FieldDescriptor* value_field_descriptor; + + // We bump this whenever we perform a mutation, to invalidate existing + // iterators. + uint64 version; +}; + +extern PyTypeObject ScalarMapContainer_Type; +extern PyTypeObject ScalarMapIterator_Type; + +namespace scalar_map_container { + +// Builds a ScalarMapContainer object, from a parent message and a +// field descriptor. +extern PyObject *NewContainer( + CMessage* parent, const FieldDescriptor* parent_field_descriptor); + +// Releases the messages in the container to a new message. +// +// Returns 0 on success, -1 on failure. +int Release(ScalarMapContainer* self); + +// Set the owner field of self and any children of self. +void SetOwner(ScalarMapContainer* self, + const shared_ptr<Message>& new_owner); + +} // namespace scalar_map_container +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SCALAR_MAP_CONTAINER_H__ diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 55e653a0..82fca661 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -144,7 +144,6 @@ class GeneratedProtocolMessageType(type): _InitMessage(descriptor, cls) superclass = super(GeneratedProtocolMessageType, cls) superclass.__init__(name, bases, dictionary) - setattr(descriptor, '_concrete_class', cls) def ParseMessage(descriptor, byte_str): diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index a47ce3e3..8cbd6822 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -100,6 +100,10 @@ def MessageToString(message, as_utf8=False, as_one_line=False, return result.rstrip() return result +def _IsMapEntry(field): + return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, pointy_brackets=False, use_index_order=False, @@ -108,7 +112,19 @@ def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, if use_index_order: fields.sort(key=lambda x: x[0].index) for field, value in fields: - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if _IsMapEntry(field): + for key in value: + # This is slow for maps with submessage entires because it copies the + # entire tree. Unfortunately this would take significant refactoring + # of this file to work around. + # + # TODO(haberman): refactor and optimize if this becomes an issue. + entry_submsg = field.message_type._concrete_class( + key=key, value=value[key]) + PrintField(field, entry_submsg, out, indent, as_utf8, as_one_line, + pointy_brackets=pointy_brackets, + use_index_order=use_index_order, float_format=float_format) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in value: PrintField(field, element, out, indent, as_utf8, as_one_line, pointy_brackets=pointy_brackets, @@ -367,6 +383,7 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): message_descriptor.full_name, name)) if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + is_map_entry = _IsMapEntry(field) tokenizer.TryConsume(':') if tokenizer.TryConsume('<'): @@ -378,6 +395,8 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: if field.is_extension: sub_message = message.Extensions[field].add() + elif is_map_entry: + sub_message = field.message_type._concrete_class() else: sub_message = getattr(message, field.name).add() else: @@ -391,6 +410,14 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): if tokenizer.AtEnd(): raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % (end_token)) _MergeField(tokenizer, sub_message, allow_multiple_scalars) + + if is_map_entry: + value_cpptype = field.message_type.fields_by_name['value'].cpp_type + if value_cpptype == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + value = getattr(message, field.name)[sub_message.key] + value.MergeFrom(sub_message.value) + else: + getattr(message, field.name)[sub_message.key] = sub_message.value else: _MergeScalarField(tokenizer, message, field, allow_multiple_scalars) @@ -701,13 +728,16 @@ class _Tokenizer(object): String literals (whether bytes or text) can come in multiple adjacent tokens which are automatically concatenated, like in C or Python. This method only consumes one token. + + Raises: + ParseError: When the wrong format data is found. """ text = self.token if len(text) < 1 or text[0] not in ('\'', '"'): - raise self._ParseError('Expected string but found: "%r"' % text) + raise self._ParseError('Expected string but found: %r' % (text,)) if len(text) < 2 or text[-1] != text[0]: - raise self._ParseError('String missing ending quote.') + raise self._ParseError('String missing ending quote: %r' % (text,)) try: result = text_encoding.CUnescape(text[1:-1]) diff --git a/python/setup.py b/python/setup.py index a1365fba..5c321f50 100755 --- a/python/setup.py +++ b/python/setup.py @@ -91,6 +91,7 @@ def GenerateUnittestProtos(): if not os.path.exists("../.git"): return + generate_proto("../src/google/protobuf/map_unittest.proto") generate_proto("../src/google/protobuf/unittest.proto") generate_proto("../src/google/protobuf/unittest_custom_options.proto") generate_proto("../src/google/protobuf/unittest_import.proto") |