diff options
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 145 |
1 files changed, 128 insertions, 17 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 975e3b4d..4e0f545c 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -56,6 +56,7 @@ import sys import weakref import six +from six.moves import range # We use "as" to avoid name collisions with variables. from google.protobuf.internal import api_implementation @@ -124,6 +125,21 @@ class GeneratedProtocolMessageType(type): Newly-allocated class. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + # If a concrete class already exists for this descriptor, don't try to + # create another. Doing so will break any messages that already exist with + # the existing class. + # + # The C++ implementation appears to have its own internal `PyMessageFactory` + # to achieve similar results. + # + # This most commonly happens in `text_format.py` when using descriptors from + # a custom pool; it calls symbol_database.Global().getPrototype() on a + # descriptor which already has an existing concrete class. + new_class = getattr(descriptor, '_concrete_class', None) + if new_class: + return new_class + if descriptor.full_name in well_known_types.WKTBASES: bases += (well_known_types.WKTBASES[descriptor.full_name],) _AddClassAttributesForNestedExtensions(descriptor, dictionary) @@ -151,6 +167,16 @@ class GeneratedProtocolMessageType(type): type. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + # If this is an _existing_ class looked up via `_concrete_class` in the + # __new__ method above, then we don't need to re-initialize anything. + existing_class = getattr(descriptor, '_concrete_class', None) + if existing_class: + assert existing_class is cls, ( + 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' + % (descriptor.full_name)) + return + cls._decoders_by_tag = {} if (descriptor.has_options and descriptor.GetOptions().message_set_wire_format): @@ -245,6 +271,7 @@ def _AddSlots(message_descriptor, dictionary): '_cached_byte_size_dirty', '_fields', '_unknown_fields', + '_unknown_field_set', '_is_present_in_parent', '_listener', '_listener_for_children', @@ -271,6 +298,13 @@ def _IsMessageMapField(field): return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE +def _IsStrictUtf8Check(field): + if field.containing_type.syntax != 'proto3': + return False + enforce_utf8 = True + return enforce_utf8 + + def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) is_packable = (is_repeated and @@ -322,10 +356,16 @@ def _AttachFieldHelpers(cls, field_descriptor): field_decoder = decoder.MapDecoder( field_descriptor, _GetInitializeDefaultForMap(field_descriptor), is_message_map) + elif decode_type == _FieldDescriptor.TYPE_STRING: + is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor) + field_decoder = decoder.StringDecoder( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor, + is_strict_utf8_check) else: field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor) + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) @@ -422,6 +462,9 @@ def _DefaultValueConstructorForField(field): # _concrete_class may not yet be initialized. message_type = field.message_type def MakeSubMessageDefault(message): + assert getattr(message_type, '_concrete_class', None), ( + 'Uninitialized concrete class found for field %r (message type %r)' + % (field.full_name, message_type.full_name)) result = message_type._concrete_class() result._SetListener( _OneofListener(message, field) @@ -477,6 +520,9 @@ def _AddInitMethod(message_descriptor, cls): # _unknown_fields is () when empty for efficiency, and will be turned into # a list if fields are added. self._unknown_fields = () + # _unknown_field_set is None when empty for efficiency, and will be + # turned into UnknownFieldSet struct if fields are added. + self._unknown_field_set = None # pylint: disable=protected-access self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) @@ -584,6 +630,14 @@ def _AddPropertiesForField(field, cls): _AddPropertiesForNonRepeatedScalarField(field, cls) +class _FieldProperty(property): + __slots__ = ('DESCRIPTOR',) + + def __init__(self, descriptor, getter, setter, doc): + property.__init__(self, getter, setter, doc=doc) + self.DESCRIPTOR = descriptor + + def _AddPropertiesForRepeatedField(field, cls): """Adds a public property for a "repeated" protocol message field. Clients can use this property to get the value of the field, which will be either a @@ -625,7 +679,7 @@ def _AddPropertiesForRepeatedField(field, cls): '"%s" in protocol message object.' % proto_field_name) doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) def _AddPropertiesForNonRepeatedScalarField(field, cls): @@ -681,7 +735,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): # Add a property to encapsulate the getter/setter. doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) def _AddPropertiesForNonRepeatedCompositeField(field, cls): @@ -725,7 +779,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # Add a property to encapsulate the getter. doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) def _AddPropertiesForExtensions(descriptor, cls): @@ -949,12 +1003,12 @@ def _AddEqualsMethod(message_descriptor, cls): if not self.ListFields() == other.ListFields(): return False - # Sort unknown fields because their order shouldn't affect equality test. + # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, + # then use it for the comparison. unknown_fields = list(self._unknown_fields) unknown_fields.sort() other_unknown_fields = list(other._unknown_fields) other_unknown_fields.sort() - return unknown_fields == other_unknown_fields cls.__eq__ = __eq__ @@ -1078,6 +1132,13 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): def _AddMergeFromStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def MergeFromString(self, serialized): + if isinstance(serialized, memoryview) and six.PY2: + raise TypeError( + 'memoryview not supported in Python 2 with the pure Python proto ' + 'implementation: this is to maintain compatibility with the C++ ' + 'implementation') + + serialized = memoryview(serialized) length = len(serialized) try: if self._InternalParse(serialized, 0, length) != length: @@ -1095,26 +1156,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag - is_proto3 = message_descriptor.syntax == "proto3" def InternalParse(self, buffer, pos, end): + """Create a message from serialized bytes. + + Args: + self: Message, instance of the proto message object. + buffer: memoryview of the serialized data. + pos: int, position to start in the serialized data. + end: int, end position of the serialized data. + + Returns: + Message object. + """ + # Guard against internal misuse, since this function is called internally + # quite extensively, and its easy to accidentally pass bytes. + assert isinstance(buffer, memoryview) self._Modified() field_dict = self._fields - unknown_field_list = self._unknown_fields + # pylint: disable=protected-access + unknown_field_set = self._unknown_field_set while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) if field_decoder is None: - value_start_pos = new_pos - new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if not self._unknown_fields: # pylint: disable=protected-access + self._unknown_fields = [] # pylint: disable=protected-access + if unknown_field_set is None: + # pylint: disable=protected-access + self._unknown_field_set = containers.UnknownFieldSet() + # pylint: disable=protected-access + unknown_field_set = self._unknown_field_set + # pylint: disable=protected-access + (tag, _) = decoder._DecodeVarint(tag_bytes, 0) + field_number, wire_type = wire_format.UnpackTag(tag) + # TODO(jieluo): remove old_pos. + old_pos = new_pos + (data, new_pos) = decoder._DecodeUnknownField( + buffer, new_pos, wire_type) # pylint: disable=protected-access if new_pos == -1: return pos - if (not is_proto3 or - api_implementation.GetPythonProto3PreserveUnknownsDefault()): - if not unknown_field_list: - unknown_field_list = self._unknown_fields = [] - unknown_field_list.append( - (tag_bytes, buffer[value_start_pos:new_pos])) + # pylint: disable=protected-access + unknown_field_set._add(field_number, wire_type, data) + # TODO(jieluo): remove _unknown_fields. + new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) + if new_pos == -1: + return pos + self._unknown_fields.append( + (tag_bytes, buffer[old_pos:new_pos].tobytes())) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -1259,6 +1348,10 @@ def _AddMergeFromMethod(cls): if not self._unknown_fields: self._unknown_fields = [] self._unknown_fields.extend(msg._unknown_fields) + # pylint: disable=protected-access + if self._unknown_field_set is None: + self._unknown_field_set = containers.UnknownFieldSet() + self._unknown_field_set._extend(msg._unknown_field_set) cls.MergeFrom = MergeFrom @@ -1291,12 +1384,25 @@ def _Clear(self): # Clear fields. self._fields = {} self._unknown_fields = () + # pylint: disable=protected-access + if self._unknown_field_set is not None: + self._unknown_field_set._clear() + self._unknown_field_set = None + self._oneofs = {} self._Modified() +def _UnknownFields(self): + if self._unknown_field_set is None: # pylint: disable=protected-access + # pylint: disable=protected-access + self._unknown_field_set = containers.UnknownFieldSet() + return self._unknown_field_set # pylint: disable=protected-access + + def _DiscardUnknownFields(self): self._unknown_fields = [] + self._unknown_field_set = None # pylint: disable=protected-access for field, value in self.ListFields(): if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: @@ -1335,6 +1441,7 @@ def _AddMessageMethods(message_descriptor, cls): _AddReduceMethod(cls) # Adds methods which do not depend on cls. cls.Clear = _Clear + cls.UnknownFields = _UnknownFields cls.DiscardUnknownFields = _DiscardUnknownFields cls._SetListener = _SetListener @@ -1471,6 +1578,10 @@ class _ExtensionDict(object): if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: result = extension_handle._default_constructor(self._extended_message) elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + assert getattr(extension_handle.message_type, '_concrete_class', None), ( + 'Uninitialized concrete class found for field %r (message type %r)' + % (extension_handle.full_name, + extension_handle.message_type.full_name)) result = extension_handle.message_type._concrete_class() try: result._SetListener(self._extended_message._listener_for_children) |