From d64a2d9941c36a7bc2a7959ea10ab8363192ac14 Mon Sep 17 00:00:00 2001 From: Adam Cozzette Date: Wed, 29 Jun 2016 15:23:27 -0700 Subject: Integrated internal changes from Google This includes all internal changes from around May 20 to now. --- python/google/protobuf/descriptor.py | 42 +- python/google/protobuf/descriptor_pool.py | 65 +- python/google/protobuf/internal/containers.py | 6 +- .../protobuf/internal/descriptor_pool_test.py | 18 + python/google/protobuf/internal/descriptor_test.py | 41 +- .../protobuf/internal/file_options_test.proto | 43 ++ .../google/protobuf/internal/json_format_test.py | 13 + python/google/protobuf/internal/message_test.py | 4 + .../google/protobuf/internal/text_format_test.py | 624 +++++++++++------ python/google/protobuf/json_format.py | 762 +++++++++++---------- python/google/protobuf/pyext/descriptor.cc | 277 +++++++- python/google/protobuf/pyext/descriptor.h | 6 + .../google/protobuf/pyext/descriptor_containers.cc | 158 ++++- .../google/protobuf/pyext/descriptor_containers.h | 8 + python/google/protobuf/pyext/descriptor_pool.cc | 38 + python/google/protobuf/pyext/map_container.cc | 1 - python/google/protobuf/pyext/message.cc | 66 +- python/google/protobuf/pyext/message.h | 6 +- python/google/protobuf/pyext/message_module.cc | 88 +++ python/google/protobuf/text_format.py | 613 ++++++++++++----- 20 files changed, 2007 insertions(+), 872 deletions(-) create mode 100644 python/google/protobuf/internal/file_options_test.proto create mode 100644 python/google/protobuf/pyext/message_module.cc (limited to 'python/google') diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 3209b34d..2eba1232 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -258,7 +258,7 @@ class Descriptor(_NestedDescriptorBase): def __new__(cls, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=None, is_extendable=True, extension_ranges=None, oneofs=None, - file=None, serialized_start=None, serialized_end=None, + file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin syntax=None): _message.Message._CheckCalledFromGeneratedFile() return _message.default_pool.FindMessageTypeByName(full_name) @@ -269,8 +269,8 @@ class Descriptor(_NestedDescriptorBase): def __init__(self, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=None, is_extendable=True, extension_ranges=None, oneofs=None, - file=None, serialized_start=None, serialized_end=None, - syntax=None): # pylint:disable=redefined-builtin + file=None, serialized_start=None, serialized_end=None, # pylint: disable=redefined-builtin + syntax=None): """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -665,7 +665,7 @@ class EnumValueDescriptor(DescriptorBase): self.type = type -class OneofDescriptor(object): +class OneofDescriptor(DescriptorBase): """Descriptor for a oneof field. name: (str) Name of the oneof field. @@ -682,12 +682,15 @@ class OneofDescriptor(object): if _USE_C_DESCRIPTORS: _C_DESCRIPTOR_CLASS = _message.OneofDescriptor - def __new__(cls, name, full_name, index, containing_type, fields): + def __new__( + cls, name, full_name, index, containing_type, fields, options=None): _message.Message._CheckCalledFromGeneratedFile() return _message.default_pool.FindOneofByName(full_name) - def __init__(self, name, full_name, index, containing_type, fields): + def __init__( + self, name, full_name, index, containing_type, fields, options=None): """Arguments are as described in the attribute description above.""" + super(OneofDescriptor, self).__init__(options, 'OneofOptions') self.name = name self.full_name = full_name self.index = index @@ -705,11 +708,22 @@ class ServiceDescriptor(_NestedDescriptorBase): definition appears withing the .proto file. methods: (list of MethodDescriptor) List of methods provided by this service. + methods_by_name: (dict str -> MethodDescriptor) Same MethodDescriptor + objects as in |methods_by_name|, but indexed by "name" attribute in each + MethodDescriptor. options: (descriptor_pb2.ServiceOptions) Service options message or None to use default service options. file: (FileDescriptor) Reference to file info. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.ServiceDescriptor + + def __new__(cls, name, full_name, index, methods, options=None, file=None, # pylint: disable=redefined-builtin + serialized_start=None, serialized_end=None): + _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access + return _message.default_pool.FindServiceByName(full_name) + def __init__(self, name, full_name, index, methods, options=None, file=None, serialized_start=None, serialized_end=None): super(ServiceDescriptor, self).__init__( @@ -718,16 +732,14 @@ class ServiceDescriptor(_NestedDescriptorBase): serialized_end=serialized_end) self.index = index self.methods = methods + self.methods_by_name = dict((m.name, m) for m in methods) # Set the containing service for each method in this service. for method in self.methods: method.containing_service = self def FindMethodByName(self, name): """Searches for the specified method, and returns its descriptor.""" - for method in self.methods: - if name == method.name: - return method - return None + return self.methods_by_name.get(name, None) def CopyToProto(self, proto): """Copies this to a descriptor_pb2.ServiceDescriptorProto. @@ -754,6 +766,14 @@ class MethodDescriptor(DescriptorBase): None to use default method options. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.MethodDescriptor + + def __new__(cls, name, full_name, index, containing_service, + input_type, output_type, options=None): + _message.Message._CheckCalledFromGeneratedFile() # pylint: disable=protected-access + return _message.default_pool.FindMethodByName(full_name) + def __init__(self, name, full_name, index, containing_service, input_type, output_type, options=None): """The arguments are as described in the description of MethodDescriptor @@ -788,6 +808,7 @@ class FileDescriptor(DescriptorBase): message_types_by_name: Dict of message names of their descriptors. enum_types_by_name: Dict of enum names and their descriptors. extensions_by_name: Dict of extension names and their descriptors. + services_by_name: Dict of services names and their descriptors. pool: the DescriptorPool this descriptor belongs to. When not passed to the constructor, the global default pool is used. """ @@ -825,6 +846,7 @@ class FileDescriptor(DescriptorBase): self.enum_types_by_name = {} self.extensions_by_name = {} + self.services_by_name = {} self.dependencies = (dependencies or []) self.public_dependencies = (public_dependencies or []) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 20a33701..5c055ab9 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -394,6 +394,11 @@ class DescriptorPool(object): desc_proto_prefix, desc_proto.name, scope) file_descriptor.message_types_by_name[desc_proto.name] = desc + for index, service_proto in enumerate(file_proto.service): + file_descriptor.services_by_name[service_proto.name] = ( + self._MakeServiceDescriptor(service_proto, index, scope, + file_proto.package, file_descriptor)) + self.Add(file_proto) self._file_descriptors[file_proto.name] = file_descriptor @@ -441,7 +446,7 @@ class DescriptorPool(object): for index, extension in enumerate(desc_proto.extension)] oneofs = [ descriptor.OneofDescriptor(desc.name, '.'.join((desc_name, desc.name)), - index, None, []) + index, None, [], desc.options) for index, desc in enumerate(desc_proto.oneof_decl)] extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] if extension_ranges: @@ -679,6 +684,64 @@ class DescriptorPool(object): options=value_proto.options, type=None) + def _MakeServiceDescriptor(self, service_proto, service_index, scope, + package, file_desc): + """Make a protobuf ServiceDescriptor given a ServiceDescriptorProto. + + Args: + service_proto: The descriptor_pb2.ServiceDescriptorProto protobuf message. + service_index: The index of the service in the File. + scope: Dict mapping short and full symbols to message and enum types. + package: Optional package name for the new message EnumDescriptor. + file_desc: The file containing the service descriptor. + + Returns: + The added descriptor. + """ + + if package: + service_name = '.'.join((package, service_proto.name)) + else: + service_name = service_proto.name + + methods = [self._MakeMethodDescriptor(method_proto, service_name, package, + scope, index) + for index, method_proto in enumerate(service_proto.method)] + desc = descriptor.ServiceDescriptor(name=service_proto.name, + full_name=service_name, + index=service_index, + methods=methods, + options=service_proto.options, + file=file_desc) + return desc + + def _MakeMethodDescriptor(self, method_proto, service_name, package, scope, + index): + """Creates a method descriptor from a MethodDescriptorProto. + + Args: + method_proto: The proto describing the method. + service_name: The name of the containing service. + package: Optional package name to look up for types. + scope: Scope containing available types. + index: Index of the method in the service. + + Returns: + An initialized MethodDescriptor object. + """ + full_name = '.'.join((service_name, method_proto.name)) + input_type = self._GetTypeFromScope( + package, method_proto.input_type, scope) + output_type = self._GetTypeFromScope( + package, method_proto.output_type, scope) + return descriptor.MethodDescriptor(name=method_proto.name, + full_name=full_name, + index=index, + containing_service=None, + input_type=input_type, + output_type=output_type, + options=method_proto.options) + def _ExtractSymbols(self, descriptors): """Pulls out all the symbols from descriptor protos. diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 97cdd848..ce46d08c 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -594,7 +594,11 @@ class MessageMap(MutableMapping): def MergeFrom(self, other): for key in other: - self[key].MergeFrom(other[key]) + # According to documentation: "When parsing from the wire or when merging, + # if there are duplicate map keys the last key seen is used". + if key in self: + del self[key] + self[key].CopyFrom(other[key]) # self._message_listener.Modified() not required here, because # mutations to submessages already propagate. diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 6a13e0bc..3c8c7935 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -51,6 +51,7 @@ from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 +from google.protobuf.internal import file_options_test_pb2 from google.protobuf.internal import more_messages_pb2 from google.protobuf import descriptor from google.protobuf import descriptor_database @@ -630,6 +631,23 @@ class AddDescriptorTest(unittest.TestCase): self.assertEqual(pool.FindMessageTypeByName('package.Message').name, 'Message') + def testFileDescriptorOptionsWithCustomDescriptorPool(self): + # Create a descriptor pool, and add a new FileDescriptorProto to it. + pool = descriptor_pool.DescriptorPool() + file_name = 'file_descriptor_options_with_custom_descriptor_pool.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto(name=file_name) + extension_id = file_options_test_pb2.foo_options + file_descriptor_proto.options.Extensions[extension_id].foo_name = 'foo' + pool.Add(file_descriptor_proto) + # The options set on the FileDescriptorProto should be available in the + # descriptor even if they contain extensions that cannot be deserialized + # using the pool. + file_descriptor = pool.FindFileByName(file_name) + options = file_descriptor.GetOptions() + self.assertEqual('foo', options.Extensions[extension_id].foo_name) + # The object returned by GetOptions() is cached. + self.assertIs(options, file_descriptor.GetOptions()) + @unittest.skipIf( api_implementation.Type() != 'cpp', diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index b8e75553..623198c8 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -77,27 +77,24 @@ class DescriptorTest(unittest.TestCase): enum_proto.value.add(name='FOREIGN_BAR', number=5) enum_proto.value.add(name='FOREIGN_BAZ', number=6) + file_proto.message_type.add(name='ResponseMessage') + service_proto = file_proto.service.add( + name='Service') + method_proto = service_proto.method.add( + name='CallMethod', + input_type='.protobuf_unittest.NestedMessage', + output_type='.protobuf_unittest.ResponseMessage') + + # Note: Calling DescriptorPool.Add() multiple times with the same file only + # works if the input is canonical; in particular, all type names must be + # fully qualified. self.pool = self.GetDescriptorPool() self.pool.Add(file_proto) self.my_file = self.pool.FindFileByName(file_proto.name) self.my_message = self.my_file.message_types_by_name[message_proto.name] self.my_enum = self.my_message.enum_types_by_name[enum_proto.name] - - self.my_method = descriptor.MethodDescriptor( - name='Bar', - full_name='protobuf_unittest.TestService.Bar', - index=0, - containing_service=None, - input_type=None, - output_type=None) - self.my_service = descriptor.ServiceDescriptor( - name='TestServiceWithOptions', - full_name='protobuf_unittest.TestServiceWithOptions', - file=self.my_file, - index=0, - methods=[ - self.my_method - ]) + self.my_service = self.my_file.services_by_name[service_proto.name] + self.my_method = self.my_service.methods_by_name[method_proto.name] def GetDescriptorPool(self): return symbol_database.Default().pool @@ -139,13 +136,14 @@ class DescriptorTest(unittest.TestCase): file_descriptor = unittest_custom_options_pb2.DESCRIPTOR message_descriptor =\ unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR - field_descriptor = message_descriptor.fields_by_name["field1"] - enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"] + field_descriptor = message_descriptor.fields_by_name['field1'] + oneof_descriptor = message_descriptor.oneofs_by_name['AnOneof'] + enum_descriptor = message_descriptor.enum_types_by_name['AnEnum'] enum_value_descriptor =\ - message_descriptor.enum_values_by_name["ANENUM_VAL2"] + message_descriptor.enum_values_by_name['ANENUM_VAL2'] service_descriptor =\ unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR - method_descriptor = service_descriptor.FindMethodByName("Foo") + method_descriptor = service_descriptor.FindMethodByName('Foo') file_options = file_descriptor.GetOptions() file_opt1 = unittest_custom_options_pb2.file_opt1 @@ -158,6 +156,9 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(8765432109, field_options.Extensions[field_opt1]) field_opt2 = unittest_custom_options_pb2.field_opt2 self.assertEqual(42, field_options.Extensions[field_opt2]) + oneof_options = oneof_descriptor.GetOptions() + oneof_opt1 = unittest_custom_options_pb2.oneof_opt1 + self.assertEqual(-99, oneof_options.Extensions[oneof_opt1]) enum_options = enum_descriptor.GetOptions() enum_opt1 = unittest_custom_options_pb2.enum_opt1 self.assertEqual(-789, enum_options.Extensions[enum_opt1]) diff --git a/python/google/protobuf/internal/file_options_test.proto b/python/google/protobuf/internal/file_options_test.proto new file mode 100644 index 00000000..4eceeb07 --- /dev/null +++ b/python/google/protobuf/internal/file_options_test.proto @@ -0,0 +1,43 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +syntax = "proto2"; + +import "google/protobuf/descriptor.proto"; + +package google.protobuf.python.internal; + +message FooOptions { + optional string foo_name = 1; +} + +extend .google.protobuf.FileOptions { + optional FooOptions foo_options = 120436268; +} diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py index 9e32ea47..6df12bea 100644 --- a/python/google/protobuf/internal/json_format_test.py +++ b/python/google/protobuf/internal/json_format_test.py @@ -643,6 +643,19 @@ class JsonFormatTest(JsonFormatBase): 'Message type "proto3.TestMessage" has no field named ' '"unknownName".') + def testIgnoreUnknownField(self): + text = '{"unknownName": 1}' + parsed_message = json_format_proto3_pb2.TestMessage() + json_format.Parse(text, parsed_message, ignore_unknown_fields=True) + text = ('{\n' + ' "repeatedValue": [ {\n' + ' "@type": "type.googleapis.com/proto3.MessageType",\n' + ' "unknownName": 1\n' + ' }]\n' + '}\n') + parsed_message = json_format_proto3_pb2.TestAny() + json_format.Parse(text, parsed_message, ignore_unknown_fields=True) + def testDuplicateField(self): # Duplicate key check is not supported for python2.6 if sys.version_info < (2, 7): diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 4ee31d8e..1e95adf9 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1435,6 +1435,8 @@ class Proto3Test(unittest.TestCase): msg2.map_int32_int32[12] = 55 msg2.map_int64_int64[88] = 99 msg2.map_int32_foreign_message[222].c = 15 + msg2.map_int32_foreign_message[222].d = 20 + old_map_value = msg2.map_int32_foreign_message[222] msg2.MergeFrom(msg) @@ -1444,6 +1446,8 @@ class Proto3Test(unittest.TestCase): self.assertEqual(99, msg2.map_int64_int64[88]) self.assertEqual(5, msg2.map_int32_foreign_message[111].c) self.assertEqual(10, msg2.map_int32_foreign_message[222].c) + self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) + self.assertEqual(15, old_map_value.c) # Verify that there is only one entry per key, even though the MergeFrom # may have internally created multiple entries for a single key in the diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index ab2bf05b..0e38e0e9 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -40,12 +40,13 @@ import six import string try: - import unittest2 as unittest #PY26 + import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top except ImportError: - import unittest + import unittest # pylint: disable=g-import-not-at-top from google.protobuf.internal import _parameterized +from google.protobuf import any_test_pb2 from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -53,6 +54,7 @@ from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf.internal import message_set_extensions_pb2 +from google.protobuf import descriptor_pool from google.protobuf import text_format @@ -90,13 +92,11 @@ class TextFormatBase(unittest.TestCase): .replace('e-0','e-').replace('e-0','e-') # Floating point fields are printed with .0 suffix even if they are # actualy integer numbers. - text = re.compile('\.0$', re.MULTILINE).sub('', text) + text = re.compile(r'\.0$', re.MULTILINE).sub('', text) return text -@_parameterized.Parameters( - (unittest_pb2), - (unittest_proto3_arena_pb2)) +@_parameterized.Parameters((unittest_pb2), (unittest_proto3_arena_pb2)) class TextFormatTest(TextFormatBase): def testPrintExotic(self, message_module): @@ -120,8 +120,10 @@ class TextFormatTest(TextFormatBase): 'repeated_string: "\\303\\274\\352\\234\\237"\n') def testPrintExoticUnicodeSubclass(self, message_module): + class UnicodeSub(six.text_type): pass + message = message_module.TestAllTypes() message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) self.CompareToGoldenText( @@ -165,8 +167,8 @@ class TextFormatTest(TextFormatBase): message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') message.repeated_string.append(u'\u00fc\ua71f') self.CompareToGoldenText( - self.RemoveRedundantZeros( - text_format.MessageToString(message, as_one_line=True)), + self.RemoveRedundantZeros(text_format.MessageToString( + message, as_one_line=True)), 'repeated_int64: -9223372036854775808' ' repeated_uint64: 18446744073709551615' ' repeated_double: 123.456' @@ -187,21 +189,23 @@ class TextFormatTest(TextFormatBase): message.repeated_string.append(u'\u00fc\ua71f') # Test as_utf8 = False. - wire_text = text_format.MessageToString( - message, as_one_line=True, as_utf8=False) + wire_text = text_format.MessageToString(message, + as_one_line=True, + as_utf8=False) parsed_message = message_module.TestAllTypes() r = text_format.Parse(wire_text, parsed_message) self.assertIs(r, parsed_message) self.assertEqual(message, parsed_message) # Test as_utf8 = True. - wire_text = text_format.MessageToString( - message, as_one_line=True, as_utf8=True) + wire_text = text_format.MessageToString(message, + as_one_line=True, + as_utf8=True) parsed_message = message_module.TestAllTypes() r = text_format.Parse(wire_text, parsed_message) self.assertIs(r, parsed_message) self.assertEqual(message, parsed_message, - '\n%s != %s' % (message, parsed_message)) + '\n%s != %s' % (message, parsed_message)) def testPrintRawUtf8String(self, message_module): message = message_module.TestAllTypes() @@ -211,7 +215,7 @@ class TextFormatTest(TextFormatBase): parsed_message = message_module.TestAllTypes() text_format.Parse(text, parsed_message) self.assertEqual(message, parsed_message, - '\n%s != %s' % (message, parsed_message)) + '\n%s != %s' % (message, parsed_message)) def testPrintFloatFormat(self, message_module): # Check that float_format argument is passed to sub-message formatting. @@ -232,14 +236,15 @@ class TextFormatTest(TextFormatBase): message.payload.repeated_double.append(.000078900) formatted_fields = ['optional_float: 1.25', 'optional_double: -3.45678901234568e-6', - 'repeated_float: -5642', - 'repeated_double: 7.89e-5'] + 'repeated_float: -5642', 'repeated_double: 7.89e-5'] text_message = text_format.MessageToString(message, float_format='.15g') self.CompareToGoldenText( self.RemoveRedundantZeros(text_message), - 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format(*formatted_fields)) + 'payload {{\n {0}\n {1}\n {2}\n {3}\n}}\n'.format( + *formatted_fields)) # as_one_line=True is a separate code branch where float_format is passed. - text_message = text_format.MessageToString(message, as_one_line=True, + text_message = text_format.MessageToString(message, + as_one_line=True, float_format='.15g') self.CompareToGoldenText( self.RemoveRedundantZeros(text_message), @@ -311,8 +316,7 @@ class TextFormatTest(TextFormatBase): self.assertEqual(123.456, message.repeated_double[0]) self.assertEqual(1.23e22, message.repeated_double[1]) self.assertEqual(1.23e-18, message.repeated_double[2]) - self.assertEqual( - '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0]) + self.assertEqual('\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0]) self.assertEqual('foocorgegrault', message.repeated_string[1]) self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) self.assertEqual(u'\u00fc', message.repeated_string[3]) @@ -371,45 +375,38 @@ class TextFormatTest(TextFormatBase): def testParseSingleWord(self, message_module): message = message_module.TestAllTypes() text = 'foo' - six.assertRaisesRegex(self, - text_format.ParseError, - (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' - r'"foo".'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, ( + r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"foo".'), text_format.Parse, text, message) def testParseUnknownField(self, message_module): message = message_module.TestAllTypes() text = 'unknown_field: 8\n' - six.assertRaisesRegex(self, - text_format.ParseError, - (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' - r'"unknown_field".'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, ( + r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"unknown_field".'), text_format.Parse, text, message) def testParseBadEnumValue(self, message_module): message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' - six.assertRaisesRegex(self, - text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' - r'has no value named BARR.'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, + (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + r'has no value named BARR.'), text_format.Parse, + text, message) message = message_module.TestAllTypes() text = 'optional_nested_enum: 100' - six.assertRaisesRegex(self, - text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' - r'has no value with number 100.'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, + (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + r'has no value with number 100.'), text_format.Parse, + text, message) def testParseBadIntValue(self, message_module): message = message_module.TestAllTypes() text = 'optional_int32: bork' - six.assertRaisesRegex(self, - text_format.ParseError, - ('1:17 : Couldn\'t parse integer: bork'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, + ('1:17 : Couldn\'t parse integer: bork'), + text_format.Parse, text, message) def testParseStringFieldUnescape(self, message_module): message = message_module.TestAllTypes() @@ -419,6 +416,7 @@ class TextFormatTest(TextFormatBase): repeated_string: "\\\\xf\\\\x62" repeated_string: "\\\\\xf\\\\\x62" repeated_string: "\x5cx20"''' + text_format.Parse(text, message) SLASH = '\\' @@ -433,8 +431,7 @@ class TextFormatTest(TextFormatBase): def testMergeDuplicateScalars(self, message_module): message = message_module.TestAllTypes() - text = ('optional_int32: 42 ' - 'optional_int32: 67') + text = ('optional_int32: 42 ' 'optional_int32: 67') r = text_format.Merge(text, message) self.assertIs(r, message) self.assertEqual(67, message.optional_int32) @@ -455,13 +452,11 @@ class TextFormatTest(TextFormatBase): self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) def testParseMultipleOneof(self, message_module): - m_string = '\n'.join([ - 'oneof_uint32: 11', - 'oneof_string: "foo"']) + m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) m2 = message_module.TestAllTypes() if message_module is unittest_pb2: - with self.assertRaisesRegexp( - text_format.ParseError, ' is specified along with field '): + with self.assertRaisesRegexp(text_format.ParseError, + ' is specified along with field '): text_format.Parse(m_string, m2) else: text_format.Parse(m_string, m2) @@ -477,8 +472,8 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros( - text_format.MessageToString(message, pointy_brackets=True)), + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), 'text_format_unittest_data_pointy_oneof.txt') def testParseGolden(self): @@ -499,14 +494,6 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): self.RemoveRedundantZeros(text_format.MessageToString(message)), 'text_format_unittest_data_oneof_implemented.txt') - def testPrintAllFieldsPointy(self): - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros( - text_format.MessageToString(message, pointy_brackets=True)), - 'text_format_unittest_data_pointy_oneof.txt') - def testPrintInIndexOrder(self): message = unittest_pb2.TestFieldOrderings() message.my_string = '115' @@ -520,8 +507,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n' 'optional_nested_message {\n oo: 0\n bb: 1\n}\n') self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString( - message)), + self.RemoveRedundantZeros(text_format.MessageToString(message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n' 'optional_nested_message {\n bb: 1\n oo: 0\n}\n') @@ -552,14 +538,13 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): message.map_int64_int64[-2**33] = -2**34 message.map_uint32_uint32[123] = 456 message.map_uint64_uint64[2**33] = 2**34 - message.map_string_string["abc"] = "123" + message.map_string_string['abc'] = '123' message.map_int32_foreign_message[111].c = 5 # Maps are serialized to text format using their underlying repeated # representation. self.CompareToGoldenText( - text_format.MessageToString(message), - 'map_int32_int32 {\n' + text_format.MessageToString(message), 'map_int32_int32 {\n' ' key: -123\n' ' value: -456\n' '}\n' @@ -592,9 +577,8 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): message.map_string_string[letter] = 'dummy' for letter in reversed(string.ascii_uppercase[0:13]): message.map_string_string[letter] = 'dummy' - golden = ''.join(( - 'map_string_string {\n key: "%c"\n value: "dummy"\n}\n' % (letter,) - for letter in string.ascii_uppercase)) + golden = ''.join(('map_string_string {\n key: "%c"\n value: "dummy"\n}\n' + % (letter,) for letter in string.ascii_uppercase)) self.CompareToGoldenText(text_format.MessageToString(message), golden) def testMapOrderSemantics(self): @@ -602,9 +586,7 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): # The C++ implementation emits defaulted-value fields, while the Python # implementation does not. Adjusting for this is awkward, but it is # valuable to test against a common golden file. - line_blacklist = (' key: 0\n', - ' value: 0\n', - ' key: false\n', + line_blacklist = (' key: 0\n', ' value: 0\n', ' key: false\n', ' value: false\n') golden_lines = [line for line in golden_lines if line not in line_blacklist] @@ -627,8 +609,7 @@ class Proto2Tests(TextFormatBase): message.message_set.Extensions[ext1].i = 23 message.message_set.Extensions[ext2].str = 'foo' self.CompareToGoldenText( - text_format.MessageToString(message), - 'message_set {\n' + text_format.MessageToString(message), 'message_set {\n' ' [protobuf_unittest.TestMessageSetExtension1] {\n' ' i: 23\n' ' }\n' @@ -654,16 +635,14 @@ class Proto2Tests(TextFormatBase): message.message_set.Extensions[ext1].i = 23 message.message_set.Extensions[ext2].str = 'foo' text_format.PrintMessage(message, out, use_field_number=True) - self.CompareToGoldenText( - out.getvalue(), - '1 {\n' - ' 1545008 {\n' - ' 15: 23\n' - ' }\n' - ' 1547769 {\n' - ' 25: \"foo\"\n' - ' }\n' - '}\n') + self.CompareToGoldenText(out.getvalue(), '1 {\n' + ' 1545008 {\n' + ' 15: 23\n' + ' }\n' + ' 1547769 {\n' + ' 25: \"foo\"\n' + ' }\n' + '}\n') out.close() def testPrintMessageSetAsOneLine(self): @@ -685,8 +664,7 @@ class Proto2Tests(TextFormatBase): def testParseMessageSet(self): message = unittest_pb2.TestAllTypes() - text = ('repeated_uint64: 1\n' - 'repeated_uint64: 2\n') + text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n') text_format.Parse(text, message) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) @@ -708,8 +686,7 @@ class Proto2Tests(TextFormatBase): def testParseMessageByFieldNumber(self): message = unittest_pb2.TestAllTypes() - text = ('34: 1\n' - 'repeated_uint64: 2\n') + text = ('34: 1\n' 'repeated_uint64: 2\n') text_format.Parse(text, message, allow_field_number=True) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) @@ -732,12 +709,9 @@ class Proto2Tests(TextFormatBase): # Can't parse field number without set allow_field_number=True. message = unittest_pb2.TestAllTypes() text = '34:1\n' - six.assertRaisesRegex( - self, - text_format.ParseError, - (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' - r'"34".'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, ( + r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"34".'), text_format.Parse, text, message) # Can't parse if field number is not found. text = '1234:1\n' @@ -746,7 +720,10 @@ class Proto2Tests(TextFormatBase): text_format.ParseError, (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' r'"1234".'), - text_format.Parse, text, message, allow_field_number=True) + text_format.Parse, + text, + message, + allow_field_number=True) def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() @@ -824,7 +801,9 @@ class Proto2Tests(TextFormatBase): six.assertRaisesRegex(self, text_format.ParseError, 'Invalid field value: }', - text_format.Parse, malformed, message, + text_format.Parse, + malformed, + message, allow_unknown_extension=True) message = unittest_mset_pb2.TestMessageSetContainer() @@ -836,7 +815,9 @@ class Proto2Tests(TextFormatBase): six.assertRaisesRegex(self, text_format.ParseError, 'Invalid field value: "', - text_format.Parse, malformed, message, + text_format.Parse, + malformed, + message, allow_unknown_extension=True) message = unittest_mset_pb2.TestMessageSetContainer() @@ -848,7 +829,9 @@ class Proto2Tests(TextFormatBase): six.assertRaisesRegex(self, text_format.ParseError, 'Invalid field value: "', - text_format.Parse, malformed, message, + text_format.Parse, + malformed, + message, allow_unknown_extension=True) message = unittest_mset_pb2.TestMessageSetContainer() @@ -860,7 +843,9 @@ class Proto2Tests(TextFormatBase): six.assertRaisesRegex(self, text_format.ParseError, '5:1 : Expected ">".', - text_format.Parse, malformed, message, + text_format.Parse, + malformed, + message, allow_unknown_extension=True) # Don't allow unknown fields with allow_unknown_extension=True. @@ -874,7 +859,9 @@ class Proto2Tests(TextFormatBase): ('2:3 : Message type ' '"proto2_wireformat_unittest.TestMessageSet" has no' ' field named "unknown_field".'), - text_format.Parse, malformed, message, + text_format.Parse, + malformed, + message, allow_unknown_extension=True) # Parse known extension correcty. @@ -896,67 +883,57 @@ class Proto2Tests(TextFormatBase): def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' - six.assertRaisesRegex(self, - text_format.ParseError, - '1:2 : Extension "unknown_extension" not registered.', - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, + '1:2 : Extension "unknown_extension" not registered.', + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() - six.assertRaisesRegex(self, - text_format.ParseError, - ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' - 'extensions.'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, ( + '1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' + 'extensions.'), text_format.Parse, text, message) def testMergeDuplicateExtensionScalars(self): message = unittest_pb2.TestAllExtensions() text = ('[protobuf_unittest.optional_int32_extension]: 42 ' '[protobuf_unittest.optional_int32_extension]: 67') text_format.Merge(text, message) - self.assertEqual( - 67, - message.Extensions[unittest_pb2.optional_int32_extension]) + self.assertEqual(67, + message.Extensions[unittest_pb2.optional_int32_extension]) def testParseDuplicateExtensionScalars(self): message = unittest_pb2.TestAllExtensions() text = ('[protobuf_unittest.optional_int32_extension]: 42 ' '[protobuf_unittest.optional_int32_extension]: 67') - six.assertRaisesRegex(self, - text_format.ParseError, - ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' - 'should not have multiple ' - '"protobuf_unittest.optional_int32_extension" extensions.'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, ( + '1:96 : Message type "protobuf_unittest.TestAllExtensions" ' + 'should not have multiple ' + '"protobuf_unittest.optional_int32_extension" extensions.'), + text_format.Parse, text, message) def testParseDuplicateNestedMessageScalars(self): message = unittest_pb2.TestAllTypes() text = ('optional_nested_message { bb: 1 } ' 'optional_nested_message { bb: 2 }') - six.assertRaisesRegex(self, - text_format.ParseError, - ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' - 'should not have multiple "bb" fields.'), - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, ( + '1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), text_format.Parse, text, + message) def testParseDuplicateScalars(self): message = unittest_pb2.TestAllTypes() - text = ('optional_int32: 42 ' - 'optional_int32: 67') - six.assertRaisesRegex(self, - text_format.ParseError, - ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' - 'have multiple "optional_int32" fields.'), - text_format.Parse, text, message) + text = ('optional_int32: 42 ' 'optional_int32: 67') + six.assertRaisesRegex(self, text_format.ParseError, ( + '1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), text_format.Parse, text, + message) def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' - six.assertRaisesRegex(self, - text_format.ParseError, '1:16 : Expected ">".', - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected ">".', + text_format.Parse, text, message) text = 'RepeatedGroup: {' - six.assertRaisesRegex(self, - text_format.ParseError, '1:16 : Expected "}".', - text_format.Parse, text, message) + six.assertRaisesRegex(self, text_format.ParseError, '1:16 : Expected "}".', + text_format.Parse, text, message) def testParseEmptyGroup(self): message = unittest_pb2.TestAllTypes() @@ -1007,10 +984,197 @@ class Proto2Tests(TextFormatBase): self.assertEqual(-2**34, message.map_int64_int64[-2**33]) self.assertEqual(456, message.map_uint32_uint32[123]) self.assertEqual(2**34, message.map_uint64_uint64[2**33]) - self.assertEqual("123", message.map_string_string["abc"]) + self.assertEqual('123', message.map_string_string['abc']) self.assertEqual(5, message.map_int32_foreign_message[111].c) +class Proto3Tests(unittest.TestCase): + + def testPrintMessageExpandAny(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + self.assertEqual( + text_format.MessageToString(message, + descriptor_pool=descriptor_pool.Default()), + 'any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string"\n' + ' }\n' + '}\n') + + def testPrintMessageExpandAnyRepeated(self): + packed_message = unittest_pb2.OneString() + message = any_test_pb2.TestAny() + packed_message.data = 'string0' + message.repeated_any_value.add().Pack(packed_message) + packed_message.data = 'string1' + message.repeated_any_value.add().Pack(packed_message) + self.assertEqual( + text_format.MessageToString(message, + descriptor_pool=descriptor_pool.Default()), + 'repeated_any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string0"\n' + ' }\n' + '}\n' + 'repeated_any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string1"\n' + ' }\n' + '}\n') + + def testPrintMessageExpandAnyNoDescriptorPool(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + self.assertEqual( + text_format.MessageToString(message, descriptor_pool=None), + 'any_value {\n' + ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n' + ' value: "\\n\\006string"\n' + '}\n') + + def testPrintMessageExpandAnyDescriptorPoolMissingType(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + empty_pool = descriptor_pool.DescriptorPool() + self.assertEqual( + text_format.MessageToString(message, descriptor_pool=empty_pool), + 'any_value {\n' + ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n' + ' value: "\\n\\006string"\n' + '}\n') + + def testPrintMessageExpandAnyPointyBrackets(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + self.assertEqual( + text_format.MessageToString(message, + pointy_brackets=True, + descriptor_pool=descriptor_pool.Default()), + 'any_value <\n' + ' [type.googleapis.com/protobuf_unittest.OneString] <\n' + ' data: "string"\n' + ' >\n' + '>\n') + + def testPrintMessageExpandAnyAsOneLine(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + self.assertEqual( + text_format.MessageToString(message, + as_one_line=True, + descriptor_pool=descriptor_pool.Default()), + 'any_value {' + ' [type.googleapis.com/protobuf_unittest.OneString]' + ' { data: "string" } ' + '}') + + def testPrintMessageExpandAnyAsOneLinePointyBrackets(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + self.assertEqual( + text_format.MessageToString(message, + as_one_line=True, + pointy_brackets=True, + descriptor_pool=descriptor_pool.Default()), + 'any_value <' + ' [type.googleapis.com/protobuf_unittest.OneString]' + ' < data: "string" > ' + '>') + + def testMergeExpandedAny(self): + message = any_test_pb2.TestAny() + text = ('any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string"\n' + ' }\n' + '}\n') + text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + packed_message = unittest_pb2.OneString() + message.any_value.Unpack(packed_message) + self.assertEqual('string', packed_message.data) + + def testMergeExpandedAnyRepeated(self): + message = any_test_pb2.TestAny() + text = ('repeated_any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string0"\n' + ' }\n' + '}\n' + 'repeated_any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string1"\n' + ' }\n' + '}\n') + text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + packed_message = unittest_pb2.OneString() + message.repeated_any_value[0].Unpack(packed_message) + self.assertEqual('string0', packed_message.data) + message.repeated_any_value[1].Unpack(packed_message) + self.assertEqual('string1', packed_message.data) + + def testMergeExpandedAnyPointyBrackets(self): + message = any_test_pb2.TestAny() + text = ('any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] <\n' + ' data: "string"\n' + ' >\n' + '}\n') + text_format.Merge(text, message, descriptor_pool=descriptor_pool.Default()) + packed_message = unittest_pb2.OneString() + message.any_value.Unpack(packed_message) + self.assertEqual('string', packed_message.data) + + def testMergeExpandedAnyNoDescriptorPool(self): + message = any_test_pb2.TestAny() + text = ('any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string"\n' + ' }\n' + '}\n') + with self.assertRaises(text_format.ParseError) as e: + text_format.Merge(text, message, descriptor_pool=None) + self.assertEqual(str(e.exception), + 'Descriptor pool required to parse expanded Any field') + + def testMergeExpandedAnyDescriptorPoolMissingType(self): + message = any_test_pb2.TestAny() + text = ('any_value {\n' + ' [type.googleapis.com/protobuf_unittest.OneString] {\n' + ' data: "string"\n' + ' }\n' + '}\n') + with self.assertRaises(text_format.ParseError) as e: + empty_pool = descriptor_pool.DescriptorPool() + text_format.Merge(text, message, descriptor_pool=empty_pool) + self.assertEqual( + str(e.exception), + 'Type protobuf_unittest.OneString not found in descriptor pool') + + def testMergeUnexpandedAny(self): + text = ('any_value {\n' + ' type_url: "type.googleapis.com/protobuf_unittest.OneString"\n' + ' value: "\\n\\006string"\n' + '}\n') + message = any_test_pb2.TestAny() + text_format.Merge(text, message) + packed_message = unittest_pb2.OneString() + message.any_value.Unpack(packed_message) + self.assertEqual('string', packed_message.data) + + class TokenizerTest(unittest.TestCase): def testSimpleTokenCases(self): @@ -1021,79 +1185,55 @@ class TokenizerTest(unittest.TestCase): 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ') - tokenizer = text_format._Tokenizer(text.splitlines()) - methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), - ':', + tokenizer = text_format.Tokenizer(text.splitlines()) + methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), - (tokenizer.ConsumeIdentifier, 'identifier2'), - ':', - (tokenizer.ConsumeInt32, 123), - (tokenizer.ConsumeIdentifier, 'identifier3'), - ':', + (tokenizer.ConsumeIdentifier, 'identifier2'), ':', + (tokenizer.ConsumeInteger, 123), + (tokenizer.ConsumeIdentifier, 'identifier3'), ':', (tokenizer.ConsumeString, 'string'), - (tokenizer.ConsumeIdentifier, 'identifiER_4'), - ':', + (tokenizer.ConsumeIdentifier, 'identifiER_4'), ':', (tokenizer.ConsumeFloat, 1.1e+2), - (tokenizer.ConsumeIdentifier, 'ID5'), - ':', + (tokenizer.ConsumeIdentifier, 'ID5'), ':', (tokenizer.ConsumeFloat, -0.23), - (tokenizer.ConsumeIdentifier, 'ID6'), - ':', + (tokenizer.ConsumeIdentifier, 'ID6'), ':', (tokenizer.ConsumeString, 'aaaa\'bbbb'), - (tokenizer.ConsumeIdentifier, 'ID7'), - ':', + (tokenizer.ConsumeIdentifier, 'ID7'), ':', (tokenizer.ConsumeString, 'aa\"bb'), - (tokenizer.ConsumeIdentifier, 'ID8'), - ':', - '{', - (tokenizer.ConsumeIdentifier, 'A'), - ':', + (tokenizer.ConsumeIdentifier, 'ID8'), ':', '{', + (tokenizer.ConsumeIdentifier, 'A'), ':', (tokenizer.ConsumeFloat, float('inf')), - (tokenizer.ConsumeIdentifier, 'B'), - ':', + (tokenizer.ConsumeIdentifier, 'B'), ':', (tokenizer.ConsumeFloat, -float('inf')), - (tokenizer.ConsumeIdentifier, 'C'), - ':', + (tokenizer.ConsumeIdentifier, 'C'), ':', (tokenizer.ConsumeBool, True), - (tokenizer.ConsumeIdentifier, 'D'), - ':', - (tokenizer.ConsumeBool, False), - '}', - (tokenizer.ConsumeIdentifier, 'ID9'), - ':', - (tokenizer.ConsumeUint32, 22), - (tokenizer.ConsumeIdentifier, 'ID10'), - ':', - (tokenizer.ConsumeInt64, -111111111111111111), - (tokenizer.ConsumeIdentifier, 'ID11'), - ':', - (tokenizer.ConsumeInt32, -22), - (tokenizer.ConsumeIdentifier, 'ID12'), - ':', - (tokenizer.ConsumeUint64, 2222222222222222222), - (tokenizer.ConsumeIdentifier, 'ID13'), - ':', + (tokenizer.ConsumeIdentifier, 'D'), ':', + (tokenizer.ConsumeBool, False), '}', + (tokenizer.ConsumeIdentifier, 'ID9'), ':', + (tokenizer.ConsumeInteger, 22), + (tokenizer.ConsumeIdentifier, 'ID10'), ':', + (tokenizer.ConsumeInteger, -111111111111111111), + (tokenizer.ConsumeIdentifier, 'ID11'), ':', + (tokenizer.ConsumeInteger, -22), + (tokenizer.ConsumeIdentifier, 'ID12'), ':', + (tokenizer.ConsumeInteger, 2222222222222222222), + (tokenizer.ConsumeIdentifier, 'ID13'), ':', (tokenizer.ConsumeFloat, 1.23456), - (tokenizer.ConsumeIdentifier, 'ID14'), - ':', + (tokenizer.ConsumeIdentifier, 'ID14'), ':', (tokenizer.ConsumeFloat, 1.2e+2), - (tokenizer.ConsumeIdentifier, 'false_bool'), - ':', + (tokenizer.ConsumeIdentifier, 'false_bool'), ':', (tokenizer.ConsumeBool, False), - (tokenizer.ConsumeIdentifier, 'true_BOOL'), - ':', + (tokenizer.ConsumeIdentifier, 'true_BOOL'), ':', (tokenizer.ConsumeBool, True), - (tokenizer.ConsumeIdentifier, 'true_bool1'), - ':', + (tokenizer.ConsumeIdentifier, 'true_bool1'), ':', (tokenizer.ConsumeBool, True), - (tokenizer.ConsumeIdentifier, 'false_BOOL1'), - ':', + (tokenizer.ConsumeIdentifier, 'false_BOOL1'), ':', (tokenizer.ConsumeBool, False)] i = 0 while not tokenizer.AtEnd(): m = methods[i] - if type(m) == str: + if isinstance(m, str): token = tokenizer.token self.assertEqual(token, m) tokenizer.NextToken() @@ -1101,59 +1241,119 @@ class TokenizerTest(unittest.TestCase): self.assertEqual(m[1], m[0]()) i += 1 - def testConsumeIntegers(self): + def testConsumeAbstractIntegers(self): # This test only tests the failures in the integer parsing methods as well # as the '0' special cases. int64_max = (1 << 63) - 1 uint32_max = (1 << 32) - 1 text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) - tokenizer = text_format._Tokenizer(text.splitlines()) - self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) - self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) - self.assertEqual(-1, tokenizer.ConsumeInt32()) + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertEqual(-1, tokenizer.ConsumeInteger()) - self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) - self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt32) - self.assertEqual(uint32_max + 1, tokenizer.ConsumeInt64()) + self.assertEqual(uint32_max + 1, tokenizer.ConsumeInteger()) - self.assertRaises(text_format.ParseError, tokenizer.ConsumeInt64) - self.assertEqual(int64_max + 1, tokenizer.ConsumeUint64()) + self.assertEqual(int64_max + 1, tokenizer.ConsumeInteger()) + self.assertTrue(tokenizer.AtEnd()) + + text = '-0 0' + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertTrue(tokenizer.AtEnd()) + + def testConsumeIntegers(self): + # This test only tests the failures in the integer parsing methods as well + # as the '0' special cases. + int64_max = (1 << 63) - 1 + uint32_max = (1 << 32) - 1 + text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertRaises(text_format.ParseError, + text_format._ConsumeUint32, tokenizer) + self.assertRaises(text_format.ParseError, + text_format._ConsumeUint64, tokenizer) + self.assertEqual(-1, text_format._ConsumeInt32(tokenizer)) + + self.assertRaises(text_format.ParseError, + text_format._ConsumeUint32, tokenizer) + self.assertRaises(text_format.ParseError, + text_format._ConsumeInt32, tokenizer) + self.assertEqual(uint32_max + 1, text_format._ConsumeInt64(tokenizer)) + + self.assertRaises(text_format.ParseError, + text_format._ConsumeInt64, tokenizer) + self.assertEqual(int64_max + 1, text_format._ConsumeUint64(tokenizer)) self.assertTrue(tokenizer.AtEnd()) text = '-0 -0 0 0' - tokenizer = text_format._Tokenizer(text.splitlines()) - self.assertEqual(0, tokenizer.ConsumeUint32()) - self.assertEqual(0, tokenizer.ConsumeUint64()) - self.assertEqual(0, tokenizer.ConsumeUint32()) - self.assertEqual(0, tokenizer.ConsumeUint64()) + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertEqual(0, text_format._ConsumeUint32(tokenizer)) + self.assertEqual(0, text_format._ConsumeUint64(tokenizer)) + self.assertEqual(0, text_format._ConsumeUint32(tokenizer)) + self.assertEqual(0, text_format._ConsumeUint64(tokenizer)) self.assertTrue(tokenizer.AtEnd()) def testConsumeByteString(self): text = '"string1\'' - tokenizer = text_format._Tokenizer(text.splitlines()) + tokenizer = text_format.Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = 'string1"' - tokenizer = text_format._Tokenizer(text.splitlines()) + tokenizer = text_format.Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\xt"' - tokenizer = text_format._Tokenizer(text.splitlines()) + tokenizer = text_format.Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\"' - tokenizer = text_format._Tokenizer(text.splitlines()) + tokenizer = text_format.Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\x"' - tokenizer = text_format._Tokenizer(text.splitlines()) + tokenizer = text_format.Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) def testConsumeBool(self): text = 'not-a-bool' - tokenizer = text_format._Tokenizer(text.splitlines()) + tokenizer = text_format.Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) + def testSkipComment(self): + tokenizer = text_format.Tokenizer('# some comment'.splitlines()) + self.assertTrue(tokenizer.AtEnd()) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment) + + def testConsumeComment(self): + tokenizer = text_format.Tokenizer('# some comment'.splitlines(), + skip_comments=False) + self.assertFalse(tokenizer.AtEnd()) + self.assertEqual('# some comment', tokenizer.ConsumeComment()) + self.assertTrue(tokenizer.AtEnd()) + + def testConsumeTwoComments(self): + text = '# some comment\n# another comment' + tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False) + self.assertEqual('# some comment', tokenizer.ConsumeComment()) + self.assertFalse(tokenizer.AtEnd()) + self.assertEqual('# another comment', tokenizer.ConsumeComment()) + self.assertTrue(tokenizer.AtEnd()) + + def testConsumeTrailingComment(self): + text = 'some_number: 4\n# some comment' + tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False) + self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment) + + self.assertEqual('some_number', tokenizer.ConsumeIdentifier()) + self.assertEqual(tokenizer.token, ':') + tokenizer.NextToken() + self.assertRaises(text_format.ParseError, tokenizer.ConsumeComment) + self.assertEqual(4, tokenizer.ConsumeInteger()) + self.assertFalse(tokenizer.AtEnd()) + + self.assertEqual('# some comment', tokenizer.ConsumeComment()) + self.assertTrue(tokenizer.AtEnd()) + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index be6a9b63..bb6a1998 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -53,6 +53,7 @@ import re import six import sys +from operator import methodcaller from google.protobuf import descriptor from google.protobuf import symbol_database @@ -98,22 +99,8 @@ def MessageToJson(message, including_default_value_fields=False): Returns: A string containing the JSON formatted protocol buffer message. """ - js = _MessageToJsonObject(message, including_default_value_fields) - return json.dumps(js, indent=2) - - -def _MessageToJsonObject(message, including_default_value_fields): - """Converts message to an object according to Proto3 JSON Specification.""" - message_descriptor = message.DESCRIPTOR - full_name = message_descriptor.full_name - if _IsWrapperMessage(message_descriptor): - return _WrapperMessageToJsonObject(message) - if full_name in _WKTJSONMETHODS: - return _WKTJSONMETHODS[full_name][0]( - message, including_default_value_fields) - js = {} - return _RegularMessageToJsonObject( - message, js, including_default_value_fields) + printer = _Printer(including_default_value_fields) + return printer.ToJsonString(message) def _IsMapEntry(field): @@ -122,179 +109,179 @@ def _IsMapEntry(field): field.message_type.GetOptions().map_entry) -def _RegularMessageToJsonObject(message, js, including_default_value_fields): - """Converts normal message according to Proto3 JSON Specification.""" - fields = message.ListFields() - include_default = including_default_value_fields +class _Printer(object): + """JSON format printer for protocol message.""" - try: - for field, value in fields: - name = field.camelcase_name - if _IsMapEntry(field): - # Convert a map field. - v_field = field.message_type.fields_by_name['value'] - js_map = {} - for key in value: - if isinstance(key, bool): - if key: - recorded_key = 'true' - else: - recorded_key = 'false' - else: - recorded_key = key - js_map[recorded_key] = _FieldToJsonObject( - v_field, value[key], including_default_value_fields) - js[name] = js_map - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - # Convert a repeated field. - js[name] = [_FieldToJsonObject(field, k, include_default) - for k in value] - else: - js[name] = _FieldToJsonObject(field, value, include_default) - - # Serialize default value if including_default_value_fields is True. - if including_default_value_fields: - message_descriptor = message.DESCRIPTOR - for field in message_descriptor.fields: - # Singular message fields and oneof fields will not be affected. - if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and - field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or - field.containing_oneof): - continue - name = field.camelcase_name - if name in js: - # Skip the field which has been serailized already. - continue - if _IsMapEntry(field): - js[name] = {} - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - js[name] = [] - else: - js[name] = _FieldToJsonObject(field, field.default_value) + def __init__(self, + including_default_value_fields=False): + self.including_default_value_fields = including_default_value_fields - except ValueError as e: - raise SerializeToJsonError( - 'Failed to serialize {0} field: {1}.'.format(field.name, e)) + def ToJsonString(self, message): + js = self._MessageToJsonObject(message) + return json.dumps(js, indent=2) - return js + def _MessageToJsonObject(self, message): + """Converts message to an object according to Proto3 JSON Specification.""" + message_descriptor = message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + return self._WrapperMessageToJsonObject(message) + if full_name in _WKTJSONMETHODS: + return methodcaller(_WKTJSONMETHODS[full_name][0], message)(self) + js = {} + return self._RegularMessageToJsonObject(message, js) + def _RegularMessageToJsonObject(self, message, js): + """Converts normal message according to Proto3 JSON Specification.""" + fields = message.ListFields() -def _FieldToJsonObject( - field, value, including_default_value_fields=False): - """Converts field value according to Proto3 JSON Specification.""" - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - return _MessageToJsonObject(value, including_default_value_fields) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: - enum_value = field.enum_type.values_by_number.get(value, None) - if enum_value is not None: - return enum_value.name - else: - raise SerializeToJsonError('Enum field contains an integer value ' - 'which can not mapped to an enum value.') - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: - if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # Use base64 Data encoding for bytes - return base64.b64encode(value).decode('utf-8') - else: - return value - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: - return bool(value) - elif field.cpp_type in _INT64_TYPES: - return str(value) - elif field.cpp_type in _FLOAT_TYPES: - if math.isinf(value): - if value < 0.0: - return _NEG_INFINITY - else: - return _INFINITY - if math.isnan(value): - return _NAN - return value + try: + for field, value in fields: + name = field.camelcase_name + if _IsMapEntry(field): + # Convert a map field. + v_field = field.message_type.fields_by_name['value'] + js_map = {} + for key in value: + if isinstance(key, bool): + if key: + recorded_key = 'true' + else: + recorded_key = 'false' + else: + recorded_key = key + js_map[recorded_key] = self._FieldToJsonObject( + v_field, value[key]) + js[name] = js_map + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + # Convert a repeated field. + js[name] = [self._FieldToJsonObject(field, k) + for k in value] + else: + js[name] = self._FieldToJsonObject(field, value) + + # Serialize default value if including_default_value_fields is True. + if self.including_default_value_fields: + message_descriptor = message.DESCRIPTOR + for field in message_descriptor.fields: + # Singular message fields and oneof fields will not be affected. + if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and + field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or + field.containing_oneof): + continue + name = field.camelcase_name + if name in js: + # Skip the field which has been serailized already. + continue + if _IsMapEntry(field): + js[name] = {} + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + js[name] = [] + else: + js[name] = self._FieldToJsonObject(field, field.default_value) + except ValueError as e: + raise SerializeToJsonError( + 'Failed to serialize {0} field: {1}.'.format(field.name, e)) -def _AnyMessageToJsonObject(message, including_default): - """Converts Any message according to Proto3 JSON Specification.""" - if not message.ListFields(): - return {} - # Must print @type first, use OrderedDict instead of {} - js = OrderedDict() - type_url = message.type_url - js['@type'] = type_url - sub_message = _CreateMessageFromTypeUrl(type_url) - sub_message.ParseFromString(message.value) - message_descriptor = sub_message.DESCRIPTOR - full_name = message_descriptor.full_name - if _IsWrapperMessage(message_descriptor): - js['value'] = _WrapperMessageToJsonObject(sub_message) return js - if full_name in _WKTJSONMETHODS: - js['value'] = _WKTJSONMETHODS[full_name][0](sub_message, including_default) - return js - return _RegularMessageToJsonObject(sub_message, js, including_default) - - -def _CreateMessageFromTypeUrl(type_url): - # TODO(jieluo): Should add a way that users can register the type resolver - # instead of the default one. - db = symbol_database.Default() - type_name = type_url.split('/')[-1] - try: - message_descriptor = db.pool.FindMessageTypeByName(type_name) - except KeyError: - raise TypeError( - 'Can not find message descriptor by type_url: {0}.'.format(type_url)) - message_class = db.GetPrototype(message_descriptor) - return message_class() - - -def _GenericMessageToJsonObject(message, unused_including_default): - """Converts message by ToJsonString according to Proto3 JSON Specification.""" - # Duration, Timestamp and FieldMask have ToJsonString method to do the - # convert. Users can also call the method directly. - return message.ToJsonString() - - -def _ValueMessageToJsonObject(message, unused_including_default=False): - """Converts Value message according to Proto3 JSON Specification.""" - which = message.WhichOneof('kind') - # If the Value message is not set treat as null_value when serialize - # to JSON. The parse back result will be different from original message. - if which is None or which == 'null_value': - return None - if which == 'list_value': - return _ListValueMessageToJsonObject(message.list_value) - if which == 'struct_value': - value = message.struct_value - else: - value = getattr(message, which) - oneof_descriptor = message.DESCRIPTOR.fields_by_name[which] - return _FieldToJsonObject(oneof_descriptor, value) + def _FieldToJsonObject(self, field, value): + """Converts field value according to Proto3 JSON Specification.""" + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + return self._MessageToJsonObject(value) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + return enum_value.name + else: + raise SerializeToJsonError('Enum field contains an integer value ' + 'which can not mapped to an enum value.') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # Use base64 Data encoding for bytes + return base64.b64encode(value).decode('utf-8') + else: + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + return bool(value) + elif field.cpp_type in _INT64_TYPES: + return str(value) + elif field.cpp_type in _FLOAT_TYPES: + if math.isinf(value): + if value < 0.0: + return _NEG_INFINITY + else: + return _INFINITY + if math.isnan(value): + return _NAN + return value + + def _AnyMessageToJsonObject(self, message): + """Converts Any message according to Proto3 JSON Specification.""" + if not message.ListFields(): + return {} + # Must print @type first, use OrderedDict instead of {} + js = OrderedDict() + type_url = message.type_url + js['@type'] = type_url + sub_message = _CreateMessageFromTypeUrl(type_url) + sub_message.ParseFromString(message.value) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + js['value'] = self._WrapperMessageToJsonObject(sub_message) + return js + if full_name in _WKTJSONMETHODS: + js['value'] = methodcaller(_WKTJSONMETHODS[full_name][0], + sub_message)(self) + return js + return self._RegularMessageToJsonObject(sub_message, js) + + def _GenericMessageToJsonObject(self, message): + """Converts message according to Proto3 JSON Specification.""" + # Duration, Timestamp and FieldMask have ToJsonString method to do the + # convert. Users can also call the method directly. + return message.ToJsonString() + + def _ValueMessageToJsonObject(self, message): + """Converts Value message according to Proto3 JSON Specification.""" + which = message.WhichOneof('kind') + # If the Value message is not set treat as null_value when serialize + # to JSON. The parse back result will be different from original message. + if which is None or which == 'null_value': + return None + if which == 'list_value': + return self._ListValueMessageToJsonObject(message.list_value) + if which == 'struct_value': + value = message.struct_value + else: + value = getattr(message, which) + oneof_descriptor = message.DESCRIPTOR.fields_by_name[which] + return self._FieldToJsonObject(oneof_descriptor, value) -def _ListValueMessageToJsonObject(message, unused_including_default=False): - """Converts ListValue message according to Proto3 JSON Specification.""" - return [_ValueMessageToJsonObject(value) - for value in message.values] + def _ListValueMessageToJsonObject(self, message): + """Converts ListValue message according to Proto3 JSON Specification.""" + return [self._ValueMessageToJsonObject(value) + for value in message.values] + def _StructMessageToJsonObject(self, message): + """Converts Struct message according to Proto3 JSON Specification.""" + fields = message.fields + ret = {} + for key in fields: + ret[key] = self._ValueMessageToJsonObject(fields[key]) + return ret -def _StructMessageToJsonObject(message, unused_including_default=False): - """Converts Struct message according to Proto3 JSON Specification.""" - fields = message.fields - ret = {} - for key in fields: - ret[key] = _ValueMessageToJsonObject(fields[key]) - return ret + def _WrapperMessageToJsonObject(self, message): + return self._FieldToJsonObject( + message.DESCRIPTOR.fields_by_name['value'], message.value) def _IsWrapperMessage(message_descriptor): return message_descriptor.file.name == 'google/protobuf/wrappers.proto' -def _WrapperMessageToJsonObject(message): - return _FieldToJsonObject( - message.DESCRIPTOR.fields_by_name['value'], message.value) - - def _DuplicateChecker(js): result = {} for name, value in js: @@ -304,12 +291,27 @@ def _DuplicateChecker(js): return result -def Parse(text, message): +def _CreateMessageFromTypeUrl(type_url): + # TODO(jieluo): Should add a way that users can register the type resolver + # instead of the default one. + db = symbol_database.Default() + type_name = type_url.split('/')[-1] + try: + message_descriptor = db.pool.FindMessageTypeByName(type_name) + except KeyError: + raise TypeError( + 'Can not find message descriptor by type_url: {0}.'.format(type_url)) + message_class = db.GetPrototype(message_descriptor) + return message_class() + + +def Parse(text, message, ignore_unknown_fields=False): """Parses a JSON representation of a protocol message into a message. Args: text: Message JSON representation. message: A protocol beffer message to merge into. + ignore_unknown_fields: If True, do not raise errors for unknown fields. Returns: The same message passed as argument. @@ -326,213 +328,217 @@ def Parse(text, message): js = json.loads(text, object_pairs_hook=_DuplicateChecker) except ValueError as e: raise ParseError('Failed to load JSON: {0}.'.format(str(e))) - _ConvertMessage(js, message) + parser = _Parser(ignore_unknown_fields) + parser.ConvertMessage(js, message) return message -def _ConvertFieldValuePair(js, message): - """Convert field value pairs into regular message. +_INT_OR_FLOAT = six.integer_types + (float,) - Args: - js: A JSON object to convert the field value pairs. - message: A regular protocol message to record the data. - Raises: - ParseError: In case of problems converting. - """ - names = [] - message_descriptor = message.DESCRIPTOR - for name in js: - try: - field = message_descriptor.fields_by_camelcase_name.get(name, None) - if not field: - raise ParseError( - 'Message type "{0}" has no field named "{1}".'.format( - message_descriptor.full_name, name)) - if name in names: - raise ParseError( - 'Message type "{0}" should not have multiple "{1}" fields.'.format( - message.DESCRIPTOR.full_name, name)) - names.append(name) - # Check no other oneof field is parsed. - if field.containing_oneof is not None: - oneof_name = field.containing_oneof.name - if oneof_name in names: - raise ParseError('Message type "{0}" should not have multiple "{1}" ' - 'oneof fields.'.format( - message.DESCRIPTOR.full_name, oneof_name)) - names.append(oneof_name) - - value = js[name] - if value is None: - message.ClearField(field.name) - continue - - # Parse field value. - if _IsMapEntry(field): - message.ClearField(field.name) - _ConvertMapFieldValue(value, message, field) - elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - message.ClearField(field.name) - if not isinstance(value, list): - raise ParseError('repeated field {0} must be in [] which is ' - '{1}.'.format(name, value)) - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - # Repeated message field. - for item in value: - sub_message = getattr(message, field.name).add() - # None is a null_value in Value. - if (item is None and - sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'): - raise ParseError('null is not allowed to be used as an element' - ' in a repeated field.') - _ConvertMessage(item, sub_message) - else: - # Repeated scalar field. - for item in value: - if item is None: - raise ParseError('null is not allowed to be used as an element' - ' in a repeated field.') - getattr(message, field.name).append( - _ConvertScalarFieldValue(item, field)) - elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - sub_message = getattr(message, field.name) - _ConvertMessage(value, sub_message) - else: - setattr(message, field.name, _ConvertScalarFieldValue(value, field)) - except ParseError as e: - if field and field.containing_oneof is None: - raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) - else: - raise ParseError(str(e)) - except ValueError as e: - raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) - except TypeError as e: - raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) +class _Parser(object): + """JSON format parser for protocol message.""" + def __init__(self, + ignore_unknown_fields): + self.ignore_unknown_fields = ignore_unknown_fields -def _ConvertMessage(value, message): - """Convert a JSON object into a message. + def ConvertMessage(self, value, message): + """Convert a JSON object into a message. - Args: - value: A JSON object. - message: A WKT or regular protocol message to record the data. + Args: + value: A JSON object. + message: A WKT or regular protocol message to record the data. - Raises: - ParseError: In case of convert problems. - """ - message_descriptor = message.DESCRIPTOR - full_name = message_descriptor.full_name - if _IsWrapperMessage(message_descriptor): - _ConvertWrapperMessage(value, message) - elif full_name in _WKTJSONMETHODS: - _WKTJSONMETHODS[full_name][1](value, message) - else: - _ConvertFieldValuePair(value, message) - - -def _ConvertAnyMessage(value, message): - """Convert a JSON representation into Any message.""" - if isinstance(value, dict) and not value: - return - try: - type_url = value['@type'] - except KeyError: - raise ParseError('@type is missing when parsing any message.') - - sub_message = _CreateMessageFromTypeUrl(type_url) - message_descriptor = sub_message.DESCRIPTOR - full_name = message_descriptor.full_name - if _IsWrapperMessage(message_descriptor): - _ConvertWrapperMessage(value['value'], sub_message) - elif full_name in _WKTJSONMETHODS: - _WKTJSONMETHODS[full_name][1](value['value'], sub_message) - else: - del value['@type'] - _ConvertFieldValuePair(value, sub_message) - # Sets Any message - message.value = sub_message.SerializeToString() - message.type_url = type_url - - -def _ConvertGenericMessage(value, message): - """Convert a JSON representation into message with FromJsonString.""" - # Durantion, Timestamp, FieldMask have FromJsonString method to do the - # convert. Users can also call the method directly. - message.FromJsonString(value) + Raises: + ParseError: In case of convert problems. + """ + message_descriptor = message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + self._ConvertWrapperMessage(value, message) + elif full_name in _WKTJSONMETHODS: + methodcaller(_WKTJSONMETHODS[full_name][1], value, message)(self) + else: + self._ConvertFieldValuePair(value, message) + + def _ConvertFieldValuePair(self, js, message): + """Convert field value pairs into regular message. + + Args: + js: A JSON object to convert the field value pairs. + message: A regular protocol message to record the data. + + Raises: + ParseError: In case of problems converting. + """ + names = [] + message_descriptor = message.DESCRIPTOR + for name in js: + try: + field = message_descriptor.fields_by_camelcase_name.get(name, None) + if not field: + if self.ignore_unknown_fields: + continue + raise ParseError( + 'Message type "{0}" has no field named "{1}".'.format( + message_descriptor.full_name, name)) + if name in names: + raise ParseError('Message type "{0}" should not have multiple ' + '"{1}" fields.'.format( + message.DESCRIPTOR.full_name, name)) + names.append(name) + # Check no other oneof field is parsed. + if field.containing_oneof is not None: + oneof_name = field.containing_oneof.name + if oneof_name in names: + raise ParseError('Message type "{0}" should not have multiple ' + '"{1}" oneof fields.'.format( + message.DESCRIPTOR.full_name, oneof_name)) + names.append(oneof_name) + + value = js[name] + if value is None: + message.ClearField(field.name) + continue + # Parse field value. + if _IsMapEntry(field): + message.ClearField(field.name) + self._ConvertMapFieldValue(value, message, field) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + message.ClearField(field.name) + if not isinstance(value, list): + raise ParseError('repeated field {0} must be in [] which is ' + '{1}.'.format(name, value)) + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated message field. + for item in value: + sub_message = getattr(message, field.name).add() + # None is a null_value in Value. + if (item is None and + sub_message.DESCRIPTOR.full_name != 'google.protobuf.Value'): + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') + self.ConvertMessage(item, sub_message) + else: + # Repeated scalar field. + for item in value: + if item is None: + raise ParseError('null is not allowed to be used as an element' + ' in a repeated field.') + getattr(message, field.name).append( + _ConvertScalarFieldValue(item, field)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + sub_message = getattr(message, field.name) + self.ConvertMessage(value, sub_message) + else: + setattr(message, field.name, _ConvertScalarFieldValue(value, field)) + except ParseError as e: + if field and field.containing_oneof is None: + raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + else: + raise ParseError(str(e)) + except ValueError as e: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + except TypeError as e: + raise ParseError('Failed to parse {0} field: {1}.'.format(name, e)) + + def _ConvertAnyMessage(self, value, message): + """Convert a JSON representation into Any message.""" + if isinstance(value, dict) and not value: + return + try: + type_url = value['@type'] + except KeyError: + raise ParseError('@type is missing when parsing any message.') + + sub_message = _CreateMessageFromTypeUrl(type_url) + message_descriptor = sub_message.DESCRIPTOR + full_name = message_descriptor.full_name + if _IsWrapperMessage(message_descriptor): + self._ConvertWrapperMessage(value['value'], sub_message) + elif full_name in _WKTJSONMETHODS: + methodcaller( + _WKTJSONMETHODS[full_name][1], value['value'], sub_message)(self) + else: + del value['@type'] + self._ConvertFieldValuePair(value, sub_message) + # Sets Any message + message.value = sub_message.SerializeToString() + message.type_url = type_url + + def _ConvertGenericMessage(self, value, message): + """Convert a JSON representation into message with FromJsonString.""" + # Durantion, Timestamp, FieldMask have FromJsonString method to do the + # convert. Users can also call the method directly. + message.FromJsonString(value) + + def _ConvertValueMessage(self, value, message): + """Convert a JSON representation into Value message.""" + if isinstance(value, dict): + self._ConvertStructMessage(value, message.struct_value) + elif isinstance(value, list): + self. _ConvertListValueMessage(value, message.list_value) + elif value is None: + message.null_value = 0 + elif isinstance(value, bool): + message.bool_value = value + elif isinstance(value, six.string_types): + message.string_value = value + elif isinstance(value, _INT_OR_FLOAT): + message.number_value = value + else: + raise ParseError('Unexpected type for Value message.') -_INT_OR_FLOAT = six.integer_types + (float,) + def _ConvertListValueMessage(self, value, message): + """Convert a JSON representation into ListValue message.""" + if not isinstance(value, list): + raise ParseError( + 'ListValue must be in [] which is {0}.'.format(value)) + message.ClearField('values') + for item in value: + self._ConvertValueMessage(item, message.values.add()) + + def _ConvertStructMessage(self, value, message): + """Convert a JSON representation into Struct message.""" + if not isinstance(value, dict): + raise ParseError( + 'Struct must be in a dict which is {0}.'.format(value)) + for key in value: + self._ConvertValueMessage(value[key], message.fields[key]) + return + def _ConvertWrapperMessage(self, value, message): + """Convert a JSON representation into Wrapper message.""" + field = message.DESCRIPTOR.fields_by_name['value'] + setattr(message, 'value', _ConvertScalarFieldValue(value, field)) -def _ConvertValueMessage(value, message): - """Convert a JSON representation into Value message.""" - if isinstance(value, dict): - _ConvertStructMessage(value, message.struct_value) - elif isinstance(value, list): - _ConvertListValueMessage(value, message.list_value) - elif value is None: - message.null_value = 0 - elif isinstance(value, bool): - message.bool_value = value - elif isinstance(value, six.string_types): - message.string_value = value - elif isinstance(value, _INT_OR_FLOAT): - message.number_value = value - else: - raise ParseError('Unexpected type for Value message.') - - -def _ConvertListValueMessage(value, message): - """Convert a JSON representation into ListValue message.""" - if not isinstance(value, list): - raise ParseError( - 'ListValue must be in [] which is {0}.'.format(value)) - message.ClearField('values') - for item in value: - _ConvertValueMessage(item, message.values.add()) - - -def _ConvertStructMessage(value, message): - """Convert a JSON representation into Struct message.""" - if not isinstance(value, dict): - raise ParseError( - 'Struct must be in a dict which is {0}.'.format(value)) - for key in value: - _ConvertValueMessage(value[key], message.fields[key]) - return - - -def _ConvertWrapperMessage(value, message): - """Convert a JSON representation into Wrapper message.""" - field = message.DESCRIPTOR.fields_by_name['value'] - setattr(message, 'value', _ConvertScalarFieldValue(value, field)) - - -def _ConvertMapFieldValue(value, message, field): - """Convert map field value for a message map field. + def _ConvertMapFieldValue(self, value, message, field): + """Convert map field value for a message map field. - Args: - value: A JSON object to convert the map field value. - message: A protocol message to record the converted data. - field: The descriptor of the map field to be converted. + Args: + value: A JSON object to convert the map field value. + message: A protocol message to record the converted data. + field: The descriptor of the map field to be converted. - Raises: - ParseError: In case of convert problems. - """ - if not isinstance(value, dict): - raise ParseError( - 'Map field {0} must be in a dict which is {1}.'.format( - field.name, value)) - key_field = field.message_type.fields_by_name['key'] - value_field = field.message_type.fields_by_name['value'] - for key in value: - key_value = _ConvertScalarFieldValue(key, key_field, True) - if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - _ConvertMessage(value[key], getattr(message, field.name)[key_value]) - else: - getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( - value[key], value_field) + Raises: + ParseError: In case of convert problems. + """ + if not isinstance(value, dict): + raise ParseError( + 'Map field {0} must be in a dict which is {1}.'.format( + field.name, value)) + key_field = field.message_type.fields_by_name['key'] + value_field = field.message_type.fields_by_name['value'] + for key in value: + key_value = _ConvertScalarFieldValue(key, key_field, True) + if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + self.ConvertMessage(value[key], getattr( + message, field.name)[key_value]) + else: + getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( + value[key], value_field) def _ConvertScalarFieldValue(value, field, require_str=False): @@ -641,18 +647,18 @@ def _ConvertBool(value, require_str): return value _WKTJSONMETHODS = { - 'google.protobuf.Any': [_AnyMessageToJsonObject, - _ConvertAnyMessage], - 'google.protobuf.Duration': [_GenericMessageToJsonObject, - _ConvertGenericMessage], - 'google.protobuf.FieldMask': [_GenericMessageToJsonObject, - _ConvertGenericMessage], - 'google.protobuf.ListValue': [_ListValueMessageToJsonObject, - _ConvertListValueMessage], - 'google.protobuf.Struct': [_StructMessageToJsonObject, - _ConvertStructMessage], - 'google.protobuf.Timestamp': [_GenericMessageToJsonObject, - _ConvertGenericMessage], - 'google.protobuf.Value': [_ValueMessageToJsonObject, - _ConvertValueMessage] + 'google.protobuf.Any': ['_AnyMessageToJsonObject', + '_ConvertAnyMessage'], + 'google.protobuf.Duration': ['_GenericMessageToJsonObject', + '_ConvertGenericMessage'], + 'google.protobuf.FieldMask': ['_GenericMessageToJsonObject', + '_ConvertGenericMessage'], + 'google.protobuf.ListValue': ['_ListValueMessageToJsonObject', + '_ConvertListValueMessage'], + 'google.protobuf.Struct': ['_StructMessageToJsonObject', + '_ConvertStructMessage'], + 'google.protobuf.Timestamp': ['_GenericMessageToJsonObject', + '_ConvertGenericMessage'], + 'google.protobuf.Value': ['_ValueMessageToJsonObject', + '_ConvertValueMessage'] } diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 23557538..e6ef5ef5 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -172,12 +172,16 @@ template<> const FileDescriptor* GetFileDescriptor(const OneofDescriptor* descriptor) { return descriptor->containing_type()->file(); } +template<> +const FileDescriptor* GetFileDescriptor(const MethodDescriptor* descriptor) { + return descriptor->service()->file(); +} // Converts options into a Python protobuf, and cache the result. // // This is a bit tricky because options can contain extension fields defined in // the same proto file. In this case the options parsed from the serialized_pb -// have unkown fields, and we need to parse them again. +// have unknown fields, and we need to parse them again. // // Always returns a new reference. template @@ -204,11 +208,12 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { cdescriptor_pool::GetMessageClass(pool, message_type)); if (message_class == NULL) { // The Options message was not found in the current DescriptorPool. - // In this case, there cannot be extensions to these options, and we can - // try to use the basic pool instead. + // This means that the pool cannot contain any extensions to the Options + // message either, so falling back to the basic pool we can only increase + // the chances of successfully parsing the options. PyErr_Clear(); - message_class = cdescriptor_pool::GetMessageClass( - GetDefaultDescriptorPool(), message_type); + pool = GetDefaultDescriptorPool(); + message_class = cdescriptor_pool::GetMessageClass(pool, message_type); } if (message_class == NULL) { PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s", @@ -248,7 +253,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { // Cache the result. Py_INCREF(value.get()); - (*pool->descriptor_options)[descriptor] = value.get(); + (*descriptor_options)[descriptor] = value.get(); return value.release(); } @@ -1091,7 +1096,7 @@ PyTypeObject PyEnumDescriptor_Type = { 0, // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext - enum_descriptor::Methods, // tp_getset + enum_descriptor::Methods, // tp_methods 0, // tp_members enum_descriptor::Getters, // tp_getset &descriptor::PyBaseDescriptor_Type, // tp_base @@ -1275,6 +1280,10 @@ static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) { return NewFileExtensionsByName(_GetDescriptor(self)); } +static PyObject* GetServicesByName(PyFileDescriptor* self, void *closure) { + return NewFileServicesByName(_GetDescriptor(self)); +} + static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) { return NewFileDependencies(_GetDescriptor(self)); } @@ -1324,6 +1333,7 @@ static PyGetSetDef Getters[] = { { "enum_types_by_name", (getter)GetEnumTypesByName, NULL, "Enums by name"}, { "extensions_by_name", (getter)GetExtensionsByName, NULL, "Extensions by name"}, + { "services_by_name", (getter)GetServicesByName, NULL, "Services by name"}, { "dependencies", (getter)GetDependencies, NULL, "Dependencies"}, { "public_dependencies", (getter)GetPublicDependencies, NULL, "Dependencies"}, @@ -1452,16 +1462,45 @@ static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { } } +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const OneofOptions& options(_GetDescriptor(self)->options()); + if (&options != &OneofOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + static PyGetSetDef Getters[] = { { "name", (getter)GetName, NULL, "Name"}, { "full_name", (getter)GetFullName, NULL, "Full name"}, { "index", (getter)GetIndex, NULL, "Index"}, { "containing_type", (getter)GetContainingType, NULL, "Containing type"}, + { "has_options", (getter)GetHasOptions, (setter)SetHasOptions, "Has Options"}, + { "_options", (getter)NULL, (setter)SetOptions, "Options"}, { "fields", (getter)GetFields, NULL, "Fields"}, {NULL} }; +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS }, + {NULL} +}; + } // namespace oneof_descriptor PyTypeObject PyOneofDescriptor_Type = { @@ -1492,7 +1531,7 @@ PyTypeObject PyOneofDescriptor_Type = { 0, // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext - 0, // tp_methods + oneof_descriptor::Methods, // tp_methods 0, // tp_members oneof_descriptor::Getters, // tp_getset &descriptor::PyBaseDescriptor_Type, // tp_base @@ -1504,6 +1543,222 @@ PyObject* PyOneofDescriptor_FromDescriptor( &PyOneofDescriptor_Type, oneof_descriptor, NULL); } +namespace service_descriptor { + +// Unchecked accessor to the C++ pointer. +static const ServiceDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetMethods(PyBaseDescriptor* self, void *closure) { + return NewServiceMethodsSeq(_GetDescriptor(self)); +} + +static PyObject* GetMethodsByName(PyBaseDescriptor* self, void *closure) { + return NewServiceMethodsByName(_GetDescriptor(self)); +} + +static PyObject* FindMethodByName(PyBaseDescriptor *self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const MethodDescriptor* method_descriptor = + _GetDescriptor(self)->FindMethodByName(string(name, name_size)); + if (method_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name); + return NULL; + } + + return PyMethodDescriptor_FromDescriptor(method_descriptor); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto(_GetDescriptor(self), + target); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Name", NULL}, + { "full_name", (getter)GetFullName, NULL, "Full name", NULL}, + { "index", (getter)GetIndex, NULL, "Index", NULL}, + + { "methods", (getter)GetMethods, NULL, "Methods", NULL}, + { "methods_by_name", (getter)GetMethodsByName, NULL, "Methods by name", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + { "FindMethodByName", (PyCFunction)FindMethodByName, METH_O }, + {NULL} +}; + +} // namespace service_descriptor + +PyTypeObject PyServiceDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".ServiceDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Service Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + service_descriptor::Methods, // tp_methods + 0, // tp_members + service_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyServiceDescriptor_FromDescriptor( + const ServiceDescriptor* service_descriptor) { + return descriptor::NewInternedDescriptor( + &PyServiceDescriptor_Type, service_descriptor, NULL); +} + +namespace method_descriptor { + +// Unchecked accessor to the C++ pointer. +static const MethodDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetContainingService(PyBaseDescriptor *self, void *closure) { + const ServiceDescriptor* containing_service = + _GetDescriptor(self)->service(); + return PyServiceDescriptor_FromDescriptor(containing_service); +} + +static PyObject* GetInputType(PyBaseDescriptor *self, void *closure) { + const Descriptor* input_type = _GetDescriptor(self)->input_type(); + return PyMessageDescriptor_FromDescriptor(input_type); +} + +static PyObject* GetOutputType(PyBaseDescriptor *self, void *closure) { + const Descriptor* output_type = _GetDescriptor(self)->output_type(); + return PyMessageDescriptor_FromDescriptor(output_type); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto(_GetDescriptor(self), target); +} + +static PyGetSetDef Getters[] = { + { "name", (getter)GetName, NULL, "Name", NULL}, + { "full_name", (getter)GetFullName, NULL, "Full name", NULL}, + { "index", (getter)GetIndex, NULL, "Index", NULL}, + { "containing_service", (getter)GetContainingService, NULL, + "Containing service", NULL}, + { "input_type", (getter)GetInputType, NULL, "Input type", NULL}, + { "output_type", (getter)GetOutputType, NULL, "Output type", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} +}; + +} // namespace method_descriptor + +PyTypeObject PyMethodDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + FULL_MODULE_NAME ".MethodDescriptor", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "A Method Descriptor", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + method_descriptor::Methods, // tp_methods + 0, // tp_members + method_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyMethodDescriptor_FromDescriptor( + const MethodDescriptor* method_descriptor) { + return descriptor::NewInternedDescriptor( + &PyMethodDescriptor_Type, method_descriptor, NULL); +} + // Add a enum values to a type dictionary. static bool AddEnumValues(PyTypeObject *type, const EnumDescriptor* enum_descriptor) { @@ -1573,6 +1828,12 @@ bool InitDescriptor() { if (PyType_Ready(&PyOneofDescriptor_Type) < 0) return false; + if (PyType_Ready(&PyServiceDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyMethodDescriptor_Type) < 0) + return false; + if (!InitDescriptorMappingTypes()) return false; diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h index eb99df18..1ae0e672 100644 --- a/python/google/protobuf/pyext/descriptor.h +++ b/python/google/protobuf/pyext/descriptor.h @@ -47,6 +47,8 @@ extern PyTypeObject PyEnumDescriptor_Type; extern PyTypeObject PyEnumValueDescriptor_Type; extern PyTypeObject PyFileDescriptor_Type; extern PyTypeObject PyOneofDescriptor_Type; +extern PyTypeObject PyServiceDescriptor_Type; +extern PyTypeObject PyMethodDescriptor_Type; // Wraps a Descriptor in a Python object. // The C++ pointer is usually borrowed from the global DescriptorPool. @@ -60,6 +62,10 @@ PyObject* PyEnumValueDescriptor_FromDescriptor( PyObject* PyOneofDescriptor_FromDescriptor(const OneofDescriptor* descriptor); PyObject* PyFileDescriptor_FromDescriptor( const FileDescriptor* file_descriptor); +PyObject* PyServiceDescriptor_FromDescriptor( + const ServiceDescriptor* descriptor); +PyObject* PyMethodDescriptor_FromDescriptor( + const MethodDescriptor* descriptor); // Alternate constructor of PyFileDescriptor, used when we already have a // serialized FileDescriptorProto that can be cached. diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc index e505d812..d0aae9c9 100644 --- a/python/google/protobuf/pyext/descriptor_containers.cc +++ b/python/google/protobuf/pyext/descriptor_containers.cc @@ -608,6 +608,24 @@ static PyObject* GetItem(PyContainer* self, Py_ssize_t index) { return _NewObj_ByIndex(self, index); } +static PyObject * +SeqSubscript(PyContainer* self, PyObject* item) { + if (PyIndex_Check(item)) { + Py_ssize_t index; + index = PyNumber_AsSsize_t(item, PyExc_IndexError); + if (index == -1 && PyErr_Occurred()) + return NULL; + return GetItem(self, index); + } + // Materialize the list and delegate the operation to it. + ScopedPyObjectPtr list(PyObject_CallFunctionObjArgs( + reinterpret_cast(&PyList_Type), self, NULL)); + if (list == NULL) { + return NULL; + } + return Py_TYPE(list.get())->tp_as_mapping->mp_subscript(list.get(), item); +} + // Returns the position of the item in the sequence, of -1 if not found. // This function never fails. int Find(PyContainer* self, PyObject* item) { @@ -703,14 +721,20 @@ static PyMethodDef SeqMethods[] = { }; static PySequenceMethods SeqSequenceMethods = { - (lenfunc)Length, // sq_length - 0, // sq_concat - 0, // sq_repeat - (ssizeargfunc)GetItem, // sq_item - 0, // sq_slice - 0, // sq_ass_item - 0, // sq_ass_slice - (objobjproc)SeqContains, // sq_contains + (lenfunc)Length, // sq_length + 0, // sq_concat + 0, // sq_repeat + (ssizeargfunc)GetItem, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)SeqContains, // sq_contains +}; + +static PyMappingMethods SeqMappingMethods = { + (lenfunc)Length, // mp_length + (binaryfunc)SeqSubscript, // mp_subscript + 0, // mp_ass_subscript }; PyTypeObject DescriptorSequence_Type = { @@ -726,7 +750,7 @@ PyTypeObject DescriptorSequence_Type = { (reprfunc)ContainerRepr, // tp_repr 0, // tp_as_number &SeqSequenceMethods, // tp_as_sequence - 0, // tp_as_mapping + &SeqMappingMethods, // tp_as_mapping 0, // tp_hash 0, // tp_call 0, // tp_str @@ -1407,6 +1431,68 @@ PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) { } // namespace oneof_descriptor +namespace service_descriptor { + +typedef const ServiceDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast(self->descriptor); +} + +namespace methods { + +typedef const MethodDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->method_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindMethodByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->method(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyMethodDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "ServiceMethods", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace methods + +PyObject* NewServiceMethodsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&methods::ContainerDef, descriptor); +} + +PyObject* NewServiceMethodsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&methods::ContainerDef, descriptor); +} + +} // namespace service_descriptor + namespace file_descriptor { typedef const FileDescriptor* ParentDescriptor; @@ -1459,7 +1545,7 @@ static DescriptorContainerDef ContainerDef = { } // namespace messages -PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor) { +PyObject* NewFileMessageTypesByName(ParentDescriptor descriptor) { return descriptor::NewMappingByName(&messages::ContainerDef, descriptor); } @@ -1507,7 +1593,7 @@ static DescriptorContainerDef ContainerDef = { } // namespace enums -PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor) { +PyObject* NewFileEnumTypesByName(ParentDescriptor descriptor) { return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); } @@ -1555,10 +1641,58 @@ static DescriptorContainerDef ContainerDef = { } // namespace extensions -PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor) { +PyObject* NewFileExtensionsByName(ParentDescriptor descriptor) { return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); } +namespace services { + +typedef const ServiceDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->service_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindServiceByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->service(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyServiceDescriptor_FromDescriptor(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileServices", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace services + +PyObject* NewFileServicesByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&services::ContainerDef, descriptor); +} + namespace dependencies { typedef const FileDescriptor* ItemDescriptor; diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h index ce40747d..83de07b6 100644 --- a/python/google/protobuf/pyext/descriptor_containers.h +++ b/python/google/protobuf/pyext/descriptor_containers.h @@ -43,6 +43,7 @@ class Descriptor; class FileDescriptor; class EnumDescriptor; class OneofDescriptor; +class ServiceDescriptor; namespace python { @@ -89,10 +90,17 @@ PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor); PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor); +PyObject* NewFileServicesByName(const FileDescriptor* descriptor); + PyObject* NewFileDependencies(const FileDescriptor* descriptor); PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor); } // namespace file_descriptor +namespace service_descriptor { +PyObject* NewServiceMethodsSeq(const ServiceDescriptor* descriptor); +PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor); +} // namespace service_descriptor + } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index 1faff96b..cfd98690 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -305,6 +305,40 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { return PyOneofDescriptor_FromDescriptor(oneof_descriptor); } +PyObject* FindServiceByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const ServiceDescriptor* service_descriptor = + self->pool->FindServiceByName(string(name, name_size)); + if (service_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find service %.200s", name); + return NULL; + } + + return PyServiceDescriptor_FromDescriptor(service_descriptor); +} + +PyObject* FindMethodByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const MethodDescriptor* method_descriptor = + self->pool->FindMethodByName(string(name, name_size)); + if (method_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find method %.200s", name); + return NULL; + } + + return PyMethodDescriptor_FromDescriptor(method_descriptor); +} + PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) { Py_ssize_t name_size; char* name; @@ -491,6 +525,10 @@ static PyMethodDef Methods[] = { "Searches for enum type descriptor by full name." }, { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O, "Searches for oneof descriptor by full name." }, + { "FindServiceByName", (PyCFunction)FindServiceByName, METH_O, + "Searches for service descriptor by full name." }, + { "FindMethodByName", (PyCFunction)FindMethodByName, METH_O, + "Searches for method descriptor by full name." }, { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O, "Gets the FileDescriptor containing the specified symbol." }, diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index e022406d..90438df1 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -39,7 +39,6 @@ #include #include -#include #include #include #include diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 83c151ff..a9261f20 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -1593,23 +1593,20 @@ struct ReleaseChild : public ChildVisitor { parent_(parent) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { - return repeated_composite_container::Release( - reinterpret_cast(container)); + return repeated_composite_container::Release(container); } int VisitRepeatedScalarContainer(RepeatedScalarContainer* container) { - return repeated_scalar_container::Release( - reinterpret_cast(container)); + return repeated_scalar_container::Release(container); } int VisitMapContainer(MapContainer* container) { - return reinterpret_cast(container)->Release(); + return container->Release(); } int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { - return ReleaseSubMessage(parent_, field_descriptor, - reinterpret_cast(cmessage)); + return ReleaseSubMessage(parent_, field_descriptor, cmessage); } CMessage* parent_; @@ -1903,7 +1900,7 @@ static bool allow_oversize_protos = false; // Provide a method in the module to set allow_oversize_protos to a boolean // value. This method returns the newly value of allow_oversize_protos. -static PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { +PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg) { if (!arg || !PyBool_Check(arg)) { PyErr_SetString(PyExc_TypeError, "Argument to SetAllowOversizeProtos must be boolean"); @@ -3044,6 +3041,10 @@ bool InitProto2MessageModule(PyObject *m) { &PyFileDescriptor_Type)); PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast( &PyOneofDescriptor_Type)); + PyModule_AddObject(m, "ServiceDescriptor", reinterpret_cast( + &PyServiceDescriptor_Type)); + PyModule_AddObject(m, "MethodDescriptor", reinterpret_cast( + &PyMethodDescriptor_Type)); PyObject* enum_type_wrapper = PyImport_ImportModule( "google.protobuf.internal.enum_type_wrapper"); @@ -3081,53 +3082,4 @@ bool InitProto2MessageModule(PyObject *m) { } // namespace python } // namespace protobuf -static PyMethodDef ModuleMethods[] = { - {"SetAllowOversizeProtos", - (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, - METH_O, "Enable/disable oversize proto parsing."}, - { NULL, NULL} -}; - -#if PY_MAJOR_VERSION >= 3 -static struct PyModuleDef _module = { - PyModuleDef_HEAD_INIT, - "_message", - google::protobuf::python::module_docstring, - -1, - ModuleMethods, /* m_methods */ - NULL, - NULL, - NULL, - NULL -}; -#define INITFUNC PyInit__message -#define INITFUNC_ERRORVAL NULL -#else // Python 2 -#define INITFUNC init_message -#define INITFUNC_ERRORVAL -#endif - -extern "C" { - PyMODINIT_FUNC INITFUNC(void) { - PyObject* m; -#if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&_module); -#else - m = Py_InitModule3("_message", ModuleMethods, - google::protobuf::python::module_docstring); -#endif - if (m == NULL) { - return INITFUNC_ERRORVAL; - } - - if (!google::protobuf::python::InitProto2MessageModule(m)) { - Py_DECREF(m); - return INITFUNC_ERRORVAL; - } - -#if PY_MAJOR_VERSION >= 3 - return m; -#endif - } -} } // namespace google diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 3a4bec81..8b399e05 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -54,7 +54,7 @@ class MessageFactory; #ifdef _SHARED_PTR_H using std::shared_ptr; -using ::std::string; +using std::string; #else using internal::shared_ptr; #endif @@ -269,6 +269,8 @@ int AssureWritable(CMessage* self); // even in the case of extensions. PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message); +PyObject* SetAllowOversizeProtos(PyObject* m, PyObject* arg); + } // namespace cmessage @@ -354,6 +356,8 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, extern PyObject* PickleError_class; +bool InitProto2MessageModule(PyObject *m); + } // namespace python } // namespace protobuf diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc new file mode 100644 index 00000000..d90d9de3 --- /dev/null +++ b/python/google/protobuf/pyext/message_module.cc @@ -0,0 +1,88 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +static const char module_docstring[] = +"python-proto2 is a module that can be used to enhance proto2 Python API\n" +"performance.\n" +"\n" +"It provides access to the protocol buffers C++ reflection API that\n" +"implements the basic protocol buffer functions."; + +static PyMethodDef ModuleMethods[] = { + {"SetAllowOversizeProtos", + (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, + METH_O, "Enable/disable oversize proto parsing."}, + { NULL, NULL} +}; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + "_message", + module_docstring, + -1, + ModuleMethods, /* m_methods */ + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__message +#define INITFUNC_ERRORVAL NULL +#else // Python 2 +#define INITFUNC init_message +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC(void) { + PyObject* m; +#if PY_MAJOR_VERSION >= 3 + m = PyModule_Create(&_module); +#else + m = Py_InitModule3("_message", ModuleMethods, + module_docstring); +#endif + if (m == NULL) { + return INITFUNC_ERRORVAL; + } + + if (!google::protobuf::python::InitProto2MessageModule(m)) { + Py_DECREF(m); + return INITFUNC_ERRORVAL; + } + +#if PY_MAJOR_VERSION >= 3 + return m; +#endif + } +} diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 6f1e3c8b..c4b23c37 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -48,15 +48,15 @@ import re import six if six.PY3: - long = int + long = int # pylint: disable=redefined-builtin,invalid-name +# pylint: disable=g-import-not-at-top from google.protobuf.internal import type_checkers from google.protobuf import descriptor from google.protobuf import text_encoding -__all__ = ['MessageToString', 'PrintMessage', 'PrintField', - 'PrintFieldValue', 'Merge'] - +__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue', + 'Merge'] _INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), type_checkers.Int32ValueChecker(), @@ -67,6 +67,7 @@ _FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) _QUOTES = frozenset(("'", '"')) +_ANY_FULL_TYPE_NAME = 'google.protobuf.Any' class Error(Exception): @@ -74,10 +75,30 @@ class Error(Exception): class ParseError(Error): - """Thrown in case of text parsing error.""" + """Thrown in case of text parsing or tokenizing error.""" + + def __init__(self, message=None, line=None, column=None): + if message is not None and line is not None: + loc = str(line) + if column is not None: + loc += ':{}'.format(column) + message = '{} : {}'.format(loc, message) + if message is not None: + super(ParseError, self).__init__(message) + else: + super(ParseError, self).__init__() + self._line = line + self._column = column + + def GetLine(self): + return self._line + + def GetColumn(self): + return self._column class TextWriter(object): + def __init__(self, as_utf8): if six.PY2: self._writer = io.BytesIO() @@ -97,9 +118,15 @@ class TextWriter(object): return self._writer.getvalue() -def MessageToString(message, as_utf8=False, as_one_line=False, - pointy_brackets=False, use_index_order=False, - float_format=None, use_field_number=False): +def MessageToString(message, + as_utf8=False, + as_one_line=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + use_field_number=False, + descriptor_pool=None, + indent=0): """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -119,14 +146,16 @@ def MessageToString(message, as_utf8=False, as_one_line=False, float_format: If set, use this to specify floating point number formatting (per the "Format Specification Mini-Language"); otherwise, str() is used. use_field_number: If True, print field numbers instead of names. + descriptor_pool: A DescriptorPool used to resolve Any types. + indent: The indent level, in terms of spaces, for pretty print. Returns: A string of the text formatted protocol buffer message. """ out = TextWriter(as_utf8) - printer = _Printer(out, 0, as_utf8, as_one_line, - pointy_brackets, use_index_order, float_format, - use_field_number) + printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + use_index_order, float_format, use_field_number, + descriptor_pool) printer.PrintMessage(message) result = out.getvalue() out.close() @@ -141,39 +170,87 @@ def _IsMapEntry(field): field.message_type.GetOptions().map_entry) -def PrintMessage(message, out, indent=0, as_utf8=False, as_one_line=False, - pointy_brackets=False, use_index_order=False, - float_format=None, use_field_number=False): - printer = _Printer(out, indent, as_utf8, as_one_line, - pointy_brackets, use_index_order, float_format, - use_field_number) +def PrintMessage(message, + out, + indent=0, + as_utf8=False, + as_one_line=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + use_field_number=False, + descriptor_pool=None): + printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + use_index_order, float_format, use_field_number, + descriptor_pool) printer.PrintMessage(message) -def PrintField(field, value, out, indent=0, as_utf8=False, as_one_line=False, - pointy_brackets=False, use_index_order=False, float_format=None): +def PrintField(field, + value, + out, + indent=0, + as_utf8=False, + as_one_line=False, + pointy_brackets=False, + use_index_order=False, + float_format=None): """Print a single field name/value pair.""" - printer = _Printer(out, indent, as_utf8, as_one_line, - pointy_brackets, use_index_order, float_format) + printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + use_index_order, float_format) printer.PrintField(field, value) -def PrintFieldValue(field, value, out, indent=0, as_utf8=False, - as_one_line=False, pointy_brackets=False, +def PrintFieldValue(field, + value, + out, + indent=0, + as_utf8=False, + as_one_line=False, + pointy_brackets=False, use_index_order=False, float_format=None): """Print a single field value (not including name).""" - printer = _Printer(out, indent, as_utf8, as_one_line, - pointy_brackets, use_index_order, float_format) + printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + use_index_order, float_format) printer.PrintFieldValue(field, value) +def _BuildMessageFromTypeName(type_name, descriptor_pool): + """Returns a protobuf message instance. + + Args: + type_name: Fully-qualified protobuf message type name string. + descriptor_pool: DescriptorPool instance. + + Returns: + A Message instance of type matching type_name, or None if the a Descriptor + wasn't found matching type_name. + """ + # pylint: disable=g-import-not-at-top + from google.protobuf import message_factory + factory = message_factory.MessageFactory(descriptor_pool) + try: + message_descriptor = descriptor_pool.FindMessageTypeByName(type_name) + except KeyError: + return None + message_type = factory.GetPrototype(message_descriptor) + return message_type() + + class _Printer(object): """Text format printer for protocol message.""" - def __init__(self, out, indent=0, as_utf8=False, as_one_line=False, - pointy_brackets=False, use_index_order=False, float_format=None, - use_field_number=False): + def __init__(self, + out, + indent=0, + as_utf8=False, + as_one_line=False, + pointy_brackets=False, + use_index_order=False, + float_format=None, + use_field_number=False, + descriptor_pool=None): """Initialize the Printer. Floating point values can be formatted compactly with 15 digits of @@ -195,6 +272,7 @@ class _Printer(object): (per the "Format Specification Mini-Language"); otherwise, str() is used. use_field_number: If True, print field numbers instead of names. + descriptor_pool: A DescriptorPool used to resolve Any types. """ self.out = out self.indent = indent @@ -204,6 +282,20 @@ class _Printer(object): self.use_index_order = use_index_order self.float_format = float_format self.use_field_number = use_field_number + self.descriptor_pool = descriptor_pool + + def _TryPrintAsAnyMessage(self, message): + """Serializes if message is a google.protobuf.Any field.""" + packed_message = _BuildMessageFromTypeName(message.TypeName(), + self.descriptor_pool) + if packed_message: + packed_message.MergeFromString(message.value) + self.out.write('%s[%s]' % (self.indent * ' ', message.type_url)) + self._PrintMessageFieldValue(packed_message) + self.out.write(' ' if self.as_one_line else '\n') + return True + else: + return False def PrintMessage(self, message): """Convert protobuf message to text format. @@ -211,6 +303,9 @@ class _Printer(object): Args: message: The protocol buffers message. """ + if (message.DESCRIPTOR.full_name == _ANY_FULL_TYPE_NAME and + self.descriptor_pool and self._TryPrintAsAnyMessage(message)): + return fields = message.ListFields() if self.use_index_order: fields.sort(key=lambda x: x[0].index) @@ -222,8 +317,8 @@ class _Printer(object): # of this file to work around. # # TODO(haberman): refactor and optimize if this becomes an issue. - entry_submsg = field.message_type._concrete_class( - key=key, value=value[key]) + entry_submsg = field.message_type._concrete_class(key=key, + value=value[key]) self.PrintField(field, entry_submsg) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: for element in value: @@ -264,6 +359,25 @@ class _Printer(object): else: out.write('\n') + def _PrintMessageFieldValue(self, value): + if self.pointy_brackets: + openb = '<' + closeb = '>' + else: + openb = '{' + closeb = '}' + + if self.as_one_line: + self.out.write(' %s ' % openb) + self.PrintMessage(value) + self.out.write(closeb) + else: + self.out.write(' %s\n' % openb) + self.indent += 2 + self.PrintMessage(value) + self.indent -= 2 + self.out.write(' ' * self.indent + closeb) + def PrintFieldValue(self, field, value): """Print a single field value (not including name). @@ -274,24 +388,8 @@ class _Printer(object): value: The value of the field. """ out = self.out - if self.pointy_brackets: - openb = '<' - closeb = '>' - else: - openb = '{' - closeb = '}' - if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: - if self.as_one_line: - out.write(' %s ' % openb) - self.PrintMessage(value) - out.write(closeb) - else: - out.write(' %s\n' % openb) - self.indent += 2 - self.PrintMessage(value) - self.indent -= 2 - out.write(' ' * self.indent + closeb) + self._PrintMessageFieldValue(value) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: enum_value = field.enum_type.values_by_number.get(value, None) if enum_value is not None: @@ -322,9 +420,11 @@ class _Printer(object): out.write(str(value)) -def Parse(text, message, - allow_unknown_extension=False, allow_field_number=False): - """Parses an text representation of a protocol message into a message. +def Parse(text, + message, + allow_unknown_extension=False, + allow_field_number=False): + """Parses a text representation of a protocol message into a message. Args: text: Message text representation. @@ -341,13 +441,16 @@ def Parse(text, message, """ if not isinstance(text, str): text = text.decode('utf-8') - return ParseLines(text.split('\n'), message, allow_unknown_extension, - allow_field_number) + return ParseLines( + text.split('\n'), message, allow_unknown_extension, allow_field_number) -def Merge(text, message, allow_unknown_extension=False, - allow_field_number=False): - """Parses an text representation of a protocol message into a message. +def Merge(text, + message, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None): + """Parses a text representation of a protocol message into a message. Like Parse(), but allows repeated values for a non-repeated field, and uses the last one. @@ -358,6 +461,7 @@ def Merge(text, message, allow_unknown_extension=False, allow_unknown_extension: if True, skip over missing extensions and keep parsing allow_field_number: if True, both field number and field name are allowed. + descriptor_pool: A DescriptorPool used to resolve Any types. Returns: The same message passed as argument. @@ -365,13 +469,19 @@ def Merge(text, message, allow_unknown_extension=False, Raises: ParseError: On text parsing problems. """ - return MergeLines(text.split('\n'), message, allow_unknown_extension, - allow_field_number) + return MergeLines( + text.split('\n'), + message, + allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool) -def ParseLines(lines, message, allow_unknown_extension=False, +def ParseLines(lines, + message, + allow_unknown_extension=False, allow_field_number=False): - """Parses an text representation of a protocol message into a message. + """Parses a text representation of a protocol message into a message. Args: lines: An iterable of lines of a message's text representation. @@ -379,6 +489,7 @@ def ParseLines(lines, message, allow_unknown_extension=False, allow_unknown_extension: if True, skip over missing extensions and keep parsing allow_field_number: if True, both field number and field name are allowed. + descriptor_pool: A DescriptorPool used to resolve Any types. Returns: The same message passed as argument. @@ -390,9 +501,12 @@ def ParseLines(lines, message, allow_unknown_extension=False, return parser.ParseLines(lines, message) -def MergeLines(lines, message, allow_unknown_extension=False, - allow_field_number=False): - """Parses an text representation of a protocol message into a message. +def MergeLines(lines, + message, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None): + """Parses a text representation of a protocol message into a message. Args: lines: An iterable of lines of a message's text representation. @@ -407,41 +521,47 @@ def MergeLines(lines, message, allow_unknown_extension=False, Raises: ParseError: On text parsing problems. """ - parser = _Parser(allow_unknown_extension, allow_field_number) + parser = _Parser(allow_unknown_extension, + allow_field_number, + descriptor_pool=descriptor_pool) return parser.MergeLines(lines, message) class _Parser(object): """Text format parser for protocol message.""" - def __init__(self, allow_unknown_extension=False, allow_field_number=False): + def __init__(self, + allow_unknown_extension=False, + allow_field_number=False, + descriptor_pool=None): self.allow_unknown_extension = allow_unknown_extension self.allow_field_number = allow_field_number + self.descriptor_pool = descriptor_pool def ParseFromString(self, text, message): - """Parses an text representation of a protocol message into a message.""" + """Parses a text representation of a protocol message into a message.""" if not isinstance(text, str): text = text.decode('utf-8') return self.ParseLines(text.split('\n'), message) def ParseLines(self, lines, message): - """Parses an text representation of a protocol message into a message.""" + """Parses a text representation of a protocol message into a message.""" self._allow_multiple_scalars = False self._ParseOrMerge(lines, message) return message def MergeFromString(self, text, message): - """Merges an text representation of a protocol message into a message.""" + """Merges a text representation of a protocol message into a message.""" return self._MergeLines(text.split('\n'), message) def MergeLines(self, lines, message): - """Merges an text representation of a protocol message into a message.""" + """Merges a text representation of a protocol message into a message.""" self._allow_multiple_scalars = True self._ParseOrMerge(lines, message) return message def _ParseOrMerge(self, lines, message): - """Converts an text representation of a protocol message into a message. + """Converts a text representation of a protocol message into a message. Args: lines: Lines of a message's text representation. @@ -450,7 +570,7 @@ class _Parser(object): Raises: ParseError: On text parsing problems. """ - tokenizer = _Tokenizer(lines) + tokenizer = Tokenizer(lines) while not tokenizer.AtEnd(): self._MergeField(tokenizer, message) @@ -491,13 +611,13 @@ class _Parser(object): 'Extension "%s" not registered.' % name) elif message_descriptor != field.containing_type: raise tokenizer.ParseErrorPreviousToken( - 'Extension "%s" does not extend message type "%s".' % ( - name, message_descriptor.full_name)) + 'Extension "%s" does not extend message type "%s".' % + (name, message_descriptor.full_name)) tokenizer.Consume(']') else: - name = tokenizer.ConsumeIdentifier() + name = tokenizer.ConsumeIdentifierOrNumber() if self.allow_field_number and name.isdigit(): number = ParseInteger(name, True, True) field = message_descriptor.fields_by_number.get(number, None) @@ -520,8 +640,8 @@ class _Parser(object): if not field: raise tokenizer.ParseErrorPreviousToken( - 'Message type "%s" has no field named "%s".' % ( - message_descriptor.full_name, name)) + 'Message type "%s" has no field named "%s".' % + (message_descriptor.full_name, name)) if field: if not self._allow_multiple_scalars and field.containing_oneof: @@ -532,9 +652,9 @@ class _Parser(object): if which_oneof is not None and which_oneof != field.name: raise tokenizer.ParseErrorPreviousToken( 'Field "%s" is specified along with field "%s", another member ' - 'of oneof "%s" for message type "%s".' % ( - field.name, which_oneof, field.containing_oneof.name, - message_descriptor.full_name)) + 'of oneof "%s" for message type "%s".' % + (field.name, which_oneof, field.containing_oneof.name, + message_descriptor.full_name)) if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: tokenizer.TryConsume(':') @@ -543,12 +663,13 @@ class _Parser(object): tokenizer.Consume(':') merger = self._MergeScalarField - if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED - and tokenizer.TryConsume('[')): + if (field.label == descriptor.FieldDescriptor.LABEL_REPEATED and + tokenizer.TryConsume('[')): # Short repeated format, e.g. "foo: [1, 2, 3]" while True: merger(tokenizer, message, field) - if tokenizer.TryConsume(']'): break + if tokenizer.TryConsume(']'): + break tokenizer.Consume(',') else: @@ -563,6 +684,21 @@ class _Parser(object): if not tokenizer.TryConsume(','): tokenizer.TryConsume(';') + def _ConsumeAnyTypeUrl(self, tokenizer): + """Consumes a google.protobuf.Any type URL and returns the type name.""" + # Consume "type.googleapis.com/". + tokenizer.ConsumeIdentifier() + tokenizer.Consume('.') + tokenizer.ConsumeIdentifier() + tokenizer.Consume('.') + tokenizer.ConsumeIdentifier() + tokenizer.Consume('/') + # Consume the fully-qualified type name. + name = [tokenizer.ConsumeIdentifier()] + while tokenizer.TryConsume('.'): + name.append(tokenizer.ConsumeIdentifier()) + return '.'.join(name) + def _MergeMessageField(self, tokenizer, message, field): """Merges a single scalar field into a message. @@ -582,7 +718,34 @@ class _Parser(object): tokenizer.Consume('{') end_token = '}' - if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + if (field.message_type.full_name == _ANY_FULL_TYPE_NAME and + tokenizer.TryConsume('[')): + packed_type_name = self._ConsumeAnyTypeUrl(tokenizer) + tokenizer.Consume(']') + tokenizer.TryConsume(':') + if tokenizer.TryConsume('<'): + expanded_any_end_token = '>' + else: + tokenizer.Consume('{') + expanded_any_end_token = '}' + if not self.descriptor_pool: + raise ParseError('Descriptor pool required to parse expanded Any field') + expanded_any_sub_message = _BuildMessageFromTypeName(packed_type_name, + self.descriptor_pool) + if not expanded_any_sub_message: + raise ParseError('Type %s not found in descriptor pool' % + packed_type_name) + while not tokenizer.TryConsume(expanded_any_end_token): + if tokenizer.AtEnd(): + raise tokenizer.ParseErrorPreviousToken('Expected "%s".' % + (expanded_any_end_token,)) + self._MergeField(tokenizer, expanded_any_sub_message) + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + any_message = getattr(message, field.name).add() + else: + any_message = getattr(message, field.name) + any_message.Pack(expanded_any_sub_message) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: if field.is_extension: sub_message = message.Extensions[field].add() elif is_map_entry: @@ -628,17 +791,17 @@ class _Parser(object): if field.type in (descriptor.FieldDescriptor.TYPE_INT32, descriptor.FieldDescriptor.TYPE_SINT32, descriptor.FieldDescriptor.TYPE_SFIXED32): - value = tokenizer.ConsumeInt32() + value = _ConsumeInt32(tokenizer) elif field.type in (descriptor.FieldDescriptor.TYPE_INT64, descriptor.FieldDescriptor.TYPE_SINT64, descriptor.FieldDescriptor.TYPE_SFIXED64): - value = tokenizer.ConsumeInt64() + value = _ConsumeInt64(tokenizer) elif field.type in (descriptor.FieldDescriptor.TYPE_UINT32, descriptor.FieldDescriptor.TYPE_FIXED32): - value = tokenizer.ConsumeUint32() + value = _ConsumeUint32(tokenizer) elif field.type in (descriptor.FieldDescriptor.TYPE_UINT64, descriptor.FieldDescriptor.TYPE_FIXED64): - value = tokenizer.ConsumeUint64() + value = _ConsumeUint64(tokenizer) elif field.type in (descriptor.FieldDescriptor.TYPE_FLOAT, descriptor.FieldDescriptor.TYPE_DOUBLE): value = tokenizer.ConsumeFloat() @@ -753,13 +916,12 @@ def _SkipFieldValue(tokenizer): return if (not tokenizer.TryConsumeIdentifier() and - not tokenizer.TryConsumeInt64() and - not tokenizer.TryConsumeUint64() and + not _TryConsumeInt64(tokenizer) and not _TryConsumeUint64(tokenizer) and not tokenizer.TryConsumeFloat()): raise ParseError('Invalid field value: ' + tokenizer.token) -class _Tokenizer(object): +class Tokenizer(object): """Protocol buffer text representation tokenizer. This class handles the lower level string parsing by splitting it into @@ -768,17 +930,20 @@ class _Tokenizer(object): It was directly ported from the Java protocol buffer API. """ - _WHITESPACE = re.compile('(\\s|(#.*$))+', re.MULTILINE) + _WHITESPACE = re.compile(r'\s+') + _COMMENT = re.compile(r'(\s*#.*$)', re.MULTILINE) + _WHITESPACE_OR_COMMENT = re.compile(r'(\s|(#.*$))+', re.MULTILINE) _TOKEN = re.compile('|'.join([ - r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier + r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number - ] + [ # quoted str for each quote mark + ] + [ # quoted str for each quote mark r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES ])) - _IDENTIFIER = re.compile(r'\w+') + _IDENTIFIER = re.compile(r'[^\d\W]\w*') + _IDENTIFIER_OR_NUMBER = re.compile(r'\w+') - def __init__(self, lines): + def __init__(self, lines, skip_comments=True): self._position = 0 self._line = -1 self._column = 0 @@ -789,6 +954,9 @@ class _Tokenizer(object): self._previous_line = 0 self._previous_column = 0 self._more_lines = True + self._skip_comments = skip_comments + self._whitespace_pattern = (skip_comments and self._WHITESPACE_OR_COMMENT + or self._WHITESPACE) self._SkipWhitespace() self.NextToken() @@ -818,7 +986,7 @@ class _Tokenizer(object): def _SkipWhitespace(self): while True: self._PopLine() - match = self._WHITESPACE.match(self._current_line, self._column) + match = self._whitespace_pattern.match(self._current_line, self._column) if not match: break length = len(match.group(0)) @@ -848,7 +1016,14 @@ class _Tokenizer(object): ParseError: If the text couldn't be consumed. """ if not self.TryConsume(token): - raise self._ParseError('Expected "%s".' % token) + raise self.ParseError('Expected "%s".' % token) + + def ConsumeComment(self): + result = self.token + if not self._COMMENT.match(result): + raise self.ParseError('Expected comment.') + self.NextToken() + return result def TryConsumeIdentifier(self): try: @@ -868,85 +1043,55 @@ class _Tokenizer(object): """ result = self.token if not self._IDENTIFIER.match(result): - raise self._ParseError('Expected identifier.') + raise self.ParseError('Expected identifier.') self.NextToken() return result - def ConsumeInt32(self): - """Consumes a signed 32bit integer number. - - Returns: - The integer parsed. - - Raises: - ParseError: If a signed 32bit integer couldn't be consumed. - """ + def TryConsumeIdentifierOrNumber(self): try: - result = ParseInteger(self.token, is_signed=True, is_long=False) - except ValueError as e: - raise self._ParseError(str(e)) - self.NextToken() - return result - - def ConsumeUint32(self): - """Consumes an unsigned 32bit integer number. - - Returns: - The integer parsed. - - Raises: - ParseError: If an unsigned 32bit integer couldn't be consumed. - """ - try: - result = ParseInteger(self.token, is_signed=False, is_long=False) - except ValueError as e: - raise self._ParseError(str(e)) - self.NextToken() - return result - - def TryConsumeInt64(self): - try: - self.ConsumeInt64() + self.ConsumeIdentifierOrNumber() return True except ParseError: return False - def ConsumeInt64(self): - """Consumes a signed 64bit integer number. + def ConsumeIdentifierOrNumber(self): + """Consumes protocol message field identifier. Returns: - The integer parsed. + Identifier string. Raises: - ParseError: If a signed 64bit integer couldn't be consumed. + ParseError: If an identifier couldn't be consumed. """ - try: - result = ParseInteger(self.token, is_signed=True, is_long=True) - except ValueError as e: - raise self._ParseError(str(e)) + result = self.token + if not self._IDENTIFIER_OR_NUMBER.match(result): + raise self.ParseError('Expected identifier or number.') self.NextToken() return result - def TryConsumeUint64(self): + def TryConsumeInteger(self): try: - self.ConsumeUint64() + # Note: is_long only affects value type, not whether an error is raised. + self.ConsumeInteger() return True except ParseError: return False - def ConsumeUint64(self): - """Consumes an unsigned 64bit integer number. + def ConsumeInteger(self, is_long=False): + """Consumes an integer number. + Args: + is_long: True if the value should be returned as a long integer. Returns: The integer parsed. Raises: - ParseError: If an unsigned 64bit integer couldn't be consumed. + ParseError: If an integer couldn't be consumed. """ try: - result = ParseInteger(self.token, is_signed=False, is_long=True) + result = _ParseAbstractInteger(self.token, is_long=is_long) except ValueError as e: - raise self._ParseError(str(e)) + raise self.ParseError(str(e)) self.NextToken() return result @@ -969,7 +1114,7 @@ class _Tokenizer(object): try: result = ParseFloat(self.token) except ValueError as e: - raise self._ParseError(str(e)) + raise self.ParseError(str(e)) self.NextToken() return result @@ -985,7 +1130,7 @@ class _Tokenizer(object): try: result = ParseBool(self.token) except ValueError as e: - raise self._ParseError(str(e)) + raise self.ParseError(str(e)) self.NextToken() return result @@ -1039,15 +1184,15 @@ class _Tokenizer(object): """ text = self.token if len(text) < 1 or text[0] not in _QUOTES: - raise self._ParseError('Expected string but found: %r' % (text,)) + raise self.ParseError('Expected string but found: %r' % (text,)) if len(text) < 2 or text[-1] != text[0]: - raise self._ParseError('String missing ending quote: %r' % (text,)) + raise self.ParseError('String missing ending quote: %r' % (text,)) try: result = text_encoding.CUnescape(text[1:-1]) except ValueError as e: - raise self._ParseError(str(e)) + raise self.ParseError(str(e)) self.NextToken() return result @@ -1055,7 +1200,7 @@ class _Tokenizer(object): try: result = ParseEnum(field, self.token) except ValueError as e: - raise self._ParseError(str(e)) + raise self.ParseError(str(e)) self.NextToken() return result @@ -1068,16 +1213,15 @@ class _Tokenizer(object): Returns: A ParseError instance. """ - return ParseError('%d:%d : %s' % ( - self._previous_line + 1, self._previous_column + 1, message)) + return ParseError(message, self._previous_line + 1, + self._previous_column + 1) - def _ParseError(self, message): + def ParseError(self, message): """Creates and *returns* a ParseError for the current token.""" - return ParseError('%d:%d : %s' % ( - self._line + 1, self._column + 1, message)) + return ParseError(message, self._line + 1, self._column + 1) def _StringParseError(self, e): - return self._ParseError('Couldn\'t parse string: ' + str(e)) + return self.ParseError('Couldn\'t parse string: ' + str(e)) def NextToken(self): """Reads the next meaningful token.""" @@ -1092,12 +1236,124 @@ class _Tokenizer(object): return match = self._TOKEN.match(self._current_line, self._column) + if not match and not self._skip_comments: + match = self._COMMENT.match(self._current_line, self._column) if match: token = match.group(0) self.token = token else: self.token = self._current_line[self._column] +# Aliased so it can still be accessed by current visibility violators. +# TODO(dbarnett): Migrate violators to textformat_tokenizer. +_Tokenizer = Tokenizer # pylint: disable=invalid-name + + +def _ConsumeInt32(tokenizer): + """Consumes a signed 32bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 32bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=True, is_long=False) + + +def _ConsumeUint32(tokenizer): + """Consumes an unsigned 32bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 32bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=False, is_long=False) + + +def _TryConsumeInt64(tokenizer): + try: + _ConsumeInt64(tokenizer) + return True + except ParseError: + return False + + +def _ConsumeInt64(tokenizer): + """Consumes a signed 32bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If a signed 32bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=True, is_long=True) + + +def _TryConsumeUint64(tokenizer): + try: + _ConsumeUint64(tokenizer) + return True + except ParseError: + return False + + +def _ConsumeUint64(tokenizer): + """Consumes an unsigned 64bit integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + + Returns: + The integer parsed. + + Raises: + ParseError: If an unsigned 64bit integer couldn't be consumed. + """ + return _ConsumeInteger(tokenizer, is_signed=False, is_long=True) + + +def _TryConsumeInteger(tokenizer, is_signed=False, is_long=False): + try: + _ConsumeInteger(tokenizer, is_signed=is_signed, is_long=is_long) + return True + except ParseError: + return False + + +def _ConsumeInteger(tokenizer, is_signed=False, is_long=False): + """Consumes an integer number from tokenizer. + + Args: + tokenizer: A tokenizer used to parse the number. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer parsed. + + Raises: + ParseError: If an integer with given characteristics couldn't be consumed. + """ + try: + result = ParseInteger(tokenizer.token, is_signed=is_signed, is_long=is_long) + except ValueError as e: + raise tokenizer.ParseError(str(e)) + tokenizer.NextToken() + return result + def ParseInteger(text, is_signed=False, is_long=False): """Parses an integer. @@ -1110,6 +1366,28 @@ def ParseInteger(text, is_signed=False, is_long=False): Returns: The integer value. + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + # Do the actual parsing. Exception handling is propagated to caller. + result = _ParseAbstractInteger(text, is_long=is_long) + + # Check if the integer is sane. Exceptions handled by callers. + checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] + checker.CheckValue(result) + return result + + +def _ParseAbstractInteger(text, is_long=False): + """Parses an integer without checking size/signedness. + + Args: + text: The text to parse. + is_long: True if the value should be returned as a long integer. + + Returns: + The integer value. + Raises: ValueError: Thrown Iff the text is not a valid integer. """ @@ -1119,17 +1397,12 @@ def ParseInteger(text, is_signed=False, is_long=False): # alternate implementations where the distinction is more significant # (e.g. the C++ implementation) simpler. if is_long: - result = long(text, 0) + return long(text, 0) else: - result = int(text, 0) + return int(text, 0) except ValueError: raise ValueError('Couldn\'t parse integer: %s' % text) - # Check if the integer is sane. Exceptions handled by callers. - checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] - checker.CheckValue(result) - return result - def ParseFloat(text): """Parse a floating point number. @@ -1206,14 +1479,12 @@ def ParseEnum(field, value): # Identifier. enum_value = enum_descriptor.values_by_name.get(value, None) if enum_value is None: - raise ValueError( - 'Enum type "%s" has no value named %s.' % ( - enum_descriptor.full_name, value)) + raise ValueError('Enum type "%s" has no value named %s.' % + (enum_descriptor.full_name, value)) else: # Numeric value. enum_value = enum_descriptor.values_by_number.get(number, None) if enum_value is None: - raise ValueError( - 'Enum type "%s" has no value with number %d.' % ( - enum_descriptor.full_name, number)) + raise ValueError('Enum type "%s" has no value with number %d.' % + (enum_descriptor.full_name, number)) return enum_value.number -- cgit v1.2.3