From 46e8ff63cb67a6520711da5317aaaef04d0414d0 Mon Sep 17 00:00:00 2001 From: Jisi Liu Date: Mon, 5 Oct 2015 11:59:43 -0700 Subject: Down-integrate from google internal. --- python/google/protobuf/descriptor.py | 43 +- python/google/protobuf/internal/decoder.py | 2 - .../protobuf/internal/descriptor_database_test.py | 5 +- .../protobuf/internal/descriptor_pool_test.py | 6 +- python/google/protobuf/internal/descriptor_test.py | 35 +- python/google/protobuf/internal/encoder.py | 2 - python/google/protobuf/internal/generator_test.py | 5 +- .../google/protobuf/internal/json_format_test.py | 522 ++++++++++++++++++ .../protobuf/internal/message_factory_test.py | 6 +- python/google/protobuf/internal/message_test.py | 20 +- .../protobuf/internal/missing_enum_values.proto | 4 + .../google/protobuf/internal/proto_builder_test.py | 10 +- python/google/protobuf/internal/python_message.py | 2 - python/google/protobuf/internal/reflection_test.py | 13 +- .../protobuf/internal/service_reflection_test.py | 5 +- .../protobuf/internal/symbol_database_test.py | 32 +- python/google/protobuf/internal/test_util.py | 7 + .../google/protobuf/internal/text_encoding_test.py | 5 +- .../google/protobuf/internal/text_format_test.py | 8 +- python/google/protobuf/internal/type_checkers.py | 3 +- .../protobuf/internal/unknown_fields_test.py | 21 +- .../google/protobuf/internal/wire_format_test.py | 5 +- python/google/protobuf/json_format.py | 601 +++++++++++++++++++++ python/google/protobuf/message_factory.py | 2 - python/google/protobuf/proto_builder.py | 43 +- python/google/protobuf/pyext/descriptor.cc | 41 +- python/google/protobuf/pyext/descriptor.h | 2 + .../google/protobuf/pyext/descriptor_containers.cc | 88 +++ .../google/protobuf/pyext/descriptor_containers.h | 1 + python/google/protobuf/pyext/descriptor_pool.cc | 112 +++- python/google/protobuf/pyext/descriptor_pool.h | 15 +- python/google/protobuf/pyext/extension_dict.cc | 3 +- python/google/protobuf/pyext/message.cc | 153 +++--- python/google/protobuf/pyext/message.h | 13 + python/google/protobuf/symbol_database.py | 28 +- python/google/protobuf/text_encoding.py | 1 + python/google/protobuf/text_format.py | 12 +- python/google/protobuf/util/__init__.py | 0 38 files changed, 1667 insertions(+), 209 deletions(-) create mode 100644 python/google/protobuf/internal/json_format_test.py create mode 100644 python/google/protobuf/json_format.py create mode 100644 python/google/protobuf/util/__init__.py (limited to 'python/google') diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index 95b703fc..2bf36532 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -28,8 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2007 Google Inc. All Rights Reserved. - """Descriptors essentially contain exactly the information found in a .proto file, in types that make this information accessible in Python. """ @@ -40,7 +38,6 @@ import six from google.protobuf.internal import api_implementation - _USE_C_DESCRIPTORS = False if api_implementation.Type() == 'cpp': # Used by MakeDescriptor in cpp mode @@ -221,6 +218,9 @@ class Descriptor(_NestedDescriptorBase): fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor objects as in |fields|, but indexed by "name" attribute in each FieldDescriptor. + fields_by_camelcase_name: (dict str -> FieldDescriptor) Same + FieldDescriptor objects as in |fields|, but indexed by + "camelcase_name" attribute in each FieldDescriptor. nested_types: (list of Descriptors) Descriptor references for all protocol message types nested within this one. @@ -292,6 +292,7 @@ class Descriptor(_NestedDescriptorBase): field.containing_type = self self.fields_by_number = dict((f.number, f) for f in fields) self.fields_by_name = dict((f.name, f) for f in fields) + self._fields_by_camelcase_name = None self.nested_types = nested_types for nested_type in nested_types: @@ -317,6 +318,13 @@ class Descriptor(_NestedDescriptorBase): oneof.containing_type = self self.syntax = syntax or "proto2" + @property + def fields_by_camelcase_name(self): + if self._fields_by_camelcase_name is None: + self._fields_by_camelcase_name = dict( + (f.camelcase_name, f) for f in self.fields) + return self._fields_by_camelcase_name + def EnumValueName(self, enum, value): """Returns the string name of an enum value. @@ -365,6 +373,7 @@ class FieldDescriptor(DescriptorBase): name: (str) Name of this field, exactly as it appears in .proto. full_name: (str) Name of this field, including containing scope. This is particularly relevant for extensions. + camelcase_name: (str) Camelcase name of this field. index: (int) Dense, 0-indexed index giving the order that this field textually appears within its message in the .proto file. number: (int) Tag number declared for this field in the .proto file. @@ -509,6 +518,7 @@ class FieldDescriptor(DescriptorBase): super(FieldDescriptor, self).__init__(options, 'FieldOptions') self.name = name self.full_name = full_name + self._camelcase_name = None self.index = index self.number = number self.type = type @@ -530,6 +540,12 @@ class FieldDescriptor(DescriptorBase): else: self._cdescriptor = None + @property + def camelcase_name(self): + if self._camelcase_name is None: + self._camelcase_name = _ToCamelCase(self.name) + return self._camelcase_name + @staticmethod def ProtoTypeToCppProtoType(proto_type): """Converts from a Python proto type to a C++ Proto Type. @@ -822,6 +838,27 @@ def _ParseOptions(message, string): return message +def _ToCamelCase(name): + """Converts name to camel-case and returns it.""" + capitalize_next = False + result = [] + + for c in name: + if c == '_': + if result: + capitalize_next = True + elif capitalize_next: + result.append(c.upper()) + capitalize_next = False + else: + result += c + + # Lower-case the first letter. + if result and result[0].isupper(): + result[0] = result[0].lower() + return ''.join(result) + + def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, syntax=None): """Make a protobuf Descriptor given a DescriptorProto protobuf. diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 4fd7a864..31869e45 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -28,8 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2009 Google Inc. All Rights Reserved. - """Code for decoding protocol buffer primitives. This code is very similar to encoder.py -- read the docs for that module first. diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index 1baff7d1..3241cb72 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -34,10 +34,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf import descriptor_database diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 2a482fba..6bbc8233 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -35,11 +35,8 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import os -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation @@ -47,6 +44,7 @@ from google.protobuf.internal import descriptor_pool_test1_pb2 from google.protobuf.internal import descriptor_pool_test2_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 +from google.protobuf.internal import test_util from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 34843a61..f94f9f14 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -1,4 +1,4 @@ -#! /usr/bin/python +#! /usr/bin/env python # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -36,20 +36,17 @@ __author__ = 'robinson@google.com (Will Robinson)' import sys +import unittest from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import api_implementation +from google.protobuf.internal import test_util from google.protobuf import descriptor from google.protobuf import symbol_database from google.protobuf import text_format -try: - import unittest2 as unittest -except ImportError: - import unittest - TEST_EMPTY_MESSAGE_DESCRIPTOR_ASCII = """ name: 'TestEmptyMessage' @@ -394,7 +391,7 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_file.name, 'some/filename/some.proto') self.assertEqual(self.my_file.package, 'protobuf_unittest') - @unittest.skipIf( + @test_util.skipIf( api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 'Immutability of descriptors is only enforced in v2 implementation') def testImmutableCppDescriptor(self): @@ -425,10 +422,12 @@ class GeneratedDescriptorTest(unittest.TestCase): self.CheckDescriptorSequence(message_descriptor.fields) self.CheckDescriptorMapping(message_descriptor.fields_by_name) self.CheckDescriptorMapping(message_descriptor.fields_by_number) + self.CheckDescriptorMapping(message_descriptor.fields_by_camelcase_name) def CheckFieldDescriptor(self, field_descriptor): # Basic properties self.assertEqual(field_descriptor.name, 'optional_int32') + self.assertEqual(field_descriptor.camelcase_name, 'optionalInt32') self.assertEqual(field_descriptor.full_name, 'protobuf_unittest.TestAllTypes.optional_int32') self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes') @@ -437,6 +436,10 @@ class GeneratedDescriptorTest(unittest.TestCase): self.assertEqual( field_descriptor.containing_type.fields_by_name['optional_int32'], field_descriptor) + self.assertEqual( + field_descriptor.containing_type.fields_by_camelcase_name[ + 'optionalInt32'], + field_descriptor) self.assertIn(field_descriptor, [field_descriptor]) self.assertIn(field_descriptor, {field_descriptor: None}) @@ -481,6 +484,9 @@ class GeneratedDescriptorTest(unittest.TestCase): self.CheckMessageDescriptor(message_descriptor) field_descriptor = message_descriptor.fields_by_name['optional_int32'] self.CheckFieldDescriptor(field_descriptor) + field_descriptor = message_descriptor.fields_by_camelcase_name[ + 'optionalInt32'] + self.CheckFieldDescriptor(field_descriptor) def testCppDescriptorContainer(self): # Check that the collection is still valid even if the parent disappeared. @@ -779,5 +785,20 @@ class MakeDescriptorTest(unittest.TestCase): self.assertEqual(101, options.Extensions[unittest_custom_options_pb2.msgopt].i) + def testCamelcaseName(self): + descriptor_proto = descriptor_pb2.DescriptorProto() + descriptor_proto.name = 'Bar' + names = ['foo_foo', 'FooBar', 'fooBaz', 'fooFoo', 'foobar'] + camelcase_names = ['fooFoo', 'fooBar', 'fooBaz', 'fooFoo', 'foobar'] + for index in range(len(names)): + field = descriptor_proto.field.add() + field.number = index + 1 + field.name = names[index] + result = descriptor.MakeDescriptor(descriptor_proto) + for index in range(len(camelcase_names)): + self.assertEqual(result.fields[index].camelcase_name, + camelcase_names[index]) + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index d72cd29d..48ef2df3 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -28,8 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2009 Google Inc. All Rights Reserved. - """Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 9956da59..7fcb1377 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -41,10 +41,7 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 diff --git a/python/google/protobuf/internal/json_format_test.py b/python/google/protobuf/internal/json_format_test.py new file mode 100644 index 00000000..6d0071be --- /dev/null +++ b/python/google/protobuf/internal/json_format_test.py @@ -0,0 +1,522 @@ +#! /usr/bin/env python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for google.protobuf.json_format.""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +import json +import math +import sys + +import unittest +from google.protobuf import json_format +from google.protobuf.util import json_format_proto3_pb2 + + +class JsonFormatBase(unittest.TestCase): + + def FillAllFields(self, message): + message.int32_value = 20 + message.int64_value = -20 + message.uint32_value = 3120987654 + message.uint64_value = 12345678900 + message.float_value = float('-inf') + message.double_value = 3.1415 + message.bool_value = True + message.string_value = 'foo' + message.bytes_value = b'bar' + message.message_value.value = 10 + message.enum_value = json_format_proto3_pb2.BAR + # Repeated + message.repeated_int32_value.append(0x7FFFFFFF) + message.repeated_int32_value.append(-2147483648) + message.repeated_int64_value.append(9007199254740992) + message.repeated_int64_value.append(-9007199254740992) + message.repeated_uint32_value.append(0xFFFFFFF) + message.repeated_uint32_value.append(0x7FFFFFF) + message.repeated_uint64_value.append(9007199254740992) + message.repeated_uint64_value.append(9007199254740991) + message.repeated_float_value.append(0) + + message.repeated_double_value.append(1E-15) + message.repeated_double_value.append(float('inf')) + message.repeated_bool_value.append(True) + message.repeated_bool_value.append(False) + message.repeated_string_value.append('Few symbols!#$,;') + message.repeated_string_value.append('bar') + message.repeated_bytes_value.append(b'foo') + message.repeated_bytes_value.append(b'bar') + message.repeated_message_value.add().value = 10 + message.repeated_message_value.add().value = 11 + message.repeated_enum_value.append(json_format_proto3_pb2.FOO) + message.repeated_enum_value.append(json_format_proto3_pb2.BAR) + self.message = message + + def CheckParseBack(self, message, parsed_message): + json_format.Parse(json_format.MessageToJson(message), + parsed_message) + self.assertEqual(message, parsed_message) + + def CheckError(self, text, error_message): + message = json_format_proto3_pb2.TestMessage() + self.assertRaisesRegexp( + json_format.ParseError, + error_message, + json_format.Parse, text, message) + + +class JsonFormatTest(JsonFormatBase): + + def testEmptyMessageToJson(self): + message = json_format_proto3_pb2.TestMessage() + self.assertEqual(json_format.MessageToJson(message), + '{}') + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + + def testPartialMessageToJson(self): + message = json_format_proto3_pb2.TestMessage( + string_value='test', + repeated_int32_value=[89, 4]) + self.assertEqual(json.loads(json_format.MessageToJson(message)), + json.loads('{"stringValue": "test", ' + '"repeatedInt32Value": [89, 4]}')) + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + + def testAllFieldsToJson(self): + message = json_format_proto3_pb2.TestMessage() + text = ('{"int32Value": 20, ' + '"int64Value": "-20", ' + '"uint32Value": 3120987654,' + '"uint64Value": "12345678900",' + '"floatValue": "-Infinity",' + '"doubleValue": 3.1415,' + '"boolValue": true,' + '"stringValue": "foo",' + '"bytesValue": "YmFy",' + '"messageValue": {"value": 10},' + '"enumValue": "BAR",' + '"repeatedInt32Value": [2147483647, -2147483648],' + '"repeatedInt64Value": ["9007199254740992", "-9007199254740992"],' + '"repeatedUint32Value": [268435455, 134217727],' + '"repeatedUint64Value": ["9007199254740992", "9007199254740991"],' + '"repeatedFloatValue": [0],' + '"repeatedDoubleValue": [1e-15, "Infinity"],' + '"repeatedBoolValue": [true, false],' + '"repeatedStringValue": ["Few symbols!#$,;", "bar"],' + '"repeatedBytesValue": ["Zm9v", "YmFy"],' + '"repeatedMessageValue": [{"value": 10}, {"value": 11}],' + '"repeatedEnumValue": ["FOO", "BAR"]' + '}') + self.FillAllFields(message) + self.assertEqual( + json.loads(json_format.MessageToJson(message)), + json.loads(text)) + parsed_message = json_format_proto3_pb2.TestMessage() + json_format.Parse(text, parsed_message) + self.assertEqual(message, parsed_message) + + def testJsonEscapeString(self): + message = json_format_proto3_pb2.TestMessage() + if sys.version_info[0] < 3: + message.string_value = '&\n<\"\r>\b\t\f\\\001/\xe2\x80\xa8\xe2\x80\xa9' + else: + message.string_value = '&\n<\"\r>\b\t\f\\\001/' + message.string_value += (b'\xe2\x80\xa8\xe2\x80\xa9').decode('utf-8') + self.assertEqual( + json_format.MessageToJson(message), + '{\n "stringValue": ' + '"&\\n<\\\"\\r>\\b\\t\\f\\\\\\u0001/\\u2028\\u2029"\n}') + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + text = u'{"int32Value": "\u0031"}' + json_format.Parse(text, message) + self.assertEqual(message.int32_value, 1) + + def testAlwaysSeriliaze(self): + message = json_format_proto3_pb2.TestMessage( + string_value='foo') + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"repeatedStringValue": [],' + '"stringValue": "foo",' + '"repeatedBoolValue": [],' + '"repeatedUint32Value": [],' + '"repeatedInt32Value": [],' + '"enumValue": "FOO",' + '"int32Value": 0,' + '"floatValue": 0,' + '"int64Value": "0",' + '"uint32Value": 0,' + '"repeatedBytesValue": [],' + '"repeatedUint64Value": [],' + '"repeatedDoubleValue": [],' + '"bytesValue": "",' + '"boolValue": false,' + '"repeatedEnumValue": [],' + '"uint64Value": "0",' + '"doubleValue": 0,' + '"repeatedFloatValue": [],' + '"repeatedInt64Value": [],' + '"repeatedMessageValue": []}')) + parsed_message = json_format_proto3_pb2.TestMessage() + self.CheckParseBack(message, parsed_message) + + def testMapFields(self): + message = json_format_proto3_pb2.TestMap() + message.bool_map[True] = 1 + message.bool_map[False] = 2 + message.int32_map[1] = 2 + message.int32_map[2] = 3 + message.int64_map[1] = 2 + message.int64_map[2] = 3 + message.uint32_map[1] = 2 + message.uint32_map[2] = 3 + message.uint64_map[1] = 2 + message.uint64_map[2] = 3 + message.string_map['1'] = 2 + message.string_map['null'] = 3 + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"boolMap": {"false": 2, "true": 1},' + '"int32Map": {"1": 2, "2": 3},' + '"int64Map": {"1": 2, "2": 3},' + '"uint32Map": {"1": 2, "2": 3},' + '"uint64Map": {"1": 2, "2": 3},' + '"stringMap": {"1": 2, "null": 3}' + '}')) + parsed_message = json_format_proto3_pb2.TestMap() + self.CheckParseBack(message, parsed_message) + + def testOneofFields(self): + message = json_format_proto3_pb2.TestOneof() + # Always print does not affect oneof fields. + self.assertEqual( + json_format.MessageToJson(message, True), + '{}') + message.oneof_int32_value = 0 + self.assertEqual( + json_format.MessageToJson(message, True), + '{\n' + ' "oneofInt32Value": 0\n' + '}') + parsed_message = json_format_proto3_pb2.TestOneof() + self.CheckParseBack(message, parsed_message) + + def testTimestampMessage(self): + message = json_format_proto3_pb2.TestTimestamp() + message.value.seconds = 0 + message.value.nanos = 0 + message.repeated_value.add().seconds = 20 + message.repeated_value[0].nanos = 1 + message.repeated_value.add().seconds = 0 + message.repeated_value[1].nanos = 10000 + message.repeated_value.add().seconds = 100000000 + message.repeated_value[2].nanos = 0 + # Maximum time + message.repeated_value.add().seconds = 253402300799 + message.repeated_value[3].nanos = 999999999 + # Minimum time + message.repeated_value.add().seconds = -62135596800 + message.repeated_value[4].nanos = 0 + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"value": "1970-01-01T00:00:00Z",' + '"repeatedValue": [' + ' "1970-01-01T00:00:20.000000001Z",' + ' "1970-01-01T00:00:00.000010Z",' + ' "1973-03-03T09:46:40Z",' + ' "9999-12-31T23:59:59.999999999Z",' + ' "0001-01-01T00:00:00Z"' + ']' + '}')) + parsed_message = json_format_proto3_pb2.TestTimestamp() + self.CheckParseBack(message, parsed_message) + text = (r'{"value": "1972-01-01T01:00:00.01+08:00",' + r'"repeatedValue":[' + r' "1972-01-01T01:00:00.01+08:30",' + r' "1972-01-01T01:00:00.01-01:23"]}') + json_format.Parse(text, parsed_message) + self.assertEqual(parsed_message.value.seconds, 63104400) + self.assertEqual(parsed_message.value.nanos, 10000000) + self.assertEqual(parsed_message.repeated_value[0].seconds, 63106200) + self.assertEqual(parsed_message.repeated_value[1].seconds, 63070620) + + def testDurationMessage(self): + message = json_format_proto3_pb2.TestDuration() + message.value.seconds = 1 + message.repeated_value.add().seconds = 0 + message.repeated_value[0].nanos = 10 + message.repeated_value.add().seconds = -1 + message.repeated_value[1].nanos = -1000 + message.repeated_value.add().seconds = 10 + message.repeated_value[2].nanos = 11000000 + message.repeated_value.add().seconds = -315576000000 + message.repeated_value.add().seconds = 315576000000 + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{' + '"value": "1s",' + '"repeatedValue": [' + ' "0.000000010s",' + ' "-1.000001s",' + ' "10.011s",' + ' "-315576000000s",' + ' "315576000000s"' + ']' + '}')) + parsed_message = json_format_proto3_pb2.TestDuration() + self.CheckParseBack(message, parsed_message) + + def testFieldMaskMessage(self): + message = json_format_proto3_pb2.TestFieldMask() + message.value.paths.append('foo.bar') + message.value.paths.append('bar') + self.assertEqual( + json_format.MessageToJson(message, True), + '{\n' + ' "value": "foo.bar,bar"\n' + '}') + parsed_message = json_format_proto3_pb2.TestFieldMask() + self.CheckParseBack(message, parsed_message) + + def testWrapperMessage(self): + message = json_format_proto3_pb2.TestWrapper() + message.bool_value.value = False + message.int32_value.value = 0 + message.string_value.value = '' + message.bytes_value.value = b'' + message.repeated_bool_value.add().value = True + message.repeated_bool_value.add().value = False + self.assertEqual( + json.loads(json_format.MessageToJson(message, True)), + json.loads('{\n' + ' "int32Value": 0,' + ' "boolValue": false,' + ' "stringValue": "",' + ' "bytesValue": "",' + ' "repeatedBoolValue": [true, false],' + ' "repeatedInt32Value": [],' + ' "repeatedUint32Value": [],' + ' "repeatedFloatValue": [],' + ' "repeatedDoubleValue": [],' + ' "repeatedBytesValue": [],' + ' "repeatedInt64Value": [],' + ' "repeatedUint64Value": [],' + ' "repeatedStringValue": []' + '}')) + parsed_message = json_format_proto3_pb2.TestWrapper() + self.CheckParseBack(message, parsed_message) + + def testParseNull(self): + message = json_format_proto3_pb2.TestMessage() + message.repeated_int32_value.append(1) + message.repeated_int32_value.append(2) + message.repeated_int32_value.append(3) + parsed_message = json_format_proto3_pb2.TestMessage() + self.FillAllFields(parsed_message) + json_format.Parse('{"int32Value": null, ' + '"int64Value": null, ' + '"uint32Value": null,' + '"uint64Value": null,' + '"floatValue": null,' + '"doubleValue": null,' + '"boolValue": null,' + '"stringValue": null,' + '"bytesValue": null,' + '"messageValue": null,' + '"enumValue": null,' + '"repeatedInt32Value": [1, 2, null, 3],' + '"repeatedInt64Value": null,' + '"repeatedUint32Value": null,' + '"repeatedUint64Value": null,' + '"repeatedFloatValue": null,' + '"repeatedDoubleValue": null,' + '"repeatedBoolValue": null,' + '"repeatedStringValue": null,' + '"repeatedBytesValue": null,' + '"repeatedMessageValue": null,' + '"repeatedEnumValue": null' + '}', + parsed_message) + self.assertEqual(message, parsed_message) + + def testNanFloat(self): + message = json_format_proto3_pb2.TestMessage() + message.float_value = float('nan') + text = '{\n "floatValue": "NaN"\n}' + self.assertEqual(json_format.MessageToJson(message), text) + parsed_message = json_format_proto3_pb2.TestMessage() + json_format.Parse(text, parsed_message) + self.assertTrue(math.isnan(parsed_message.float_value)) + + def testParseEmptyText(self): + self.CheckError('', + r'Failed to load JSON: (Expecting value)|(No JSON)') + + def testParseBadEnumValue(self): + self.CheckError( + '{"enumValue": 1}', + 'Enum value must be a string literal with double quotes. ' + 'Type "proto3.EnumType" has no value named 1.') + self.CheckError( + '{"enumValue": "baz"}', + 'Enum value must be a string literal with double quotes. ' + 'Type "proto3.EnumType" has no value named baz.') + + def testParseBadIdentifer(self): + self.CheckError('{int32Value: 1}', + (r'Failed to load JSON: Expecting property name enclosed ' + r'in double quotes: line 1')) + self.CheckError('{"unknownName": 1}', + 'Message type "proto3.TestMessage" has no field named ' + '"unknownName".') + + def testDuplicateField(self): + self.CheckError('{"int32Value": 1,\n"int32Value":2}', + 'Failed to load JSON: duplicate key int32Value') + + def testInvalidBoolValue(self): + self.CheckError('{"boolValue": 1}', + 'Failed to parse boolValue field: ' + 'Expected true or false without quotes.') + self.CheckError('{"boolValue": "true"}', + 'Failed to parse boolValue field: ' + 'Expected true or false without quotes.') + + def testInvalidIntegerValue(self): + message = json_format_proto3_pb2.TestMessage() + text = '{"int32Value": 0x12345}' + self.assertRaises(json_format.ParseError, + json_format.Parse, text, message) + self.CheckError('{"int32Value": 012345}', + (r'Failed to load JSON: Expecting \',\' delimiter: ' + r'line 1')) + self.CheckError('{"int32Value": 1.0}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: 1.0') + self.CheckError('{"int32Value": " 1 "}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: " 1 "') + self.CheckError('{"int32Value": 12345678901234567890}', + 'Failed to parse int32Value field: Value out of range: ' + '12345678901234567890') + self.CheckError('{"int32Value": 1e5}', + 'Failed to parse int32Value field: ' + 'Couldn\'t parse integer: 100000.0') + self.CheckError('{"uint32Value": -1}', + 'Failed to parse uint32Value field: Value out of range: -1') + + def testInvalidFloatValue(self): + self.CheckError('{"floatValue": "nan"}', + 'Failed to parse floatValue field: Couldn\'t ' + 'parse float "nan", use "NaN" instead') + + def testInvalidBytesValue(self): + self.CheckError('{"bytesValue": "AQI"}', + 'Failed to parse bytesValue field: Incorrect padding') + self.CheckError('{"bytesValue": "AQI*"}', + 'Failed to parse bytesValue field: Incorrect padding') + + def testInvalidMap(self): + message = json_format_proto3_pb2.TestMap() + text = '{"int32Map": {"null": 2, "2": 3}}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse int32Map field: Couldn\'t parse integer: "null"', + json_format.Parse, text, message) + text = '{"int32Map": {1: 2, "2": 3}}' + self.assertRaisesRegexp( + json_format.ParseError, + (r'Failed to load JSON: Expecting property name enclosed ' + r'in double quotes: line 1'), + json_format.Parse, text, message) + text = r'{"stringMap": {"a": 3, "\u0061": 2}}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to load JSON: duplicate key a', + json_format.Parse, text, message) + text = '{"boolMap": {"null": 1}}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse boolMap field: Expect "true" or "false", not null.', + json_format.Parse, text, message) + + def testInvalidTimestamp(self): + message = json_format_proto3_pb2.TestTimestamp() + text = '{"value": "10000-01-01T00:00:00.00Z"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'time data \'10000-01-01T00:00:00\' does not match' + ' format \'%Y-%m-%dT%H:%M:%S\'', + json_format.Parse, text, message) + text = '{"value": "1970-01-01T00:00:00.0123456789012Z"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse value field: Failed to parse Timestamp: ' + 'nanos 0123456789012 more than 9 fractional digits.', + json_format.Parse, text, message) + text = '{"value": "1972-01-01T01:00:00.01+08"}' + self.assertRaisesRegexp( + json_format.ParseError, + (r'Failed to parse value field: Invalid timezone offset value: \+08'), + json_format.Parse, text, message) + # Time smaller than minimum time. + text = '{"value": "0000-01-01T00:00:00Z"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Failed to parse value field: year is out of range', + json_format.Parse, text, message) + # Time bigger than maxinum time. + message.value.seconds = 253402300800 + self.assertRaisesRegexp( + json_format.SerializeToJsonError, + 'Failed to serialize value field: year is out of range', + json_format.MessageToJson, message) + + def testInvalidOneof(self): + message = json_format_proto3_pb2.TestOneof() + text = '{"oneofInt32Value": 1, "oneofStringValue": "2"}' + self.assertRaisesRegexp( + json_format.ParseError, + 'Message type "proto3.TestOneof"' + ' should not have multiple "oneof_value" oneof fields.', + json_format.Parse, text, message) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 0d880a75..9ec54fff 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -34,10 +34,7 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 @@ -45,7 +42,6 @@ from google.protobuf import descriptor_database from google.protobuf import descriptor_pool from google.protobuf import message_factory - class MessageFactoryTest(unittest.TestCase): def setUp(self): diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index d99b89be..604c426a 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -43,22 +43,16 @@ abstract interface. __author__ = 'gps@google.com (Gregory P. Smith)' + import collections import copy import math import operator import pickle -import sys - import six +import sys -if six.PY3: - long = int - -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf.internal import _parameterized from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_pb2 @@ -68,6 +62,9 @@ from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf import message +if six.PY3: + long = int + # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. def isnan(val): @@ -442,7 +439,7 @@ class MessageTest(unittest.TestCase): message.repeated_nested_message.add().bb = 24 message.repeated_nested_message.add().bb = 10 message.repeated_nested_message.sort(key=lambda z: z.bb // 10) - self.assertEquals( + self.assertEqual( [13, 11, 10, 21, 20, 24, 33], [n.bb for n in message.repeated_nested_message]) @@ -451,7 +448,7 @@ class MessageTest(unittest.TestCase): pb = message.SerializeToString() message.Clear() message.MergeFromString(pb) - self.assertEquals( + self.assertEqual( [13, 11, 10, 21, 20, 24, 33], [n.bb for n in message.repeated_nested_message]) @@ -914,7 +911,6 @@ class MessageTest(unittest.TestCase): with self.assertRaises(pickle.PickleError) as _: pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) - def testSortEmptyRepeatedCompositeContainer(self, message_module): """Exercise a scenario that has led to segfaults in the past. """ diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto index 161fc5e1..1850be5b 100644 --- a/python/google/protobuf/internal/missing_enum_values.proto +++ b/python/google/protobuf/internal/missing_enum_values.proto @@ -50,3 +50,7 @@ message TestMissingEnumValues { repeated NestedEnum repeated_nested_enum = 2; repeated NestedEnum packed_nested_enum = 3 [packed = true]; } + +message JustString { + required string dummy = 1; +} diff --git a/python/google/protobuf/internal/proto_builder_test.py b/python/google/protobuf/internal/proto_builder_test.py index 1eda10fb..e0467251 100644 --- a/python/google/protobuf/internal/proto_builder_test.py +++ b/python/google/protobuf/internal/proto_builder_test.py @@ -34,14 +34,10 @@ try: from collections import OrderedDict -except ImportError: - from ordereddict import OrderedDict #PY26 - -try: - import unittest2 as unittest #PY26 except ImportError: - import unittest - + from ordereddict import OrderedDict #PY26 +import collections +import unittest from google.protobuf import descriptor_pb2 from google.protobuf import descriptor_pool from google.protobuf import proto_builder diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 4e5032a7..2b87f704 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -28,8 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2007 Google Inc. All Rights Reserved. -# # This code is meant to work on Python 2.4 and above only. # # TODO(robinson): Helpers for verbose, common checks like seeing if a diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 26611353..6815c238 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -38,14 +38,10 @@ pure-Python protocol compiler. import copy import gc import operator -import struct -try: - import unittest2 as unittest -except ImportError: - import unittest - import six +import struct +import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -1627,7 +1623,7 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) - @unittest.skipIf( + @test_util.skipIf( api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, 'Errors are only available from the most recent C++ implementation.') def testFileDescriptorErrors(self): @@ -1799,7 +1795,6 @@ class ReflectionTest(unittest.TestCase): # Just check the default value. self.assertEqual(57, msg.inner.value) - # Since we had so many tests for protocol buffer equality, we broke these out # into separate TestCase classes. @@ -2827,7 +2822,7 @@ class OptionsTest(unittest.TestCase): class ClassAPITest(unittest.TestCase): - @unittest.skipIf( + @test_util.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 'C++ implementation requires a call to MakeDescriptor()') def testMakeClassWithNestedDescriptor(self): diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index 98614b77..564e2d1a 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -34,10 +34,7 @@ __author__ = 'petar@google.com (Petar Petrov)' -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py index 97442262..7fb4b56d 100644 --- a/python/google/protobuf/internal/symbol_database_test.py +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -32,27 +32,29 @@ """Tests for google.protobuf.symbol_database.""" -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import unittest_pb2 +from google.protobuf import descriptor from google.protobuf import symbol_database - class SymbolDatabaseTest(unittest.TestCase): def _Database(self): - db = symbol_database.SymbolDatabase() - # Register representative types from unittest_pb2. - db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) - db.RegisterMessage(unittest_pb2.TestAllTypes) - db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) - db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) - db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) - db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) - db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) - return db + # TODO(b/17734095): Remove this difference when the C++ implementation + # supports multiple databases. + if descriptor._USE_C_DESCRIPTORS: + return symbol_database.Default() + else: + db = symbol_database.SymbolDatabase() + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db def testGetPrototype(self): instance = self._Database().GetPrototype( diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index ac88fa81..539236b4 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -38,6 +38,13 @@ __author__ = 'robinson@google.com (Will Robinson)' import os.path +import sys +# PY2.6 compatible skipIf +if sys.version_info < (2, 7): + from unittest2 import skipIf +else: + from unittest import skipIf + from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py index 338a287b..9e7b9ce4 100755 --- a/python/google/protobuf/internal/text_encoding_test.py +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -32,10 +32,7 @@ """Tests for google.protobuf.text_encoding.""" -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import text_encoding TEST_VALUES = [ diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index d332b77d..fb4addeb 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -34,14 +34,12 @@ __author__ = 'kenton@google.com (Kenton Varda)' + import re import six import string -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf.internal import _parameterized from google.protobuf import map_unittest_pb2 @@ -389,7 +387,7 @@ class TextFormatTest(TextFormatBase): # Ideally the schemas would be made more similar so these tests could pass. class OnlyWorksWithProto2RightNowTests(TextFormatBase): - def testPrintAllFieldsPointy(self, message_module): + def testPrintAllFieldsPointy(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.CompareToGoldenFile( diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 8fa3d8c8..f30ca6a8 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -28,8 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2008 Google Inc. All Rights Reserved. - """Provides type checking routines. This module defines type checking utilities in the forms of dictionaries: @@ -52,6 +50,7 @@ import six if six.PY3: long = int +from google.protobuf.internal import api_implementation from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import wire_format diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 011d3b55..25b447e1 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -35,11 +35,7 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' -try: - import unittest2 as unittest -except ImportError: - import unittest - +import unittest from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 @@ -52,7 +48,7 @@ from google.protobuf.internal import type_checkers def SkipIfCppImplementation(func): - return unittest.skipIf( + return test_util.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 'C++ implementation does not expose unknown fields to Python')(func) @@ -262,6 +258,19 @@ class UnknownEnumValuesTest(unittest.TestCase): decoder(value, 0, len(value), self.message, result_dict) return result_dict[field_descriptor] + def testUnknownParseMismatchEnumValue(self): + just_string = missing_enum_values_pb2.JustString() + just_string.dummy = 'blah' + + missing = missing_enum_values_pb2.TestEnumValues() + # The parse is invalid, storing the string proto into the set of + # unknown fields. + missing.ParseFromString(just_string.SerializeToString()) + + # Fetching the enum field shouldn't crash, instead returning the + # default value. + self.assertEqual(missing.optional_nested_enum, 0) + @SkipIfCppImplementation def testUnknownEnumValue(self): self.assertFalse(self.missing_message.HasField('optional_nested_enum')) diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index f659d18e..78dc1167 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -34,10 +34,7 @@ __author__ = 'robinson@google.com (Will Robinson)' -try: - import unittest2 as unittest -except ImportError: - import unittest +import unittest from google.protobuf import message from google.protobuf.internal import wire_format diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py new file mode 100644 index 00000000..09110e04 --- /dev/null +++ b/python/google/protobuf/json_format.py @@ -0,0 +1,601 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains routines for printing protocol messages in JSON format.""" + +__author__ = 'jieluo@google.com (Jie Luo)' + +import base64 +from datetime import datetime +import json +import math +import re + +from google.protobuf import descriptor + +_TIMESTAMPFOMAT = '%Y-%m-%dT%H:%M:%S' +_NUMBER = re.compile(u'[0-9+-][0-9e.+-]*') +_INTEGER = re.compile(u'[0-9+-]') +_INT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT32, + descriptor.FieldDescriptor.CPPTYPE_UINT32, + descriptor.FieldDescriptor.CPPTYPE_INT64, + descriptor.FieldDescriptor.CPPTYPE_UINT64]) +_INT64_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_INT64, + descriptor.FieldDescriptor.CPPTYPE_UINT64]) +_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, + descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) +if str is bytes: + _UNICODETYPE = unicode +else: + _UNICODETYPE = str + + +class SerializeToJsonError(Exception): + """Thrown if serialization to JSON fails.""" + + +class ParseError(Exception): + """Thrown in case of parsing error.""" + + +def MessageToJson(message, including_default_value_fields=False): + """Converts protobuf message to JSON format. + + Args: + message: The protocol buffers message instance to serialize. + including_default_value_fields: If True, singular primitive fields, + repeated fields, and map fields will always be serialized. If + False, only serialize non-empty fields. Singular message fields + and oneof fields are not affected by this option. + + Returns: + A string containing the JSON formatted protocol buffer message. + """ + js = _MessageToJsonObject(message, including_default_value_fields) + return json.dumps(js, indent=2) + + +def _MessageToJsonObject(message, including_default_value_fields): + """Converts message to an object according to Proto3 JSON Specification.""" + message_descriptor = message.DESCRIPTOR + if _IsTimestampMessage(message_descriptor): + return _TimestampMessageToJsonObject(message) + if _IsDurationMessage(message_descriptor): + return _DurationMessageToJsonObject(message) + if _IsFieldMaskMessage(message_descriptor): + return _FieldMaskMessageToJsonObject(message) + if _IsWrapperMessage(message_descriptor): + return _WrapperMessageToJsonObject(message) + return _RegularMessageToJsonObject(message, including_default_value_fields) + + +def _IsMapEntry(field): + return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type.has_options and + field.message_type.GetOptions().map_entry) + + +def _RegularMessageToJsonObject(message, including_default_value_fields): + """Converts normal message according to Proto3 JSON Specification.""" + js = {} + fields = message.ListFields() + + try: + for field, value in fields: + name = field.camelcase_name + if _IsMapEntry(field): + # Convert a map field. + js_map = {} + for key in value: + js_map[key] = _ConvertFieldToJsonObject( + field.message_type.fields_by_name['value'], + value[key], including_default_value_fields) + js[name] = js_map + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + # Convert a repeated field. + repeated = [] + for element in value: + repeated.append(_ConvertFieldToJsonObject( + field, element, including_default_value_fields)) + js[name] = repeated + else: + js[name] = _ConvertFieldToJsonObject( + field, value, including_default_value_fields) + + # Serialize default value if including_default_value_fields is True. + if including_default_value_fields: + message_descriptor = message.DESCRIPTOR + for field in message_descriptor.fields: + # Singular message fields and oneof fields will not be affected. + if ((field.label != descriptor.FieldDescriptor.LABEL_REPEATED and + field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE) or + field.containing_oneof): + continue + name = field.camelcase_name + if name in js: + # Skip the field which has been serailized already. + continue + if _IsMapEntry(field): + js[name] = {} + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + js[name] = [] + else: + js[name] = _ConvertFieldToJsonObject(field, field.default_value) + + except ValueError as e: + raise SerializeToJsonError( + 'Failed to serialize {0} field: {1}'.format(field.name, e)) + + return js + + +def _ConvertFieldToJsonObject( + field, value, including_default_value_fields=False): + """Converts field value according to Proto3 JSON Specification.""" + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + return _MessageToJsonObject(value, including_default_value_fields) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + return enum_value.name + else: + raise SerializeToJsonError('Enum field contains an integer value ' + 'which can not mapped to an enum value.') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + # Use base64 Data encoding for bytes + return base64.b64encode(value).decode('utf-8') + else: + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + return True + else: + return False + elif field.cpp_type in _INT64_TYPES: + return str(value) + elif field.cpp_type in _FLOAT_TYPES: + if math.isinf(value): + if value < 0.0: + return '-Infinity' + else: + return 'Infinity' + if math.isnan(value): + return 'NaN' + return value + + +def _IsTimestampMessage(message_descriptor): + return (message_descriptor.name == 'Timestamp' and + message_descriptor.file.name == 'google/protobuf/timestamp.proto') + + +def _TimestampMessageToJsonObject(message): + """Converts Timestamp message according to Proto3 JSON Specification.""" + nanos = message.nanos % 1e9 + dt = datetime.utcfromtimestamp( + message.seconds + (message.nanos - nanos) / 1e9) + result = dt.isoformat() + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 'Z' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03dZ' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06dZ' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09dZ' % nanos + + +def _IsDurationMessage(message_descriptor): + return (message_descriptor.name == 'Duration' and + message_descriptor.file.name == 'google/protobuf/duration.proto') + + +def _DurationMessageToJsonObject(message): + """Converts Duration message according to Proto3 JSON Specification.""" + if message.seconds < 0 or message.nanos < 0: + result = '-' + seconds = - message.seconds + int((0 - message.nanos) / 1e9) + nanos = (0 - message.nanos) % 1e9 + else: + result = '' + seconds = message.seconds + int(message.nanos / 1e9) + nanos = message.nanos % 1e9 + result += '%d' % seconds + if (nanos % 1e9) == 0: + # If there are 0 fractional digits, the fractional + # point '.' should be omitted when serializing. + return result + 's' + if (nanos % 1e6) == 0: + # Serialize 3 fractional digits. + return result + '.%03ds' % (nanos / 1e6) + if (nanos % 1e3) == 0: + # Serialize 6 fractional digits. + return result + '.%06ds' % (nanos / 1e3) + # Serialize 9 fractional digits. + return result + '.%09ds' % nanos + + +def _IsFieldMaskMessage(message_descriptor): + return (message_descriptor.name == 'FieldMask' and + message_descriptor.file.name == 'google/protobuf/field_mask.proto') + + +def _FieldMaskMessageToJsonObject(message): + """Converts FieldMask message according to Proto3 JSON Specification.""" + result = '' + first = True + for path in message.paths: + if not first: + result += ',' + result += path + first = False + return result + + +def _IsWrapperMessage(message_descriptor): + return message_descriptor.file.name == 'google/protobuf/wrappers.proto' + + +def _WrapperMessageToJsonObject(message): + return _ConvertFieldToJsonObject( + message.DESCRIPTOR.fields_by_name['value'], message.value) + + +def _DuplicateChecker(js): + result = {} + for name, value in js: + if name in result: + raise ParseError('Failed to load JSON: duplicate key ' + name) + result[name] = value + return result + + +def Parse(text, message): + """Parses a JSON representation of a protocol message into a message. + + Args: + text: Message JSON representation. + message: A protocol beffer message to merge into. + + Returns: + The same message passed as argument. + + Raises:: + ParseError: On JSON parsing problems. + """ + if not isinstance(text, _UNICODETYPE): text = text.decode('utf-8') + try: + js = json.loads(text, object_pairs_hook=_DuplicateChecker) + except ValueError as e: + raise ParseError('Failed to load JSON: ' + str(e)) + _ConvertFieldValuePair(js, message) + return message + + +def _ConvertFieldValuePair(js, message): + """Convert field value pairs into regular message. + + Args: + js: A JSON object to convert the field value pairs. + message: A regular protocol message to record the data. + + Raises: + ParseError: In case of problems converting. + """ + names = [] + message_descriptor = message.DESCRIPTOR + for name in js: + try: + field = message_descriptor.fields_by_camelcase_name.get(name, None) + if not field: + raise ParseError( + 'Message type "{0}" has no field named "{1}".'.format( + message_descriptor.full_name, name)) + if name in names: + raise ParseError( + 'Message type "{0}" should not have multiple "{1}" fields.'.format( + message.DESCRIPTOR.full_name, name)) + names.append(name) + # Check no other oneof field is parsed. + if field.containing_oneof is not None: + oneof_name = field.containing_oneof.name + if oneof_name in names: + raise ParseError('Message type "{0}" should not have multiple "{1}" ' + 'oneof fields.'.format( + message.DESCRIPTOR.full_name, oneof_name)) + names.append(oneof_name) + + value = js[name] + if value is None: + message.ClearField(field.name) + continue + + # Parse field value. + if _IsMapEntry(field): + message.ClearField(field.name) + _ConvertMapFieldValue(value, message, field) + elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + message.ClearField(field.name) + if not isinstance(value, list): + raise ParseError('repeated field {0} must be in [] which is ' + '{1}'.format(name, value)) + for item in value: + if item is None: + continue + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + sub_message = getattr(message, field.name).add() + _ConvertMessage(item, sub_message) + else: + getattr(message, field.name).append( + _ConvertScalarFieldValue(item, field)) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + sub_message = getattr(message, field.name) + _ConvertMessage(value, sub_message) + else: + setattr(message, field.name, _ConvertScalarFieldValue(value, field)) + except ParseError as e: + if field and field.containing_oneof is None: + raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + else: + raise ParseError(str(e)) + except ValueError as e: + raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + except TypeError as e: + raise ParseError('Failed to parse {0} field: {1}'.format(name, e)) + + +def _ConvertMessage(value, message): + """Convert a JSON object into a message. + + Args: + value: A JSON object. + message: A WKT or regular protocol message to record the data. + + Raises: + ParseError: In case of convert problems. + """ + message_descriptor = message.DESCRIPTOR + if _IsTimestampMessage(message_descriptor): + _ConvertTimestampMessage(value, message) + elif _IsDurationMessage(message_descriptor): + _ConvertDurationMessage(value, message) + elif _IsFieldMaskMessage(message_descriptor): + _ConvertFieldMaskMessage(value, message) + elif _IsWrapperMessage(message_descriptor): + _ConvertWrapperMessage(value, message) + else: + _ConvertFieldValuePair(value, message) + + +def _ConvertTimestampMessage(value, message): + """Convert a JSON representation into Timestamp message.""" + timezone_offset = value.find('Z') + if timezone_offset == -1: + timezone_offset = value.find('+') + if timezone_offset == -1: + timezone_offset = value.rfind('-') + if timezone_offset == -1: + raise ParseError( + 'Failed to parse timestamp: missing valid timezone offset.') + time_value = value[0:timezone_offset] + # Parse datetime and nanos + point_position = time_value.find('.') + if point_position == -1: + second_value = time_value + nano_value = '' + else: + second_value = time_value[:point_position] + nano_value = time_value[point_position + 1:] + date_object = datetime.strptime(second_value, _TIMESTAMPFOMAT) + seconds = (date_object - datetime(1970, 1, 1)).total_seconds() + if len(nano_value) > 9: + raise ParseError( + 'Failed to parse Timestamp: nanos {0} more than ' + '9 fractional digits.'.format(nano_value)) + if nano_value: + nanos = round(float('0.' + nano_value) * 1e9) + else: + nanos = 0 + # Parse timezone offsets + if value[timezone_offset] == 'Z': + if len(value) != timezone_offset + 1: + raise ParseError( + 'Failed to parse timestamp: invalid trailing data {0}.'.format(value)) + else: + timezone = value[timezone_offset:] + pos = timezone.find(':') + if pos == -1: + raise ParseError( + 'Invalid timezone offset value: ' + timezone) + if timezone[0] == '+': + seconds += (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + else: + seconds -= (int(timezone[1:pos])*60+int(timezone[pos+1:]))*60 + # Set seconds and nanos + message.seconds = int(seconds) + message.nanos = int(nanos) + + +def _ConvertDurationMessage(value, message): + """Convert a JSON representation into Duration message.""" + if value[-1] != 's': + raise ParseError( + 'Duration must end with letter "s": ' + value) + try: + duration = float(value[:-1]) + except ValueError: + raise ParseError( + 'Couldn\'t parse duration: ' + value) + message.seconds = int(duration) + message.nanos = int(round((duration - message.seconds) * 1e9)) + + +def _ConvertFieldMaskMessage(value, message): + """Convert a JSON representation into FieldMask message.""" + for path in value.split(','): + message.paths.append(path) + + +def _ConvertWrapperMessage(value, message): + """Convert a JSON representation into Wrapper message.""" + field = message.DESCRIPTOR.fields_by_name['value'] + setattr(message, 'value', _ConvertScalarFieldValue(value, field)) + + +def _ConvertMapFieldValue(value, message, field): + """Convert map field value for a message map field. + + Args: + value: A JSON object to convert the map field value. + message: A protocol message to record the converted data. + field: The descriptor of the map field to be converted. + + Raises: + ParseError: In case of convert problems. + """ + if not isinstance(value, dict): + raise ParseError( + 'Map fieled {0} must be in {} which is {1}.'.format(field.name, value)) + key_field = field.message_type.fields_by_name['key'] + value_field = field.message_type.fields_by_name['value'] + for key in value: + key_value = _ConvertScalarFieldValue(key, key_field, True) + if value_field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + _ConvertMessage(value[key], getattr(message, field.name)[key_value]) + else: + getattr(message, field.name)[key_value] = _ConvertScalarFieldValue( + value[key], value_field) + + +def _ConvertScalarFieldValue(value, field, require_quote=False): + """Convert a single scalar field value. + + Args: + value: A scalar value to convert the scalar field value. + field: The descriptor of the field to convert. + require_quote: If True, '"' is required for the field value. + + Returns: + The converted scalar field value + + Raises: + ParseError: In case of convert problems. + """ + if field.cpp_type in _INT_TYPES: + return _ConvertInteger(value) + elif field.cpp_type in _FLOAT_TYPES: + return _ConvertFloat(value) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + return _ConvertBool(value, require_quote) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + if field.type == descriptor.FieldDescriptor.TYPE_BYTES: + return base64.b64decode(value) + else: + return value + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + # Convert an enum value. + enum_value = field.enum_type.values_by_name.get(value, None) + if enum_value is None: + raise ParseError( + 'Enum value must be a string literal with double quotes. ' + 'Type "{0}" has no value named {1}.'.format( + field.enum_type.full_name, value)) + return enum_value.number + + +def _ConvertInteger(value): + """Convert an integer. + + Args: + value: A scalar value to convert. + + Returns: + The integer value. + + Raises: + ParseError: If an integer couldn't be consumed. + """ + if isinstance(value, float): + raise ParseError('Couldn\'t parse integer: {0}'.format(value)) + + if isinstance(value, _UNICODETYPE) and not _INTEGER.match(value): + raise ParseError('Couldn\'t parse integer: "{0}"'.format(value)) + + return int(value) + + +def _ConvertFloat(value): + """Convert an floating point number.""" + if value == 'nan': + raise ParseError('Couldn\'t parse float "nan", use "NaN" instead') + try: + # Assume Python compatible syntax. + return float(value) + except ValueError: + # Check alternative spellings. + if value == '-Infinity': + return float('-inf') + elif value == 'Infinity': + return float('inf') + elif value == 'NaN': + return float('nan') + else: + raise ParseError('Couldn\'t parse float: {0}'.format(value)) + + +def _ConvertBool(value, require_quote): + """Convert a boolean value. + + Args: + value: A scalar value to convert. + require_quote: If True, '"' is required for the boolean value. + + Returns: + The bool parsed. + + Raises: + ParseError: If a boolean value couldn't be consumed. + """ + if require_quote: + if value == 'true': + return True + elif value == 'false': + return False + else: + raise ParseError('Expect "true" or "false", not {0}.'.format(value)) + + if not isinstance(value, bool): + raise ParseError('Expected true or false without quotes.') + return value diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index 36062a56..9cd9c2a8 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -28,8 +28,6 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2012 Google Inc. All Rights Reserved. - """Provides a factory class for generating dynamic messages. The easiest way to use this class is if you have access to the FileDescriptor diff --git a/python/google/protobuf/proto_builder.py b/python/google/protobuf/proto_builder.py index 700e3c25..736caed3 100644 --- a/python/google/protobuf/proto_builder.py +++ b/python/google/protobuf/proto_builder.py @@ -48,7 +48,7 @@ def _GetMessageFromFactory(factory, full_name): factory: a MessageFactory instance. full_name: str, the fully qualified name of the proto type. Returns: - a class, for the type identified by full_name. + A class, for the type identified by full_name. Raises: KeyError, if the proto is not found in the factory's descriptor pool. """ @@ -57,7 +57,7 @@ def _GetMessageFromFactory(factory, full_name): return proto_cls -def MakeSimpleProtoClass(fields, full_name, pool=None): +def MakeSimpleProtoClass(fields, full_name=None, pool=None): """Create a Protobuf class whose fields are basic types. Note: this doesn't validate field names! @@ -66,18 +66,20 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): fields: dict of {name: field_type} mappings for each field in the proto. If this is an OrderedDict the order will be maintained, otherwise the fields will be sorted by name. - full_name: str, the fully-qualified name of the proto type. + full_name: optional str, the fully-qualified name of the proto type. pool: optional DescriptorPool instance. Returns: a class, the new protobuf class with a FileDescriptor. """ factory = message_factory.MessageFactory(pool=pool) - try: - proto_cls = _GetMessageFromFactory(factory, full_name) - return proto_cls - except KeyError: - # The factory's DescriptorPool doesn't know about this class yet. - pass + + if full_name is not None: + try: + proto_cls = _GetMessageFromFactory(factory, full_name) + return proto_cls + except KeyError: + # The factory's DescriptorPool doesn't know about this class yet. + pass # Get a list of (name, field_type) tuples from the fields dict. If fields was # an OrderedDict we keep the order, but otherwise we sort the field to ensure @@ -94,6 +96,25 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): fields_hash.update(str(f_type).encode('utf-8')) proto_file_name = fields_hash.hexdigest() + '.proto' + # If the proto is anonymous, use the same hash to name it. + if full_name is None: + full_name = ('net.proto2.python.public.proto_builder.AnonymousProto_' + + fields_hash.hexdigest()) + try: + proto_cls = _GetMessageFromFactory(factory, full_name) + return proto_cls + except KeyError: + # The factory's DescriptorPool doesn't know about this class yet. + pass + + # This is the first time we see this proto: add a new descriptor to the pool. + factory.pool.Add( + _MakeFileDescriptorProto(proto_file_name, full_name, field_items)) + return _GetMessageFromFactory(factory, full_name) + + +def _MakeFileDescriptorProto(proto_file_name, full_name, field_items): + """Populate FileDescriptorProto for MessageFactory's DescriptorPool.""" package, name = full_name.rsplit('.', 1) file_proto = descriptor_pb2.FileDescriptorProto() file_proto.name = os.path.join(package.replace('.', '/'), proto_file_name) @@ -106,6 +127,4 @@ def MakeSimpleProtoClass(fields, full_name, pool=None): field_proto.number = f_number field_proto.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL field_proto.type = f_type - - factory.pool.Add(file_proto) - return _GetMessageFromFactory(factory, full_name) + return file_proto diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 3806643f..b238fd02 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -223,8 +223,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { options.SerializeToString(&serialized); io::CodedInputStream input( reinterpret_cast(serialized.c_str()), serialized.size()); - input.SetExtensionRegistry(pool->pool, - GetDescriptorPool()->message_factory); + input.SetExtensionRegistry(pool->pool, pool->message_factory); bool success = cmsg->message->MergePartialFromCodedStream(&input); if (!success) { PyErr_Format(PyExc_ValueError, "Error parsing Options message"); @@ -414,8 +413,14 @@ static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { } static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { + // Retuns the canonical class for the given descriptor. + // This is the class that was registered with the primary descriptor pool + // which contains this descriptor. + // This might not be the one you expect! For example the returned object does + // not know about extensions defined in a custom pool. PyObject* concrete_class(cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), _GetDescriptor(self))); + GetDescriptorPool_FromPool(_GetDescriptor(self)->file()->pool()), + _GetDescriptor(self))); Py_XINCREF(concrete_class); return concrete_class; } @@ -424,6 +429,11 @@ static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) { return NewMessageFieldsByName(_GetDescriptor(self)); } +static PyObject* GetFieldsByCamelcaseName(PyBaseDescriptor* self, + void *closure) { + return NewMessageFieldsByCamelcaseName(_GetDescriptor(self)); +} + static PyObject* GetFieldsByNumber(PyBaseDescriptor* self, void *closure) { return NewMessageFieldsByNumber(_GetDescriptor(self)); } @@ -564,6 +574,8 @@ static PyGetSetDef Getters[] = { { "fields", (getter)GetFieldsSeq, NULL, "Fields sequence"}, { "fields_by_name", (getter)GetFieldsByName, NULL, "Fields by name"}, + { "fields_by_camelcase_name", (getter)GetFieldsByCamelcaseName, NULL, + "Fields by camelCase name"}, { "fields_by_number", (getter)GetFieldsByNumber, NULL, "Fields by number"}, { "nested_types", (getter)GetNestedTypesSeq, NULL, "Nested types sequence"}, { "nested_types_by_name", (getter)GetNestedTypesByName, NULL, @@ -662,6 +674,10 @@ static PyObject* GetName(PyBaseDescriptor *self, void *closure) { return PyString_FromCppString(_GetDescriptor(self)->name()); } +static PyObject* GetCamelcaseName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->camelcase_name()); +} + static PyObject* GetType(PyBaseDescriptor *self, void *closure) { return PyInt_FromLong(_GetDescriptor(self)->type()); } @@ -850,6 +866,7 @@ static int SetOptions(PyBaseDescriptor *self, PyObject *value, static PyGetSetDef Getters[] = { { "full_name", (getter)GetFullName, NULL, "Full name"}, { "name", (getter)GetName, NULL, "Unqualified name"}, + { "camelcase_name", (getter)GetCamelcaseName, NULL, "Camelcase name"}, { "type", (getter)GetType, NULL, "C++ Type"}, { "cpp_type", (getter)GetCppType, NULL, "C++ Type"}, { "label", (getter)GetLabel, NULL, "Label"}, @@ -1070,6 +1087,15 @@ PyObject* PyEnumDescriptor_FromDescriptor( &PyEnumDescriptor_Type, enum_descriptor, NULL); } +const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyEnumDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not an EnumDescriptor"); + return NULL; + } + return reinterpret_cast( + reinterpret_cast(obj)->descriptor); +} + namespace enumvalue_descriptor { // Unchecked accessor to the C++ pointer. @@ -1359,6 +1385,15 @@ PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( return py_descriptor; } +const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyFileDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a FileDescriptor"); + return NULL; + } + return reinterpret_cast( + reinterpret_cast(obj)->descriptor); +} + namespace oneof_descriptor { // Unchecked accessor to the C++ pointer. diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h index b2550406..eb99df18 100644 --- a/python/google/protobuf/pyext/descriptor.h +++ b/python/google/protobuf/pyext/descriptor.h @@ -72,6 +72,8 @@ PyObject* PyFileDescriptor_FromDescriptorWithSerializedPb( // exception set. const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj); const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj); +const EnumDescriptor* PyEnumDescriptor_AsDescriptor(PyObject* obj); +const FileDescriptor* PyFileDescriptor_AsDescriptor(PyObject* obj); // Returns the raw C++ pointer. const void* PyDescriptor_AsVoidPtr(PyObject* obj); diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc index 92e11e31..b20f5e4f 100644 --- a/python/google/protobuf/pyext/descriptor_containers.cc +++ b/python/google/protobuf/pyext/descriptor_containers.cc @@ -79,9 +79,12 @@ struct PyContainer; typedef int (*CountMethod)(PyContainer* self); typedef const void* (*GetByIndexMethod)(PyContainer* self, int index); typedef const void* (*GetByNameMethod)(PyContainer* self, const string& name); +typedef const void* (*GetByCamelcaseNameMethod)(PyContainer* self, + const string& name); typedef const void* (*GetByNumberMethod)(PyContainer* self, int index); typedef PyObject* (*NewObjectFromItemMethod)(const void* descriptor); typedef const string& (*GetItemNameMethod)(const void* descriptor); +typedef const string& (*GetItemCamelcaseNameMethod)(const void* descriptor); typedef int (*GetItemNumberMethod)(const void* descriptor); typedef int (*GetItemIndexMethod)(const void* descriptor); @@ -95,6 +98,9 @@ struct DescriptorContainerDef { // Retrieve item by name (usually a call to some 'FindByName' method). // Used by "by_name" mappings. GetByNameMethod get_by_name_fn; + // Retrieve item by camelcase name (usually a call to some + // 'FindByCamelcaseName' method). Used by "by_camelcase_name" mappings. + GetByCamelcaseNameMethod get_by_camelcase_name_fn; // Retrieve item by declared number (field tag, or enum value). // Used by "by_number" mappings. GetByNumberMethod get_by_number_fn; @@ -102,6 +108,9 @@ struct DescriptorContainerDef { NewObjectFromItemMethod new_object_from_item_fn; // Retrieve the name of an item. Used by iterators on "by_name" mappings. GetItemNameMethod get_item_name_fn; + // Retrieve the camelcase name of an item. Used by iterators on + // "by_camelcase_name" mappings. + GetItemCamelcaseNameMethod get_item_camelcase_name_fn; // Retrieve the number of an item. Used by iterators on "by_number" mappings. GetItemNumberMethod get_item_number_fn; // Retrieve the index of an item for the container type. @@ -125,6 +134,7 @@ struct PyContainer { enum ContainerKind { KIND_SEQUENCE, KIND_BYNAME, + KIND_BYCAMELCASENAME, KIND_BYNUMBER, } kind; }; @@ -172,6 +182,23 @@ static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) { self, string(name, name_size)); return true; } + case PyContainer::KIND_BYCAMELCASENAME: + { + char* camelcase_name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(key, &camelcase_name, &name_size) < 0) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a string, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_camelcase_name_fn( + self, string(camelcase_name, name_size)); + return true; + } case PyContainer::KIND_BYNUMBER: { Py_ssize_t number = PyNumber_AsSsize_t(key, NULL); @@ -203,6 +230,12 @@ static PyObject* _NewKey_ByIndex(PyContainer* self, Py_ssize_t index) { const string& name(self->container_def->get_item_name_fn(item)); return PyString_FromStringAndSize(name.c_str(), name.size()); } + case PyContainer::KIND_BYCAMELCASENAME: + { + const string& name( + self->container_def->get_item_camelcase_name_fn(item)); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } case PyContainer::KIND_BYNUMBER: { int value = self->container_def->get_item_number_fn(item); @@ -276,6 +309,9 @@ static PyObject* ContainerRepr(PyContainer* self) { case PyContainer::KIND_BYNAME: kind = "mapping by name"; break; + case PyContainer::KIND_BYCAMELCASENAME: + kind = "mapping by camelCase name"; + break; case PyContainer::KIND_BYNUMBER: kind = "mapping by number"; break; @@ -731,6 +767,18 @@ static PyObject* NewMappingByName( return reinterpret_cast(self); } +static PyObject* NewMappingByCamelcaseName( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYCAMELCASENAME; + return reinterpret_cast(self); +} + static PyObject* NewMappingByNumber( DescriptorContainerDef* container_def, const void* descriptor) { if (container_def->get_by_number_fn == NULL || @@ -889,6 +937,11 @@ static ItemDescriptor GetByName(PyContainer* self, const string& name) { return GetDescriptor(self)->FindFieldByName(name); } +static ItemDescriptor GetByCamelcaseName(PyContainer* self, + const string& name) { + return GetDescriptor(self)->FindFieldByCamelcaseName(name); +} + static ItemDescriptor GetByNumber(PyContainer* self, int number) { return GetDescriptor(self)->FindFieldByNumber(number); } @@ -905,6 +958,10 @@ static const string& GetItemName(ItemDescriptor item) { return item->name(); } +static const string& GetItemCamelcaseName(ItemDescriptor item) { + return item->camelcase_name(); +} + static int GetItemNumber(ItemDescriptor item) { return item->number(); } @@ -918,9 +975,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)GetByCamelcaseName, (GetByNumberMethod)GetByNumber, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)GetItemCamelcaseName, (GetItemNumberMethod)GetItemNumber, (GetItemIndexMethod)GetItemIndex, }; @@ -931,6 +990,11 @@ PyObject* NewMessageFieldsByName(ParentDescriptor descriptor) { return descriptor::NewMappingByName(&fields::ContainerDef, descriptor); } +PyObject* NewMessageFieldsByCamelcaseName(ParentDescriptor descriptor) { + return descriptor::NewMappingByCamelcaseName(&fields::ContainerDef, + descriptor); +} + PyObject* NewMessageFieldsByNumber(ParentDescriptor descriptor) { return descriptor::NewMappingByNumber(&fields::ContainerDef, descriptor); } @@ -972,9 +1036,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1022,9 +1088,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1094,9 +1162,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)NULL, }; @@ -1140,9 +1210,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1190,9 +1262,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1258,9 +1332,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)GetByNumber, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)GetItemNumber, (GetItemIndexMethod)GetItemIndex, }; @@ -1314,9 +1390,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)NULL, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)NULL, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1370,9 +1448,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1416,9 +1496,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1462,9 +1544,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)GetByName, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)GetItemName, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)GetItemIndex, }; @@ -1496,9 +1580,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)NULL, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)NULL, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)NULL, }; @@ -1530,9 +1616,11 @@ static DescriptorContainerDef ContainerDef = { (CountMethod)Count, (GetByIndexMethod)GetByIndex, (GetByNameMethod)NULL, + (GetByCamelcaseNameMethod)NULL, (GetByNumberMethod)NULL, (NewObjectFromItemMethod)NewObjectFromItem, (GetItemNameMethod)NULL, + (GetItemCamelcaseNameMethod)NULL, (GetItemNumberMethod)NULL, (GetItemIndexMethod)NULL, }; diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h index 8fbdaff9..ce40747d 100644 --- a/python/google/protobuf/pyext/descriptor_containers.h +++ b/python/google/protobuf/pyext/descriptor_containers.h @@ -54,6 +54,7 @@ bool InitDescriptorMappingTypes(); namespace message_descriptor { PyObject* NewMessageFieldsByName(const Descriptor* descriptor); +PyObject* NewMessageFieldsByCamelcaseName(const Descriptor* descriptor); PyObject* NewMessageFieldsByNumber(const Descriptor* descriptor); PyObject* NewMessageFieldsSeq(const Descriptor* descriptor); diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index 7aed651d..6443a7d5 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -108,6 +108,7 @@ static void Dealloc(PyDescriptorPool* self) { Py_DECREF(it->second); } delete self->descriptor_options; + delete self->pool; delete self->message_factory; Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -131,22 +132,9 @@ PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) { } // Add a message class to our database. -const Descriptor* RegisterMessageClass( - PyDescriptorPool* self, PyObject *message_class, PyObject* descriptor) { - ScopedPyObjectPtr full_message_name( - PyObject_GetAttrString(descriptor, "full_name")); - Py_ssize_t name_size; - char* name; - if (PyString_AsStringAndSize(full_message_name, &name, &name_size) < 0) { - return NULL; - } - const Descriptor *message_descriptor = - self->pool->FindMessageTypeByName(string(name, name_size)); - if (!message_descriptor) { - PyErr_Format(PyExc_TypeError, "Could not find C++ descriptor for '%s'", - name); - return NULL; - } +int RegisterMessageClass(PyDescriptorPool* self, + const Descriptor *message_descriptor, + PyObject *message_class) { Py_INCREF(message_class); typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; std::pair ret = self->classes_by_descriptor->insert( @@ -156,7 +144,7 @@ const Descriptor* RegisterMessageClass( Py_DECREF(ret.first->second); ret.first->second = message_class; } - return message_descriptor; + return 0; } // Retrieve the message class added to our database. @@ -260,6 +248,80 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { return PyOneofDescriptor_FromDescriptor(oneof_descriptor); } +PyObject* FindFileContainingSymbol(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FileDescriptor* file_descriptor = + self->pool->FindFileContainingSymbol(string(name, name_size)); + if (file_descriptor == NULL) { + PyErr_Format(PyExc_KeyError, "Couldn't find symbol %.200s", name); + return NULL; + } + + return PyFileDescriptor_FromDescriptor(file_descriptor); +} + +// These functions should not exist -- the only valid way to create +// descriptors is to call Add() or AddSerializedFile(). +// But these AddDescriptor() functions were created in Python and some people +// call them, so we support them for now for compatibility. +// However we do check that the existing descriptor already exists in the pool, +// which appears to always be true for existing calls -- but then why do people +// call a function that will just be a no-op? +// TODO(amauryfa): Need to investigate further. + +PyObject* AddFileDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const FileDescriptor* file_descriptor = + PyFileDescriptor_AsDescriptor(descriptor); + if (!file_descriptor) { + return NULL; + } + if (file_descriptor != + self->pool->FindFileByName(file_descriptor->name())) { + PyErr_Format(PyExc_ValueError, + "The file descriptor %s does not belong to this pool", + file_descriptor->name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +PyObject* AddDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const Descriptor* message_descriptor = + PyMessageDescriptor_AsDescriptor(descriptor); + if (!message_descriptor) { + return NULL; + } + if (message_descriptor != + self->pool->FindMessageTypeByName(message_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The message descriptor %s does not belong to this pool", + message_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + +PyObject* AddEnumDescriptor(PyDescriptorPool* self, PyObject* descriptor) { + const EnumDescriptor* enum_descriptor = + PyEnumDescriptor_AsDescriptor(descriptor); + if (!enum_descriptor) { + return NULL; + } + if (enum_descriptor != + self->pool->FindEnumTypeByName(enum_descriptor->full_name())) { + PyErr_Format(PyExc_ValueError, + "The enum descriptor %s does not belong to this pool", + enum_descriptor->full_name().c_str()); + return NULL; + } + Py_RETURN_NONE; +} + // The code below loads new Descriptors from a serialized FileDescriptorProto. @@ -341,6 +403,15 @@ static PyMethodDef Methods[] = { { "AddSerializedFile", (PyCFunction)AddSerializedFile, METH_O, "Adds a serialized FileDescriptorProto to this pool." }, + // TODO(amauryfa): Understand why the Python implementation differs from + // this one, ask users to use another API and deprecate these functions. + { "AddFileDescriptor", (PyCFunction)AddFileDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddDescriptor", (PyCFunction)AddDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "AddEnumDescriptor", (PyCFunction)AddEnumDescriptor, METH_O, + "No-op. Add() must have been called before." }, + { "FindFileByName", (PyCFunction)FindFileByName, METH_O, "Searches for a file descriptor by its .proto name." }, { "FindMessageTypeByName", (PyCFunction)FindMessageByName, METH_O, @@ -353,6 +424,9 @@ static PyMethodDef Methods[] = { "Searches for enum type descriptor by full name." }, { "FindOneofByName", (PyCFunction)FindOneofByName, METH_O, "Searches for oneof descriptor by full name." }, + + { "FindFileContainingSymbol", (PyCFunction)FindFileContainingSymbol, METH_O, + "Gets the FileDescriptor containing the specified symbol." }, {NULL} }; @@ -420,7 +494,7 @@ bool InitDescriptorPool() { return true; } -PyDescriptorPool* GetDescriptorPool() { +PyDescriptorPool* GetDefaultDescriptorPool() { return python_generated_pool; } @@ -432,7 +506,7 @@ PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) { } hash_map::iterator it = descriptor_pool_map.find(pool); - if (it != descriptor_pool_map.end()) { + if (it == descriptor_pool_map.end()) { PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool"); return NULL; } diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index 541d920b..eda73d38 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -89,12 +89,10 @@ const Descriptor* FindMessageTypeByName(PyDescriptorPool* self, const string& name); // Registers a new Python class for the given message descriptor. -// Returns the message Descriptor. -// On error, returns NULL with a Python exception set. -const Descriptor* RegisterMessageClass( - PyDescriptorPool* self, PyObject* message_class, PyObject* descriptor); - -// The function below are also exposed as methods of the DescriptorPool type. +// On error, returns -1 with a Python exception set. +int RegisterMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor, + PyObject* message_class); // Retrieves the Python class registered with the given message descriptor. // @@ -103,6 +101,8 @@ const Descriptor* RegisterMessageClass( PyObject* GetMessageClass(PyDescriptorPool* self, const Descriptor* message_descriptor); +// The functions below are also exposed as methods of the DescriptorPool type. + // Looks up a message by name. Returns a PyMessageDescriptor corresponding to // the field on success, or NULL on failure. // @@ -136,8 +136,9 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg); } // namespace cdescriptor_pool // Retrieve the global descriptor pool owned by the _message module. +// This is the one used by pb2.py generated modules. // Returns a *borrowed* reference. -PyDescriptorPool* GetDescriptorPool(); +PyDescriptorPool* GetDefaultDescriptorPool(); // Retrieve the python descriptor pool owning a C++ descriptor pool. // Returns a *borrowed* reference. diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 8ebbb27c..9c9b4178 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -123,7 +123,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject *message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), descriptor->message_type()); + cmessage::GetDescriptorPoolForMessage(self->parent), + descriptor->message_type()); if (message_class == NULL) { return NULL; } diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 62c7c478..63d53136 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -55,6 +55,7 @@ #include #include #include +#include #include #include #include @@ -107,8 +108,18 @@ struct PyMessageMeta { // C++ descriptor of this message. const Descriptor* message_descriptor; + // Owned reference, used to keep the pointer above alive. PyObject* py_message_descriptor; + + // The Python DescriptorPool used to create the class. It is needed to resolve + // fields descriptors, including extensions fields; its C++ MessageFactory is + // used to instantiate submessages. + // This can be different from DESCRIPTOR.file.pool, in the case of a custom + // DescriptorPool which defines new extensions. + // We own the reference, because it's important to keep the descriptors and + // factory alive. + PyDescriptorPool* py_descriptor_pool; }; namespace message_meta { @@ -139,18 +150,10 @@ static bool AddFieldNumberToClass( // Finalize the creation of the Message class. -// Called from its metaclass: GeneratedProtocolMessageType.__init__(). -static int AddDescriptors(PyObject* cls, PyObject* descriptor) { - const Descriptor* message_descriptor = - cdescriptor_pool::RegisterMessageClass( - GetDescriptorPool(), cls, descriptor); - if (message_descriptor == NULL) { - return -1; - } - +static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // If there are extension_ranges, the message is "extendable", and extension // classes will register themselves in this class. - if (message_descriptor->extension_range_count() > 0) { + if (descriptor->extension_range_count() > 0) { ScopedPyObjectPtr by_name(PyDict_New()); if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { return -1; @@ -162,8 +165,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) { } // For each field set: cls._FIELD_NUMBER = - for (int i = 0; i < message_descriptor->field_count(); ++i) { - if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { + for (int i = 0; i < descriptor->field_count(); ++i) { + if (!AddFieldNumberToClass(cls, descriptor->field(i))) { return -1; } } @@ -173,8 +176,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) { // The enum descriptor we get from // .enum_types_by_name[name] // which was built previously. - for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { - const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); + for (int i = 0; i < descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = descriptor->enum_type(i); ScopedPyObjectPtr enum_type( PyEnumDescriptor_FromDescriptor(enum_descriptor)); if (enum_type == NULL) { @@ -212,8 +215,8 @@ static int AddDescriptors(PyObject* cls, PyObject* descriptor) { // Extension descriptors come from // .extensions_by_name[name] // which was defined previously. - for (int i = 0; i < message_descriptor->extension_count(); ++i) { - const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); + for (int i = 0; i < descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = descriptor->extension(i); ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); if (extension_field == NULL) { return -1; @@ -258,14 +261,14 @@ static PyObject* New(PyTypeObject* type, } // Check dict['DESCRIPTOR'] - PyObject* descriptor = PyDict_GetItem(dict, kDESCRIPTOR); - if (descriptor == NULL) { + PyObject* py_descriptor = PyDict_GetItem(dict, kDESCRIPTOR); + if (py_descriptor == NULL) { PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); return NULL; } - if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { + if (!PyObject_TypeCheck(py_descriptor, &PyMessageDescriptor_Type)) { PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", - descriptor->ob_type->tp_name); + py_descriptor->ob_type->tp_name); return NULL; } @@ -291,14 +294,28 @@ static PyObject* New(PyTypeObject* type, } // Cache the descriptor, both as Python object and as C++ pointer. - const Descriptor* message_descriptor = - PyMessageDescriptor_AsDescriptor(descriptor); - if (message_descriptor == NULL) { + const Descriptor* descriptor = + PyMessageDescriptor_AsDescriptor(py_descriptor); + if (descriptor == NULL) { + return NULL; + } + Py_INCREF(py_descriptor); + newtype->py_message_descriptor = py_descriptor; + newtype->message_descriptor = descriptor; + // TODO(amauryfa): Don't always use the canonical pool of the descriptor, + // use the MessageFactory optionally passed in the class dict. + newtype->py_descriptor_pool = GetDescriptorPool_FromPool( + descriptor->file()->pool()); + if (newtype->py_descriptor_pool == NULL) { + return NULL; + } + Py_INCREF(newtype->py_descriptor_pool); + + // Add the message to the DescriptorPool. + if (cdescriptor_pool::RegisterMessageClass(newtype->py_descriptor_pool, + descriptor, result) < 0) { return NULL; } - Py_INCREF(descriptor); - newtype->py_message_descriptor = descriptor; - newtype->message_descriptor = message_descriptor; // Continue with type initialization: add other descriptors, enum values... if (AddDescriptors(result, descriptor) < 0) { @@ -309,6 +326,7 @@ static PyObject* New(PyTypeObject* type, static void Dealloc(PyMessageMeta *self) { Py_DECREF(self->py_message_descriptor); + Py_DECREF(self->py_descriptor_pool); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -381,12 +399,20 @@ PyTypeObject PyMessageMeta_Type = { message_meta::New, // tp_new }; -static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { +static PyMessageMeta* CheckMessageClass(PyTypeObject* cls) { if (!PyObject_TypeCheck(cls, &PyMessageMeta_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } - return reinterpret_cast(cls)->message_descriptor; + return reinterpret_cast(cls); +} + +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { + PyMessageMeta* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; + } + return type->message_descriptor; } // Forward declarations @@ -723,6 +749,17 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { +PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message) { + // No need to check the type: the type of instances of CMessage is always + // an instance of PyMessageMeta. Let's prove it with a debug-only check. + GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); + return reinterpret_cast(Py_TYPE(message))->py_descriptor_pool; +} + +MessageFactory* GetFactoryForMessage(CMessage* message) { + return GetDescriptorPoolForMessage(message)->message_factory; +} + static int MaybeReleaseOverlappingOneofField( CMessage* cmessage, const FieldDescriptor* field) { @@ -773,7 +810,7 @@ static Message* GetMutableMessage( return NULL; } return reflection->MutableMessage( - parent_message, parent_field, GetDescriptorPool()->message_factory); + parent_message, parent_field, GetFactoryForMessage(parent)); } struct FixupMessageReference : public ChildVisitor { @@ -814,10 +851,7 @@ int AssureWritable(CMessage* self) { // If parent is NULL but we are trying to modify a read-only message, this // is a reference to a constant default instance that needs to be replaced // with a mutable top-level message. - const Message* prototype = - GetDescriptorPool()->message_factory->GetPrototype( - self->message->GetDescriptor()); - self->message = prototype->New(); + self->message = self->message->New(); self->owner.reset(self->message); // Cascade the new owner to eventual children: even if this message is // empty, some submessages or repeated containers might exist already. @@ -1190,15 +1224,19 @@ CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { // The __new__ method of Message classes. // Creates a new C++ message and takes ownership. -static PyObject* New(PyTypeObject* type, +static PyObject* New(PyTypeObject* cls, PyObject* unused_args, PyObject* unused_kwargs) { + PyMessageMeta* type = CheckMessageClass(cls); + if (type == NULL) { + return NULL; + } // Retrieve the message descriptor and the default instance (=prototype). - const Descriptor* message_descriptor = GetMessageDescriptor(type); + const Descriptor* message_descriptor = type->message_descriptor; if (message_descriptor == NULL) { return NULL; } - const Message* default_message = - GetDescriptorPool()->message_factory->GetPrototype(message_descriptor); + const Message* default_message = type->py_descriptor_pool->message_factory + ->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); return NULL; @@ -1528,7 +1566,7 @@ int SetOwner(CMessage* self, const shared_ptr& new_owner) { Message* ReleaseMessage(CMessage* self, const Descriptor* descriptor, const FieldDescriptor* field_descriptor) { - MessageFactory* message_factory = GetDescriptorPool()->message_factory; + MessageFactory* message_factory = GetFactoryForMessage(self); Message* released_message = self->message->GetReflection()->ReleaseMessage( self->message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from @@ -1883,8 +1921,8 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); io::CodedInputStream input( reinterpret_cast(data), data_length); - input.SetExtensionRegistry(GetDescriptorPool()->pool, - GetDescriptorPool()->message_factory); + PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); + input.SetExtensionRegistry(pool->pool, pool->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -1907,11 +1945,6 @@ static PyObject* ByteSize(CMessage* self, PyObject* args) { static PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { - ScopedPyObjectPtr message_descriptor(PyObject_GetAttr(cls, kDESCRIPTOR)); - if (message_descriptor == NULL) { - return NULL; - } - const FieldDescriptor* descriptor = GetExtensionDescriptor(extension_handle); if (descriptor == NULL) { @@ -1920,13 +1953,6 @@ static PyObject* RegisterExtension(PyObject* cls, const Descriptor* cmessage_descriptor = GetMessageDescriptor( reinterpret_cast(cls)); - if (cmessage_descriptor != descriptor->containing_type()) { - if (PyObject_SetAttrString(extension_handle, "containing_type", - message_descriptor) < 0) { - return NULL; - } - } - ScopedPyObjectPtr extensions_by_name( PyObject_GetAttr(cls, k_extensions_by_name)); if (extensions_by_name == NULL) { @@ -2050,7 +2076,8 @@ static PyObject* ListFields(CMessage* self) { // TODO(amauryfa): consider building the class on the fly! if (fields[i]->message_type() != NULL && cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), fields[i]->message_type()) == NULL) { + GetDescriptorPoolForMessage(self), + fields[i]->message_type()) == NULL) { PyErr_Clear(); continue; } @@ -2207,7 +2234,9 @@ PyObject* InternalGetScalar(const Message* message, message->GetReflection()->GetUnknownFields(*message); for (int i = 0; i < unknown_field_set.field_count(); ++i) { if (unknown_field_set.field(i).number() == - field_descriptor->number()) { + field_descriptor->number() && + unknown_field_set.field(i).type() == + google::protobuf::UnknownField::TYPE_VARINT) { result = PyInt_FromLong(unknown_field_set.field(i).varint()); break; } @@ -2233,11 +2262,12 @@ PyObject* InternalGetScalar(const Message* message, PyObject* InternalGetSubMessage( CMessage* self, const FieldDescriptor* field_descriptor) { const Reflection* reflection = self->message->GetReflection(); + PyDescriptorPool* pool = GetDescriptorPoolForMessage(self); const Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, GetDescriptorPool()->message_factory); + *self->message, field_descriptor, pool->message_factory); PyObject *message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), field_descriptor->message_type()); + pool, field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2560,7 +2590,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { const FieldDescriptor* value_type = entry_type->FindFieldByName("value"); if (value_type->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject* value_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), value_type->message_type()); + GetDescriptorPoolForMessage(self), value_type->message_type()); if (value_class == NULL) { return NULL; } @@ -2583,7 +2613,7 @@ PyObject* GetAttr(CMessage* self, PyObject* name) { PyObject* py_container = NULL; if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyObject *message_class = cdescriptor_pool::GetMessageClass( - GetDescriptorPool(), field_descriptor->message_type()); + GetDescriptorPoolForMessage(self), field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -2908,9 +2938,10 @@ bool InitProto2MessageModule(PyObject *m) { // Expose the DescriptorPool used to hold all descriptors added from generated // pb2.py files. - Py_INCREF(GetDescriptorPool()); // PyModule_AddObject steals a reference. - PyModule_AddObject( - m, "default_pool", reinterpret_cast(GetDescriptorPool())); + // PyModule_AddObject steals a reference. + Py_INCREF(GetDefaultDescriptorPool()); + PyModule_AddObject(m, "default_pool", + reinterpret_cast(GetDefaultDescriptorPool())); // This implementation provides full Descriptor types, we advertise it so that // descriptor.py can use them in replacement of the Python classes. diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index f147d433..1ff82e2f 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -49,12 +49,15 @@ class Message; class Reflection; class FieldDescriptor; class Descriptor; +class DescriptorPool; +class MessageFactory; using internal::shared_ptr; namespace python { struct ExtensionDict; +struct PyDescriptorPool; typedef struct CMessage { PyObject_HEAD; @@ -220,6 +223,16 @@ PyObject* FindInitializationErrors(CMessage* self); int SetOwner(CMessage* self, const shared_ptr& new_owner); int AssureWritable(CMessage* self); + +// Returns the "best" DescriptorPool for the given message. +// This is often equivalent to message.DESCRIPTOR.pool, but not always, when +// the message class was created from a MessageFactory using a custom pool which +// uses the generated pool as an underlay. +// +// The returned pool is suitable for finding fields and building submessages, +// even in the case of extensions. +PyDescriptorPool* GetDescriptorPoolForMessage(CMessage* message); + } // namespace cmessage diff --git a/python/google/protobuf/symbol_database.py b/python/google/protobuf/symbol_database.py index 4c70b393..b81ef4d7 100644 --- a/python/google/protobuf/symbol_database.py +++ b/python/google/protobuf/symbol_database.py @@ -60,6 +60,7 @@ Example usage: """ +from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool @@ -72,6 +73,31 @@ class SymbolDatabase(object): buffer types used within a program. """ + # pylint: disable=protected-access + if _descriptor._USE_C_DESCRIPTORS: + + def __new__(cls): + raise TypeError("Instances of SymbolDatabase cannot be created") + + @classmethod + def _CreateDefaultDatabase(cls): + self = object.__new__(cls) # Bypass the __new__ above. + # Don't call __init__() and initialize here. + self._symbols = {} + self._symbols_by_file = {} + # As of today all descriptors are registered and retrieved from + # _message.default_pool (see FileDescriptor.__new__), so it's not + # necessary to use another pool. + self.pool = _descriptor._message.default_pool + return self + # pylint: enable=protected-access + + else: + + @classmethod + def _CreateDefaultDatabase(cls): + return cls() + def __init__(self): """Constructor.""" @@ -177,7 +203,7 @@ class SymbolDatabase(object): result.update(self._symbols_by_file[f]) return result -_DEFAULT = SymbolDatabase() +_DEFAULT = SymbolDatabase._CreateDefaultDatabase() def Default(): diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py index a0728e3c..98995638 100644 --- a/python/google/protobuf/text_encoding.py +++ b/python/google/protobuf/text_encoding.py @@ -27,6 +27,7 @@ # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + """Encoding related utilities.""" import re diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 1399223f..e4fadf09 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -28,9 +28,17 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# Copyright 2007 Google Inc. All Rights Reserved. +"""Contains routines for printing protocol messages in text format. -"""Contains routines for printing protocol messages in text format.""" +Simple usage example: + + # Create a proto object and serialize it to a text proto string. + message = my_proto_pb2.MyMessage(foo='bar') + text_proto = text_format.MessageToString(message) + + # Parse a text proto string. + message = text_format.Parse(text_proto, my_proto_pb2.MyMessage()) +""" __author__ = 'kenton@google.com (Kenton Varda)' diff --git a/python/google/protobuf/util/__init__.py b/python/google/protobuf/util/__init__.py new file mode 100644 index 00000000..e69de29b -- cgit v1.2.3