diff options
Diffstat (limited to 'python/google/protobuf/internal/message_test.py')
-rwxr-xr-x | python/google/protobuf/internal/message_test.py | 157 |
1 files changed, 140 insertions, 17 deletions
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 13c3caa6..4ee31d8e 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -53,21 +53,27 @@ import six import sys try: - import unittest2 as unittest + import unittest2 as unittest #PY26 except ImportError: import unittest -from google.protobuf.internal import _parameterized + from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor_pool +from google.protobuf import message_factory +from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message +from google.protobuf.internal import _parameterized if six.PY3: long = int + # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. def isnan(val): @@ -1156,6 +1162,7 @@ class Proto2Test(unittest.TestCase): unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') + # Class to test proto3-only features/behavior (updated field presence & enums) class Proto3Test(unittest.TestCase): @@ -1258,7 +1265,10 @@ class Proto3Test(unittest.TestCase): self.assertFalse(-2**33 in msg.map_int64_int64) self.assertFalse(123 in msg.map_uint32_uint32) self.assertFalse(2**33 in msg.map_uint64_uint64) + self.assertFalse(123 in msg.map_int32_double) + self.assertFalse(False in msg.map_bool_bool) self.assertFalse('abc' in msg.map_string_string) + self.assertFalse(111 in msg.map_int32_bytes) self.assertFalse(888 in msg.map_int32_enum) # Accessing an unset key returns the default. @@ -1266,7 +1276,12 @@ class Proto3Test(unittest.TestCase): self.assertEqual(0, msg.map_int64_int64[-2**33]) self.assertEqual(0, msg.map_uint32_uint32[123]) self.assertEqual(0, msg.map_uint64_uint64[2**33]) + self.assertEqual(0.0, msg.map_int32_double[123]) + self.assertTrue(isinstance(msg.map_int32_double[123], float)) + self.assertEqual(False, msg.map_bool_bool[False]) + self.assertTrue(isinstance(msg.map_bool_bool[False], bool)) self.assertEqual('', msg.map_string_string['abc']) + self.assertEqual(b'', msg.map_int32_bytes[111]) self.assertEqual(0, msg.map_int32_enum[888]) # It also sets the value in the map @@ -1274,17 +1289,21 @@ class Proto3Test(unittest.TestCase): self.assertTrue(-2**33 in msg.map_int64_int64) self.assertTrue(123 in msg.map_uint32_uint32) self.assertTrue(2**33 in msg.map_uint64_uint64) + self.assertTrue(123 in msg.map_int32_double) + self.assertTrue(False in msg.map_bool_bool) self.assertTrue('abc' in msg.map_string_string) + self.assertTrue(111 in msg.map_int32_bytes) self.assertTrue(888 in msg.map_int32_enum) self.assertIsInstance(msg.map_string_string['abc'], six.text_type) - # Accessing an unset key still throws TypeError of the type of the key + # Accessing an unset key still throws TypeError if the type of the key # is incorrect. with self.assertRaises(TypeError): msg.map_string_string[123] - self.assertFalse(123 in msg.map_string_string) + with self.assertRaises(TypeError): + 123 in msg.map_string_string def testMapGet(self): # Need to test that get() properly returns the default, even though the dict @@ -1446,6 +1465,22 @@ class Proto3Test(unittest.TestCase): del msg2.map_int32_foreign_message[222] self.assertFalse(222 in msg2.map_int32_foreign_message) + def testMergeFromBadType(self): + msg = map_unittest_pb2.TestMap() + with self.assertRaisesRegexp( + TypeError, + r'Parameter to MergeFrom\(\) must be instance of same class: expected ' + r'.*TestMap got int\.'): + msg.MergeFrom(1) + + def testCopyFromBadType(self): + msg = map_unittest_pb2.TestMap() + with self.assertRaisesRegexp( + TypeError, + r'Parameter to [A-Za-z]*From\(\) must be instance of same class: ' + r'expected .*TestMap got int\.'): + msg.CopyFrom(1) + def testIntegerMapWithLongs(self): msg = map_unittest_pb2.TestMap() msg.map_int32_int32[long(-123)] = long(-456) @@ -1563,6 +1598,21 @@ class Proto3Test(unittest.TestCase): matching_dict = {2: 4, 3: 6, 4: 8} self.assertMapIterEquals(msg.map_int32_int32.items(), matching_dict) + def testMapItems(self): + # Map items used to have strange behaviors when use c extension. Because + # [] may reorder the map and invalidate any exsting iterators. + # TODO(jieluo): Check if [] reordering the map is a bug or intended + # behavior. + msg = map_unittest_pb2.TestMap() + msg.map_string_string['local_init_op'] = '' + msg.map_string_string['trainable_variables'] = '' + msg.map_string_string['variables'] = '' + msg.map_string_string['init_op'] = '' + msg.map_string_string['summaries'] = '' + items1 = msg.map_string_string.items() + items2 = msg.map_string_string.items() + self.assertEqual(items1, items2) + def testMapIterationClearMessage(self): # Iterator needs to work even if message and map are deleted. msg = map_unittest_pb2.TestMap() @@ -1591,31 +1641,49 @@ class Proto3Test(unittest.TestCase): # For the C++ implementation this tests the correctness of # ScalarMapContainer::Release() msg = map_unittest_pb2.TestMap() - map = msg.map_int32_int32 + int32_map = msg.map_int32_int32 - map[2] = 4 - map[3] = 6 - map[4] = 8 + int32_map[2] = 4 + int32_map[3] = 6 + int32_map[4] = 8 msg.ClearField('map_int32_int32') + self.assertEqual(b'', msg.SerializeToString()) matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(map.items(), matching_dict) + self.assertMapIterEquals(int32_map.items(), matching_dict) - def testMapIterValidAfterFieldCleared(self): - # Map iterator needs to work even if field is cleared. + def testMessageMapValidAfterFieldCleared(self): + # Map needs to work even if field is cleared. # For the C++ implementation this tests the correctness of # ScalarMapContainer::Release() msg = map_unittest_pb2.TestMap() + int32_foreign_message = msg.map_int32_foreign_message - msg.map_int32_int32[2] = 4 - msg.map_int32_int32[3] = 6 - msg.map_int32_int32[4] = 8 + int32_foreign_message[2].c = 5 - it = msg.map_int32_int32.items() + msg.ClearField('map_int32_foreign_message') + self.assertEqual(b'', msg.SerializeToString()) + self.assertTrue(2 in int32_foreign_message.keys()) + + def testMapIterInvalidatedByClearField(self): + # Map iterator is invalidated when field is cleared. + # But this case does need to not crash the interpreter. + # For the C++ implementation this tests the correctness of + # ScalarMapContainer::Release() + msg = map_unittest_pb2.TestMap() + + it = iter(msg.map_int32_int32) msg.ClearField('map_int32_int32') - matching_dict = {2: 4, 3: 6, 4: 8} - self.assertMapIterEquals(it, matching_dict) + with self.assertRaises(RuntimeError): + for _ in it: + pass + + it = iter(msg.map_int32_foreign_message) + msg.ClearField('map_int32_foreign_message') + with self.assertRaises(RuntimeError): + for _ in it: + pass def testMapDelete(self): msg = map_unittest_pb2.TestMap() @@ -1725,5 +1793,60 @@ class PackedFieldTest(unittest.TestCase): b'\x70\x01') self.assertEqual(golden_data, message.SerializeToString()) + +@unittest.skipIf(api_implementation.Type() != 'cpp', + 'explicit tests of the C++ implementation') +class OversizeProtosTest(unittest.TestCase): + + def setUp(self): + self.file_desc = """ + name: "f/f.msg2" + package: "f" + message_type { + name: "msg1" + field { + name: "payload" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_STRING + } + } + message_type { + name: "msg2" + field { + name: "field" + number: 1 + label: LABEL_OPTIONAL + type: TYPE_MESSAGE + type_name: "msg1" + } + } + """ + pool = descriptor_pool.DescriptorPool() + desc = descriptor_pb2.FileDescriptorProto() + text_format.Parse(self.file_desc, desc) + pool.Add(desc) + self.proto_cls = message_factory.MessageFactory(pool).GetPrototype( + pool.FindMessageTypeByName('f.msg2')) + self.p = self.proto_cls() + self.p.field.payload = 'c' * (1024 * 1024 * 64 + 1) + self.p_serialized = self.p.SerializeToString() + + def testAssertOversizeProto(self): + from google.protobuf.pyext._message import SetAllowOversizeProtos + SetAllowOversizeProtos(False) + q = self.proto_cls() + try: + q.ParseFromString(self.p_serialized) + except message.DecodeError as e: + self.assertEqual(str(e), 'Error parsing message') + + def testSucceedOversizeProto(self): + from google.protobuf.pyext._message import SetAllowOversizeProtos + SetAllowOversizeProtos(True) + q = self.proto_cls() + q.ParseFromString(self.p_serialized) + self.assertEqual(self.p.field.payload, q.field.payload) + if __name__ == '__main__': unittest.main() |