aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/internal/unknown_fields_test.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/unknown_fields_test.py')
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py165
1 files changed, 123 insertions, 42 deletions
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 8b7de2e7..fceadf71 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
+from google.protobuf import descriptor
BaseTestCase = testing_refleaks.BaseTestCase
-# CheckUnknownField() cannot be used by the C++ implementation because
-# some protect members are called. It is not a behavior difference
-# for python and C++ implementation.
-def SkipCheckUnknownFieldIfCppImplementation(func):
- return unittest.skipIf(
- api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
- 'Addtional test for pure python involved protect members')(func)
-
-
class UnknownFieldsTest(BaseTestCase):
def setUp(self):
@@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase):
# stdout.
self.assertTrue(data == self.all_fields_data)
- def expectSerializeProto3(self, preserve):
+ def testSerializeProto3(self):
+ # Verify proto3 unknown fields behavior.
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
- if preserve:
- self.assertEqual(self.all_fields_data, message.SerializeToString())
- else:
- self.assertEqual(0, len(message.SerializeToString()))
-
- def testSerializeProto3(self):
- # Verify that proto3 unknown fields behavior.
- default_preserve = (api_implementation
- .GetPythonProto3PreserveUnknownsDefault())
- self.expectSerializeProto3(default_preserve)
- api_implementation.SetPythonProto3PreserveUnknownsDefault(
- not default_preserve)
- self.expectSerializeProto3(not default_preserve)
- api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve)
+ self.assertEqual(self.all_fields_data, message.SerializeToString())
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
@@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- # CheckUnknownField() is an additional Pure Python check which checks
+ # InternalCheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
-
- def CheckUnknownField(self, name, expected_value):
+ # TODO(jieluo): Remove message._unknown_fields.
+ def InternalCheckUnknownField(self, name, expected_value):
+ if api_implementation.Type() == 'cpp':
+ return
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
@@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
- decoder(value, 0, len(value), self.all_fields, result_dict)
+ decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
- @SkipCheckUnknownFieldIfCppImplementation
+ def CheckUnknownField(self, name, unknown_fields, expected_value):
+ field_descriptor = self.descriptor.fields_by_name[name]
+ expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
+ field_descriptor.type]
+ for unknown_field in unknown_fields:
+ if unknown_field.field_number == field_descriptor.number:
+ self.assertEqual(expected_type, unknown_field.wire_type)
+ if expected_type == 3:
+ # Check group
+ self.assertEqual(expected_value[0],
+ unknown_field.data[0].field_number)
+ self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
+ self.assertEqual(expected_value[2], unknown_field.data[0].data)
+ continue
+ if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ self.assertIn(unknown_field.data, expected_value)
+ else:
+ self.assertEqual(expected_value, unknown_field.data)
+
def testCheckUnknownFieldValue(self):
+ unknown_fields = self.empty_message.UnknownFields()
# Test enum.
self.CheckUnknownField('optional_nested_enum',
+ unknown_fields,
self.all_fields.optional_nested_enum)
+ self.InternalCheckUnknownField('optional_nested_enum',
+ self.all_fields.optional_nested_enum)
+
# Test repeated enum.
self.CheckUnknownField('repeated_nested_enum',
+ unknown_fields,
self.all_fields.repeated_nested_enum)
+ self.InternalCheckUnknownField('repeated_nested_enum',
+ self.all_fields.repeated_nested_enum)
# Test varint.
self.CheckUnknownField('optional_int32',
+ unknown_fields,
self.all_fields.optional_int32)
+ self.InternalCheckUnknownField('optional_int32',
+ self.all_fields.optional_int32)
+
# Test fixed32.
self.CheckUnknownField('optional_fixed32',
+ unknown_fields,
self.all_fields.optional_fixed32)
+ self.InternalCheckUnknownField('optional_fixed32',
+ self.all_fields.optional_fixed32)
# Test fixed64.
self.CheckUnknownField('optional_fixed64',
+ unknown_fields,
self.all_fields.optional_fixed64)
+ self.InternalCheckUnknownField('optional_fixed64',
+ self.all_fields.optional_fixed64)
# Test lengthd elimited.
self.CheckUnknownField('optional_string',
- self.all_fields.optional_string)
+ unknown_fields,
+ self.all_fields.optional_string.encode('utf-8'))
+ self.InternalCheckUnknownField('optional_string',
+ self.all_fields.optional_string)
# Test group.
self.CheckUnknownField('optionalgroup',
- self.all_fields.optionalgroup)
+ unknown_fields,
+ (17, 0, 117))
+ self.InternalCheckUnknownField('optionalgroup',
+ self.all_fields.optionalgroup)
+
+ self.assertEqual(97, len(unknown_fields))
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
@@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
message.optional_int64 = 3
message.optional_uint32 = 4
destination = unittest_pb2.TestEmptyMessage()
+ unknown_fields = destination.UnknownFields()
+ self.assertEqual(0, len(unknown_fields))
destination.ParseFromString(message.SerializeToString())
-
+ # ParseFromString clears the message thus unknown fields is invalid.
+ with self.assertRaises(ValueError) as context:
+ len(unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ unknown_fields = destination.UnknownFields()
+ self.assertEqual(2, len(unknown_fields))
destination.MergeFrom(source)
+ self.assertEqual(4, len(unknown_fields))
# Check that the fields where correctly merged, even stored in the unknown
# fields set.
message.ParseFromString(destination.SerializeToString())
@@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.assertEqual(message.optional_int64, 3)
def testClear(self):
+ unknown_fields = self.empty_message.UnknownFields()
self.empty_message.Clear()
# All cleared, even unknown fields.
self.assertEqual(self.empty_message.SerializeToString(), b'')
+ with self.assertRaises(ValueError) as context:
+ len(unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+
+ def testSubUnknownFields(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optionalgroup.a = 123
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ sub_unknown_fields = destination.UnknownFields()[0].data
+ self.assertEqual(1, len(sub_unknown_fields))
+ self.assertEqual(sub_unknown_fields[0].data, 123)
+ destination.Clear()
+ with self.assertRaises(ValueError) as context:
+ len(sub_unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ with self.assertRaises(ValueError) as context:
+ # pylint: disable=pointless-statement
+ sub_unknown_fields[0]
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ message.Clear()
+ message.optional_uint32 = 456
+ nested_message = unittest_pb2.NestedTestAllTypes()
+ nested_message.payload.optional_nested_message.ParseFromString(
+ message.SerializeToString())
+ unknown_fields = (
+ nested_message.payload.optional_nested_message.UnknownFields())
+ self.assertEqual(unknown_fields[0].data, 456)
+ nested_message.ClearField('payload')
+ self.assertEqual(unknown_fields[0].data, 456)
+ unknown_fields = (
+ nested_message.payload.optional_nested_message.UnknownFields())
+ self.assertEqual(0, len(unknown_fields))
+
+ def testUnknownField(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_int32 = 123
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ unknown_field = destination.UnknownFields()[0]
+ destination.Clear()
+ with self.assertRaises(ValueError) as context:
+ unknown_field.data # pylint: disable=pointless-statement
+ self.assertIn('The parent message might be cleared.',
+ str(context.exception))
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
@@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase):
def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
- wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
- field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
- result_dict = {}
- for tag_bytes, value in self.missing_message._unknown_fields:
- if tag_bytes == field_tag:
- decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
- tag_bytes][0]
- decoder(value, 0, len(value), self.message, result_dict)
- self.assertEqual(expected_value, result_dict[field_descriptor])
+ unknown_fields = self.missing_message.UnknownFields()
+ for field in unknown_fields:
+ if field.field_number == field_descriptor.number:
+ if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ self.assertIn(field.data, expected_value)
+ else:
+ self.assertEqual(expected_value, field.data)
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
@@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase):
def testUnknownPackedEnumValue(self):
self.assertEqual([], self.missing_message.packed_nested_enum)
- @SkipCheckUnknownFieldIfCppImplementation
def testCheckUnknownFieldValueForEnum(self):
self.CheckUnknownField('optional_nested_enum',
self.message.optional_nested_enum)