aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf')
-rwxr-xr-xpython/google/protobuf/internal/containers.py13
-rwxr-xr-xpython/google/protobuf/internal/decoder_test.py2
-rwxr-xr-xpython/google/protobuf/internal/encoder.py140
-rwxr-xr-xpython/google/protobuf/internal/encoder_test.py70
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py120
-rwxr-xr-xpython/google/protobuf/internal/service_reflection_test.py6
-rwxr-xr-xpython/google/protobuf/internal/test_util.py48
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py17
-rwxr-xr-xpython/google/protobuf/internal/wire_format.py35
-rwxr-xr-xpython/google/protobuf/reflection.py99
-rwxr-xr-xpython/google/protobuf/service.py17
-rwxr-xr-xpython/google/protobuf/service_reflection.py23
12 files changed, 494 insertions, 96 deletions
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index 14fe863e..fa1e3402 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -106,6 +106,19 @@ class RepeatedScalarFieldContainer(BaseContainer):
if len(self._values) == 1:
self._message_listener.TransitionToNonempty()
+ def extend(self, elem_seq):
+ """Extends by appending the given sequence. Similar to list.extend()."""
+ if not elem_seq:
+ return
+
+ orig_empty = len(self._values) == 0
+ for elem in elem_seq:
+ self._type_checker.CheckValue(elem)
+ self._values.extend(elem_seq)
+ self._message_listener.ByteSizeDirty()
+ if orig_empty:
+ self._message_listener.TransitionToNonempty()
+
def remove(self, elem):
"""Removes an item from the list. Similar to list.remove()."""
self._values.remove(elem)
diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py
index abcc07fc..e186e14d 100755
--- a/python/google/protobuf/internal/decoder_test.py
+++ b/python/google/protobuf/internal/decoder_test.py
@@ -57,7 +57,7 @@ class DecoderTest(unittest.TestCase):
for expected_field_number in (1, 15, 16, 2047, 2048):
for expected_wire_type in range(6): # Highest-numbered wiretype is 5.
e = encoder.Encoder()
- e._AppendTag(expected_field_number, expected_wire_type)
+ e.AppendTag(expected_field_number, expected_wire_type)
s = e.ToString()
d = decoder.Decoder(s)
field_number, wire_type = d.ReadFieldNumberAndWireType()
diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py
index 7071241c..eed8c8bd 100755
--- a/python/google/protobuf/internal/encoder.py
+++ b/python/google/protobuf/internal/encoder.py
@@ -58,89 +58,161 @@ class Encoder(object):
"""Returns all values encoded in this object as a string."""
return self._stream.ToString()
- # All the Append*() methods below first append a tag+type pair to the buffer
- # before appending the specified value.
-
- def AppendInt32(self, field_number, value):
+ # Append*NoTag methods. These are necessary for serializing packed
+ # repeated fields. The Append*() methods call these methods to do
+ # the actual serialization.
+ def AppendInt32NoTag(self, value):
"""Appends a 32-bit integer to our buffer, varint-encoded."""
- self._AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self._stream.AppendVarint32(value)
- def AppendInt64(self, field_number, value):
+ def AppendInt64NoTag(self, value):
"""Appends a 64-bit integer to our buffer, varint-encoded."""
- self._AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self._stream.AppendVarint64(value)
- def AppendUInt32(self, field_number, unsigned_value):
+ def AppendUInt32NoTag(self, unsigned_value):
"""Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
- self._AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self._stream.AppendVarUInt32(unsigned_value)
- def AppendUInt64(self, field_number, unsigned_value):
+ def AppendUInt64NoTag(self, unsigned_value):
"""Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
- self._AppendTag(field_number, wire_format.WIRETYPE_VARINT)
self._stream.AppendVarUInt64(unsigned_value)
- def AppendSInt32(self, field_number, value):
+ def AppendSInt32NoTag(self, value):
"""Appends a 32-bit integer to our buffer, zigzag-encoded and then
varint-encoded.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_VARINT)
zigzag_value = wire_format.ZigZagEncode(value)
self._stream.AppendVarUInt32(zigzag_value)
- def AppendSInt64(self, field_number, value):
+ def AppendSInt64NoTag(self, value):
"""Appends a 64-bit integer to our buffer, zigzag-encoded and then
varint-encoded.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_VARINT)
zigzag_value = wire_format.ZigZagEncode(value)
self._stream.AppendVarUInt64(zigzag_value)
- def AppendFixed32(self, field_number, unsigned_value):
+ def AppendFixed32NoTag(self, unsigned_value):
"""Appends an unsigned 32-bit integer to our buffer, in little-endian
byte-order.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
self._stream.AppendLittleEndian32(unsigned_value)
- def AppendFixed64(self, field_number, unsigned_value):
+ def AppendFixed64NoTag(self, unsigned_value):
"""Appends an unsigned 64-bit integer to our buffer, in little-endian
byte-order.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
self._stream.AppendLittleEndian64(unsigned_value)
- def AppendSFixed32(self, field_number, value):
+ def AppendSFixed32NoTag(self, value):
"""Appends a signed 32-bit integer to our buffer, in little-endian
byte-order.
"""
sign = (value & 0x80000000) and -1 or 0
if value >> 32 != sign:
raise message.EncodeError('SFixed32 out of range: %d' % value)
- self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
self._stream.AppendLittleEndian32(value & 0xffffffff)
- def AppendSFixed64(self, field_number, value):
+ def AppendSFixed64NoTag(self, value):
"""Appends a signed 64-bit integer to our buffer, in little-endian
byte-order.
"""
sign = (value & 0x8000000000000000) and -1 or 0
if value >> 64 != sign:
raise message.EncodeError('SFixed64 out of range: %d' % value)
- self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
self._stream.AppendLittleEndian64(value & 0xffffffffffffffff)
- def AppendFloat(self, field_number, value):
+ def AppendFloatNoTag(self, value):
"""Appends a floating-point number to our buffer."""
- self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
self._stream.AppendRawBytes(struct.pack('f', value))
- def AppendDouble(self, field_number, value):
+ def AppendDoubleNoTag(self, value):
"""Appends a double-precision floating-point number to our buffer."""
- self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
self._stream.AppendRawBytes(struct.pack('d', value))
+ def AppendBoolNoTag(self, value):
+ """Appends a boolean to our buffer."""
+ self.AppendInt32NoTag(value)
+
+ def AppendEnumNoTag(self, value):
+ """Appends an enum value to our buffer."""
+ self.AppendInt32NoTag(value)
+
+
+ # All the Append*() methods below first append a tag+type pair to the buffer
+ # before appending the specified value.
+
+ def AppendInt32(self, field_number, value):
+ """Appends a 32-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendInt32NoTag(value)
+
+ def AppendInt64(self, field_number, value):
+ """Appends a 64-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendInt64NoTag(value)
+
+ def AppendUInt32(self, field_number, unsigned_value):
+ """Appends an unsigned 32-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendUInt32NoTag(unsigned_value)
+
+ def AppendUInt64(self, field_number, unsigned_value):
+ """Appends an unsigned 64-bit integer to our buffer, varint-encoded."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendUInt64NoTag(unsigned_value)
+
+ def AppendSInt32(self, field_number, value):
+ """Appends a 32-bit integer to our buffer, zigzag-encoded and then
+ varint-encoded.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendSInt32NoTag(value)
+
+ def AppendSInt64(self, field_number, value):
+ """Appends a 64-bit integer to our buffer, zigzag-encoded and then
+ varint-encoded.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_VARINT)
+ self.AppendSInt64NoTag(value)
+
+ def AppendFixed32(self, field_number, unsigned_value):
+ """Appends an unsigned 32-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
+ self.AppendFixed32NoTag(unsigned_value)
+
+ def AppendFixed64(self, field_number, unsigned_value):
+ """Appends an unsigned 64-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
+ self.AppendFixed64NoTag(unsigned_value)
+
+ def AppendSFixed32(self, field_number, value):
+ """Appends a signed 32-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
+ self.AppendSFixed32NoTag(value)
+
+ def AppendSFixed64(self, field_number, value):
+ """Appends a signed 64-bit integer to our buffer, in little-endian
+ byte-order.
+ """
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
+ self.AppendSFixed64NoTag(value)
+
+ def AppendFloat(self, field_number, value):
+ """Appends a floating-point number to our buffer."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32)
+ self.AppendFloatNoTag(value)
+
+ def AppendDouble(self, field_number, value):
+ """Appends a double-precision floating-point number to our buffer."""
+ self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64)
+ self.AppendDoubleNoTag(value)
+
def AppendBool(self, field_number, value):
"""Appends a boolean to our buffer."""
self.AppendInt32(field_number, value)
@@ -159,7 +231,7 @@ class Encoder(object):
"""Appends a length-prefixed sequence of bytes to our buffer, with the
length varint-encoded.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
self._stream.AppendVarUInt32(len(value))
self._stream.AppendRawBytes(value)
@@ -174,14 +246,14 @@ class Encoder(object):
def AppendGroup(self, field_number, group):
"""Appends a group to our buffer.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
+ self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP)
self._stream.AppendRawBytes(group.SerializeToString())
- self._AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
+ self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP)
def AppendMessage(self, field_number, msg):
"""Appends a nested message to our buffer.
"""
- self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
self._stream.AppendVarUInt32(msg.ByteSize())
self._stream.AppendRawBytes(msg.SerializeToString())
@@ -196,11 +268,11 @@ class Encoder(object):
}
}
"""
- self._AppendTag(1, wire_format.WIRETYPE_START_GROUP)
+ self.AppendTag(1, wire_format.WIRETYPE_START_GROUP)
self.AppendInt32(2, field_number)
self.AppendMessage(3, msg)
- self._AppendTag(1, wire_format.WIRETYPE_END_GROUP)
+ self.AppendTag(1, wire_format.WIRETYPE_END_GROUP)
- def _AppendTag(self, field_number, wire_type):
+ def AppendTag(self, field_number, wire_type):
"""Appends a tag containing field number and wire type information."""
self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type))
diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py
index 61668223..83a21c39 100755
--- a/python/google/protobuf/internal/encoder_test.py
+++ b/python/google/protobuf/internal/encoder_test.py
@@ -59,7 +59,8 @@ class EncoderTest(unittest.TestCase):
def AppendScalarTestHelper(self, test_name, encoder_method,
expected_stream_method_name,
wire_type, field_value,
- expected_value=None, expected_length=None):
+ expected_value=None, expected_length=None,
+ is_tag_test=True):
"""Helper for testAppendScalars.
Calls one of the Encoder methods, and ensures that the Encoder
@@ -67,9 +68,10 @@ class EncoderTest(unittest.TestCase):
Args:
test_name: Name of this test, used only for logging.
- encoder_method: Callable on self.encoder, which should
- accept |field_value| as an argument. This is the Encoder
- method we're testing.
+ encoder_method: Callable on self.encoder. This is the Encoder
+ method we're testing. If is_tag_test=True, the encoder method
+ accepts a field_number and field_value. if is_tag_test=False,
+ the encoder method accepts a field_value.
expected_stream_method_name: (string) Name of the OutputStream
method we expect Encoder to call to actually put the value
on the wire.
@@ -83,6 +85,9 @@ class EncoderTest(unittest.TestCase):
expected_length: The length we expect Encoder to pass to the
AppendVarUInt32 method. If None we expect the length of the
field_value.
+ is_tag_test: A Boolean. If True (the default), we append the
+ the packed field number and wire_type to the stream before
+ the field value.
"""
if expected_value is None:
expected_value = field_value
@@ -93,14 +98,16 @@ class EncoderTest(unittest.TestCase):
test_name, encoder_method, field_value,
expected_stream_method_name, expected_value))
- field_number = 10
- # Should first append the field number and type information.
- self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type))
- # If we're length-delimited, we should then append the length.
- if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
- if expected_length is None:
- expected_length = len(field_value)
- self.mock_stream.AppendVarUInt32(expected_length)
+ if is_tag_test:
+ field_number = 10
+ # Should first append the field number and type information.
+ self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type))
+ # If we're length-delimited, we should then append the length.
+ if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+ if expected_length is None:
+ expected_length = len(field_value)
+ self.mock_stream.AppendVarUInt32(expected_length)
+
# Should then append the value itself.
# We have to use names instead of methods to work around some
# mox weirdness. (ResetAll() is overzealous).
@@ -109,7 +116,10 @@ class EncoderTest(unittest.TestCase):
expected_stream_method(expected_value)
self.mox.ReplayAll()
- encoder_method(field_number, field_value)
+ if is_tag_test:
+ encoder_method(field_number, field_value)
+ else:
+ encoder_method(field_value)
self.mox.VerifyAll()
self.mox.ResetAll()
@@ -160,6 +170,40 @@ class EncoderTest(unittest.TestCase):
for args in scalar_tests:
self.AppendScalarTestHelper(*args)
+ def testAppendScalarsWithoutTags(self):
+ scalar_no_tag_tests = [
+ ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0],
+ ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0],
+ ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0],
+ ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0],
+ ['fixed32', self.encoder.AppendFixed32NoTag,
+ 'AppendLittleEndian32', None, 0],
+ ['fixed64', self.encoder.AppendFixed64NoTag,
+ 'AppendLittleEndian64', None, 0],
+ ['sfixed32', self.encoder.AppendSFixed32NoTag,
+ 'AppendLittleEndian32', None, 0],
+ ['sfixed64', self.encoder.AppendSFixed64NoTag,
+ 'AppendLittleEndian64', None, 0],
+ ['float', self.encoder.AppendFloatNoTag,
+ 'AppendRawBytes', None, 0.0, struct.pack('f', 0.0)],
+ ['double', self.encoder.AppendDoubleNoTag,
+ 'AppendRawBytes', None, 0.0, struct.pack('d', 0.0)],
+ ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0],
+ ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0],
+ ['sint32', self.encoder.AppendSInt32NoTag,
+ 'AppendVarUInt32', None, -1, 1],
+ ['sint64', self.encoder.AppendSInt64NoTag,
+ 'AppendVarUInt64', None, -1, 1],
+ ]
+
+ self.assertEqual(len(scalar_no_tag_tests),
+ len(set(t[0] for t in scalar_no_tag_tests)))
+ self.assert_(len(scalar_no_tag_tests) >=
+ len(set(t[1] for t in scalar_no_tag_tests)))
+ for args in scalar_no_tag_tests:
+ # For no tag tests, the wire_type is not used, so we put in None.
+ self.AppendScalarTestHelper(is_tag_test=False, *args)
+
def testAppendGroup(self):
field_number = 23
# Should first append the start-group marker.
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 14062762..e405f60b 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -229,8 +229,8 @@ class ReflectionTest(unittest.TestCase):
proto.repeated_fixed32.append(1)
proto.repeated_int32.append(5)
proto.repeated_int32.append(11)
- proto.repeated_string.append('foo')
- proto.repeated_string.append('bar')
+ proto.repeated_string.extend(['foo', 'bar'])
+ proto.repeated_string.extend([])
proto.repeated_string.append('baz')
proto.optional_int32 = 21
self.assertEqual(
@@ -757,6 +757,16 @@ class ReflectionTest(unittest.TestCase):
self.assertRaises(KeyError, extendee_proto.HasExtension,
unittest_pb2.repeated_string_extension)
+ def testStaticParseFrom(self):
+ proto1 = unittest_pb2.TestAllTypes()
+ test_util.SetAllFields(proto1)
+
+ string1 = proto1.SerializeToString()
+ proto2 = unittest_pb2.TestAllTypes.FromString(string1)
+
+ # Messages should be equal.
+ self.assertEqual(proto2, proto1)
+
def testMergeFromSingularField(self):
# Test merge with just a singular field.
proto1 = unittest_pb2.TestAllTypes()
@@ -1209,6 +1219,8 @@ class ByteSizeTest(unittest.TestCase):
def setUp(self):
self.proto = unittest_pb2.TestAllTypes()
self.extended_proto = more_extensions_pb2.ExtendedMessage()
+ self.packed_proto = unittest_pb2.TestPackedTypes()
+ self.packed_extended_proto = unittest_pb2.TestPackedExtensions()
def Size(self):
return self.proto.ByteSize()
@@ -1291,6 +1303,11 @@ class ByteSizeTest(unittest.TestCase):
# Also need 2 bytes for each entry for tag.
self.assertEqual(1 + 2 + 2*2, self.Size())
+ def testRepeatedScalarsExtend(self):
+ self.proto.repeated_int32.extend([10, 128]) # 3 bytes.
+ # Also need 2 bytes for each entry for tag.
+ self.assertEqual(1 + 2 + 2*2, self.Size())
+
def testRepeatedScalarsRemove(self):
self.proto.repeated_int32.append(10) # 1 byte.
self.proto.repeated_int32.append(128) # 2 bytes.
@@ -1443,6 +1460,33 @@ class ByteSizeTest(unittest.TestCase):
self.extended_proto.ClearExtension(extension)
self.assertEqual(0, self.extended_proto.ByteSize())
+ def testPackedRepeatedScalars(self):
+ self.assertEqual(0, self.packed_proto.ByteSize())
+
+ self.packed_proto.packed_int32.append(10) # 1 byte.
+ self.packed_proto.packed_int32.append(128) # 2 bytes.
+ # The tag is 2 bytes (the field number is 90), and the varint
+ # storing the length is 1 byte.
+ int_size = 1 + 2 + 3
+ self.assertEqual(int_size, self.packed_proto.ByteSize())
+
+ self.packed_proto.packed_double.append(4.2) # 8 bytes
+ self.packed_proto.packed_double.append(3.25) # 8 bytes
+ # 2 more tag bytes, 1 more length byte.
+ double_size = 8 + 8 + 3
+ self.assertEqual(int_size+double_size, self.packed_proto.ByteSize())
+
+ self.packed_proto.ClearField('packed_int32')
+ self.assertEqual(double_size, self.packed_proto.ByteSize())
+
+ def testPackedExtensions(self):
+ self.assertEqual(0, self.packed_extended_proto.ByteSize())
+ extension = self.packed_extended_proto.Extensions[
+ unittest_pb2.packed_fixed32_extension]
+ extension.extend([1, 2, 3, 4]) # 16 bytes
+ # Tag is 3 bytes.
+ self.assertEqual(19, self.packed_extended_proto.ByteSize())
+
# TODO(robinson): We need cross-language serialization consistency tests.
# Issues to be sure to cover include:
@@ -1686,6 +1730,63 @@ class SerializationTest(unittest.TestCase):
self.assertEqual(2, proto2.b)
self.assertEqual(3, proto2.c)
+ def testSerializedAllPackedFields(self):
+ first_proto = unittest_pb2.TestPackedTypes()
+ second_proto = unittest_pb2.TestPackedTypes()
+ test_util.SetAllPackedFields(first_proto)
+ serialized = first_proto.SerializeToString()
+ self.assertEqual(first_proto.ByteSize(), len(serialized))
+ second_proto.MergeFromString(serialized)
+ self.assertEqual(first_proto, second_proto)
+
+ def testSerializeAllPackedExtensions(self):
+ first_proto = unittest_pb2.TestPackedExtensions()
+ second_proto = unittest_pb2.TestPackedExtensions()
+ test_util.SetAllPackedExtensions(first_proto)
+ serialized = first_proto.SerializeToString()
+ second_proto.MergeFromString(serialized)
+ self.assertEqual(first_proto, second_proto)
+
+ def testMergePackedFromStringWhenSomeFieldsAlreadySet(self):
+ first_proto = unittest_pb2.TestPackedTypes()
+ first_proto.packed_int32.extend([1, 2])
+ first_proto.packed_double.append(3.0)
+ serialized = first_proto.SerializeToString()
+
+ second_proto = unittest_pb2.TestPackedTypes()
+ second_proto.packed_int32.append(3)
+ second_proto.packed_double.extend([1.0, 2.0])
+ second_proto.packed_sint32.append(4)
+
+ second_proto.MergeFromString(serialized)
+ self.assertEqual([3, 1, 2], second_proto.packed_int32)
+ self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double)
+ self.assertEqual([4], second_proto.packed_sint32)
+
+ def testPackedFieldsWireFormat(self):
+ proto = unittest_pb2.TestPackedTypes()
+ proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes
+ proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes
+ proto.packed_float.append(2.0) # 4 bytes, will be before double
+ serialized = proto.SerializeToString()
+ self.assertEqual(proto.ByteSize(), len(serialized))
+ d = decoder.Decoder(serialized)
+ ReadTag = d.ReadFieldNumberAndWireType
+ self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
+ self.assertEqual(1+1+1+2, d.ReadInt32())
+ self.assertEqual(1, d.ReadInt32())
+ self.assertEqual(2, d.ReadInt32())
+ self.assertEqual(150, d.ReadInt32())
+ self.assertEqual(3, d.ReadInt32())
+ self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
+ self.assertEqual(4, d.ReadInt32())
+ self.assertEqual(2.0, d.ReadFloat())
+ self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag())
+ self.assertEqual(8+8, d.ReadInt32())
+ self.assertEqual(1.0, d.ReadDouble())
+ self.assertEqual(1000.0, d.ReadDouble())
+ self.assertTrue(d.EndOfStream())
+
class OptionsTest(unittest.TestCase):
@@ -1697,6 +1798,21 @@ class OptionsTest(unittest.TestCase):
self.assertEqual(False,
proto.DESCRIPTOR.GetOptions().message_set_wire_format)
+ def testPackedOptions(self):
+ proto = unittest_pb2.TestAllTypes()
+ proto.optional_int32 = 1
+ proto.optional_double = 3.0
+ for field_descriptor, _ in proto.ListFields():
+ self.assertEqual(False, field_descriptor.GetOptions().packed)
+
+ proto = unittest_pb2.TestPackedTypes()
+ proto.packed_int32.append(1)
+ proto.packed_double.append(3.0)
+ for field_descriptor, _ in proto.ListFields():
+ self.assertEqual(True, field_descriptor.GetOptions().packed)
+ self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED,
+ field_descriptor.label)
+
class UtilityTest(unittest.TestCase):
diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py
index 29492e16..e04f8252 100755
--- a/python/google/protobuf/internal/service_reflection_test.py
+++ b/python/google/protobuf/internal/service_reflection_test.py
@@ -74,7 +74,7 @@ class FooUnitTest(unittest.TestCase):
rpc_controller.failure_message = None
- service_descriptor = unittest_pb2.TestService.DESCRIPTOR
+ service_descriptor = unittest_pb2.TestService.GetDescriptor()
srvc.CallMethod(service_descriptor.methods[1], rpc_controller,
unittest_pb2.BarRequest(), MyCallback)
self.assertEqual('Method Bar not implemented.',
@@ -118,6 +118,10 @@ class FooUnitTest(unittest.TestCase):
rpc_controller = 'controller'
request = 'request'
+ # GetDescriptor now static, still works as instance method for compatability
+ self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(),
+ stub.GetDescriptor())
+
# Invoke method.
stub.Foo(rpc_controller, request, MyCallback)
diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py
index 14dfbc51..2d50bc4a 100755
--- a/python/google/protobuf/internal/test_util.py
+++ b/python/google/protobuf/internal/test_util.py
@@ -366,3 +366,51 @@ def GoldenFile(filename):
'Could not find golden files. This test must be run from within the '
'protobuf source package so that it can read test data files from the '
'C++ source tree.')
+
+
+def SetAllPackedFields(message):
+ """Sets every field in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestPackedTypes instance.
+ """
+ message.packed_int32.extend([101, 102])
+ message.packed_int64.extend([103, 104])
+ message.packed_uint32.extend([105, 106])
+ message.packed_uint64.extend([107, 108])
+ message.packed_sint32.extend([109, 110])
+ message.packed_sint64.extend([111, 112])
+ message.packed_fixed32.extend([113, 114])
+ message.packed_fixed64.extend([115, 116])
+ message.packed_sfixed32.extend([117, 118])
+ message.packed_sfixed64.extend([119, 120])
+ message.packed_float.extend([121.0, 122.0])
+ message.packed_double.extend([122.0, 123.0])
+ message.packed_bool.extend([True, False])
+ message.packed_enum.extend([unittest_pb2.FOREIGN_FOO,
+ unittest_pb2.FOREIGN_BAR])
+
+
+def SetAllPackedExtensions(message):
+ """Sets every extension in the message to a unique value.
+
+ Args:
+ message: A unittest_pb2.TestPackedExtensions instance.
+ """
+ extensions = message.Extensions
+ pb2 = unittest_pb2
+
+ extensions[pb2.packed_int32_extension].append(101)
+ extensions[pb2.packed_int64_extension].append(102)
+ extensions[pb2.packed_uint32_extension].append(103)
+ extensions[pb2.packed_uint64_extension].append(104)
+ extensions[pb2.packed_sint32_extension].append(105)
+ extensions[pb2.packed_sint64_extension].append(106)
+ extensions[pb2.packed_fixed32_extension].append(107)
+ extensions[pb2.packed_fixed64_extension].append(108)
+ extensions[pb2.packed_sfixed32_extension].append(109)
+ extensions[pb2.packed_sfixed64_extension].append(110)
+ extensions[pb2.packed_float_extension].append(111.0)
+ extensions[pb2.packed_double_extension].append(112.0)
+ extensions[pb2.packed_bool_extension].append(True)
+ extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ)
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index ba470cf2..c009627f 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -216,6 +216,23 @@ TYPE_TO_SERIALIZE_METHOD = {
}
+TYPE_TO_NOTAG_SERIALIZE_METHOD = {
+ _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag,
+ _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag,
+ _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag,
+ _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag,
+ _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag,
+ _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag,
+ _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag,
+ _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag,
+ _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag,
+ _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag,
+ _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag,
+ _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag,
+ _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag,
+ _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag,
+ }
+
# Maps from field type to expected wiretype.
FIELD_TYPE_TO_WIRE_TYPE = {
_FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64,
diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py
index 5f0af11e..531c9b85 100755
--- a/python/google/protobuf/internal/wire_format.py
+++ b/python/google/protobuf/internal/wire_format.py
@@ -120,6 +120,10 @@ def Int32ByteSize(field_number, int32):
return Int64ByteSize(field_number, int32)
+def Int32ByteSizeNoTag(int32):
+ return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32)
+
+
def Int64ByteSize(field_number, int64):
# Have to convert to uint before calling UInt64ByteSize().
return UInt64ByteSize(field_number, 0xffffffffffffffff & int64)
@@ -130,7 +134,7 @@ def UInt32ByteSize(field_number, uint32):
def UInt64ByteSize(field_number, uint64):
- return _TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64)
+ return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64)
def SInt32ByteSize(field_number, int32):
@@ -142,31 +146,31 @@ def SInt64ByteSize(field_number, int64):
def Fixed32ByteSize(field_number, fixed32):
- return _TagByteSize(field_number) + 4
+ return TagByteSize(field_number) + 4
def Fixed64ByteSize(field_number, fixed64):
- return _TagByteSize(field_number) + 8
+ return TagByteSize(field_number) + 8
def SFixed32ByteSize(field_number, sfixed32):
- return _TagByteSize(field_number) + 4
+ return TagByteSize(field_number) + 4
def SFixed64ByteSize(field_number, sfixed64):
- return _TagByteSize(field_number) + 8
+ return TagByteSize(field_number) + 8
def FloatByteSize(field_number, flt):
- return _TagByteSize(field_number) + 4
+ return TagByteSize(field_number) + 4
def DoubleByteSize(field_number, double):
- return _TagByteSize(field_number) + 8
+ return TagByteSize(field_number) + 8
def BoolByteSize(field_number, b):
- return _TagByteSize(field_number) + 1
+ return TagByteSize(field_number) + 1
def EnumByteSize(field_number, enum):
@@ -178,18 +182,18 @@ def StringByteSize(field_number, string):
def BytesByteSize(field_number, b):
- return (_TagByteSize(field_number)
+ return (TagByteSize(field_number)
+ _VarUInt64ByteSizeNoTag(len(b))
+ len(b))
def GroupByteSize(field_number, message):
- return (2 * _TagByteSize(field_number) # START and END group.
+ return (2 * TagByteSize(field_number) # START and END group.
+ message.ByteSize())
def MessageByteSize(field_number, message):
- return (_TagByteSize(field_number)
+ return (TagByteSize(field_number)
+ _VarUInt64ByteSizeNoTag(message.ByteSize())
+ message.ByteSize())
@@ -199,7 +203,7 @@ def MessageSetItemByteSize(field_number, msg):
# There are 2 tags for the beginning and ending of the repeated group, that
# is field number 1, one with field number 2 (type_id) and one with field
# number 3 (message).
- total_size = (2 * _TagByteSize(1) + _TagByteSize(2) + _TagByteSize(3))
+ total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3))
# Add the number of bytes for type_id.
total_size += _VarUInt64ByteSizeNoTag(field_number)
@@ -214,15 +218,14 @@ def MessageSetItemByteSize(field_number, msg):
return total_size
-# Private helper functions for the *ByteSize() functions above.
-
-
-def _TagByteSize(field_number):
+def TagByteSize(field_number):
"""Returns the bytes required to serialize a tag with this field number."""
# Just pass in type 0, since the type won't affect the tag+type size.
return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0))
+# Private helper function for the *ByteSize() functions above.
+
def _VarUInt64ByteSizeNoTag(uint64):
"""Returns the bytes required to serialize a single varint.
uint64 must be unsigned.
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index 9ba752e4..0d5191be 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -462,6 +462,12 @@ def _AddStaticMethods(cls):
cls._known_extensions.append(extension_handle)
cls.RegisterExtension = staticmethod(RegisterExtension)
+ def FromString(s):
+ message = cls()
+ message.MergeFromString(s)
+ return message
+ cls.FromString = staticmethod(FromString)
+
def _AddListFieldsMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
@@ -665,9 +671,36 @@ def _AddByteSizeMethod(message_descriptor, cls):
else:
elements = [value]
- size = sum(_BytesForNonRepeatedElement(element, field_number, field_type)
- for element in elements)
- return size
+ if field.GetOptions().packed:
+ content_size = _ContentBytesForPackedField(message, field, elements)
+ if content_size:
+ tag_size = wire_format.TagByteSize(field_number)
+ length_size = wire_format.Int32ByteSizeNoTag(content_size)
+ return tag_size + length_size + content_size
+ else:
+ return 0
+ else:
+ return sum(_BytesForNonRepeatedElement(element, field_number, field_type)
+ for element in elements)
+
+ def _ContentBytesForPackedField(self, field, value):
+ """Returns the number of bytes required to serialize the actual
+ content of a packed field (not including the tag or the encoding
+ of the length.
+
+ Args:
+ self: The Message instance containing a field of the given type.
+ field: A FieldDescriptor describing the field of interest.
+ value: The value whose byte size we're interested in.
+
+ Returns: The number of bytes required to serialize the current value
+ of the packed "field" in "message", excluding space for tags and the
+ length encoding.
+ """
+ size = sum(_BytesForNonRepeatedElement(element, field.number, field.type)
+ for element in value)
+ # In the packed case, there are no per element tags.
+ return size - wire_format.TagByteSize(field.number) * len(value)
fields = message_descriptor.fields
has_field_names = (_HasFieldName(f.name) for f in fields)
@@ -691,6 +724,8 @@ def _AddByteSizeMethod(message_descriptor, cls):
self._cached_byte_size = size
self._cached_byte_size_dirty = False
return size
+
+ cls._ContentBytesForPackedField = _ContentBytesForPackedField
cls.ByteSize = ByteSize
@@ -788,10 +823,29 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
repeated_value = field_value
else:
repeated_value = [field_value]
- for element in repeated_value:
- _SerializeValueToEncoder(element, field_descriptor.number,
- field_descriptor, encoder)
+ if field_descriptor.GetOptions().packed:
+ # First, write the field number and WIRETYPE_LENGTH_DELIMITED.
+ field_number = field_descriptor.number
+ encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)
+ # Next, write the number of bytes.
+ content_bytes = self._ContentBytesForPackedField(
+ field_descriptor, field_value)
+ encoder.AppendInt32NoTag(content_bytes)
+ # Finally, write the actual values.
+ try:
+ method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[
+ field_descriptor.type]
+ for value in repeated_value:
+ method(encoder, value)
+ except KeyError:
+ raise message_mod.EncodeError('Unrecognized field type: %d' %
+ field_descriptor.type)
+ else:
+ for element in repeated_value:
+ _SerializeValueToEncoder(element, field_descriptor.number,
+ field_descriptor, encoder)
return encoder.ToString()
+
cls.SerializePartialToString = SerializePartialToString
@@ -803,6 +857,14 @@ def _WireTypeForFieldType(field_type):
raise message_mod.DecodeError('Unknown field type: %d' % field_type)
+def _WireTypeForField(field_descriptor):
+ """Given a field descriptor, returns the expected wire type."""
+ if field_descriptor.GetOptions().packed:
+ return wire_format.WIRETYPE_LENGTH_DELIMITED
+ else:
+ return _WireTypeForFieldType(field_descriptor.type)
+
+
def _RecursivelyMerge(field_number, field_type, decoder, message):
"""Decodes a message from decoder into message.
message is either a group or a nested message within some containing
@@ -918,9 +980,11 @@ def _DeserializeMessageSetItem(message, decoder):
def _DeserializeOneEntity(message_descriptor, message, decoder):
"""Deserializes the next wire entity from decoder into message.
- The next wire entity is either a scalar or a nested message,
- and may also be an element in a repeated field (the wire encoding
- is the same).
+
+ The next wire entity is either a scalar or a nested message, an
+ element in a repeated field (the wire encoding in this case is the
+ same), or a packed repeated field (in this case, the entire repeated
+ field is read by a single call to _DeserializeOneEntity).
Args:
message_descriptor: A Descriptor instance describing all fields
@@ -973,14 +1037,14 @@ def _DeserializeOneEntity(message_descriptor, message, decoder):
# if this field is a nonrepeated scalar.
field_number = field_descriptor.number
- field_type = field_descriptor.type
- expected_wire_type = _WireTypeForFieldType(field_type)
+ expected_wire_type = _WireTypeForField(field_descriptor)
if wire_type != expected_wire_type:
# Need to fill in uninterpreted_bytes. Work for the next CL.
raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.')
property_name = _PropertyName(field_descriptor.name)
label = field_descriptor.label
+ field_type = field_descriptor.type
cpp_type = field_descriptor.cpp_type
# Nonrepeated scalar. Just set the field directly.
@@ -1000,8 +1064,17 @@ def _DeserializeOneEntity(message_descriptor, message, decoder):
if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE:
# Repeated scalar.
- element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
- return decoder.Position() - initial_position
+ if not field_descriptor.GetOptions().packed:
+ element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - initial_position
+ else:
+ # Packed repeated field.
+ length = _DeserializeScalarFromDecoder(
+ _FieldDescriptor.TYPE_INT32, decoder)
+ content_start = decoder.Position()
+ while decoder.Position() - content_start < length:
+ element_list.append(_DeserializeScalarFromDecoder(field_type, decoder))
+ return decoder.Position() - content_start
else:
# Repeated composite.
composite = element_list.add()
diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py
index 3989216a..9ec42fe3 100755
--- a/python/google/protobuf/service.py
+++ b/python/google/protobuf/service.py
@@ -31,7 +31,7 @@
"""Declares the RPC service interfaces.
This module declares the abstract interfaces underlying proto2 RPC
-services. These are intented to be independent of any particular RPC
+services. These are intended to be independent of any particular RPC
implementation, so that proto2 services can be used on top of a variety
of implementations.
"""
@@ -39,6 +39,11 @@ of implementations.
__author__ = 'petar@google.com (Petar Petrov)'
+class RpcException(Exception):
+ """Exception raised on failed blocking RPC method call."""
+ pass
+
+
class Service(object):
"""Abstract base interface for protocol-buffer-based RPC services.
@@ -49,7 +54,7 @@ class Service(object):
its exact type at compile time (analogous to the Message interface).
"""
- def GetDescriptor(self):
+ def GetDescriptor():
"""Retrieves this service's descriptor."""
raise NotImplementedError
@@ -57,6 +62,14 @@ class Service(object):
request, done):
"""Calls a method of the service specified by method_descriptor.
+ If "done" is None then the call is blocking and the response
+ message will be returned directly. Otherwise the call is asynchronous
+ and "done" will later be called with the response value.
+
+ In the blocking case, RpcException will be raised on error.
+ Asynchronous calls must check status via the Failed method of the
+ RpcController.
+
Preconditions:
* method_descriptor.service == GetDescriptor
* request is of the exact same classes as returned by
diff --git a/python/google/protobuf/service_reflection.py b/python/google/protobuf/service_reflection.py
index bdd6bad5..851e83e7 100755
--- a/python/google/protobuf/service_reflection.py
+++ b/python/google/protobuf/service_reflection.py
@@ -142,24 +142,17 @@ class _ServiceBuilder(object):
# instance to the method that does the real CallMethod work.
def _WrapCallMethod(srvc, method_descriptor,
rpc_controller, request, callback):
- self._CallMethod(srvc, method_descriptor,
+ return self._CallMethod(srvc, method_descriptor,
rpc_controller, request, callback)
self.cls = cls
cls.CallMethod = _WrapCallMethod
- cls.GetDescriptor = self._GetDescriptor
+ cls.GetDescriptor = staticmethod(lambda: self.descriptor)
+ cls.GetDescriptor.__doc__ = "Returns the service descriptor."
cls.GetRequestClass = self._GetRequestClass
cls.GetResponseClass = self._GetResponseClass
for method in self.descriptor.methods:
setattr(cls, method.name, self._GenerateNonImplementedMethod(method))
- def _GetDescriptor(self):
- """Retrieves the service descriptor.
-
- Returns:
- The descriptor of the service (of type ServiceDescriptor).
- """
- return self.descriptor
-
def _CallMethod(self, srvc, method_descriptor,
rpc_controller, request, callback):
"""Calls the method described by a given method descriptor.
@@ -175,7 +168,7 @@ class _ServiceBuilder(object):
raise RuntimeError(
'CallMethod() given method descriptor for wrong service type.')
method = getattr(srvc, method_descriptor.name)
- method(rpc_controller, request, callback)
+ return method(rpc_controller, request, callback)
def _GetRequestClass(self, method_descriptor):
"""Returns the class of the request protocol message.
@@ -270,8 +263,8 @@ class _ServiceStubBuilder(object):
setattr(cls, method.name, self._GenerateStubMethod(method))
def _GenerateStubMethod(self, method):
- return lambda inst, rpc_controller, request, callback: self._StubMethod(
- inst, method, rpc_controller, request, callback)
+ return (lambda inst, rpc_controller, request, callback=None:
+ self._StubMethod(inst, method, rpc_controller, request, callback))
def _StubMethod(self, stub, method_descriptor,
rpc_controller, request, callback):
@@ -283,7 +276,9 @@ class _ServiceStubBuilder(object):
rpc_controller: Rpc controller to execute the method.
request: Request protocol message.
callback: A callback to execute when the method finishes.
+ Returns:
+ Response message (in case of blocking call).
"""
- stub.rpc_channel.CallMethod(
+ return stub.rpc_channel.CallMethod(
method_descriptor, rpc_controller, request,
method_descriptor.output_type._concrete_class, callback)