diff options
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 58 |
1 files changed, 40 insertions, 18 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 4b701039..975e3b4d 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -58,6 +58,7 @@ import weakref import six # We use "as" to avoid name collisions with variables. +from google.protobuf.internal import api_implementation from google.protobuf.internal import containers from google.protobuf.internal import decoder from google.protobuf.internal import encoder @@ -288,7 +289,8 @@ def _AttachFieldHelpers(cls, field_descriptor): if is_map_entry: field_encoder = encoder.MapEncoder(field_descriptor) - sizer = encoder.MapSizer(field_descriptor) + sizer = encoder.MapSizer(field_descriptor, + _IsMessageMapField(field_descriptor)) elif _IsMessageSetExtension(field_descriptor): field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) sizer = encoder.MessageSetItemSizer(field_descriptor.number) @@ -891,7 +893,7 @@ def _AddHasExtensionMethod(cls): def _InternalUnpackAny(msg): """Unpacks Any message and returns the unpacked message. - This internal method is differnt from public Any Unpack method which takes + This internal method is different from public Any Unpack method which takes the target message as argument. _InternalUnpackAny method does not have target message type and need to find the message type in descriptor pool. @@ -1008,11 +1010,16 @@ def _AddByteSizeMethod(message_descriptor, cls): return self._cached_byte_size size = 0 - for field_descriptor, field_value in self.ListFields(): - size += field_descriptor._sizer(field_value) - - for tag_bytes, value_bytes in self._unknown_fields: - size += len(tag_bytes) + len(value_bytes) + descriptor = self.DESCRIPTOR + if descriptor.GetOptions().map_entry: + # Fields of map entry should always be serialized. + size = descriptor.fields_by_name['key']._sizer(self.key) + size += descriptor.fields_by_name['value']._sizer(self.value) + else: + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) self._cached_byte_size = size self._cached_byte_size_dirty = False @@ -1025,32 +1032,46 @@ def _AddByteSizeMethod(message_descriptor, cls): def _AddSerializeToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def SerializeToString(self): + def SerializeToString(self, **kwargs): # Check if the message has all of its required fields set. errors = [] if not self.IsInitialized(): raise message_mod.EncodeError( 'Message %s is missing required fields: %s' % ( self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) - return self.SerializePartialToString() + return self.SerializePartialToString(**kwargs) cls.SerializeToString = SerializeToString def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def SerializePartialToString(self): + def SerializePartialToString(self, **kwargs): out = BytesIO() - self._InternalSerialize(out.write) + self._InternalSerialize(out.write, **kwargs) return out.getvalue() cls.SerializePartialToString = SerializePartialToString - def InternalSerialize(self, write_bytes): - for field_descriptor, field_value in self.ListFields(): - field_descriptor._encoder(write_bytes, field_value) - for tag_bytes, value_bytes in self._unknown_fields: - write_bytes(tag_bytes) - write_bytes(value_bytes) + def InternalSerialize(self, write_bytes, deterministic=None): + if deterministic is None: + deterministic = ( + api_implementation.IsPythonDefaultSerializationDeterministic()) + else: + deterministic = bool(deterministic) + + descriptor = self.DESCRIPTOR + if descriptor.GetOptions().map_entry: + # Fields of map entry should always be serialized. + descriptor.fields_by_name['key']._encoder( + write_bytes, self.key, deterministic) + descriptor.fields_by_name['value']._encoder( + write_bytes, self.value, deterministic) + else: + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value, deterministic) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) cls._InternalSerialize = InternalSerialize @@ -1088,7 +1109,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls): new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos - if not is_proto3: + if (not is_proto3 or + api_implementation.GetPythonProto3PreserveUnknownsDefault()): if not unknown_field_list: unknown_field_list = self._unknown_fields = [] unknown_field_list.append( |