diff options
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 66 |
1 files changed, 59 insertions, 7 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 66fca918..4bea57ac 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -54,6 +54,7 @@ try: from cStringIO import StringIO except ImportError: from StringIO import StringIO +import copy_reg import struct import weakref @@ -61,6 +62,7 @@ import weakref from google.protobuf.internal import containers from google.protobuf.internal import decoder from google.protobuf.internal import encoder +from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import message_listener as message_listener_mod from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format @@ -71,9 +73,10 @@ from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor -def NewMessage(descriptor, dictionary): +def NewMessage(bases, descriptor, dictionary): _AddClassAttributesForNestedExtensions(descriptor, dictionary) _AddSlots(descriptor, dictionary) + return bases def InitMessage(descriptor, cls): @@ -96,6 +99,7 @@ def InitMessage(descriptor, cls): _AddStaticMethods(cls) _AddMessageMethods(descriptor, cls) _AddPrivateHelperMethods(cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -145,6 +149,10 @@ def _VerifyExtensionHandle(message, extension_handle): if not extension_handle.is_extension: raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + if not extension_handle.containing_type: + raise KeyError('"%s" is missing a containing_type.' + % extension_handle.full_name) + if extension_handle.containing_type is not message.DESCRIPTOR: raise KeyError('Extension "%s" extends message type "%s", but this ' 'message is of type "%s".' % @@ -164,6 +172,7 @@ def _AddSlots(message_descriptor, dictionary): dictionary['__slots__'] = ['_cached_byte_size', '_cached_byte_size_dirty', '_fields', + '_unknown_fields', '_is_present_in_parent', '_listener', '_listener_for_children', @@ -224,11 +233,14 @@ def _AddClassAttributesForNestedExtensions(descriptor, dictionary): def _AddEnumValues(descriptor, cls): """Sets class-level attributes for all enum fields defined in this message. + Also exporting a class-level object that can name enum values. + Args: descriptor: Descriptor object for this message type. cls: Class we're constructing for this message type. """ for enum_type in descriptor.enum_types: + setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) for enum_value in enum_type.values: setattr(cls, enum_value.name, enum_value.number) @@ -248,7 +260,7 @@ def _DefaultValueConstructorForField(field): """ if field.label == _FieldDescriptor.LABEL_REPEATED: - if field.default_value != []: + if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( field.default_value)) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: @@ -276,6 +288,8 @@ def _DefaultValueConstructorForField(field): return MakeSubMessageDefault def MakeScalarDefault(message): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. return field.default_value return MakeScalarDefault @@ -287,6 +301,9 @@ def _AddInitMethod(message_descriptor, cls): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} + # _unknown_fields is () when empty for efficiency, and will be turned into + # a list if fields are added. + self._unknown_fields = () self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) @@ -428,6 +445,8 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): valid_values = set() def getter(self): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -462,13 +481,18 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # for non-repeated scalars. proto_field_name = field.name property_name = _PropertyName(proto_field_name) + + # TODO(komarek): Can anyone explain to me why we cache the message_type this + # way, instead of referring to field.message_type inside of getter(self)? + # What if someone sets message_type later on (which makes for simpler + # dyanmic proto descriptor and class creation code). message_type = field.message_type def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = message_type._concrete_class() + field_value = message_type._concrete_class() # use field.message_type? field_value._SetListener(self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap @@ -620,6 +644,7 @@ def _AddClearMethod(message_descriptor, cls): def Clear(self): # Clear fields. self._fields = {} + self._unknown_fields = () self._Modified() cls.Clear = Clear @@ -649,7 +674,16 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - return self.ListFields() == other.ListFields() + if not self.ListFields() == other.ListFields(): + return False + + # Sort unknown fields because their order shouldn't affect equality test. + unknown_fields = list(self._unknown_fields) + unknown_fields.sort() + other_unknown_fields = list(other._unknown_fields) + other_unknown_fields.sort() + + return unknown_fields == other_unknown_fields cls.__eq__ = __eq__ @@ -710,6 +744,9 @@ def _AddByteSizeMethod(message_descriptor, cls): for field_descriptor, field_value in self.ListFields(): size += field_descriptor._sizer(field_value) + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) + self._cached_byte_size = size self._cached_byte_size_dirty = False self._listener_for_children.dirty = False @@ -726,8 +763,8 @@ def _AddSerializeToStringMethod(message_descriptor, cls): errors = [] if not self.IsInitialized(): raise message_mod.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) + 'Message %s is missing required fields: %s' % ( + self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) return self.SerializePartialToString() cls.SerializeToString = SerializeToString @@ -744,6 +781,9 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): def InternalSerialize(self, write_bytes): for field_descriptor, field_value in self.ListFields(): field_descriptor._encoder(write_bytes, field_value) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) cls._InternalSerialize = InternalSerialize @@ -770,13 +810,18 @@ def _AddMergeFromStringMethod(message_descriptor, cls): def InternalParse(self, buffer, pos, end): self._Modified() field_dict = self._fields + unknown_field_list = self._unknown_fields while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) field_decoder = decoders_by_tag.get(tag_bytes) if field_decoder is None: + value_start_pos = new_pos new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -873,7 +918,8 @@ def _AddMergeFromMethod(cls): def MergeFrom(self, msg): if not isinstance(msg, cls): raise TypeError( - "Parameter to MergeFrom() must be instance of same class.") + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) assert msg is not self self._Modified() @@ -898,6 +944,12 @@ def _AddMergeFromMethod(cls): field_value.MergeFrom(value) else: self._fields[field] = value + + if msg._unknown_fields: + if not self._unknown_fields: + self._unknown_fields = [] + self._unknown_fields.extend(msg._unknown_fields) + cls.MergeFrom = MergeFrom |