aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/reflection.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/reflection.py')
-rwxr-xr-xpython/google/protobuf/reflection.py99
1 files changed, 86 insertions, 13 deletions
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 9ba752e4..0d5191be 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -462,6 +462,12 @@ def _AddStaticMethods(cls):
cls._known_extensions.append(extension_handle)
cls.RegisterExtension = staticmethod(RegisterExtension)
+ def FromString(s):
+ message = cls()
+ message.MergeFromString(s)
+ return message
+ cls.FromString = staticmethod(FromString)
+
def _AddListFieldsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -665,9 +671,36 @@ def _AddByteSizeMethod(message_descriptor, cls):
else:
elements = [value]
- size = sum(_BytesForNonRepeatedElement(element, field_number, field_type)
- for element in elements)
- return size
+ if field.GetOptions().packed:
+ content_size = _ContentBytesForPackedField(message, field, elements)
+ if content_size:
+ tag_size = wire_format.TagByteSize(field_number)
+ length_size = wire_format.Int32ByteSizeNoTag(content_size)
+ return tag_size + length_size + content_size
+ else:
+ return 0
+ else:
+ return sum(_BytesForNonRepeatedElement(element, field_number, field_type)
+ for element in elements)
+
+ def _ContentBytesForPackedField(self, field, value):
+ """Returns the number of bytes required to serialize the actual
+ content of a packed field (not including the tag or the encoding
+ of the length.
+
+ Args:
+ self: The Message instance containing a field of the given type.
+ field: A FieldDescriptor describing the field of interest.
+ value: The value whose byte size we're interested in.
+
+ Returns: The number of bytes required to serialize the current value
+ of the packed "field" in "message", excluding space for tags and the
+ length encoding.
+ """
+ size = sum(_BytesForNonRepeatedElement(element, field.number, field.type)
+ for element in value)
+ # In the packed case, there are no per element tags.
+ return size - wire_format.TagByteSize(field.number) * len(value)
fields = message_descriptor.fields
has_field_names = (_HasFieldName(f.name) for f in fields)
@@ -691,6 +724,8 @@ def _AddByteSizeMethod(message_descriptor, cls):
self._cached_byte_size = size
self._cached_byte_size_dirty = False
return size
+
+ cls._ContentBytesForPackedField = _ContentBytesForPackedField
cls.ByteSize = ByteSize
@@ -788,10 +823,29 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
repeated_value = field_value
else:
repeated_value = [field_value]
- for element in repeated_value:
- _SerializeValueToEncoder(element, field_descriptor.number,
- field_descriptor, encoder)
+ if field_descriptor.GetOptions().packed:
+ # First, write the field number and WIRETYPE_LENGTH_DELIMITED.
+ field_number = field_descriptor.number
+ encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ # Next, write the number of bytes.
+ content_bytes = self._ContentBytesForPackedField(
+ field_descriptor, field_value)
+ encoder.AppendInt32NoTag(content_bytes)
+ # Finally, write the actual values.
+ try:
+ method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[
+ field_descriptor.type]
+ for value in repeated_value:
+ method(encoder, value)
+ except KeyError:
+ raise message_mod.EncodeError('Unrecognized field type: %d' %
+ field_descriptor.type)
+ else:
+ for element in repeated_value:
+ _SerializeValueToEncoder(element, field_descriptor.number,
+ field_descriptor, encoder)
return encoder.ToString()
+
cls.SerializePartialToString = SerializePartialToString
@@ -803,6 +857,14 @@ def _WireTypeForFieldType(field_type):
raise message_mod.DecodeError('Unknown field type: %d' % field_type)
+def _WireTypeForField(field_descriptor):
+ """Given a field descriptor, returns the expected wire type."""
+ if field_descriptor.GetOptions().packed:
+ return wire_format.WIRETYPE_LENGTH_DELIMITED
+ else:
+ return _WireTypeForFieldType(field_descriptor.type)
+
+
def _RecursivelyMerge(field_number, field_type, decoder, message):
"""Decodes a message from decoder into message.
message is either a group or a nested message within some containing
@@ -918,9 +980,11 @@ def _DeserializeMessageSetItem(message, decoder):
def _DeserializeOneEntity(message_descriptor, message, decoder):
"""Deserializes the next wire entity from decoder into message.
- The next wire entity is either a scalar or a nested message,
- and may also be an element in a repeated field (the wire encoding
- is the same).
+
+ The next wire entity is either a scalar or a nested message, an
+ element in a repeated field (the wire encoding in this case is the
+ same), or a packed repeated field (in this case, the entire repeated
+ field is read by a single call to _DeserializeOneEntity).
Args:
message_descriptor: A Descriptor instance describing all fields
@@ -973,14 +1037,14 @@ def _DeserializeOneEntity(message_descriptor, message, decoder):
# if this field is a nonrepeated scalar.
field_number = field_descriptor.number
- field_type = field_descriptor.type
- expected_wire_type = _WireTypeForFieldType(field_type)
+ expected_wire_type = _WireTypeForField(field_descriptor)
if wire_type != expected_wire_type:
# Need to fill in uninterpreted_bytes. Work for the next CL.
raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
property_name = _PropertyName(field_descriptor.name)
label = field_descriptor.label
+ field_type = field_descriptor.type
cpp_type = field_descriptor.cpp_type
# Nonrepeated scalar. Just set the field directly.
@@ -1000,8 +1064,17 @@ def _DeserializeOneEntity(message_descriptor, message, decoder):
if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
# Repeated scalar.
- element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - initial_position
+ if not field_descriptor.GetOptions().packed:
+ element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - initial_position
+ else:
+ # Packed repeated field.
+ length = _DeserializeScalarFromDecoder(
+ _FieldDescriptor.TYPE_INT32, decoder)
+ content_start = decoder.Position()
+ while decoder.Position() - content_start < length:
+ element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - content_start
else:
# Repeated composite.
composite = element_list.add()