diff options
Diffstat (limited to 'python/google/protobuf/internal/python_message.py')
-rwxr-xr-x | python/google/protobuf/internal/python_message.py | 235 |
1 files changed, 200 insertions, 35 deletions
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 58c65db9..bb06beb3 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -59,6 +59,7 @@ import weakref import six import six.moves.copyreg as copyreg +import six.string_types # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers @@ -70,6 +71,7 @@ from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod +from google.protobuf import symbol_database from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor @@ -94,6 +96,7 @@ def InitMessage(descriptor, cls): for field in descriptor.fields: _AttachFieldHelpers(cls, field) + descriptor._concrete_class = cls # pylint: disable=protected-access _AddEnumValues(descriptor, cls) _AddInitMethod(descriptor, cls) _AddPropertiesForFields(descriptor, cls) @@ -191,12 +194,37 @@ def _IsMessageSetExtension(field): field.label == _FieldDescriptor.LABEL_OPTIONAL) +def _IsMapField(field): + return (field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _IsMessageMapField(field): + value_type = field.message_type.fields_by_name["value"] + return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + + def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) - is_packed = (field_descriptor.has_options and - field_descriptor.GetOptions().packed) - - if _IsMessageSetExtension(field_descriptor): + is_packable = (is_repeated and + wire_format.IsTypePackable(field_descriptor.type)) + if not is_packable: + is_packed = False + elif field_descriptor.containing_type.syntax == "proto2": + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + else: + has_packed_false = (field_descriptor.has_options and + field_descriptor.GetOptions().HasField("packed") and + field_descriptor.GetOptions().packed == False) + is_packed = not has_packed_false + is_map_entry = _IsMapField(field_descriptor) + + if is_map_entry: + field_encoder = encoder.MapEncoder(field_descriptor) + sizer = encoder.MapSizer(field_descriptor) + elif _IsMessageSetExtension(field_descriptor): field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) sizer = encoder.MessageSetItemSizer(field_descriptor.number) else: @@ -212,12 +240,27 @@ def _AttachFieldHelpers(cls, field_descriptor): def AddDecoder(wiretype, is_packed): tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) - cls._decoders_by_tag[tag_bytes] = ( - type_checkers.TYPE_TO_DECODER[field_descriptor.type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor), - field_descriptor if field_descriptor.containing_oneof is not None - else None) + decode_type = field_descriptor.type + if (decode_type == _FieldDescriptor.TYPE_ENUM and + type_checkers.SupportsOpenEnums(field_descriptor)): + decode_type = _FieldDescriptor.TYPE_INT32 + + oneof_descriptor = None + if field_descriptor.containing_oneof is not None: + oneof_descriptor = field_descriptor + + if is_map_entry: + is_message_map = _IsMessageMapField(field_descriptor) + + field_decoder = decoder.MapDecoder( + field_descriptor, _GetInitializeDefaultForMap(field_descriptor), + is_message_map) + else: + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) + + cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False) @@ -250,6 +293,26 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) +def _GetInitializeDefaultForMap(field): + if field.label != _FieldDescriptor.LABEL_REPEATED: + raise ValueError('map_entry set on non-repeated field %s' % ( + field.name)) + fields_by_name = field.message_type.fields_by_name + key_checker = type_checkers.GetTypeChecker(fields_by_name['key']) + + value_field = fields_by_name['value'] + if _IsMessageMapField(field): + def MakeMessageMapDefault(message): + return containers.MessageMap( + message._listener_for_children, value_field.message_type, key_checker) + return MakeMessageMapDefault + else: + value_checker = type_checkers.GetTypeChecker(value_field) + def MakePrimitiveMapDefault(message): + return containers.ScalarMap( + message._listener_for_children, key_checker, value_checker) + return MakePrimitiveMapDefault + def _DefaultValueConstructorForField(field): """Returns a function which returns a default value for a field. @@ -264,6 +327,9 @@ def _DefaultValueConstructorForField(field): value may refer back to |message| via a weak reference. """ + if _IsMapField(field): + return _GetInitializeDefaultForMap(field) + if field.label == _FieldDescriptor.LABEL_REPEATED: if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( @@ -289,6 +355,8 @@ def _DefaultValueConstructorForField(field): def MakeSubMessageDefault(message): result = message_type._concrete_class() result._SetListener(message._listener_for_children) + if field.containing_oneof: + message._UpdateOneofState(field) return result return MakeSubMessageDefault @@ -312,7 +380,22 @@ def _ReraiseTypeErrorWithFieldName(message_name, field_name): def _AddInitMethod(message_descriptor, cls): """Adds an __init__ method to cls.""" - fields = message_descriptor.fields + + def _GetIntegerEnumValue(enum_type, value): + """Convert a string or integer enum value to an integer. + + If the value is a string, it is converted to the enum value in + enum_type with the same name. If the value is not a string, it's + returned as-is. (No conversion or bounds-checking is done.) + """ + if isinstance(value, six.string_types): + try: + return enum_type.values_by_name[value].number + except KeyError: + raise ValueError('Enum type %s: unknown label "%s"' % ( + enum_type.full_name, value)) + return value + def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 @@ -335,19 +418,37 @@ def _AddInitMethod(message_descriptor, cls): if field.label == _FieldDescriptor.LABEL_REPEATED: copy = field._default_constructor(self) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite - for val in field_value: - copy.add().MergeFrom(val) + if _IsMapField(field): + if _IsMessageMapField(field): + for key in field_value: + copy[key].MergeFrom(field_value[key]) + else: + copy.update(field_value) + else: + for val in field_value: + if isinstance(val, dict): + copy.add(**val) + else: + copy.add().MergeFrom(val) else: # Scalar + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = [_GetIntegerEnumValue(field.enum_type, val) + for val in field_value] copy.extend(field_value) self._fields[field] = copy elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: copy = field._default_constructor(self) + new_val = field_value + if isinstance(field_value, dict): + new_val = field.message_type._concrete_class(**field_value) try: - copy.MergeFrom(field_value) + copy.MergeFrom(new_val) except TypeError: _ReraiseTypeErrorWithFieldName(message_descriptor.name, field_name) self._fields[field] = copy else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + field_value = _GetIntegerEnumValue(field.enum_type, field_value) try: setattr(self, field_name, field_value) except TypeError: @@ -469,6 +570,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): type_checker = type_checkers.GetTypeChecker(field) default_value = field.default_value valid_values = set() + is_proto3 = field.containing_type.syntax == "proto3" def getter(self): # TODO(protobuf-team): This may be broken since there may not be @@ -476,15 +578,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name + + clear_when_set_to_default = is_proto3 and not field.containing_oneof + def field_setter(self, new_value): # pylint: disable=protected-access - self._fields[field] = type_checker.CheckValue(new_value) + # Testing the value for truthiness captures all of the proto3 defaults + # (0, 0.0, enum 0, and False). + new_value = type_checker.CheckValue(new_value) + if clear_when_set_to_default and not new_value: + self._fields.pop(field, None) + else: + self._fields[field] = new_value # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() - if field.containing_oneof is not None: + if field.containing_oneof: def setter(self, new_value): field_setter(self, new_value) self._UpdateOneofState(field) @@ -617,24 +728,35 @@ def _AddListFieldsMethod(message_descriptor, cls): cls.ListFields = ListFields +_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' +_Proto2HasError = 'Protocol message has no non-repeated field "%s"' def _AddHasFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - singular_fields = {} + is_proto3 = (message_descriptor.syntax == "proto3") + error_msg = _Proto3HasError if is_proto3 else _Proto2HasError + + hassable_fields = {} for field in message_descriptor.fields: - if field.label != _FieldDescriptor.LABEL_REPEATED: - singular_fields[field.name] = field - # Fields inside oneofs are never repeated (enforced by the compiler). - for field in message_descriptor.oneofs: - singular_fields[field.name] = field + if field.label == _FieldDescriptor.LABEL_REPEATED: + continue + # For proto3, only submessages and fields inside a oneof have presence. + if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and + not field.containing_oneof): + continue + hassable_fields[field.name] = field + + if not is_proto3: + # Fields inside oneofs are never repeated (enforced by the compiler). + for oneof in message_descriptor.oneofs: + hassable_fields[oneof.name] = oneof def HasField(self, field_name): try: - field = singular_fields[field_name] + field = hassable_fields[field_name] except KeyError: - raise ValueError( - 'Protocol message has no singular "%s" field.' % field_name) + raise ValueError(error_msg % field_name) if isinstance(field, descriptor_mod.OneofDescriptor): try: @@ -720,6 +842,26 @@ def _AddHasExtensionMethod(cls): return extension_handle in self._fields cls.HasExtension = HasExtension +def _UnpackAny(msg): + type_url = msg.type_url + db = symbol_database.Default() + + if not type_url: + return None + + # TODO(haberman): For now we just strip the hostname. Better logic will be + # required. + type_name = type_url.split("/")[-1] + descriptor = db.pool.FindMessageTypeByName(type_name) + + if descriptor is None: + return None + + message_class = db.GetPrototype(descriptor) + message = message_class() + + message.ParseFromString(msg.value) + return message def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -731,6 +873,12 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True + if self.DESCRIPTOR.full_name == "google.protobuf.Any": + any_a = _UnpackAny(self) + any_b = _UnpackAny(other) + if any_a and any_b: + return any_a == any_b + if not self.ListFields() == other.ListFields(): return False @@ -864,6 +1012,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag + is_proto3 = message_descriptor.syntax == "proto3" def InternalParse(self, buffer, pos, end): self._Modified() @@ -877,9 +1026,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls): new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos - if not unknown_field_list: - unknown_field_list = self._unknown_fields = [] - unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) + if not is_proto3: + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append( + (tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -920,6 +1071,9 @@ def _AddIsInitializedMethod(message_descriptor, cls): for field, value in list(self._fields.items()): # dict can change size! if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: + if (field.message_type.has_options and + field.message_type.GetOptions().map_entry): + continue for element in value: if not element.IsInitialized(): if errors is not None: @@ -955,16 +1109,26 @@ def _AddIsInitializedMethod(message_descriptor, cls): else: name = field.name - if field.label == _FieldDescriptor.LABEL_REPEATED: - for i in range(len(value)): + if _IsMapField(field): + if _IsMessageMapField(field): + for key in value: + element = value[key] + prefix = "%s[%d]." % (name, key) + sub_errors = element.FindInitializationErrors() + errors += [prefix + error for error in sub_errors] + else: + # ScalarMaps can't have any initialization errors. + pass + elif field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): element = value[i] prefix = "%s[%d]." % (name, i) sub_errors = element.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] else: prefix = name + "." sub_errors = value.FindInitializationErrors() - errors += [ prefix + error for error in sub_errors ] + errors += [prefix + error for error in sub_errors] return errors @@ -1001,6 +1165,8 @@ def _AddMergeFromMethod(cls): # Construct a new object to represent this field. field_value = field._default_constructor(self) fields[field] = field_value + if field.containing_oneof: + self._UpdateOneofState(field) field_value.MergeFrom(value) else: self._fields[field] = value @@ -1245,11 +1411,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker( - extension_handle) + type_checker = type_checkers.GetTypeChecker(extension_handle) # pylint: disable=protected-access self._extended_message._fields[extension_handle] = ( - type_checker.CheckValue(value)) + type_checker.CheckValue(value)) self._extended_message._Modified() def _FindExtensionByName(self, name): |