diff options
Diffstat (limited to 'python/google/protobuf/internal/text_format_test.py')
-rwxr-xr-x | python/google/protobuf/internal/text_format_test.py | 398 |
1 files changed, 356 insertions, 42 deletions
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 237a2d50..ccf8ac16 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -33,20 +33,19 @@ """Test for google.protobuf.text_format.""" -__author__ = 'kenton@google.com (Kenton Varda)' - - +import io import math import re -import six import string +import textwrap + +import six +# pylint: disable=g-import-not-at-top try: - import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top + import unittest2 as unittest # PY26 except ImportError: - import unittest # pylint: disable=g-import-not-at-top - -from google.protobuf.internal import _parameterized + import unittest from google.protobuf import any_pb2 from google.protobuf import any_test_pb2 @@ -54,12 +53,13 @@ from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 -from google.protobuf.internal import api_implementation from google.protobuf.internal import any_test_pb2 as test_extend_any from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import test_util from google.protobuf import descriptor_pool from google.protobuf import text_format +from google.protobuf.internal import _parameterized +# pylint: enable=g-import-not-at-top # Low-level nuts-n-bolts tests. @@ -100,8 +100,8 @@ class TextFormatBase(unittest.TestCase): return text -@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2)) -class TextFormatTest(TextFormatBase): +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatMessageToStringTests(TextFormatBase): def testPrintExotic(self, message_module): message = message_module.TestAllTypes() @@ -154,6 +154,40 @@ class TextFormatTest(TextFormatBase): 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' 'repeated_string: "Google" repeated_string: "Zurich"') + def VerifyPrintShortFormatRepeatedFields(self, message_module, as_one_line): + message = message_module.TestAllTypes() + message.repeated_int32.append(1) + message.repeated_string.append('Google') + message.repeated_string.append('Hello,World') + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_FOO) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) + message.optional_nested_message.bb = 3 + for i in (21, 32): + msg = message.repeated_nested_message.add() + msg.bb = i + expected_ascii = ( + 'optional_nested_message {\n bb: 3\n}\n' + 'repeated_int32: [1]\n' + 'repeated_string: "Google"\n' + 'repeated_string: "Hello,World"\n' + 'repeated_nested_message {\n bb: 21\n}\n' + 'repeated_nested_message {\n bb: 32\n}\n' + 'repeated_foreign_enum: [FOREIGN_FOO, FOREIGN_BAR, FOREIGN_BAZ]\n') + if as_one_line: + expected_ascii = expected_ascii.replace('\n ', '').replace('\n', '') + actual_ascii = text_format.MessageToString( + message, use_short_repeated_primitives=True, + as_one_line=as_one_line) + self.CompareToGoldenText(actual_ascii, expected_ascii) + parsed_message = message_module.TestAllTypes() + text_format.Parse(actual_ascii, parsed_message) + self.assertEqual(parsed_message, message) + + def tesPrintShortFormatRepeatedFields(self, message_module, as_one_line): + self.VerifyPrintShortFormatRepeatedFields(message_module, False) + self.VerifyPrintShortFormatRepeatedFields(message_module, True) + def testPrintNestedNewLineInStringAsOneLine(self, message_module): message = message_module.TestAllTypes() message.optional_string = 'a\nnew\nline' @@ -213,13 +247,18 @@ class TextFormatTest(TextFormatBase): def testPrintRawUtf8String(self, message_module): message = message_module.TestAllTypes() - message.repeated_string.append(u'\u00fc\ua71f') + message.repeated_string.append(u'\u00fc\t\ua71f') text = text_format.MessageToString(message, as_utf8=True) - self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') + golden_unicode = u'repeated_string: "\u00fc\\t\ua71f"\n' + golden_text = golden_unicode if six.PY3 else golden_unicode.encode('utf-8') + # MessageToString always returns a native str. + self.CompareToGoldenText(text, golden_text) parsed_message = message_module.TestAllTypes() text_format.Parse(text, parsed_message) - self.assertEqual(message, parsed_message, - '\n%s != %s' % (message, parsed_message)) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) def testPrintFloatFormat(self, message_module): # Check that float_format argument is passed to sub-message formatting. @@ -259,6 +298,36 @@ class TextFormatTest(TextFormatBase): message.c = 123 self.assertEqual('c: 123\n', str(message)) + def testMessageToStringUnicode(self, message_module): + golden_unicode = u'Á short desçription and a 🍌.' + golden_bytes = golden_unicode.encode('utf-8') + message = message_module.TestAllTypes() + message.optional_string = golden_unicode + message.optional_bytes = golden_bytes + text = text_format.MessageToString(message, as_utf8=True) + golden_message = textwrap.dedent( + 'optional_string: "Á short desçription and a 🍌."\n' + 'optional_bytes: ' + r'"\303\201 short des\303\247ription and a \360\237\215\214."' + '\n') + self.CompareToGoldenText(text, golden_message) + + def testMessageToStringASCII(self, message_module): + golden_unicode = u'Á short desçription and a 🍌.' + golden_bytes = golden_unicode.encode('utf-8') + message = message_module.TestAllTypes() + message.optional_string = golden_unicode + message.optional_bytes = golden_bytes + text = text_format.MessageToString(message, as_utf8=False) # ASCII + golden_message = ( + 'optional_string: ' + r'"\303\201 short des\303\247ription and a \360\237\215\214."' + '\n' + 'optional_bytes: ' + r'"\303\201 short des\303\247ription and a \360\237\215\214."' + '\n') + self.CompareToGoldenText(text, golden_message) + def testPrintField(self, message_module): message = message_module.TestAllTypes() field = message.DESCRIPTOR.fields_by_name['optional_float'] @@ -289,6 +358,45 @@ class TextFormatTest(TextFormatBase): self.assertEqual('0.0', out.getvalue()) out.close() + +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatMessageToTextBytesTests(TextFormatBase): + + def testMessageToBytes(self, message_module): + message = message_module.ForeignMessage() + message.c = 123 + self.assertEqual(b'c: 123\n', text_format.MessageToBytes(message)) + + def testRawUtf8RoundTrip(self, message_module): + message = message_module.TestAllTypes() + message.repeated_string.append(u'\u00fc\t\ua71f') + utf8_text = text_format.MessageToBytes(message, as_utf8=True) + golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n' + self.CompareToGoldenText(utf8_text, golden_bytes) + parsed_message = message_module.TestAllTypes() + text_format.Parse(utf8_text, parsed_message) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) + + def testEscapedUtf8ASCIIRoundTrip(self, message_module): + message = message_module.TestAllTypes() + message.repeated_string.append(u'\u00fc\t\ua71f') + ascii_text = text_format.MessageToBytes(message) # as_utf8=False default + golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n' + self.CompareToGoldenText(ascii_text, golden_bytes) + parsed_message = message_module.TestAllTypes() + text_format.Parse(ascii_text, parsed_message) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) + + +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatParserTests(TextFormatBase): + def testParseAllFields(self, message_module): message = message_module.TestAllTypes() test_util.SetAllFields(message) @@ -318,14 +426,14 @@ class TextFormatTest(TextFormatBase): if message_module is unittest_pb2: test_util.ExpectAllFieldsSet(self, message) - if six.PY2: - msg2 = message_module.TestAllTypes() - text = (u'optional_string: "café"') - text_format.Merge(text, msg2) - self.assertEqual(msg2.optional_string, u'café') - msg2.Clear() - text_format.Parse(text, msg2) - self.assertEqual(msg2.optional_string, u'café') + msg2 = message_module.TestAllTypes() + text = (u'optional_string: "café"') + text_format.Merge(text, msg2) + self.assertEqual(msg2.optional_string, u'café') + msg2.Clear() + self.assertEqual(msg2.optional_string, u'') + text_format.Parse(text, msg2) + self.assertEqual(msg2.optional_string, u'café') def testParseExotic(self, message_module): message = message_module.TestAllTypes() @@ -425,7 +533,8 @@ class TextFormatTest(TextFormatBase): message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' six.assertRaisesRegex(self, text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + (r'1:23 : \'optional_nested_enum: BARR\': ' + r'Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value named BARR.'), text_format.Parse, text, message) @@ -433,7 +542,8 @@ class TextFormatTest(TextFormatBase): message = message_module.TestAllTypes() text = 'optional_int32: bork' six.assertRaisesRegex(self, text_format.ParseError, - ('1:17 : Couldn\'t parse integer: bork'), + ('1:17 : \'optional_int32: bork\': ' + 'Couldn\'t parse integer: bork'), text_format.Parse, text, message) def testParseStringFieldUnescape(self, message_module): @@ -457,6 +567,96 @@ class TextFormatTest(TextFormatBase): message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5]) + def testParseOneof(self, message_module): + m = message_module.TestAllTypes() + m.oneof_uint32 = 11 + m2 = message_module.TestAllTypes() + text_format.Parse(text_format.MessageToString(m), m2) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + def testParseMultipleOneof(self, message_module): + m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) + m2 = message_module.TestAllTypes() + with six.assertRaisesRegex(self, text_format.ParseError, + ' is specified along with field '): + text_format.Parse(m_string, m2) + + # This example contains non-ASCII codepoint unicode data as literals + # which should come through as utf-8 for bytes, and as the unicode + # itself for string fields. It also demonstrates escaped binary data. + # The ur"" string prefix is unfortunately missing from Python 3 + # so we resort to double escaping our \s so that they come through. + _UNICODE_SAMPLE = u""" + optional_bytes: 'Á short desçription' + optional_string: 'Á short desçription' + repeated_bytes: '\\303\\201 short des\\303\\247ription' + repeated_bytes: '\\x12\\x34\\x56\\x78\\x90\\xab\\xcd\\xef' + repeated_string: '\\xd0\\x9f\\xd1\\x80\\xd0\\xb8\\xd0\\xb2\\xd0\\xb5\\xd1\\x82' + """ + _BYTES_SAMPLE = _UNICODE_SAMPLE.encode('utf-8') + _GOLDEN_UNICODE = u'Á short desçription' + _GOLDEN_BYTES = _GOLDEN_UNICODE.encode('utf-8') + _GOLDEN_BYTES_1 = b'\x12\x34\x56\x78\x90\xab\xcd\xef' + _GOLDEN_STR_0 = u'Привет' + + def testParseUnicode(self, message_module): + m = message_module.TestAllTypes() + text_format.Parse(self._UNICODE_SAMPLE, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data. + self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1) + # repeated_string[0] contained \ escaped data representing the UTF-8 + # representation of _GOLDEN_STR_0 - it needs to decode as such. + self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0) + + def testParseBytes(self, message_module): + m = message_module.TestAllTypes() + text_format.Parse(self._BYTES_SAMPLE, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data. + self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1) + # repeated_string[0] contained \ escaped data representing the UTF-8 + # representation of _GOLDEN_STR_0 - it needs to decode as such. + self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0) + + def testFromBytesFile(self, message_module): + m = message_module.TestAllTypes() + f = io.BytesIO(self._BYTES_SAMPLE) + text_format.ParseLines(f, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + def testFromUnicodeFile(self, message_module): + m = message_module.TestAllTypes() + f = io.StringIO(self._UNICODE_SAMPLE) + text_format.ParseLines(f, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + def testFromBytesLines(self, message_module): + m = message_module.TestAllTypes() + text_format.ParseLines(self._BYTES_SAMPLE.split(b'\n'), m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + def testFromUnicodeLines(self, message_module): + m = message_module.TestAllTypes() + text_format.ParseLines(self._UNICODE_SAMPLE.split(u'\n'), m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatMergeTests(TextFormatBase): + def testMergeDuplicateScalars(self, message_module): message = message_module.TestAllTypes() text = ('optional_int32: 42 ' 'optional_int32: 67') @@ -472,26 +672,12 @@ class TextFormatTest(TextFormatBase): self.assertTrue(r is message) self.assertEqual(2, message.optional_nested_message.bb) - def testParseOneof(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m2 = message_module.TestAllTypes() - text_format.Parse(text_format.MessageToString(m), m2) - self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - def testMergeMultipleOneof(self, message_module): m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) m2 = message_module.TestAllTypes() text_format.Merge(m_string, m2) self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) - def testParseMultipleOneof(self, message_module): - m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) - m2 = message_module.TestAllTypes() - with self.assertRaisesRegexp(text_format.ParseError, - ' is specified along with field '): - text_format.Parse(m_string, m2) - # These are tests that aren't fundamentally specific to proto2, but are at # the moment because of differences between the proto2 and proto3 test schemas. @@ -649,6 +835,29 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): ' }\n' '}\n') + # In cpp implementation, __str__ calls the cpp implementation of text format. + def testPrintMapUsingCppImplementation(self): + message = map_unittest_pb2.TestMap() + inner_msg = message.map_int32_foreign_message[111] + inner_msg.c = 1 + self.assertEqual( + str(message), + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 1\n' + ' }\n' + '}\n') + inner_msg.c = 2 + self.assertEqual( + str(message), + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 2\n' + ' }\n' + '}\n') + def testMapOrderEnforcement(self): message = map_unittest_pb2.TestMap() for letter in string.ascii_uppercase[13:26]: @@ -938,7 +1147,7 @@ class Proto2Tests(TextFormatBase): '}\n') six.assertRaisesRegex(self, text_format.ParseError, - '5:1 : Expected ">".', + '5:1 : \'}\': Expected ">".', text_format.Parse, malformed, message, @@ -981,7 +1190,8 @@ class Proto2Tests(TextFormatBase): with self.assertRaises(text_format.ParseError) as e: text_format.Parse(text, message) self.assertEqual(str(e.exception), - '1:27 : Expected identifier or number, got "bb".') + '1:27 : \'optional_nested_message { "bb": 1 }\': ' + 'Expected identifier or number, got "bb".') def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() @@ -998,7 +1208,8 @@ class Proto2Tests(TextFormatBase): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' six.assertRaisesRegex(self, text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + (r'1:23 : \'optional_nested_enum: 100\': ' + r'Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value with number 100.'), text_format.Parse, text, message) @@ -1209,6 +1420,24 @@ class Proto3Tests(unittest.TestCase): ' < data: "string" > ' '>') + def testPrintAndParseMessageInvalidAny(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + # Only include string after last '/' in type_url. + message.any_value.type_url = message.any_value.TypeName() + text = text_format.MessageToString(message) + self.assertEqual( + text, 'any_value {\n' + ' type_url: "protobuf_unittest.OneString"\n' + ' value: "\\n\\006string"\n' + '}\n') + + parsed_message = any_test_pb2.TestAny() + text_format.Parse(text, parsed_message) + self.assertEqual(message, parsed_message) + def testUnknownEnums(self): message = unittest_proto3_arena_pb2.TestAllTypes() message2 = unittest_proto3_arena_pb2.TestAllTypes() @@ -1448,6 +1677,26 @@ class TokenizerTest(unittest.TestCase): self.assertEqual(0, text_format._ConsumeUint64(tokenizer)) self.assertTrue(tokenizer.AtEnd()) + def testConsumeOctalIntegers(self): + """Test support for C style octal integers.""" + text = '00 -00 04 0755 -010 007 -0033 08 -09 01' + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(4, tokenizer.ConsumeInteger()) + self.assertEqual(0o755, tokenizer.ConsumeInteger()) + self.assertEqual(-0o10, tokenizer.ConsumeInteger()) + self.assertEqual(7, tokenizer.ConsumeInteger()) + self.assertEqual(-0o033, tokenizer.ConsumeInteger()) + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() # 08 + tokenizer.NextToken() + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() # -09 + tokenizer.NextToken() + self.assertEqual(1, tokenizer.ConsumeInteger()) + self.assertTrue(tokenizer.AtEnd()) + def testConsumeByteString(self): text = '"string1\'' tokenizer = text_format.Tokenizer(text.splitlines()) @@ -1556,6 +1805,12 @@ class TokenizerTest(unittest.TestCase): tokenizer.ConsumeCommentOrTrailingComment()) self.assertTrue(tokenizer.AtEnd()) + def testHugeString(self): + # With pathologic backtracking, fails with Forge OOM. + text = '"' + 'a' * (10 * 1024 * 1024) + '"' + tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False) + tokenizer.ConsumeString() + # Tests for pretty printer functionality. @_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2)) @@ -1652,5 +1907,64 @@ class PrettyPrinterTest(TextFormatBase): 'repeated_nested_message { My lucky number is 42 } ' 'repeated_nested_message { My lucky number is 99 }')) + +class WhitespaceTest(TextFormatBase): + + def setUp(self): + self.out = text_format.TextWriter(False) + self.addCleanup(self.out.close) + self.message = unittest_pb2.NestedTestAllTypes() + self.message.child.payload.optional_string = 'value' + self.field = self.message.DESCRIPTOR.fields_by_name['child'] + self.value = self.message.child + + def testMessageToString(self): + self.CompareToGoldenText( + text_format.MessageToString(self.message), + textwrap.dedent("""\ + child { + payload { + optional_string: "value" + } + } + """)) + + def testPrintMessage(self): + text_format.PrintMessage(self.message, self.out) + self.CompareToGoldenText( + self.out.getvalue(), + textwrap.dedent("""\ + child { + payload { + optional_string: "value" + } + } + """)) + + def testPrintField(self): + text_format.PrintField(self.field, self.value, self.out) + self.CompareToGoldenText( + self.out.getvalue(), + textwrap.dedent("""\ + child { + payload { + optional_string: "value" + } + } + """)) + + def testPrintFieldValue(self): + text_format.PrintFieldValue( + self.field, self.value, self.out) + self.CompareToGoldenText( + self.out.getvalue(), + textwrap.dedent("""\ + { + payload { + optional_string: "value" + } + }""")) + + if __name__ == '__main__': unittest.main() |