diff options
Diffstat (limited to 'python/google/protobuf/internal/cpp_message.py')
-rwxr-xr-x | python/google/protobuf/internal/cpp_message.py | 117 |
1 files changed, 82 insertions, 35 deletions
diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py index 3f426502..23ab9ba4 100755 --- a/python/google/protobuf/internal/cpp_message.py +++ b/python/google/protobuf/internal/cpp_message.py @@ -34,8 +34,10 @@ Descriptor objects at runtime backed by the protocol buffer C++ API. __author__ = 'petar@google.com (Petar Petrov)' +import copy_reg import operator from google.protobuf.internal import _net_proto2___python +from google.protobuf.internal import enum_type_wrapper from google.protobuf import message @@ -156,10 +158,12 @@ class RepeatedScalarContainer(object): def __hash__(self): raise TypeError('unhashable object') - def sort(self, sort_function=cmp): - values = self[slice(None, None, None)] - values.sort(sort_function) - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + def sort(self, *args, **kwargs): + # Maintain compatibility with the previous interface. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, + sorted(self, *args, **kwargs)) def RepeatedScalarProperty(cdescriptor): @@ -202,6 +206,12 @@ class RepeatedCompositeContainer(object): for message in elem_seq: self.add().MergeFrom(message) + def remove(self, value): + # TODO(protocol-devel): This is inefficient as it needs to generate a + # message pointer for each message only to do index(). Move this to a C++ + # extension function. + self.__delitem__(self[slice(None, None, None)].index(value)) + def MergeFrom(self, other): for message in other[:]: self.add().MergeFrom(message) @@ -236,27 +246,29 @@ class RepeatedCompositeContainer(object): def __hash__(self): raise TypeError('unhashable object') - def sort(self, sort_function=cmp): - messages = [] - for index in range(len(self)): - # messages[i][0] is where the i-th element of the new array has to come - # from. - # messages[i][1] is where the i-th element of the old array has to go. - messages.append([index, 0, self[index]]) - messages.sort(lambda x,y: sort_function(x[2], y[2])) + def sort(self, cmp=None, key=None, reverse=False, **kwargs): + # Maintain compatibility with the old interface. + if cmp is None and 'sort_function' in kwargs: + cmp = kwargs.pop('sort_function') - # Remember which position each elements has to move to. - for i in range(len(messages)): - messages[messages[i][0]][1] = i + # The cmp function, if provided, is passed the results of the key function, + # so we only need to wrap one of them. + if key is None: + index_key = self.__getitem__ + else: + index_key = lambda i: key(self[i]) + + # Sort the list of current indexes by the underlying object. + indexes = range(len(self)) + indexes.sort(cmp=cmp, key=index_key, reverse=reverse) # Apply the transposition. - for i in range(len(messages)): - from_position = messages[i][0] - if i == from_position: + for dest, src in enumerate(indexes): + if dest == src: continue - self._cmsg.SwapRepeatedFieldElements( - self._cfield_descriptor, i, from_position) - messages[messages[i][1]][0] = from_position + self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) + # Don't swap the same value twice. + indexes[src] = src def RepeatedCompositeProperty(cdescriptor, message_type): @@ -359,11 +371,12 @@ class ExtensionDict(object): return None -def NewMessage(message_descriptor, dictionary): +def NewMessage(bases, message_descriptor, dictionary): """Creates a new protocol message *class*.""" _AddClassAttributesForNestedExtensions(message_descriptor, dictionary) _AddEnumValues(message_descriptor, dictionary) _AddDescriptors(message_descriptor, dictionary) + return bases def InitMessage(message_descriptor, cls): @@ -372,6 +385,7 @@ def InitMessage(message_descriptor, cls): _AddInitMethod(message_descriptor, cls) _AddMessageMethods(message_descriptor, cls) _AddPropertiesForExtensions(message_descriptor, cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) def _AddDescriptors(message_descriptor, dictionary): @@ -387,7 +401,7 @@ def _AddDescriptors(message_descriptor, dictionary): field.full_name) dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ - '_cmsg', '_owner', '_composite_fields', 'Extensions'] + '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] def _AddEnumValues(message_descriptor, dictionary): @@ -398,6 +412,7 @@ def _AddEnumValues(message_descriptor, dictionary): dictionary: Class dictionary that should be populated. """ for enum_type in message_descriptor.enum_types: + dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) for enum_value in enum_type.values: dictionary[enum_value.name] = enum_value.number @@ -439,28 +454,35 @@ def _AddInitMethod(message_descriptor, cls): def Init(self, **kwargs): """Message constructor.""" cmessage = kwargs.pop('__cmessage', None) - if cmessage is None: - self._cmsg = NewCMessage(message_descriptor.full_name) - else: + if cmessage: self._cmsg = cmessage + else: + self._cmsg = NewCMessage(message_descriptor.full_name) # Keep a reference to the owner, as the owner keeps a reference to the # underlying protocol buffer message. owner = kwargs.pop('__owner', None) - if owner is not None: + if owner: self._owner = owner - self.Extensions = ExtensionDict(self) + if message_descriptor.is_extendable: + self.Extensions = ExtensionDict(self) + else: + # Reference counting in the C++ code is broken and depends on + # the Extensions reference to keep this object alive during unit + # tests (see b/4856052). Remove this once b/4945904 is fixed. + self._HACK_REFCOUNTS = self self._composite_fields = {} for field_name, field_value in kwargs.iteritems(): field_cdescriptor = self.__descriptors.get(field_name, None) - if field_cdescriptor is None: + if not field_cdescriptor: raise ValueError('Protocol message has no "%s" field.' % field_name) if field_cdescriptor.label == _LABEL_REPEATED: if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + field_name = getattr(self, field_name) for val in field_value: - getattr(self, field_name).add().MergeFrom(val) + field_name.add().MergeFrom(val) else: getattr(self, field_name).extend(field_value) elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: @@ -497,12 +519,34 @@ def _AddMessageMethods(message_descriptor, cls): return self._cmsg.HasField(field_name) def ClearField(self, field_name): + child_cmessage = None if field_name in self._composite_fields: + child_field = self._composite_fields[field_name] del self._composite_fields[field_name] - self._cmsg.ClearField(field_name) + + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + child_cmessage = child_field._cmsg + + if child_cmessage is not None: + self._cmsg.ClearField(field_name, child_cmessage) + else: + self._cmsg.ClearField(field_name) def Clear(self): - return self._cmsg.Clear() + cmessages_to_release = [] + for field_name, child_field in self._composite_fields.iteritems(): + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) + self._composite_fields.clear() + self._cmsg.Clear(cmessages_to_release) def IsInitialized(self, errors=None): if self._cmsg.IsInitialized(): @@ -514,8 +558,8 @@ def _AddMessageMethods(message_descriptor, cls): def SerializeToString(self): if not self.IsInitialized(): raise message.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) + 'Message %s is missing required fields: %s' % ( + self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) return self._cmsg.SerializeToString() def SerializePartialToString(self): @@ -534,7 +578,8 @@ def _AddMessageMethods(message_descriptor, cls): def MergeFrom(self, msg): if not isinstance(msg, cls): raise TypeError( - "Parameter to MergeFrom() must be instance of same class.") + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) self._cmsg.MergeFrom(msg._cmsg) def CopyFrom(self, msg): @@ -581,6 +626,8 @@ def _AddMessageMethods(message_descriptor, cls): raise TypeError('unhashable object') def __unicode__(self): + # Lazy import to prevent circular import when text_format imports this file. + from google.protobuf import text_format return text_format.MessageToString(self, as_utf8=True).decode('utf-8') # Attach the local methods to the message class. |