diff options
Diffstat (limited to 'python/google/protobuf/reflection.py')
-rwxr-xr-x | python/google/protobuf/reflection.py | 66 |
1 files changed, 56 insertions, 10 deletions
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 5ab7a1b1..d65d8b67 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -62,6 +62,7 @@ from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format from google.protobuf import descriptor as descriptor_mod from google.protobuf import message as message_mod +from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor @@ -291,7 +292,7 @@ def _DefaultValueForField(message, field): def _AddInitMethod(message_descriptor, cls): """Adds an __init__ method to cls.""" fields = message_descriptor.fields - def init(self): + def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = False self._listener = message_listener_mod.NullMessageListener() @@ -306,12 +307,30 @@ def _AddInitMethod(message_descriptor, cls): if field.label != _FieldDescriptor.LABEL_REPEATED: setattr(self, _HasFieldName(field.name), False) self.Extensions = _ExtensionDict(self, cls._known_extensions) + for field_name, field_value in kwargs.iteritems(): + field = _GetFieldByName(message_descriptor, field_name) + _MergeFieldOrExtension(self, field, field_value) init.__module__ = None init.__doc__ = None cls.__init__ = init +def _GetFieldByName(message_descriptor, field_name): + """Returns a field descriptor by field name. + + Args: + message_descriptor: A Descriptor describing all fields in message. + field_name: The name of the field to retrieve. + Returns: + The field descriptor associated with the field name. + """ + try: + return message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + def _AddPropertiesForFields(descriptor, cls): """Adds properties for all fields in this protocol message type.""" for field in descriptor.fields: @@ -543,10 +562,7 @@ def _AddHasFieldMethod(cls): def _AddClearFieldMethod(cls): """Helper for _AddMessageMethods().""" def ClearField(self, field_name): - try: - field = self.DESCRIPTOR.fields_by_name[field_name] - except KeyError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + field = _GetFieldByName(self.DESCRIPTOR, field_name) proto_field_name = field.name python_field_name = _ValueFieldName(proto_field_name) has_field_name = _HasFieldName(proto_field_name) @@ -629,6 +645,13 @@ def _AddEqualsMethod(message_descriptor, cls): cls.__eq__ = __eq__ +def _AddStrMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __str__(self): + return text_format.MessageToString(self) + cls.__str__ = __str__ + + def _AddSetListenerMethod(cls): """Helper for _AddMessageMethods().""" def SetListener(self, listener): @@ -1090,7 +1113,7 @@ def _DeserializeOneEntity(message_descriptor, message, decoder): content_start = decoder.Position() while decoder.Position() - content_start < length: element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - content_start + return decoder.Position() - initial_position else: # Repeated composite. composite = element_list.add() @@ -1275,6 +1298,7 @@ def _AddMessageMethods(message_descriptor, cls): _AddClearMethod(cls) _AddHasExtensionMethod(cls) _AddEqualsMethod(message_descriptor, cls) + _AddStrMethod(message_descriptor, cls) _AddSetListenerMethod(cls) _AddByteSizeMethod(message_descriptor, cls) _AddSerializeToStringMethod(message_descriptor, cls) @@ -1436,6 +1460,20 @@ class _ExtensionDict(object): if extension.label != _FieldDescriptor.LABEL_REPEATED) self._has_bits = dict.fromkeys(keys, False) + self._extensions_by_number = dict( + (f.number, f) for f in self._known_extensions.itervalues()) + + self._extensions_by_name = {} + for extension in self._known_extensions.itervalues(): + if (extension.containing_type.GetOptions().message_set_wire_format and + extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and + extension.message_type == extension.extension_scope and + extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL): + extension_name = extension.message_type.full_name + else: + extension_name = extension.full_name + self._extensions_by_name[extension_name] = extension + def __getitem__(self, extension_handle): """Returns the current value of the given extension handle.""" # We don't care as much about keeping critical sections short in the @@ -1609,7 +1647,15 @@ class _ExtensionDict(object): Returns: A dict mapping field_number to (handle, field_descriptor), for *all* registered extensions for this dict. """ - # TODO(robinson): Precompute and store this away. Note that we'll have to - # be careful when we move away from having _known_extensions as a - # deep-copied member of this object. - return dict((f.number, f) for f in self._known_extensions.itervalues()) + return self._extensions_by_number + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._extensions_by_name.get(name, None) |