diff options
Diffstat (limited to 'python/google/protobuf/internal/message_test.py')
-rwxr-xr-x | python/google/protobuf/internal/message_test.py | 45 |
1 files changed, 30 insertions, 15 deletions
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 1e95adf9..9986c0d9 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -67,6 +67,7 @@ from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util +from google.protobuf.internal import testing_refleaks from google.protobuf import message from google.protobuf.internal import _parameterized @@ -88,10 +89,13 @@ def IsNegInf(val): return isinf(val) and (val < 0) -@_parameterized.Parameters( - (unittest_pb2), - (unittest_proto3_arena_pb2)) -class MessageTest(unittest.TestCase): +BaseTestCase = testing_refleaks.BaseTestCase + + +@_parameterized.NamedParameters( + ('_proto2', unittest_pb2), + ('_proto3', unittest_proto3_arena_pb2)) +class MessageTest(BaseTestCase): def testBadUtf8String(self, message_module): if api_implementation.Type() != 'python': @@ -957,7 +961,7 @@ class MessageTest(unittest.TestCase): # Class to test proto2-only features (required, extensions, etc.) -class Proto2Test(unittest.TestCase): +class Proto2Test(BaseTestCase): def testFieldPresence(self): message = unittest_pb2.TestAllTypes() @@ -1113,6 +1117,7 @@ class Proto2Test(unittest.TestCase): optional_bytes=b'x', optionalgroup={'a': 400}, optional_nested_message={'bb': 500}, + optional_foreign_message={}, optional_nested_enum='BAZ', repeatedgroup=[{'a': 600}, {'a': 700}], @@ -1125,8 +1130,12 @@ class Proto2Test(unittest.TestCase): self.assertEqual(300.5, message.optional_float) self.assertEqual(b'x', message.optional_bytes) self.assertEqual(400, message.optionalgroup.a) - self.assertIsInstance(message.optional_nested_message, unittest_pb2.TestAllTypes.NestedMessage) + self.assertIsInstance(message.optional_nested_message, + unittest_pb2.TestAllTypes.NestedMessage) self.assertEqual(500, message.optional_nested_message.bb) + self.assertTrue(message.HasField('optional_foreign_message')) + self.assertEqual(message.optional_foreign_message, + unittest_pb2.ForeignMessage()) self.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.optional_nested_enum) self.assertEqual(2, len(message.repeatedgroup)) @@ -1164,7 +1173,7 @@ class Proto2Test(unittest.TestCase): # Class to test proto3-only features/behavior (updated field presence & enums) -class Proto3Test(unittest.TestCase): +class Proto3Test(BaseTestCase): # Utility method for comparing equality with a map. def assertMapIterEquals(self, map_iter, dict_value): @@ -1720,7 +1729,7 @@ class Proto3Test(unittest.TestCase): -class ValidTypeNamesTest(unittest.TestCase): +class ValidTypeNamesTest(BaseTestCase): def assertImportFromName(self, msg, base_name): # Parse <type 'module.class_name'> to extra 'some.name' as a string. @@ -1741,7 +1750,7 @@ class ValidTypeNamesTest(unittest.TestCase): self.assertImportFromName(pb.repeated_int32, 'Scalar') self.assertImportFromName(pb.repeated_nested_message, 'Composite') -class PackedFieldTest(unittest.TestCase): +class PackedFieldTest(BaseTestCase): def setMessage(self, message): message.repeated_int32.append(1) @@ -1800,10 +1809,14 @@ class PackedFieldTest(unittest.TestCase): @unittest.skipIf(api_implementation.Type() != 'cpp', 'explicit tests of the C++ implementation') -class OversizeProtosTest(unittest.TestCase): - - def setUp(self): - self.file_desc = """ +class OversizeProtosTest(BaseTestCase): + + @classmethod + def setUpClass(cls): + # At the moment, reference cycles between DescriptorPool and Message classes + # are not detected and these objects are never freed. + # To avoid errors with ReferenceLeakChecker, we create the class only once. + file_desc = """ name: "f/f.msg2" package: "f" message_type { @@ -1828,10 +1841,12 @@ class OversizeProtosTest(unittest.TestCase): """ pool = descriptor_pool.DescriptorPool() desc = descriptor_pb2.FileDescriptorProto() - text_format.Parse(self.file_desc, desc) + text_format.Parse(file_desc, desc) pool.Add(desc) - self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( + cls.proto_cls = message_factory.MessageFactory(pool).GetPrototype( pool.FindMessageTypeByName('f.msg2')) + + def setUp(self): self.p = self.proto_cls() self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) self.p_serialized = self.p.SerializeToString() |