aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/internal/python_message.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-xpython/google/protobuf/internal/python_message.py58
1 files changed, 40 insertions, 18 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 4b701039..975e3b4d 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -58,6 +58,7 @@ import weakref
import six
# We use "as" to avoid name collisions with variables.
+from google.protobuf.internal import api_implementation
from google.protobuf.internal import containers
from google.protobuf.internal import decoder
from google.protobuf.internal import encoder
@@ -288,7 +289,8 @@ def _AttachFieldHelpers(cls, field_descriptor):
if is_map_entry:
field_encoder = encoder.MapEncoder(field_descriptor)
- sizer = encoder.MapSizer(field_descriptor)
+ sizer = encoder.MapSizer(field_descriptor,
+ _IsMessageMapField(field_descriptor))
elif _IsMessageSetExtension(field_descriptor):
field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number)
sizer = encoder.MessageSetItemSizer(field_descriptor.number)
@@ -891,7 +893,7 @@ def _AddHasExtensionMethod(cls):
def _InternalUnpackAny(msg):
"""Unpacks Any message and returns the unpacked message.
- This internal method is differnt from public Any Unpack method which takes
+ This internal method is different from public Any Unpack method which takes
the target message as argument. _InternalUnpackAny method does not have
target message type and need to find the message type in descriptor pool.
@@ -1008,11 +1010,16 @@ def _AddByteSizeMethod(message_descriptor, cls):
return self._cached_byte_size
size = 0
- for field_descriptor, field_value in self.ListFields():
- size += field_descriptor._sizer(field_value)
-
- for tag_bytes, value_bytes in self._unknown_fields:
- size += len(tag_bytes) + len(value_bytes)
+ descriptor = self.DESCRIPTOR
+ if descriptor.GetOptions().map_entry:
+ # Fields of map entry should always be serialized.
+ size = descriptor.fields_by_name['key']._sizer(self.key)
+ size += descriptor.fields_by_name['value']._sizer(self.value)
+ else:
+ for field_descriptor, field_value in self.ListFields():
+ size += field_descriptor._sizer(field_value)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ size += len(tag_bytes) + len(value_bytes)
self._cached_byte_size = size
self._cached_byte_size_dirty = False
@@ -1025,32 +1032,46 @@ def _AddByteSizeMethod(message_descriptor, cls):
def _AddSerializeToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def SerializeToString(self):
+ def SerializeToString(self, **kwargs):
# Check if the message has all of its required fields set.
errors = []
if not self.IsInitialized():
raise message_mod.EncodeError(
'Message %s is missing required fields: %s' % (
self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors())))
- return self.SerializePartialToString()
+ return self.SerializePartialToString(**kwargs)
cls.SerializeToString = SerializeToString
def _AddSerializePartialToStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
- def SerializePartialToString(self):
+ def SerializePartialToString(self, **kwargs):
out = BytesIO()
- self._InternalSerialize(out.write)
+ self._InternalSerialize(out.write, **kwargs)
return out.getvalue()
cls.SerializePartialToString = SerializePartialToString
- def InternalSerialize(self, write_bytes):
- for field_descriptor, field_value in self.ListFields():
- field_descriptor._encoder(write_bytes, field_value)
- for tag_bytes, value_bytes in self._unknown_fields:
- write_bytes(tag_bytes)
- write_bytes(value_bytes)
+ def InternalSerialize(self, write_bytes, deterministic=None):
+ if deterministic is None:
+ deterministic = (
+ api_implementation.IsPythonDefaultSerializationDeterministic())
+ else:
+ deterministic = bool(deterministic)
+
+ descriptor = self.DESCRIPTOR
+ if descriptor.GetOptions().map_entry:
+ # Fields of map entry should always be serialized.
+ descriptor.fields_by_name['key']._encoder(
+ write_bytes, self.key, deterministic)
+ descriptor.fields_by_name['value']._encoder(
+ write_bytes, self.value, deterministic)
+ else:
+ for field_descriptor, field_value in self.ListFields():
+ field_descriptor._encoder(write_bytes, field_value, deterministic)
+ for tag_bytes, value_bytes in self._unknown_fields:
+ write_bytes(tag_bytes)
+ write_bytes(value_bytes)
cls._InternalSerialize = InternalSerialize
@@ -1088,7 +1109,8 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
if new_pos == -1:
return pos
- if not is_proto3:
+ if (not is_proto3 or
+ api_implementation.GetPythonProto3PreserveUnknownsDefault()):
if not unknown_field_list:
unknown_field_list = self._unknown_fields = []
unknown_field_list.append(