diff options
Diffstat (limited to 'python')
-rwxr-xr-x | python/google/protobuf/internal/containers.py | 13 | ||||
-rwxr-xr-x | python/google/protobuf/internal/decoder_test.py | 2 | ||||
-rwxr-xr-x | python/google/protobuf/internal/encoder.py | 140 | ||||
-rwxr-xr-x | python/google/protobuf/internal/encoder_test.py | 70 | ||||
-rwxr-xr-x | python/google/protobuf/internal/reflection_test.py | 120 | ||||
-rwxr-xr-x | python/google/protobuf/internal/service_reflection_test.py | 6 | ||||
-rwxr-xr-x | python/google/protobuf/internal/test_util.py | 48 | ||||
-rwxr-xr-x | python/google/protobuf/internal/type_checkers.py | 17 | ||||
-rwxr-xr-x | python/google/protobuf/internal/wire_format.py | 35 | ||||
-rwxr-xr-x | python/google/protobuf/reflection.py | 99 | ||||
-rwxr-xr-x | python/google/protobuf/service.py | 17 | ||||
-rwxr-xr-x | python/google/protobuf/service_reflection.py | 23 |
12 files changed, 494 insertions, 96 deletions
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 14fe863e..fa1e3402 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -106,6 +106,19 @@ class RepeatedScalarFieldContainer(BaseContainer): if len(self._values) == 1: self._message_listener.TransitionToNonempty() + def extend(self, elem_seq): + """Extends by appending the given sequence. Similar to list.extend().""" + if not elem_seq: + return + + orig_empty = len(self._values) == 0 + for elem in elem_seq: + self._type_checker.CheckValue(elem) + self._values.extend(elem_seq) + self._message_listener.ByteSizeDirty() + if orig_empty: + self._message_listener.TransitionToNonempty() + def remove(self, elem): """Removes an item from the list. Similar to list.remove().""" self._values.remove(elem) diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py index abcc07fc..e186e14d 100755 --- a/python/google/protobuf/internal/decoder_test.py +++ b/python/google/protobuf/internal/decoder_test.py @@ -57,7 +57,7 @@ class DecoderTest(unittest.TestCase): for expected_field_number in (1, 15, 16, 2047, 2048): for expected_wire_type in range(6): # Highest-numbered wiretype is 5. e = encoder.Encoder() - e._AppendTag(expected_field_number, expected_wire_type) + e.AppendTag(expected_field_number, expected_wire_type) s = e.ToString() d = decoder.Decoder(s) field_number, wire_type = d.ReadFieldNumberAndWireType() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index 7071241c..eed8c8bd 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -58,89 +58,161 @@ class Encoder(object): """Returns all values encoded in this object as a string.""" return self._stream.ToString() - # All the Append*() methods below first append a tag+type pair to the buffer - # before appending the specified value. - - def AppendInt32(self, field_number, value): + # Append*NoTag methods. These are necessary for serializing packed + # repeated fields. The Append*() methods call these methods to do + # the actual serialization. + def AppendInt32NoTag(self, value): """Appends a 32-bit integer to our buffer, varint-encoded.""" - self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) self._stream.AppendVarint32(value) - def AppendInt64(self, field_number, value): + def AppendInt64NoTag(self, value): """Appends a 64-bit integer to our buffer, varint-encoded.""" - self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) self._stream.AppendVarint64(value) - def AppendUInt32(self, field_number, unsigned_value): + def AppendUInt32NoTag(self, unsigned_value): """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" - self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) self._stream.AppendVarUInt32(unsigned_value) - def AppendUInt64(self, field_number, unsigned_value): + def AppendUInt64NoTag(self, unsigned_value): """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" - self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) self._stream.AppendVarUInt64(unsigned_value) - def AppendSInt32(self, field_number, value): + def AppendSInt32NoTag(self, value): """Appends a 32-bit integer to our buffer, zigzag-encoded and then varint-encoded. """ - self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) zigzag_value = wire_format.ZigZagEncode(value) self._stream.AppendVarUInt32(zigzag_value) - def AppendSInt64(self, field_number, value): + def AppendSInt64NoTag(self, value): """Appends a 64-bit integer to our buffer, zigzag-encoded and then varint-encoded. """ - self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) zigzag_value = wire_format.ZigZagEncode(value) self._stream.AppendVarUInt64(zigzag_value) - def AppendFixed32(self, field_number, unsigned_value): + def AppendFixed32NoTag(self, unsigned_value): """Appends an unsigned 32-bit integer to our buffer, in little-endian byte-order. """ - self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) self._stream.AppendLittleEndian32(unsigned_value) - def AppendFixed64(self, field_number, unsigned_value): + def AppendFixed64NoTag(self, unsigned_value): """Appends an unsigned 64-bit integer to our buffer, in little-endian byte-order. """ - self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) self._stream.AppendLittleEndian64(unsigned_value) - def AppendSFixed32(self, field_number, value): + def AppendSFixed32NoTag(self, value): """Appends a signed 32-bit integer to our buffer, in little-endian byte-order. """ sign = (value & 0x80000000) and -1 or 0 if value >> 32 != sign: raise message.EncodeError('SFixed32 out of range: %d' % value) - self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) self._stream.AppendLittleEndian32(value & 0xffffffff) - def AppendSFixed64(self, field_number, value): + def AppendSFixed64NoTag(self, value): """Appends a signed 64-bit integer to our buffer, in little-endian byte-order. """ sign = (value & 0x8000000000000000) and -1 or 0 if value >> 64 != sign: raise message.EncodeError('SFixed64 out of range: %d' % value) - self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) self._stream.AppendLittleEndian64(value & 0xffffffffffffffff) - def AppendFloat(self, field_number, value): + def AppendFloatNoTag(self, value): """Appends a floating-point number to our buffer.""" - self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) self._stream.AppendRawBytes(struct.pack('f', value)) - def AppendDouble(self, field_number, value): + def AppendDoubleNoTag(self, value): """Appends a double-precision floating-point number to our buffer.""" - self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) self._stream.AppendRawBytes(struct.pack('d', value)) + def AppendBoolNoTag(self, value): + """Appends a boolean to our buffer.""" + self.AppendInt32NoTag(value) + + def AppendEnumNoTag(self, value): + """Appends an enum value to our buffer.""" + self.AppendInt32NoTag(value) + + + # All the Append*() methods below first append a tag+type pair to the buffer + # before appending the specified value. + + def AppendInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendInt32NoTag(value) + + def AppendInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendInt64NoTag(value) + + def AppendUInt32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendUInt32NoTag(unsigned_value) + + def AppendUInt64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendUInt64NoTag(unsigned_value) + + def AppendSInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendSInt32NoTag(value) + + def AppendSInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self.AppendSInt64NoTag(value) + + def AppendFixed32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self.AppendFixed32NoTag(unsigned_value) + + def AppendFixed64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self.AppendFixed64NoTag(unsigned_value) + + def AppendSFixed32(self, field_number, value): + """Appends a signed 32-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self.AppendSFixed32NoTag(value) + + def AppendSFixed64(self, field_number, value): + """Appends a signed 64-bit integer to our buffer, in little-endian + byte-order. + """ + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self.AppendSFixed64NoTag(value) + + def AppendFloat(self, field_number, value): + """Appends a floating-point number to our buffer.""" + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self.AppendFloatNoTag(value) + + def AppendDouble(self, field_number, value): + """Appends a double-precision floating-point number to our buffer.""" + self.AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self.AppendDoubleNoTag(value) + def AppendBool(self, field_number, value): """Appends a boolean to our buffer.""" self.AppendInt32(field_number, value) @@ -159,7 +231,7 @@ class Encoder(object): """Appends a length-prefixed sequence of bytes to our buffer, with the length varint-encoded. """ - self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) self._stream.AppendVarUInt32(len(value)) self._stream.AppendRawBytes(value) @@ -174,14 +246,14 @@ class Encoder(object): def AppendGroup(self, field_number, group): """Appends a group to our buffer. """ - self._AppendTag(field_number, wire_format.WIRETYPE_START_GROUP) + self.AppendTag(field_number, wire_format.WIRETYPE_START_GROUP) self._stream.AppendRawBytes(group.SerializeToString()) - self._AppendTag(field_number, wire_format.WIRETYPE_END_GROUP) + self.AppendTag(field_number, wire_format.WIRETYPE_END_GROUP) def AppendMessage(self, field_number, msg): """Appends a nested message to our buffer. """ - self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) self._stream.AppendVarUInt32(msg.ByteSize()) self._stream.AppendRawBytes(msg.SerializeToString()) @@ -196,11 +268,11 @@ class Encoder(object): } } """ - self._AppendTag(1, wire_format.WIRETYPE_START_GROUP) + self.AppendTag(1, wire_format.WIRETYPE_START_GROUP) self.AppendInt32(2, field_number) self.AppendMessage(3, msg) - self._AppendTag(1, wire_format.WIRETYPE_END_GROUP) + self.AppendTag(1, wire_format.WIRETYPE_END_GROUP) - def _AppendTag(self, field_number, wire_type): + def AppendTag(self, field_number, wire_type): """Appends a tag containing field number and wire type information.""" self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type)) diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py index 61668223..83a21c39 100755 --- a/python/google/protobuf/internal/encoder_test.py +++ b/python/google/protobuf/internal/encoder_test.py @@ -59,7 +59,8 @@ class EncoderTest(unittest.TestCase): def AppendScalarTestHelper(self, test_name, encoder_method, expected_stream_method_name, wire_type, field_value, - expected_value=None, expected_length=None): + expected_value=None, expected_length=None, + is_tag_test=True): """Helper for testAppendScalars. Calls one of the Encoder methods, and ensures that the Encoder @@ -67,9 +68,10 @@ class EncoderTest(unittest.TestCase): Args: test_name: Name of this test, used only for logging. - encoder_method: Callable on self.encoder, which should - accept |field_value| as an argument. This is the Encoder - method we're testing. + encoder_method: Callable on self.encoder. This is the Encoder + method we're testing. If is_tag_test=True, the encoder method + accepts a field_number and field_value. if is_tag_test=False, + the encoder method accepts a field_value. expected_stream_method_name: (string) Name of the OutputStream method we expect Encoder to call to actually put the value on the wire. @@ -83,6 +85,9 @@ class EncoderTest(unittest.TestCase): expected_length: The length we expect Encoder to pass to the AppendVarUInt32 method. If None we expect the length of the field_value. + is_tag_test: A Boolean. If True (the default), we append the + the packed field number and wire_type to the stream before + the field value. """ if expected_value is None: expected_value = field_value @@ -93,14 +98,16 @@ class EncoderTest(unittest.TestCase): test_name, encoder_method, field_value, expected_stream_method_name, expected_value)) - field_number = 10 - # Should first append the field number and type information. - self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type)) - # If we're length-delimited, we should then append the length. - if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: - if expected_length is None: - expected_length = len(field_value) - self.mock_stream.AppendVarUInt32(expected_length) + if is_tag_test: + field_number = 10 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type)) + # If we're length-delimited, we should then append the length. + if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + if expected_length is None: + expected_length = len(field_value) + self.mock_stream.AppendVarUInt32(expected_length) + # Should then append the value itself. # We have to use names instead of methods to work around some # mox weirdness. (ResetAll() is overzealous). @@ -109,7 +116,10 @@ class EncoderTest(unittest.TestCase): expected_stream_method(expected_value) self.mox.ReplayAll() - encoder_method(field_number, field_value) + if is_tag_test: + encoder_method(field_number, field_value) + else: + encoder_method(field_value) self.mox.VerifyAll() self.mox.ResetAll() @@ -160,6 +170,40 @@ class EncoderTest(unittest.TestCase): for args in scalar_tests: self.AppendScalarTestHelper(*args) + def testAppendScalarsWithoutTags(self): + scalar_no_tag_tests = [ + ['int32', self.encoder.AppendInt32NoTag, 'AppendVarint32', None, 0], + ['int64', self.encoder.AppendInt64NoTag, 'AppendVarint64', None, 0], + ['uint32', self.encoder.AppendUInt32NoTag, 'AppendVarUInt32', None, 0], + ['uint64', self.encoder.AppendUInt64NoTag, 'AppendVarUInt64', None, 0], + ['fixed32', self.encoder.AppendFixed32NoTag, + 'AppendLittleEndian32', None, 0], + ['fixed64', self.encoder.AppendFixed64NoTag, + 'AppendLittleEndian64', None, 0], + ['sfixed32', self.encoder.AppendSFixed32NoTag, + 'AppendLittleEndian32', None, 0], + ['sfixed64', self.encoder.AppendSFixed64NoTag, + 'AppendLittleEndian64', None, 0], + ['float', self.encoder.AppendFloatNoTag, + 'AppendRawBytes', None, 0.0, struct.pack('f', 0.0)], + ['double', self.encoder.AppendDoubleNoTag, + 'AppendRawBytes', None, 0.0, struct.pack('d', 0.0)], + ['bool', self.encoder.AppendBoolNoTag, 'AppendVarint32', None, 0], + ['enum', self.encoder.AppendEnumNoTag, 'AppendVarint32', None, 0], + ['sint32', self.encoder.AppendSInt32NoTag, + 'AppendVarUInt32', None, -1, 1], + ['sint64', self.encoder.AppendSInt64NoTag, + 'AppendVarUInt64', None, -1, 1], + ] + + self.assertEqual(len(scalar_no_tag_tests), + len(set(t[0] for t in scalar_no_tag_tests))) + self.assert_(len(scalar_no_tag_tests) >= + len(set(t[1] for t in scalar_no_tag_tests))) + for args in scalar_no_tag_tests: + # For no tag tests, the wire_type is not used, so we put in None. + self.AppendScalarTestHelper(is_tag_test=False, *args) + def testAppendGroup(self): field_number = 23 # Should first append the start-group marker. diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 14062762..e405f60b 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -229,8 +229,8 @@ class ReflectionTest(unittest.TestCase): proto.repeated_fixed32.append(1) proto.repeated_int32.append(5) proto.repeated_int32.append(11) - proto.repeated_string.append('foo') - proto.repeated_string.append('bar') + proto.repeated_string.extend(['foo', 'bar']) + proto.repeated_string.extend([]) proto.repeated_string.append('baz') proto.optional_int32 = 21 self.assertEqual( @@ -757,6 +757,16 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(KeyError, extendee_proto.HasExtension, unittest_pb2.repeated_string_extension) + def testStaticParseFrom(self): + proto1 = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto1) + + string1 = proto1.SerializeToString() + proto2 = unittest_pb2.TestAllTypes.FromString(string1) + + # Messages should be equal. + self.assertEqual(proto2, proto1) + def testMergeFromSingularField(self): # Test merge with just a singular field. proto1 = unittest_pb2.TestAllTypes() @@ -1209,6 +1219,8 @@ class ByteSizeTest(unittest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() self.extended_proto = more_extensions_pb2.ExtendedMessage() + self.packed_proto = unittest_pb2.TestPackedTypes() + self.packed_extended_proto = unittest_pb2.TestPackedExtensions() def Size(self): return self.proto.ByteSize() @@ -1291,6 +1303,11 @@ class ByteSizeTest(unittest.TestCase): # Also need 2 bytes for each entry for tag. self.assertEqual(1 + 2 + 2*2, self.Size()) + def testRepeatedScalarsExtend(self): + self.proto.repeated_int32.extend([10, 128]) # 3 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + def testRepeatedScalarsRemove(self): self.proto.repeated_int32.append(10) # 1 byte. self.proto.repeated_int32.append(128) # 2 bytes. @@ -1443,6 +1460,33 @@ class ByteSizeTest(unittest.TestCase): self.extended_proto.ClearExtension(extension) self.assertEqual(0, self.extended_proto.ByteSize()) + def testPackedRepeatedScalars(self): + self.assertEqual(0, self.packed_proto.ByteSize()) + + self.packed_proto.packed_int32.append(10) # 1 byte. + self.packed_proto.packed_int32.append(128) # 2 bytes. + # The tag is 2 bytes (the field number is 90), and the varint + # storing the length is 1 byte. + int_size = 1 + 2 + 3 + self.assertEqual(int_size, self.packed_proto.ByteSize()) + + self.packed_proto.packed_double.append(4.2) # 8 bytes + self.packed_proto.packed_double.append(3.25) # 8 bytes + # 2 more tag bytes, 1 more length byte. + double_size = 8 + 8 + 3 + self.assertEqual(int_size+double_size, self.packed_proto.ByteSize()) + + self.packed_proto.ClearField('packed_int32') + self.assertEqual(double_size, self.packed_proto.ByteSize()) + + def testPackedExtensions(self): + self.assertEqual(0, self.packed_extended_proto.ByteSize()) + extension = self.packed_extended_proto.Extensions[ + unittest_pb2.packed_fixed32_extension] + extension.extend([1, 2, 3, 4]) # 16 bytes + # Tag is 3 bytes. + self.assertEqual(19, self.packed_extended_proto.ByteSize()) + # TODO(robinson): We need cross-language serialization consistency tests. # Issues to be sure to cover include: @@ -1686,6 +1730,63 @@ class SerializationTest(unittest.TestCase): self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) + def testSerializedAllPackedFields(self): + first_proto = unittest_pb2.TestPackedTypes() + second_proto = unittest_pb2.TestPackedTypes() + test_util.SetAllPackedFields(first_proto) + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllPackedExtensions(self): + first_proto = unittest_pb2.TestPackedExtensions() + second_proto = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(first_proto) + serialized = first_proto.SerializeToString() + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testMergePackedFromStringWhenSomeFieldsAlreadySet(self): + first_proto = unittest_pb2.TestPackedTypes() + first_proto.packed_int32.extend([1, 2]) + first_proto.packed_double.append(3.0) + serialized = first_proto.SerializeToString() + + second_proto = unittest_pb2.TestPackedTypes() + second_proto.packed_int32.append(3) + second_proto.packed_double.extend([1.0, 2.0]) + second_proto.packed_sint32.append(4) + + second_proto.MergeFromString(serialized) + self.assertEqual([3, 1, 2], second_proto.packed_int32) + self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) + self.assertEqual([4], second_proto.packed_sint32) + + def testPackedFieldsWireFormat(self): + proto = unittest_pb2.TestPackedTypes() + proto.packed_int32.extend([1, 2, 150, 3]) # 1 + 1 + 2 + 1 bytes + proto.packed_double.extend([1.0, 1000.0]) # 8 + 8 bytes + proto.packed_float.append(2.0) # 4 bytes, will be before double + serialized = proto.SerializeToString() + self.assertEqual(proto.ByteSize(), len(serialized)) + d = decoder.Decoder(serialized) + ReadTag = d.ReadFieldNumberAndWireType + self.assertEqual((90, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(1+1+1+2, d.ReadInt32()) + self.assertEqual(1, d.ReadInt32()) + self.assertEqual(2, d.ReadInt32()) + self.assertEqual(150, d.ReadInt32()) + self.assertEqual(3, d.ReadInt32()) + self.assertEqual((100, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(4, d.ReadInt32()) + self.assertEqual(2.0, d.ReadFloat()) + self.assertEqual((101, wire_format.WIRETYPE_LENGTH_DELIMITED), ReadTag()) + self.assertEqual(8+8, d.ReadInt32()) + self.assertEqual(1.0, d.ReadDouble()) + self.assertEqual(1000.0, d.ReadDouble()) + self.assertTrue(d.EndOfStream()) + class OptionsTest(unittest.TestCase): @@ -1697,6 +1798,21 @@ class OptionsTest(unittest.TestCase): self.assertEqual(False, proto.DESCRIPTOR.GetOptions().message_set_wire_format) + def testPackedOptions(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_int32 = 1 + proto.optional_double = 3.0 + for field_descriptor, _ in proto.ListFields(): + self.assertEqual(False, field_descriptor.GetOptions().packed) + + proto = unittest_pb2.TestPackedTypes() + proto.packed_int32.append(1) + proto.packed_double.append(3.0) + for field_descriptor, _ in proto.ListFields(): + self.assertEqual(True, field_descriptor.GetOptions().packed) + self.assertEqual(reflection._FieldDescriptor.LABEL_REPEATED, + field_descriptor.label) + class UtilityTest(unittest.TestCase): diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index 29492e16..e04f8252 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -74,7 +74,7 @@ class FooUnitTest(unittest.TestCase): rpc_controller.failure_message = None - service_descriptor = unittest_pb2.TestService.DESCRIPTOR + service_descriptor = unittest_pb2.TestService.GetDescriptor() srvc.CallMethod(service_descriptor.methods[1], rpc_controller, unittest_pb2.BarRequest(), MyCallback) self.assertEqual('Method Bar not implemented.', @@ -118,6 +118,10 @@ class FooUnitTest(unittest.TestCase): rpc_controller = 'controller' request = 'request' + # GetDescriptor now static, still works as instance method for compatability + self.assertEqual(unittest_pb2.TestService_Stub.GetDescriptor(), + stub.GetDescriptor()) + # Invoke method. stub.Foo(rpc_controller, request, MyCallback) diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 14dfbc51..2d50bc4a 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -366,3 +366,51 @@ def GoldenFile(filename): 'Could not find golden files. This test must be run from within the ' 'protobuf source package so that it can read test data files from the ' 'C++ source tree.') + + +def SetAllPackedFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestPackedTypes instance. + """ + message.packed_int32.extend([101, 102]) + message.packed_int64.extend([103, 104]) + message.packed_uint32.extend([105, 106]) + message.packed_uint64.extend([107, 108]) + message.packed_sint32.extend([109, 110]) + message.packed_sint64.extend([111, 112]) + message.packed_fixed32.extend([113, 114]) + message.packed_fixed64.extend([115, 116]) + message.packed_sfixed32.extend([117, 118]) + message.packed_sfixed64.extend([119, 120]) + message.packed_float.extend([121.0, 122.0]) + message.packed_double.extend([122.0, 123.0]) + message.packed_bool.extend([True, False]) + message.packed_enum.extend([unittest_pb2.FOREIGN_FOO, + unittest_pb2.FOREIGN_BAR]) + + +def SetAllPackedExtensions(message): + """Sets every extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestPackedExtensions instance. + """ + extensions = message.Extensions + pb2 = unittest_pb2 + + extensions[pb2.packed_int32_extension].append(101) + extensions[pb2.packed_int64_extension].append(102) + extensions[pb2.packed_uint32_extension].append(103) + extensions[pb2.packed_uint64_extension].append(104) + extensions[pb2.packed_sint32_extension].append(105) + extensions[pb2.packed_sint64_extension].append(106) + extensions[pb2.packed_fixed32_extension].append(107) + extensions[pb2.packed_fixed64_extension].append(108) + extensions[pb2.packed_sfixed32_extension].append(109) + extensions[pb2.packed_sfixed64_extension].append(110) + extensions[pb2.packed_float_extension].append(111.0) + extensions[pb2.packed_double_extension].append(112.0) + extensions[pb2.packed_bool_extension].append(True) + extensions[pb2.packed_enum_extension].append(pb2.FOREIGN_BAZ) diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index ba470cf2..c009627f 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -216,6 +216,23 @@ TYPE_TO_SERIALIZE_METHOD = { } +TYPE_TO_NOTAG_SERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDoubleNoTag, + _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloatNoTag, + _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64NoTag, + _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64NoTag, + _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32NoTag, + _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64NoTag, + _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32NoTag, + _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBoolNoTag, + _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32NoTag, + _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnumNoTag, + _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32NoTag, + _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64NoTag, + _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32NoTag, + _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64NoTag, + } + # Maps from field type to expected wiretype. FIELD_TYPE_TO_WIRE_TYPE = { _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py index 5f0af11e..531c9b85 100755 --- a/python/google/protobuf/internal/wire_format.py +++ b/python/google/protobuf/internal/wire_format.py @@ -120,6 +120,10 @@ def Int32ByteSize(field_number, int32): return Int64ByteSize(field_number, int32) +def Int32ByteSizeNoTag(int32): + return _VarUInt64ByteSizeNoTag(0xffffffffffffffff & int32) + + def Int64ByteSize(field_number, int64): # Have to convert to uint before calling UInt64ByteSize(). return UInt64ByteSize(field_number, 0xffffffffffffffff & int64) @@ -130,7 +134,7 @@ def UInt32ByteSize(field_number, uint32): def UInt64ByteSize(field_number, uint64): - return _TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) + return TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) def SInt32ByteSize(field_number, int32): @@ -142,31 +146,31 @@ def SInt64ByteSize(field_number, int64): def Fixed32ByteSize(field_number, fixed32): - return _TagByteSize(field_number) + 4 + return TagByteSize(field_number) + 4 def Fixed64ByteSize(field_number, fixed64): - return _TagByteSize(field_number) + 8 + return TagByteSize(field_number) + 8 def SFixed32ByteSize(field_number, sfixed32): - return _TagByteSize(field_number) + 4 + return TagByteSize(field_number) + 4 def SFixed64ByteSize(field_number, sfixed64): - return _TagByteSize(field_number) + 8 + return TagByteSize(field_number) + 8 def FloatByteSize(field_number, flt): - return _TagByteSize(field_number) + 4 + return TagByteSize(field_number) + 4 def DoubleByteSize(field_number, double): - return _TagByteSize(field_number) + 8 + return TagByteSize(field_number) + 8 def BoolByteSize(field_number, b): - return _TagByteSize(field_number) + 1 + return TagByteSize(field_number) + 1 def EnumByteSize(field_number, enum): @@ -178,18 +182,18 @@ def StringByteSize(field_number, string): def BytesByteSize(field_number, b): - return (_TagByteSize(field_number) + return (TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(len(b)) + len(b)) def GroupByteSize(field_number, message): - return (2 * _TagByteSize(field_number) # START and END group. + return (2 * TagByteSize(field_number) # START and END group. + message.ByteSize()) def MessageByteSize(field_number, message): - return (_TagByteSize(field_number) + return (TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(message.ByteSize()) + message.ByteSize()) @@ -199,7 +203,7 @@ def MessageSetItemByteSize(field_number, msg): # There are 2 tags for the beginning and ending of the repeated group, that # is field number 1, one with field number 2 (type_id) and one with field # number 3 (message). - total_size = (2 * _TagByteSize(1) + _TagByteSize(2) + _TagByteSize(3)) + total_size = (2 * TagByteSize(1) + TagByteSize(2) + TagByteSize(3)) # Add the number of bytes for type_id. total_size += _VarUInt64ByteSizeNoTag(field_number) @@ -214,15 +218,14 @@ def MessageSetItemByteSize(field_number, msg): return total_size -# Private helper functions for the *ByteSize() functions above. - - -def _TagByteSize(field_number): +def TagByteSize(field_number): """Returns the bytes required to serialize a tag with this field number.""" # Just pass in type 0, since the type won't affect the tag+type size. return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) +# Private helper function for the *ByteSize() functions above. + def _VarUInt64ByteSizeNoTag(uint64): """Returns the bytes required to serialize a single varint. uint64 must be unsigned. diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 9ba752e4..0d5191be 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -462,6 +462,12 @@ def _AddStaticMethods(cls): cls._known_extensions.append(extension_handle) cls.RegisterExtension = staticmethod(RegisterExtension) + def FromString(s): + message = cls() + message.MergeFromString(s) + return message + cls.FromString = staticmethod(FromString) + def _AddListFieldsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" @@ -665,9 +671,36 @@ def _AddByteSizeMethod(message_descriptor, cls): else: elements = [value] - size = sum(_BytesForNonRepeatedElement(element, field_number, field_type) - for element in elements) - return size + if field.GetOptions().packed: + content_size = _ContentBytesForPackedField(message, field, elements) + if content_size: + tag_size = wire_format.TagByteSize(field_number) + length_size = wire_format.Int32ByteSizeNoTag(content_size) + return tag_size + length_size + content_size + else: + return 0 + else: + return sum(_BytesForNonRepeatedElement(element, field_number, field_type) + for element in elements) + + def _ContentBytesForPackedField(self, field, value): + """Returns the number of bytes required to serialize the actual + content of a packed field (not including the tag or the encoding + of the length. + + Args: + self: The Message instance containing a field of the given type. + field: A FieldDescriptor describing the field of interest. + value: The value whose byte size we're interested in. + + Returns: The number of bytes required to serialize the current value + of the packed "field" in "message", excluding space for tags and the + length encoding. + """ + size = sum(_BytesForNonRepeatedElement(element, field.number, field.type) + for element in value) + # In the packed case, there are no per element tags. + return size - wire_format.TagByteSize(field.number) * len(value) fields = message_descriptor.fields has_field_names = (_HasFieldName(f.name) for f in fields) @@ -691,6 +724,8 @@ def _AddByteSizeMethod(message_descriptor, cls): self._cached_byte_size = size self._cached_byte_size_dirty = False return size + + cls._ContentBytesForPackedField = _ContentBytesForPackedField cls.ByteSize = ByteSize @@ -788,10 +823,29 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): repeated_value = field_value else: repeated_value = [field_value] - for element in repeated_value: - _SerializeValueToEncoder(element, field_descriptor.number, - field_descriptor, encoder) + if field_descriptor.GetOptions().packed: + # First, write the field number and WIRETYPE_LENGTH_DELIMITED. + field_number = field_descriptor.number + encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + # Next, write the number of bytes. + content_bytes = self._ContentBytesForPackedField( + field_descriptor, field_value) + encoder.AppendInt32NoTag(content_bytes) + # Finally, write the actual values. + try: + method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[ + field_descriptor.type] + for value in repeated_value: + method(encoder, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % + field_descriptor.type) + else: + for element in repeated_value: + _SerializeValueToEncoder(element, field_descriptor.number, + field_descriptor, encoder) return encoder.ToString() + cls.SerializePartialToString = SerializePartialToString @@ -803,6 +857,14 @@ def _WireTypeForFieldType(field_type): raise message_mod.DecodeError('Unknown field type: %d' % field_type) +def _WireTypeForField(field_descriptor): + """Given a field descriptor, returns the expected wire type.""" + if field_descriptor.GetOptions().packed: + return wire_format.WIRETYPE_LENGTH_DELIMITED + else: + return _WireTypeForFieldType(field_descriptor.type) + + def _RecursivelyMerge(field_number, field_type, decoder, message): """Decodes a message from decoder into message. message is either a group or a nested message within some containing @@ -918,9 +980,11 @@ def _DeserializeMessageSetItem(message, decoder): def _DeserializeOneEntity(message_descriptor, message, decoder): """Deserializes the next wire entity from decoder into message. - The next wire entity is either a scalar or a nested message, - and may also be an element in a repeated field (the wire encoding - is the same). + + The next wire entity is either a scalar or a nested message, an + element in a repeated field (the wire encoding in this case is the + same), or a packed repeated field (in this case, the entire repeated + field is read by a single call to _DeserializeOneEntity). Args: message_descriptor: A Descriptor instance describing all fields @@ -973,14 +1037,14 @@ def _DeserializeOneEntity(message_descriptor, message, decoder): # if this field is a nonrepeated scalar. field_number = field_descriptor.number - field_type = field_descriptor.type - expected_wire_type = _WireTypeForFieldType(field_type) + expected_wire_type = _WireTypeForField(field_descriptor) if wire_type != expected_wire_type: # Need to fill in uninterpreted_bytes. Work for the next CL. raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') property_name = _PropertyName(field_descriptor.name) label = field_descriptor.label + field_type = field_descriptor.type cpp_type = field_descriptor.cpp_type # Nonrepeated scalar. Just set the field directly. @@ -1000,8 +1064,17 @@ def _DeserializeOneEntity(message_descriptor, message, decoder): if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: # Repeated scalar. - element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position + if not field_descriptor.GetOptions().packed: + element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + else: + # Packed repeated field. + length = _DeserializeScalarFromDecoder( + _FieldDescriptor.TYPE_INT32, decoder) + content_start = decoder.Position() + while decoder.Position() - content_start < length: + element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - content_start else: # Repeated composite. composite = element_list.add() diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py index 3989216a..9ec42fe3 100755 --- a/python/google/protobuf/service.py +++ b/python/google/protobuf/service.py @@ -31,7 +31,7 @@ """Declares the RPC service interfaces. This module declares the abstract interfaces underlying proto2 RPC -services. These are intented to be independent of any particular RPC +services. These are intended to be independent of any particular RPC implementation, so that proto2 services can be used on top of a variety of implementations. """ @@ -39,6 +39,11 @@ of implementations. __author__ = 'petar@google.com (Petar Petrov)' +class RpcException(Exception): + """Exception raised on failed blocking RPC method call.""" + pass + + class Service(object): """Abstract base interface for protocol-buffer-based RPC services. @@ -49,7 +54,7 @@ class Service(object): its exact type at compile time (analogous to the Message interface). """ - def GetDescriptor(self): + def GetDescriptor(): """Retrieves this service's descriptor.""" raise NotImplementedError @@ -57,6 +62,14 @@ class Service(object): request, done): """Calls a method of the service specified by method_descriptor. + If "done" is None then the call is blocking and the response + message will be returned directly. Otherwise the call is asynchronous + and "done" will later be called with the response value. + + In the blocking case, RpcException will be raised on error. + Asynchronous calls must check status via the Failed method of the + RpcController. + Preconditions: * method_descriptor.service == GetDescriptor * request is of the exact same classes as returned by diff --git a/python/google/protobuf/service_reflection.py b/python/google/protobuf/service_reflection.py index bdd6bad5..851e83e7 100755 --- a/python/google/protobuf/service_reflection.py +++ b/python/google/protobuf/service_reflection.py @@ -142,24 +142,17 @@ class _ServiceBuilder(object): # instance to the method that does the real CallMethod work. def _WrapCallMethod(srvc, method_descriptor, rpc_controller, request, callback): - self._CallMethod(srvc, method_descriptor, + return self._CallMethod(srvc, method_descriptor, rpc_controller, request, callback) self.cls = cls cls.CallMethod = _WrapCallMethod - cls.GetDescriptor = self._GetDescriptor + cls.GetDescriptor = staticmethod(lambda: self.descriptor) + cls.GetDescriptor.__doc__ = "Returns the service descriptor." cls.GetRequestClass = self._GetRequestClass cls.GetResponseClass = self._GetResponseClass for method in self.descriptor.methods: setattr(cls, method.name, self._GenerateNonImplementedMethod(method)) - def _GetDescriptor(self): - """Retrieves the service descriptor. - - Returns: - The descriptor of the service (of type ServiceDescriptor). - """ - return self.descriptor - def _CallMethod(self, srvc, method_descriptor, rpc_controller, request, callback): """Calls the method described by a given method descriptor. @@ -175,7 +168,7 @@ class _ServiceBuilder(object): raise RuntimeError( 'CallMethod() given method descriptor for wrong service type.') method = getattr(srvc, method_descriptor.name) - method(rpc_controller, request, callback) + return method(rpc_controller, request, callback) def _GetRequestClass(self, method_descriptor): """Returns the class of the request protocol message. @@ -270,8 +263,8 @@ class _ServiceStubBuilder(object): setattr(cls, method.name, self._GenerateStubMethod(method)) def _GenerateStubMethod(self, method): - return lambda inst, rpc_controller, request, callback: self._StubMethod( - inst, method, rpc_controller, request, callback) + return (lambda inst, rpc_controller, request, callback=None: + self._StubMethod(inst, method, rpc_controller, request, callback)) def _StubMethod(self, stub, method_descriptor, rpc_controller, request, callback): @@ -283,7 +276,9 @@ class _ServiceStubBuilder(object): rpc_controller: Rpc controller to execute the method. request: Request protocol message. callback: A callback to execute when the method finishes. + Returns: + Response message (in case of blocking call). """ - stub.rpc_channel.CallMethod( + return stub.rpc_channel.CallMethod( method_descriptor, rpc_controller, request, method_descriptor.output_type._concrete_class, callback) |