diff options
Diffstat (limited to 'python/google/protobuf/internal/unknown_fields_test.py')
-rwxr-xr-x | python/google/protobuf/internal/unknown_fields_test.py | 165 |
1 files changed, 123 insertions, 42 deletions
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 8b7de2e7..fceadf71 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks from google.protobuf.internal import type_checkers +from google.protobuf import descriptor BaseTestCase = testing_refleaks.BaseTestCase -# CheckUnknownField() cannot be used by the C++ implementation because -# some protect members are called. It is not a behavior difference -# for python and C++ implementation. -def SkipCheckUnknownFieldIfCppImplementation(func): - return unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'Addtional test for pure python involved protect members')(func) - - class UnknownFieldsTest(BaseTestCase): def setUp(self): @@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase): # stdout. self.assertTrue(data == self.all_fields_data) - def expectSerializeProto3(self, preserve): + def testSerializeProto3(self): + # Verify proto3 unknown fields behavior. message = unittest_proto3_arena_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) - if preserve: - self.assertEqual(self.all_fields_data, message.SerializeToString()) - else: - self.assertEqual(0, len(message.SerializeToString())) - - def testSerializeProto3(self): - # Verify that proto3 unknown fields behavior. - default_preserve = (api_implementation - .GetPythonProto3PreserveUnknownsDefault()) - self.expectSerializeProto3(default_preserve) - api_implementation.SetPythonProto3PreserveUnknownsDefault( - not default_preserve) - self.expectSerializeProto3(not default_preserve) - api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve) + self.assertEqual(self.all_fields_data, message.SerializeToString()) def testByteSize(self): self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) @@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - # CheckUnknownField() is an additional Pure Python check which checks + # InternalCheckUnknownField() is an additional Pure Python check which checks # a detail of unknown fields. It cannot be used by the C++ # implementation because some protect members are called. # The test is added for historical reasons. It is not necessary as # serialized string is checked. - - def CheckUnknownField(self, name, expected_value): + # TODO(jieluo): Remove message._unknown_fields. + def InternalCheckUnknownField(self, name, expected_value): + if api_implementation.Type() == 'cpp': + return field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) @@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase): for tag_bytes, value in self.empty_message._unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] - decoder(value, 0, len(value), self.all_fields, result_dict) + decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) self.assertEqual(expected_value, result_dict[field_descriptor]) - @SkipCheckUnknownFieldIfCppImplementation + def CheckUnknownField(self, name, unknown_fields, expected_value): + field_descriptor = self.descriptor.fields_by_name[name] + expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ + field_descriptor.type] + for unknown_field in unknown_fields: + if unknown_field.field_number == field_descriptor.number: + self.assertEqual(expected_type, unknown_field.wire_type) + if expected_type == 3: + # Check group + self.assertEqual(expected_value[0], + unknown_field.data[0].field_number) + self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) + self.assertEqual(expected_value[2], unknown_field.data[0].data) + continue + if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: + self.assertIn(unknown_field.data, expected_value) + else: + self.assertEqual(expected_value, unknown_field.data) + def testCheckUnknownFieldValue(self): + unknown_fields = self.empty_message.UnknownFields() # Test enum. self.CheckUnknownField('optional_nested_enum', + unknown_fields, self.all_fields.optional_nested_enum) + self.InternalCheckUnknownField('optional_nested_enum', + self.all_fields.optional_nested_enum) + # Test repeated enum. self.CheckUnknownField('repeated_nested_enum', + unknown_fields, self.all_fields.repeated_nested_enum) + self.InternalCheckUnknownField('repeated_nested_enum', + self.all_fields.repeated_nested_enum) # Test varint. self.CheckUnknownField('optional_int32', + unknown_fields, self.all_fields.optional_int32) + self.InternalCheckUnknownField('optional_int32', + self.all_fields.optional_int32) + # Test fixed32. self.CheckUnknownField('optional_fixed32', + unknown_fields, self.all_fields.optional_fixed32) + self.InternalCheckUnknownField('optional_fixed32', + self.all_fields.optional_fixed32) # Test fixed64. self.CheckUnknownField('optional_fixed64', + unknown_fields, self.all_fields.optional_fixed64) + self.InternalCheckUnknownField('optional_fixed64', + self.all_fields.optional_fixed64) # Test lengthd elimited. self.CheckUnknownField('optional_string', - self.all_fields.optional_string) + unknown_fields, + self.all_fields.optional_string.encode('utf-8')) + self.InternalCheckUnknownField('optional_string', + self.all_fields.optional_string) # Test group. self.CheckUnknownField('optionalgroup', - self.all_fields.optionalgroup) + unknown_fields, + (17, 0, 117)) + self.InternalCheckUnknownField('optionalgroup', + self.all_fields.optionalgroup) + + self.assertEqual(97, len(unknown_fields)) def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() @@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase): message.optional_int64 = 3 message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() + unknown_fields = destination.UnknownFields() + self.assertEqual(0, len(unknown_fields)) destination.ParseFromString(message.SerializeToString()) - + # ParseFromString clears the message thus unknown fields is invalid. + with self.assertRaises(ValueError) as context: + len(unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + unknown_fields = destination.UnknownFields() + self.assertEqual(2, len(unknown_fields)) destination.MergeFrom(source) + self.assertEqual(4, len(unknown_fields)) # Check that the fields where correctly merged, even stored in the unknown # fields set. message.ParseFromString(destination.SerializeToString()) @@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.assertEqual(message.optional_int64, 3) def testClear(self): + unknown_fields = self.empty_message.UnknownFields() self.empty_message.Clear() # All cleared, even unknown fields. self.assertEqual(self.empty_message.SerializeToString(), b'') + with self.assertRaises(ValueError) as context: + len(unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + + def testSubUnknownFields(self): + message = unittest_pb2.TestAllTypes() + message.optionalgroup.a = 123 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + sub_unknown_fields = destination.UnknownFields()[0].data + self.assertEqual(1, len(sub_unknown_fields)) + self.assertEqual(sub_unknown_fields[0].data, 123) + destination.Clear() + with self.assertRaises(ValueError) as context: + len(sub_unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + with self.assertRaises(ValueError) as context: + # pylint: disable=pointless-statement + sub_unknown_fields[0] + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + message.Clear() + message.optional_uint32 = 456 + nested_message = unittest_pb2.NestedTestAllTypes() + nested_message.payload.optional_nested_message.ParseFromString( + message.SerializeToString()) + unknown_fields = ( + nested_message.payload.optional_nested_message.UnknownFields()) + self.assertEqual(unknown_fields[0].data, 456) + nested_message.ClearField('payload') + self.assertEqual(unknown_fields[0].data, 456) + unknown_fields = ( + nested_message.payload.optional_nested_message.UnknownFields()) + self.assertEqual(0, len(unknown_fields)) + + def testUnknownField(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 123 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_field = destination.UnknownFields()[0] + destination.Clear() + with self.assertRaises(ValueError) as context: + unknown_field.data # pylint: disable=pointless-statement + self.assertIn('The parent message might be cleared.', + str(context.exception)) def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() @@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase): def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] - wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] - field_tag = encoder.TagBytes(field_descriptor.number, wire_type) - result_dict = {} - for tag_bytes, value in self.missing_message._unknown_fields: - if tag_bytes == field_tag: - decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ - tag_bytes][0] - decoder(value, 0, len(value), self.message, result_dict) - self.assertEqual(expected_value, result_dict[field_descriptor]) + unknown_fields = self.missing_message.UnknownFields() + for field in unknown_fields: + if field.field_number == field_descriptor.number: + if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: + self.assertIn(field.data, expected_value) + else: + self.assertEqual(expected_value, field.data) def testUnknownParseMismatchEnumValue(self): just_string = missing_enum_values_pb2.JustString() @@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase): def testUnknownPackedEnumValue(self): self.assertEqual([], self.missing_message.packed_nested_enum) - @SkipCheckUnknownFieldIfCppImplementation def testCheckUnknownFieldValueForEnum(self): self.CheckUnknownField('optional_nested_enum', self.message.optional_nested_enum) |