aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/reflection.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/reflection.py')
-rwxr-xr-xpython/google/protobuf/reflection.py66
1 files changed, 56 insertions, 10 deletions
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 5ab7a1b1..d65d8b67 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -62,6 +62,7 @@ from google.protobuf.internal import type_checkers
from google.protobuf.internal import wire_format
from google.protobuf import descriptor as descriptor_mod
from google.protobuf import message as message_mod
+from google.protobuf import text_format
_FieldDescriptor = descriptor_mod.FieldDescriptor
@@ -291,7 +292,7 @@ def _DefaultValueForField(message, field):
def _AddInitMethod(message_descriptor, cls):
"""Adds an __init__ method to cls."""
fields = message_descriptor.fields
- def init(self):
+ def init(self, **kwargs):
self._cached_byte_size = 0
self._cached_byte_size_dirty = False
self._listener = message_listener_mod.NullMessageListener()
@@ -306,12 +307,30 @@ def _AddInitMethod(message_descriptor, cls):
if field.label != _FieldDescriptor.LABEL_REPEATED:
setattr(self, _HasFieldName(field.name), False)
self.Extensions = _ExtensionDict(self, cls._known_extensions)
+ for field_name, field_value in kwargs.iteritems():
+ field = _GetFieldByName(message_descriptor, field_name)
+ _MergeFieldOrExtension(self, field, field_value)
init.__module__ = None
init.__doc__ = None
cls.__init__ = init
+def _GetFieldByName(message_descriptor, field_name):
+ """Returns a field descriptor by field name.
+
+ Args:
+ message_descriptor: A Descriptor describing all fields in message.
+ field_name: The name of the field to retrieve.
+ Returns:
+ The field descriptor associated with the field name.
+ """
+ try:
+ return message_descriptor.fields_by_name[field_name]
+ except KeyError:
+ raise ValueError('Protocol message has no "%s" field.' % field_name)
+
+
def _AddPropertiesForFields(descriptor, cls):
"""Adds properties for all fields in this protocol message type."""
for field in descriptor.fields:
@@ -543,10 +562,7 @@ def _AddHasFieldMethod(cls):
def _AddClearFieldMethod(cls):
"""Helper for _AddMessageMethods()."""
def ClearField(self, field_name):
- try:
- field = self.DESCRIPTOR.fields_by_name[field_name]
- except KeyError:
- raise ValueError('Protocol message has no "%s" field.' % field_name)
+ field = _GetFieldByName(self.DESCRIPTOR, field_name)
proto_field_name = field.name
python_field_name = _ValueFieldName(proto_field_name)
has_field_name = _HasFieldName(proto_field_name)
@@ -629,6 +645,13 @@ def _AddEqualsMethod(message_descriptor, cls):
cls.__eq__ = __eq__
+def _AddStrMethod(message_descriptor, cls):
+ """Helper for _AddMessageMethods()."""
+ def __str__(self):
+ return text_format.MessageToString(self)
+ cls.__str__ = __str__
+
+
def _AddSetListenerMethod(cls):
"""Helper for _AddMessageMethods()."""
def SetListener(self, listener):
@@ -1090,7 +1113,7 @@ def _DeserializeOneEntity(message_descriptor, message, decoder):
content_start = decoder.Position()
while decoder.Position() - content_start < length:
element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - content_start
+ return decoder.Position() - initial_position
else:
# Repeated composite.
composite = element_list.add()
@@ -1275,6 +1298,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddClearMethod(cls)
_AddHasExtensionMethod(cls)
_AddEqualsMethod(message_descriptor, cls)
+ _AddStrMethod(message_descriptor, cls)
_AddSetListenerMethod(cls)
_AddByteSizeMethod(message_descriptor, cls)
_AddSerializeToStringMethod(message_descriptor, cls)
@@ -1436,6 +1460,20 @@ class _ExtensionDict(object):
if extension.label != _FieldDescriptor.LABEL_REPEATED)
self._has_bits = dict.fromkeys(keys, False)
+ self._extensions_by_number = dict(
+ (f.number, f) for f in self._known_extensions.itervalues())
+
+ self._extensions_by_name = {}
+ for extension in self._known_extensions.itervalues():
+ if (extension.containing_type.GetOptions().message_set_wire_format and
+ extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and
+ extension.message_type == extension.extension_scope and
+ extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL):
+ extension_name = extension.message_type.full_name
+ else:
+ extension_name = extension.full_name
+ self._extensions_by_name[extension_name] = extension
+
def __getitem__(self, extension_handle):
"""Returns the current value of the given extension handle."""
# We don't care as much about keeping critical sections short in the
@@ -1609,7 +1647,15 @@ class _ExtensionDict(object):
Returns: A dict mapping field_number to (handle, field_descriptor),
for *all* registered extensions for this dict.
"""
- # TODO(robinson): Precompute and store this away. Note that we'll have to
- # be careful when we move away from having _known_extensions as a
- # deep-copied member of this object.
- return dict((f.number, f) for f in self._known_extensions.itervalues())
+ return self._extensions_by_number
+
+ def _FindExtensionByName(self, name):
+ """Tries to find a known extension with the specified name.
+
+ Args:
+ name: Extension full name.
+
+ Returns:
+ Extension field descriptor.
+ """
+ return self._extensions_by_name.get(name, None)