diff options
Diffstat (limited to 'python/google/protobuf/internal/message_test.py')
-rwxr-xr-x | python/google/protobuf/internal/message_test.py | 279 |
1 files changed, 259 insertions, 20 deletions
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 61a56a67..4dd1104a 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1,4 +1,5 @@ #! /usr/bin/env python +# -*- coding: utf-8 -*- # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -49,6 +50,7 @@ import copy import math import operator import pickle +import pydoc import six import sys import warnings @@ -72,12 +74,14 @@ from google.protobuf import message_factory from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import encoder +from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks from google.protobuf import message from google.protobuf.internal import _parameterized +UCS2_MAXUNICODE = 65535 if six.PY3: long = int @@ -415,6 +419,37 @@ class MessageTest(BaseTestCase): empty.ParseFromString(populated.SerializeToString()) self.assertEqual(str(empty), '') + def testMergeFromRepeatedField(self, message_module): + msg = message_module.TestAllTypes() + msg.repeated_int32.append(1) + msg.repeated_int32.append(3) + msg.repeated_nested_message.add(bb=1) + msg.repeated_nested_message.add(bb=2) + other_msg = message_module.TestAllTypes() + other_msg.repeated_nested_message.add(bb=3) + other_msg.repeated_nested_message.add(bb=4) + other_msg.repeated_int32.append(5) + other_msg.repeated_int32.append(7) + + msg.repeated_int32.MergeFrom(other_msg.repeated_int32) + self.assertEqual(4, len(msg.repeated_int32)) + + msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message) + self.assertEqual([1, 2, 3, 4], + [m.bb for m in msg.repeated_nested_message]) + + def testAddWrongRepeatedNestedField(self, message_module): + msg = message_module.TestAllTypes() + try: + msg.repeated_nested_message.add('wrong') + except TypeError: + pass + try: + msg.repeated_nested_message.add(value_field='wrong') + except ValueError: + pass + self.assertEqual(len(msg.repeated_nested_message), 0) + def testRepeatedNestedFieldIteration(self, message_module): msg = message_module.TestAllTypes() msg.repeated_nested_message.add(bb=1) @@ -645,6 +680,82 @@ class MessageTest(BaseTestCase): m.payload.repeated_int32.extend([]) self.assertTrue(m.HasField('payload')) + def testMergeFrom(self, message_module): + m1 = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, m1.optional_nested_message.bb) + m2.optional_nested_message.bb = 1 + # Make sure cmessage pointing to a mutable message after merge instead of + # the lazily created message. + m1.MergeFrom(m2) + self.assertEqual(1, m1.optional_nested_message.bb) + + # Test more nested sub message. + msg1 = message_module.NestedTestAllTypes() + msg2 = message_module.NestedTestAllTypes() + self.assertEqual(0, msg1.child.payload.optional_nested_message.bb) + msg2.child.payload.optional_nested_message.bb = 1 + msg1.MergeFrom(msg2) + self.assertEqual(1, msg1.child.payload.optional_nested_message.bb) + + # Test repeated field. + self.assertEqual(msg1.payload.repeated_nested_message, + msg1.payload.repeated_nested_message) + msg2.payload.repeated_nested_message.add().bb = 1 + msg1.MergeFrom(msg2) + self.assertEqual(1, len(msg1.payload.repeated_nested_message)) + self.assertEqual(1, msg1.payload.repeated_nested_message[0].bb) + + def testMergeFromString(self, message_module): + m1 = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, m1.optional_nested_message.bb) + m2.optional_nested_message.bb = 1 + # Make sure cmessage pointing to a mutable message after merge instead of + # the lazily created message. + m1.MergeFromString(m2.SerializeToString()) + self.assertEqual(1, m1.optional_nested_message.bb) + + @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2') + def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module): + m2 = message_module.TestAllTypes() + m2.optional_string = 'scalar string' + m2.repeated_string.append('repeated string') + m2.optional_bytes = b'scalar bytes' + m2.repeated_bytes.append(b'repeated bytes') + + serialized = m2.SerializeToString() + memview = memoryview(serialized) + m1 = message_module.TestAllTypes.FromString(memview) + + self.assertEqual(m1.optional_bytes, b'scalar bytes') + self.assertEqual(m1.repeated_bytes, [b'repeated bytes']) + self.assertEqual(m1.optional_string, 'scalar string') + self.assertEqual(m1.repeated_string, ['repeated string']) + # Make sure that the memoryview was correctly converted to bytes, and + # that a sub-sliced memoryview is not being used. + self.assertIsInstance(m1.optional_bytes, bytes) + self.assertIsInstance(m1.repeated_bytes[0], bytes) + self.assertIsInstance(m1.optional_string, six.text_type) + self.assertIsInstance(m1.repeated_string[0], six.text_type) + + @unittest.skipIf(six.PY3, 'memoryview is supported by py3') + def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module): + memview = memoryview(b'') + with self.assertRaises(TypeError): + message_module.TestAllTypes.FromString(memview) + + def testMergeFromEmpty(self, message_module): + m1 = message_module.TestAllTypes() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, m1.optional_nested_message.bb) + self.assertFalse(m1.HasField('optional_nested_message')) + # Make sure the sub message is still immutable after merge from empty. + m1.MergeFromString(b'') # field state should not change + self.assertFalse(m1.HasField('optional_nested_message')) + def ensureNestedMessageExists(self, msg, attribute): """Make sure that a nested message object exists. @@ -1067,14 +1178,8 @@ class MessageTest(BaseTestCase): with self.assertRaises(AttributeError): m.repeated_int32 = [] m.repeated_int32.append(1) - if api_implementation.Type() == 'cpp': - # For test coverage: cpp has a different path if composite - # field is in cache - with self.assertRaises(TypeError): - m.repeated_int32 = [] - else: - with self.assertRaises(AttributeError): - m.repeated_int32 = [] + with self.assertRaises(AttributeError): + m.repeated_int32 = [] # Class to test proto2-only features (required, extensions, etc.) @@ -1112,13 +1217,13 @@ class Proto2Test(BaseTestCase): message.optional_bool = True message.optional_nested_message.bb = 15 - self.assertTrue(message.HasField("optional_int32")) + self.assertTrue(message.HasField(u"optional_int32")) self.assertTrue(message.HasField("optional_bool")) self.assertTrue(message.HasField("optional_nested_message")) # Clearing the fields unsets them and resets their value to default. message.ClearField("optional_int32") - message.ClearField("optional_bool") + message.ClearField(u"optional_bool") message.ClearField("optional_nested_message") self.assertFalse(message.HasField("optional_int32")) @@ -1169,6 +1274,21 @@ class Proto2Test(BaseTestCase): msg = unittest_pb2.TestAllTypes() self.assertRaises(AttributeError, getattr, msg, 'Extensions') + def testMergeFromExtensions(self): + msg1 = more_extensions_pb2.TopLevelMessage() + msg2 = more_extensions_pb2.TopLevelMessage() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, msg1.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertFalse(msg1.HasField('submessage')) + msg2.submessage.Extensions[ + more_extensions_pb2.optional_int_extension] = 123 + # Make sure cmessage and extensions pointing to a mutable message + # after merge instead of the lazily created message. + msg1.MergeFrom(msg2) + self.assertEqual(123, msg1.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + def testGoldenExtensions(self): golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllExtensions() @@ -1315,6 +1435,25 @@ class Proto2Test(BaseTestCase): with self.assertRaises(ValueError): unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') + def testPythonicInitWithDict(self): + # Both string/unicode field name keys should work. + kwargs = { + 'optional_int32': 100, + u'optional_fixed32': 200, + } + msg = unittest_pb2.TestAllTypes(**kwargs) + self.assertEqual(100, msg.optional_int32) + self.assertEqual(200, msg.optional_fixed32) + + + def test_documentation(self): + # Also used by the interactive help() function. + doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message') + self.assertIn('class TestAllTypes', doc) + self.assertIn('SerializePartialToString', doc) + self.assertIn('repeated_float', doc) + base = unittest_pb2.TestAllTypes.__bases__[0] + self.assertRaises(AttributeError, getattr, base, '_extensions_by_name') # Class to test proto3-only features/behavior (updated field presence & enums) @@ -1539,10 +1678,8 @@ class Proto3Test(BaseTestCase): self.assertEqual(True, msg2.map_bool_bool[True]) self.assertEqual(2, msg2.map_int32_enum[888]) self.assertEqual(456, msg2.map_int32_enum[123]) - # TODO(jieluo): Add cpp extension support. - if api_implementation.Type() == 'python': - self.assertEqual('{-123: -456}', - str(msg2.map_int32_int32)) + self.assertEqual('{-123: -456}', + str(msg2.map_int32_int32)) def testMapEntryAlwaysSerialized(self): msg = map_unittest_pb2.TestMap() @@ -1603,11 +1740,10 @@ class Proto3Test(BaseTestCase): self.assertIn(123, msg2.map_int32_foreign_message) self.assertIn(-456, msg2.map_int32_foreign_message) self.assertEqual(2, len(msg2.map_int32_foreign_message)) + msg2.map_int32_foreign_message[123].c = 1 # TODO(jieluo): Fix text format for message map. - # TODO(jieluo): Add cpp extension support. - if api_implementation.Type() == 'python': - self.assertEqual(15, - len(str(msg2.map_int32_foreign_message))) + self.assertIn(str(msg2.map_int32_foreign_message), + ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }')) def testNestedMessageMapItemDelete(self): msg = map_unittest_pb2.TestMap() @@ -1721,6 +1857,15 @@ class Proto3Test(BaseTestCase): self.assertEqual(10, msg2.map_int32_foreign_message[222].c) self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) + # Test when cpp extension cache a map. + m1 = map_unittest_pb2.TestMap() + m2 = map_unittest_pb2.TestMap() + self.assertEqual(m1.map_int32_foreign_message, + m1.map_int32_foreign_message) + m2.map_int32_foreign_message[123].c = 10 + m1.MergeFrom(m2) + self.assertEqual(10, m2.map_int32_foreign_message[123].c) + def testMergeFromBadType(self): msg = map_unittest_pb2.TestMap() with self.assertRaisesRegexp( @@ -1972,7 +2117,7 @@ class Proto3Test(BaseTestCase): def testMapValidAfterFieldCleared(self): # Map needs to work even if field is cleared. # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() + # MapContainer::Release() msg = map_unittest_pb2.TestMap() int32_map = msg.map_int32_int32 @@ -1988,7 +2133,7 @@ class Proto3Test(BaseTestCase): def testMessageMapValidAfterFieldCleared(self): # Map needs to work even if field is cleared. # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() + # MapContainer::Release() msg = map_unittest_pb2.TestMap() int32_foreign_message = msg.map_int32_foreign_message @@ -1998,6 +2143,24 @@ class Proto3Test(BaseTestCase): self.assertEqual(b'', msg.SerializeToString()) self.assertTrue(2 in int32_foreign_message.keys()) + def testMessageMapItemValidAfterTopMessageCleared(self): + # Message map item needs to work even if it is cleared. + # For the C++ implementation this tests the correctness of + # MapContainer::Release() + msg = map_unittest_pb2.TestMap() + msg.map_int32_all_types[2].optional_string = 'bar' + + if api_implementation.Type() == 'cpp': + # Need to keep the map reference because of b/27942626. + # TODO(jieluo): Remove it. + unused_map = msg.map_int32_all_types # pylint: disable=unused-variable + msg_value = msg.map_int32_all_types[2] + msg.Clear() + + # Reset to trigger sync between repeated field and map in c++. + msg.map_int32_all_types[3].optional_string = 'foo' + self.assertEqual(msg_value.optional_string, 'bar') + def testMapIterInvalidatedByClearField(self): # Map iterator is invalidated when field is cleared. # But this case does need to not crash the interpreter. @@ -2058,6 +2221,82 @@ class Proto3Test(BaseTestCase): msg.map_string_foreign_message['foo'].c = 5 self.assertEqual(0, len(msg.FindInitializationErrors())) + @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2') + def testStrictUtf8Check(self): + # Test u'\ud801' is rejected at parser in both python2 and python3. + serialized = (b'r\x03\xed\xa0\x81') + msg = unittest_proto3_arena_pb2.TestAllTypes() + with self.assertRaises(Exception) as context: + msg.MergeFromString(serialized) + if api_implementation.Type() == 'python': + self.assertIn('optional_string', str(context.exception)) + else: + self.assertIn('Error parsing message', str(context.exception)) + + # Test optional_string=u'😍' is accepted. + serialized = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'😍').SerializeToString() + msg2 = unittest_proto3_arena_pb2.TestAllTypes() + msg2.MergeFromString(serialized) + self.assertEqual(msg2.optional_string, u'😍') + + msg = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud001') + self.assertEqual(msg.optional_string, u'\ud001') + + @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2') + def testSurrogatesInPython3(self): + # Surrogates like U+D83D is an invalid unicode character, it is + # supported by Python2 only because in some builds, unicode strings + # use 2-bytes code units. Since Python 3.3, we don't have this problem. + # + # Surrogates are utf16 code units, in a unicode string they are invalid + # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf + # Python3 reject such cases at setters and parsers. Python2 accpect it + # to keep same features with the language itself. 'Unpaired pairs' + # like u'\ud801' are rejected at parsers when strict utf8 check is enabled + # in proto3 to keep same behavior with c extension. + + # Surrogates are rejected at setters in Python3. + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\udc01') + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=b'\xed\xa0\x81') + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801') + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\ud801') + + @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE, + 'Surrogates are rejected at setters in Python3') + def testSurrogatesInPython2(self): + # Test optional_string=u'\ud801\udc01'. + # surrogate pair is acceptable in python2. + msg = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\udc01') + # TODO(jieluo): Change pure python to have same behavior with c extension. + # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are + # equal, some are not equal. + if api_implementation.Type() == 'python': + self.assertEqual(msg.optional_string, u'\ud801\udc01') + else: + self.assertEqual(msg.optional_string, u'\U00010401') + serialized = msg.SerializeToString() + msg2 = unittest_proto3_arena_pb2.TestAllTypes() + msg2.MergeFromString(serialized) + self.assertEqual(msg2.optional_string, u'\U00010401') + + # Python2 does not reject surrogates at setters. + msg = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=b'\xed\xa0\x81') + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801') + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\ud801') class ValidTypeNamesTest(BaseTestCase): |