aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/internal/python_message.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-xpython/google/protobuf/internal/python_message.py145
1 files changed, 128 insertions, 17 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 975e3b4d..4e0f545c 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -56,6 +56,7 @@ import sys
import weakref
import six
+from six.moves import range
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import api_implementation
@@ -124,6 +125,21 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ # If a concrete class already exists for this descriptor, don't try to
+ # create another. Doing so will break any messages that already exist with
+ # the existing class.
+ #
+ # The C++ implementation appears to have its own internal `PyMessageFactory`
+ # to achieve similar results.
+ #
+ # This most commonly happens in `text_format.py` when using descriptors from
+ # a custom pool; it calls symbol_database.Global().getPrototype() on a
+ # descriptor which already has an existing concrete class.
+ new_class = getattr(descriptor, '_concrete_class', None)
+ if new_class:
+ return new_class
+
if descriptor.full_name in well_known_types.WKTBASES:
bases += (well_known_types.WKTBASES[descriptor.full_name],)
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
@@ -151,6 +167,16 @@ class GeneratedProtocolMessageType(type):
type.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ # If this is an _existing_ class looked up via `_concrete_class` in the
+ # __new__ method above, then we don't need to re-initialize anything.
+ existing_class = getattr(descriptor, '_concrete_class', None)
+ if existing_class:
+ assert existing_class is cls, (
+ 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
+ % (descriptor.full_name))
+ return
+
cls._decoders_by_tag = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
@@ -245,6 +271,7 @@ def _AddSlots(message_descriptor, dictionary):
'_cached_byte_size_dirty',
'_fields',
'_unknown_fields',
+ '_unknown_field_set',
'_is_present_in_parent',
'_listener',
'_listener_for_children',
@@ -271,6 +298,13 @@ def _IsMessageMapField(field):
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+def _IsStrictUtf8Check(field):
+ if field.containing_type.syntax != 'proto3':
+ return False
+ enforce_utf8 = True
+ return enforce_utf8
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
is_packable = (is_repeated and
@@ -322,10 +356,16 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_decoder = decoder.MapDecoder(
field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
is_message_map)
+ elif decode_type == _FieldDescriptor.TYPE_STRING:
+ is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
+ field_decoder = decoder.StringDecoder(
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor,
+ is_strict_utf8_check)
else:
field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor)
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
@@ -422,6 +462,9 @@ def _DefaultValueConstructorForField(field):
# _concrete_class may not yet be initialized.
message_type = field.message_type
def MakeSubMessageDefault(message):
+ assert getattr(message_type, '_concrete_class', None), (
+ 'Uninitialized concrete class found for field %r (message type %r)'
+ % (field.full_name, message_type.full_name))
result = message_type._concrete_class()
result._SetListener(
_OneofListener(message, field)
@@ -477,6 +520,9 @@ def _AddInitMethod(message_descriptor, cls):
# _unknown_fields is () when empty for efficiency, and will be turned into
# a list if fields are added.
self._unknown_fields = ()
+ # _unknown_field_set is None when empty for efficiency, and will be
+ # turned into UnknownFieldSet struct if fields are added.
+ self._unknown_field_set = None # pylint: disable=protected-access
self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self)
@@ -584,6 +630,14 @@ def _AddPropertiesForField(field, cls):
_AddPropertiesForNonRepeatedScalarField(field, cls)
+class _FieldProperty(property):
+ __slots__ = ('DESCRIPTOR',)
+
+ def __init__(self, descriptor, getter, setter, doc):
+ property.__init__(self, getter, setter, doc=doc)
+ self.DESCRIPTOR = descriptor
+
+
def _AddPropertiesForRepeatedField(field, cls):
"""Adds a public property for a "repeated" protocol message field. Clients
can use this property to get the value of the field, which will be either a
@@ -625,7 +679,7 @@ def _AddPropertiesForRepeatedField(field, cls):
'"%s" in protocol message object.' % proto_field_name)
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedScalarField(field, cls):
@@ -681,7 +735,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
# Add a property to encapsulate the getter/setter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
@@ -725,7 +779,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
# Add a property to encapsulate the getter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForExtensions(descriptor, cls):
@@ -949,12 +1003,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if not self.ListFields() == other.ListFields():
return False
- # Sort unknown fields because their order shouldn't affect equality test.
+ # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
+ # then use it for the comparison.
unknown_fields = list(self._unknown_fields)
unknown_fields.sort()
other_unknown_fields = list(other._unknown_fields)
other_unknown_fields.sort()
-
return unknown_fields == other_unknown_fields
cls.__eq__ = __eq__
@@ -1078,6 +1132,13 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
def _AddMergeFromStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def MergeFromString(self, serialized):
+ if isinstance(serialized, memoryview) and six.PY2:
+ raise TypeError(
+ 'memoryview not supported in Python 2 with the pure Python proto '
+ 'implementation: this is to maintain compatibility with the C++ '
+ 'implementation')
+
+ serialized = memoryview(serialized)
length = len(serialized)
try:
if self._InternalParse(serialized, 0, length) != length:
@@ -1095,26 +1156,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
- is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
+ """Create a message from serialized bytes.
+
+ Args:
+ self: Message, instance of the proto message object.
+ buffer: memoryview of the serialized data.
+ pos: int, position to start in the serialized data.
+ end: int, end position of the serialized data.
+
+ Returns:
+ Message object.
+ """
+ # Guard against internal misuse, since this function is called internally
+ # quite extensively, and its easy to accidentally pass bytes.
+ assert isinstance(buffer, memoryview)
self._Modified()
field_dict = self._fields
- unknown_field_list = self._unknown_fields
+ # pylint: disable=protected-access
+ unknown_field_set = self._unknown_field_set
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
if field_decoder is None:
- value_start_pos = new_pos
- new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
+ if not self._unknown_fields: # pylint: disable=protected-access
+ self._unknown_fields = [] # pylint: disable=protected-access
+ if unknown_field_set is None:
+ # pylint: disable=protected-access
+ self._unknown_field_set = containers.UnknownFieldSet()
+ # pylint: disable=protected-access
+ unknown_field_set = self._unknown_field_set
+ # pylint: disable=protected-access
+ (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
+ field_number, wire_type = wire_format.UnpackTag(tag)
+ # TODO(jieluo): remove old_pos.
+ old_pos = new_pos
+ (data, new_pos) = decoder._DecodeUnknownField(
+ buffer, new_pos, wire_type) # pylint: disable=protected-access
if new_pos == -1:
return pos
- if (not is_proto3 or
- api_implementation.GetPythonProto3PreserveUnknownsDefault()):
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append(
- (tag_bytes, buffer[value_start_pos:new_pos]))
+ # pylint: disable=protected-access
+ unknown_field_set._add(field_number, wire_type, data)
+ # TODO(jieluo): remove _unknown_fields.
+ new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
+ if new_pos == -1:
+ return pos
+ self._unknown_fields.append(
+ (tag_bytes, buffer[old_pos:new_pos].tobytes()))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -1259,6 +1348,10 @@ def _AddMergeFromMethod(cls):
if not self._unknown_fields:
self._unknown_fields = []
self._unknown_fields.extend(msg._unknown_fields)
+ # pylint: disable=protected-access
+ if self._unknown_field_set is None:
+ self._unknown_field_set = containers.UnknownFieldSet()
+ self._unknown_field_set._extend(msg._unknown_field_set)
cls.MergeFrom = MergeFrom
@@ -1291,12 +1384,25 @@ def _Clear(self):
# Clear fields.
self._fields = {}
self._unknown_fields = ()
+ # pylint: disable=protected-access
+ if self._unknown_field_set is not None:
+ self._unknown_field_set._clear()
+ self._unknown_field_set = None
+
self._oneofs = {}
self._Modified()
+def _UnknownFields(self):
+ if self._unknown_field_set is None: # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ self._unknown_field_set = containers.UnknownFieldSet()
+ return self._unknown_field_set # pylint: disable=protected-access
+
+
def _DiscardUnknownFields(self):
self._unknown_fields = []
+ self._unknown_field_set = None # pylint: disable=protected-access
for field, value in self.ListFields():
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
@@ -1335,6 +1441,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddReduceMethod(cls)
# Adds methods which do not depend on cls.
cls.Clear = _Clear
+ cls.UnknownFields = _UnknownFields
cls.DiscardUnknownFields = _DiscardUnknownFields
cls._SetListener = _SetListener
@@ -1471,6 +1578,10 @@ class _ExtensionDict(object):
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
result = extension_handle._default_constructor(self._extended_message)
elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ assert getattr(extension_handle.message_type, '_concrete_class', None), (
+ 'Uninitialized concrete class found for field %r (message type %r)'
+ % (extension_handle.full_name,
+ extension_handle.message_type.full_name))
result = extension_handle.message_type._concrete_class()
try:
result._SetListener(self._extended_message._listener_for_children)