From 6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3 Mon Sep 17 00:00:00 2001 From: Feng Xiao Date: Wed, 8 Aug 2018 17:00:41 -0700 Subject: Down-integrate from google3. --- .../google/protobuf/internal/text_format_test.py | 9 +- python/google/protobuf/descriptor_database.py | 11 + python/google/protobuf/descriptor_pool.py | 196 ++++++- python/google/protobuf/internal/__init__.py | 30 + .../google/protobuf/internal/api_implementation.py | 26 - python/google/protobuf/internal/containers.py | 127 +++++ python/google/protobuf/internal/decoder.py | 192 ++++++- .../protobuf/internal/descriptor_database_test.py | 8 + .../protobuf/internal/descriptor_pool_test.py | 63 ++- python/google/protobuf/internal/descriptor_test.py | 11 + .../google/protobuf/internal/factory_test1.proto | 14 + .../protobuf/internal/message_factory_test.py | 4 +- python/google/protobuf/internal/message_test.py | 262 ++++++++- python/google/protobuf/internal/no_package.proto | 30 + python/google/protobuf/internal/python_message.py | 150 ++++- python/google/protobuf/internal/reflection_test.py | 53 +- .../google/protobuf/internal/text_format_test.py | 298 ++++++++-- python/google/protobuf/internal/type_checkers.py | 8 + .../protobuf/internal/unknown_fields_test.py | 165 ++++-- python/google/protobuf/json_format.py | 2 +- python/google/protobuf/message.py | 4 + python/google/protobuf/message_factory.py | 15 +- python/google/protobuf/proto_api.h | 11 +- python/google/protobuf/pyext/descriptor.cc | 51 +- python/google/protobuf/pyext/descriptor.h | 2 +- .../google/protobuf/pyext/descriptor_containers.cc | 12 +- .../google/protobuf/pyext/descriptor_containers.h | 2 +- .../google/protobuf/pyext/descriptor_database.cc | 39 +- python/google/protobuf/pyext/descriptor_database.h | 9 +- python/google/protobuf/pyext/descriptor_pool.cc | 47 +- python/google/protobuf/pyext/descriptor_pool.h | 6 +- python/google/protobuf/pyext/extension_dict.cc | 92 +-- python/google/protobuf/pyext/extension_dict.h | 20 +- python/google/protobuf/pyext/map_container.cc | 148 +++-- python/google/protobuf/pyext/map_container.h | 2 +- python/google/protobuf/pyext/message.cc | 614 +++++++++++++-------- python/google/protobuf/pyext/message.h | 40 +- python/google/protobuf/pyext/message_factory.cc | 6 +- python/google/protobuf/pyext/message_factory.h | 7 +- python/google/protobuf/pyext/message_module.cc | 89 +-- .../protobuf/pyext/repeated_composite_container.cc | 16 +- .../protobuf/pyext/repeated_composite_container.h | 2 +- .../protobuf/pyext/repeated_scalar_container.cc | 6 + .../protobuf/pyext/repeated_scalar_container.h | 2 +- python/google/protobuf/pyext/safe_numerics.h | 2 +- .../protobuf/pyext/thread_unsafe_shared_ptr.h | 2 +- python/google/protobuf/python_protobuf.h | 2 +- python/google/protobuf/reflection.py | 52 +- python/google/protobuf/text_encoding.py | 84 +-- python/google/protobuf/text_format.py | 134 +++-- 50 files changed, 2328 insertions(+), 849 deletions(-) (limited to 'python') diff --git a/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py b/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py index 8267cd2c..bc53e256 100755 --- a/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py +++ b/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py @@ -410,7 +410,8 @@ class TextFormatTest(unittest.TestCase): text = 'optional_nested_enum: BARR' self.assertRaisesWithMessage( text_format.ParseError, - ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' + ('1:23 : \'optional_nested_enum: BARR\': ' + 'Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value named BARR.'), text_format.Merge, text, message) @@ -418,7 +419,8 @@ class TextFormatTest(unittest.TestCase): text = 'optional_nested_enum: 100' self.assertRaisesWithMessage( text_format.ParseError, - ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' + ('1:23 : \'optional_nested_enum: 100\': ' + 'Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value with number 100.'), text_format.Merge, text, message) @@ -427,7 +429,8 @@ class TextFormatTest(unittest.TestCase): text = 'optional_int32: bork' self.assertRaisesWithMessage( text_format.ParseError, - ('1:17 : Couldn\'t parse integer: bork'), + ('1:17 : \'optional_int32: bork\': ' + 'Couldn\'t parse integer: bork'), text_format.Merge, text, message) def testMergeStringFieldUnescape(self): diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py index 8b7715cd..a7616cbc 100644 --- a/python/google/protobuf/descriptor_database.py +++ b/python/google/protobuf/descriptor_database.py @@ -76,6 +76,9 @@ class DescriptorDatabase(object): self._AddSymbol(name, file_desc_proto) for enum in file_desc_proto.enum_type: self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto) + for enum_value in enum.value: + self._file_desc_protos_by_symbol[ + '.'.join((package, enum_value.name))] = file_desc_proto for extension in file_desc_proto.extension: self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto) for service in file_desc_proto.service: @@ -133,6 +136,14 @@ class DescriptorDatabase(object): top_level, _, _ = symbol.rpartition('.') return self._file_desc_protos_by_symbol[top_level] + def FindFileContainingExtension(self, extendee_name, extension_number): + # TODO(jieluo): implement this API. + return None + + def FindAllExtensionNumbers(self, extendee_name): + # TODO(jieluo): implement this API. + return [] + def _AddSymbol(self, name, file_desc_proto): if name in self._file_desc_protos_by_symbol: warn_msg = ('Conflict register for file "' + file_desc_proto.name + diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index 8983f76f..42f7bcb5 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -131,33 +131,46 @@ class DescriptorPool(object): # TODO(jieluo): Remove _file_desc_by_toplevel_extension after # maybe year 2020 for compatibility issue (with 3.4.1 only). self._file_desc_by_toplevel_extension = {} + self._top_enum_values = {} # We store extensions in two two-level mappings: The first key is the # descriptor of the message being extended, the second key is the extension # full name or its tag number. self._extensions_by_name = collections.defaultdict(dict) self._extensions_by_number = collections.defaultdict(dict) - def _CheckConflictRegister(self, desc): + def _CheckConflictRegister(self, desc, desc_name, file_name): """Check if the descriptor name conflicts with another of the same name. Args: - desc: Descriptor of a message, enum, service or extension. + desc: Descriptor of a message, enum, service, extension or enum value. + desc_name: the full name of desc. + file_name: The file name of descriptor. """ - desc_name = desc.full_name for register, descriptor_type in [ (self._descriptors, descriptor.Descriptor), (self._enum_descriptors, descriptor.EnumDescriptor), (self._service_descriptors, descriptor.ServiceDescriptor), - (self._toplevel_extensions, descriptor.FieldDescriptor)]: + (self._toplevel_extensions, descriptor.FieldDescriptor), + (self._top_enum_values, descriptor.EnumValueDescriptor)]: if desc_name in register: - file_name = register[desc_name].file.name + old_desc = register[desc_name] + if isinstance(old_desc, descriptor.EnumValueDescriptor): + old_file = old_desc.type.file.name + else: + old_file = old_desc.file.name + if not isinstance(desc, descriptor_type) or ( - file_name != desc.file.name): - warn_msg = ('Conflict register for file "' + desc.file.name + + old_file != file_name): + warn_msg = ('Conflict register for file "' + file_name + '": ' + desc_name + ' is already defined in file "' + - file_name + '"') + old_file + '"') + if isinstance(desc, descriptor.EnumValueDescriptor): + warn_msg += ('\nNote: enum values appear as ' + 'siblings of the enum type instead of ' + 'children of it.') warnings.warn(warn_msg, RuntimeWarning) + return def Add(self, file_desc_proto): @@ -196,7 +209,7 @@ class DescriptorPool(object): if not isinstance(desc, descriptor.Descriptor): raise TypeError('Expected instance of descriptor.Descriptor.') - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._descriptors[desc.full_name] = desc self._AddFileDescriptor(desc.file) @@ -213,8 +226,26 @@ class DescriptorPool(object): if not isinstance(enum_desc, descriptor.EnumDescriptor): raise TypeError('Expected instance of descriptor.EnumDescriptor.') - self._CheckConflictRegister(enum_desc) + file_name = enum_desc.file.name + self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name) self._enum_descriptors[enum_desc.full_name] = enum_desc + + # Top enum values need to be indexed. + # Count the number of dots to see whether the enum is toplevel or nested + # in a message. We cannot use enum_desc.containing_type at this stage. + if enum_desc.file.package: + top_level = (enum_desc.full_name.count('.') + - enum_desc.file.package.count('.') == 1) + else: + top_level = enum_desc.full_name.count('.') == 0 + if top_level: + file_name = enum_desc.file.name + package = enum_desc.file.package + for enum_value in enum_desc.values: + full_name = _NormalizeFullyQualifiedName( + '.'.join((package, enum_value.name))) + self._CheckConflictRegister(enum_value, full_name, file_name) + self._top_enum_values[full_name] = enum_value self._AddFileDescriptor(enum_desc.file) def AddServiceDescriptor(self, service_desc): @@ -227,7 +258,8 @@ class DescriptorPool(object): if not isinstance(service_desc, descriptor.ServiceDescriptor): raise TypeError('Expected instance of descriptor.ServiceDescriptor.') - self._CheckConflictRegister(service_desc) + self._CheckConflictRegister(service_desc, service_desc.full_name, + service_desc.file.name) self._service_descriptors[service_desc.full_name] = service_desc def AddExtensionDescriptor(self, extension): @@ -247,7 +279,6 @@ class DescriptorPool(object): raise TypeError('Expected an extension descriptor.') if extension.extension_scope is None: - self._CheckConflictRegister(extension) self._toplevel_extensions[extension.full_name] = extension try: @@ -348,6 +379,30 @@ class DescriptorPool(object): """ symbol = _NormalizeFullyQualifiedName(symbol) + try: + return self._InternalFindFileContainingSymbol(symbol) + except KeyError: + pass + + try: + # Try fallback database. Build and find again if possible. + self._FindFileContainingSymbolInDb(symbol) + return self._InternalFindFileContainingSymbol(symbol) + except KeyError: + raise KeyError('Cannot find a file containing %s' % symbol) + + def _InternalFindFileContainingSymbol(self, symbol): + """Gets the already built FileDescriptor containing the specified symbol. + + Args: + symbol: The name of the symbol to search for. + + Returns: + A FileDescriptor that contains the specified symbol. + + Raises: + KeyError: if the file cannot be found in the pool. + """ try: return self._descriptors[symbol].file except KeyError: @@ -364,7 +419,7 @@ class DescriptorPool(object): pass try: - return self._FindFileContainingSymbolInDb(symbol) + return self._top_enum_values[symbol].type.file except KeyError: pass @@ -373,13 +428,15 @@ class DescriptorPool(object): except KeyError: pass - # Try nested extensions inside a message. - message_name, _, extension_name = symbol.rpartition('.') + # Try fields, enum values and nested extensions inside a message. + top_name, _, sub_name = symbol.rpartition('.') try: - message = self.FindMessageTypeByName(message_name) - assert message.extensions_by_name[extension_name] + message = self.FindMessageTypeByName(top_name) + assert (sub_name in message.extensions_by_name or + sub_name in message.fields_by_name or + sub_name in message.enum_values_by_name) return message.file - except KeyError: + except (KeyError, AssertionError): raise KeyError('Cannot find a file containing %s' % symbol) def FindMessageTypeByName(self, full_name): @@ -499,7 +556,11 @@ class DescriptorPool(object): KeyError: when no extension with the given number is known for the specified message. """ - return self._extensions_by_number[message_descriptor][number] + try: + return self._extensions_by_number[message_descriptor][number] + except KeyError: + self._TryLoadExtensionFromDB(message_descriptor, number) + return self._extensions_by_number[message_descriptor][number] def FindAllExtensions(self, message_descriptor): """Gets all the known extension of a given message. @@ -513,8 +574,57 @@ class DescriptorPool(object): Returns: A list of FieldDescriptor describing the extensions. """ + # Fallback to descriptor db if FindAllExtensionNumbers is provided. + if self._descriptor_db and hasattr( + self._descriptor_db, 'FindAllExtensionNumbers'): + full_name = message_descriptor.full_name + all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name) + for number in all_numbers: + if number in self._extensions_by_number[message_descriptor]: + continue + self._TryLoadExtensionFromDB(message_descriptor, number) + return list(self._extensions_by_number[message_descriptor].values()) + def _TryLoadExtensionFromDB(self, message_descriptor, number): + """Try to Load extensions from decriptor db. + + Args: + message_descriptor: descriptor of the extended message. + number: the extension number that needs to be loaded. + """ + if not self._descriptor_db: + return + # Only supported when FindFileContainingExtension is provided. + if not hasattr( + self._descriptor_db, 'FindFileContainingExtension'): + return + + full_name = message_descriptor.full_name + file_proto = self._descriptor_db.FindFileContainingExtension( + full_name, number) + + if file_proto is None: + return + + try: + file_desc = self._ConvertFileProtoToFileDescriptor(file_proto) + for extension in file_desc.extensions_by_name.values(): + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + for message_type in file_desc.message_types_by_name.values(): + for extension in message_type.extensions: + self._extensions_by_number[extension.containing_type][ + extension.number] = extension + self._extensions_by_name[extension.containing_type][ + extension.full_name] = extension + except: + warn_msg = ('Unable to load proto file %s for extension number %d.' % + (file_proto.name, number)) + warnings.warn(warn_msg, RuntimeWarning) + def FindServiceByName(self, full_name): """Loads the named service descriptor from the pool. @@ -532,6 +642,23 @@ class DescriptorPool(object): self._FindFileContainingSymbolInDb(full_name) return self._service_descriptors[full_name] + def FindMethodByName(self, full_name): + """Loads the named service method descriptor from the pool. + + Args: + full_name: The full name of the method descriptor to load. + + Returns: + The method descriptor for the service method. + + Raises: + KeyError: if the method cannot be found in the pool. + """ + full_name = _NormalizeFullyQualifiedName(full_name) + service_name, _, method_name = full_name.rpartition('.') + service_descriptor = self.FindServiceByName(service_name) + return service_descriptor.methods_by_name[method_name] + def _FindFileContainingSymbolInDb(self, symbol): """Finds the file in descriptor DB containing the specified symbol. @@ -567,7 +694,6 @@ class DescriptorPool(object): Returns: A FileDescriptor matching the passed in proto. """ - if file_proto.name not in self._file_descriptors: built_deps = list(self._GetDeps(file_proto.dependency)) direct_deps = [self.FindFileByName(n) for n in file_proto.dependency] @@ -604,7 +730,7 @@ class DescriptorPool(object): for enum_type in file_proto.enum_type: file_descriptor.enum_types_by_name[enum_type.name] = ( self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope)) + file_descriptor, None, scope, True)) for index, extension_proto in enumerate(file_proto.extension): extension_desc = self._MakeFieldDescriptor( @@ -616,6 +742,8 @@ class DescriptorPool(object): file_descriptor.package, scope) file_descriptor.extensions_by_name[extension_desc.name] = ( extension_desc) + self._file_desc_by_toplevel_extension[extension_desc.full_name] = ( + file_descriptor) for desc_proto in file_proto.message_type: self._SetAllFieldTypes(file_proto.package, desc_proto, scope) @@ -673,7 +801,8 @@ class DescriptorPool(object): nested, desc_name, file_desc, scope, syntax) for nested in desc_proto.nested_type] enums = [ - self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) + self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, + scope, False) for enum in desc_proto.enum_type] fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc) for index, field in enumerate(desc_proto.field)] @@ -718,12 +847,12 @@ class DescriptorPool(object): fields[field_index].containing_oneof = oneofs[oneof_index] scope[_PrefixWithDot(desc_name)] = desc - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._descriptors[desc_name] = desc return desc def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, - containing_type=None, scope=None): + containing_type=None, scope=None, top_level=False): """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. Args: @@ -732,6 +861,8 @@ class DescriptorPool(object): file_desc: The file containing the enum descriptor. containing_type: The type containing this enum. scope: Scope containing available types. + top_level: If True, the enum is a top level symbol. If False, the enum + is defined inside a message. Returns: The added descriptor @@ -757,8 +888,17 @@ class DescriptorPool(object): containing_type=containing_type, options=_OptionsOrNone(enum_proto)) scope['.%s' % enum_name] = desc - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._enum_descriptors[enum_name] = desc + + # Add top level enum values. + if top_level: + for value in values: + full_name = _NormalizeFullyQualifiedName( + '.'.join((package, value.name))) + self._CheckConflictRegister(value, full_name, file_name) + self._top_enum_values[full_name] = value + return desc def _MakeFieldDescriptor(self, field_proto, message_name, index, @@ -885,6 +1025,8 @@ class DescriptorPool(object): elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: field_desc.default_value = text_encoding.CUnescape( field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + field_desc.default_value = None else: # All other types are of the "int" type. field_desc.default_value = int(field_proto.default_value) @@ -901,6 +1043,8 @@ class DescriptorPool(object): field_desc.default_value = field_desc.enum_type.values[0].number elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES: field_desc.default_value = b'' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE: + field_desc.default_value = None else: # All other types are of the "int" type. field_desc.default_value = 0 @@ -954,7 +1098,7 @@ class DescriptorPool(object): methods=methods, options=_OptionsOrNone(service_proto), file=file_desc) - self._CheckConflictRegister(desc) + self._CheckConflictRegister(desc, desc.full_name, desc.file.name) self._service_descriptors[service_name] = desc return desc diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py index e69de29b..7d2e571a 100755 --- a/python/google/protobuf/internal/__init__.py +++ b/python/google/protobuf/internal/__init__.py @@ -0,0 +1,30 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index ab9e7812..23cc2c0a 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -145,29 +145,3 @@ def Version(): # For internal use only def IsPythonDefaultSerializationDeterministic(): return _python_deterministic_proto_serialization - -# DO NOT USE: For migration and testing only. Will be removed when Proto3 -# defaults to preserve unknowns. -if _implementation_type == 'cpp': - try: - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - - def GetPythonProto3PreserveUnknownsDefault(): - return _message.GetPythonProto3PreserveUnknownsDefault() - - def SetPythonProto3PreserveUnknownsDefault(preserve): - _message.SetPythonProto3PreserveUnknownsDefault(preserve) - except ImportError: - # Unrecognized cpp implementation. Skipping the unknown fields APIs. - pass -else: - _python_proto3_preserve_unknowns_default = True - - def GetPythonProto3PreserveUnknownsDefault(): - return _python_proto3_preserve_unknowns_default - - def SetPythonProto3PreserveUnknownsDefault(preserve): - global _python_proto3_preserve_unknowns_default - _python_proto3_preserve_unknowns_default = preserve - diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index c6a3692a..182cac99 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -628,3 +628,130 @@ class MessageMap(MutableMapping): def GetEntryClass(self): return self._entry_descriptor._concrete_class + + +class _UnknownField(object): + + """A parsed unknown field.""" + + # Disallows assignment to other attributes. + __slots__ = ['_field_number', '_wire_type', '_data'] + + def __init__(self, field_number, wire_type, data): + self._field_number = field_number + self._wire_type = wire_type + self._data = data + return + + def __lt__(self, other): + # pylint: disable=protected-access + return self._field_number < other._field_number + + def __eq__(self, other): + if self is other: + return True + # pylint: disable=protected-access + return (self._field_number == other._field_number and + self._wire_type == other._wire_type and + self._data == other._data) + + +class UnknownFieldRef(object): + + def __init__(self, parent, index): + self._parent = parent + self._index = index + return + + def _check_valid(self): + if not self._parent: + raise ValueError('UnknownField does not exist. ' + 'The parent message might be cleared.') + if self._index >= len(self._parent): + raise ValueError('UnknownField does not exist. ' + 'The parent message might be cleared.') + + @property + def field_number(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._field_number + + @property + def wire_type(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._wire_type + + @property + def data(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._data + + +class UnknownFieldSet(object): + + """UnknownField container""" + + # Disallows assignment to other attributes. + __slots__ = ['_values'] + + def __init__(self): + self._values = [] + + def __getitem__(self, index): + if self._values is None: + raise ValueError('UnknownFields does not exist. ' + 'The parent message might be cleared.') + size = len(self._values) + if index < 0: + index += size + if index < 0 or index >= size: + raise IndexError('index %d out of range'.index) + + return UnknownFieldRef(self, index) + + def _internal_get(self, index): + return self._values[index] + + def __len__(self): + if self._values is None: + raise ValueError('UnknownFields does not exist. ' + 'The parent message might be cleared.') + return len(self._values) + + def _add(self, field_number, wire_type, data): + unknown_field = _UnknownField(field_number, wire_type, data) + self._values.append(unknown_field) + return unknown_field + + def __iter__(self): + for i in range(len(self)): + yield UnknownFieldRef(self, i) + + def _extend(self, other): + if other is None: + return + # pylint: disable=protected-access + self._values.extend(other._values) + + def __eq__(self, other): + if self is other: + return True + # Sort unknown fields because their order shouldn't + # affect equality test. + values = list(self._values) + if other is None: + return not values + values.sort() + # pylint: disable=protected-access + other_values = sorted(other._values) + return values == other_values + + def _clear(self): + for value in self._values: + # pylint: disable=protected-access + if isinstance(value._data, UnknownFieldSet): + value._data._clear() # pylint: disable=protected-access + self._values = None diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 52b64915..938f6293 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -86,7 +86,11 @@ import six if six.PY3: long = int +else: + import re # pylint: disable=g-import-not-at-top + _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) +from google.protobuf.internal import containers from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -167,7 +171,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int) def ReadTag(buffer, pos): - """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. + """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. We return the raw bytes of the tag rather than decoding them. The raw bytes can then be used to look up the proper decoder. This effectively allows @@ -175,13 +179,21 @@ def ReadTag(buffer, pos): for work that is done in C (searching for a byte string in a hash table). In a low-level language it would be much cheaper to decode the varint and use that, but not in Python. - """ + Args: + buffer: memoryview object of the encoded bytes + pos: int of the current position to start from + + Returns: + Tuple[bytes, int] of the tag data and new position. + """ start = pos while six.indexbytes(buffer, pos) & 0x80: pos += 1 pos += 1 - return (six.binary_type(buffer[start:pos]), pos) + + tag_bytes = buffer[start:pos].tobytes() + return tag_bytes, pos # -------------------------------------------------------------------- @@ -295,10 +307,20 @@ def _FloatDecoder(): local_unpack = struct.unpack def InnerDecode(buffer, pos): + """Decode serialized float to a float and new position. + + Args: + buffer: memoryview of the serialized bytes + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the deserialized float value and new position + in the serialized data. + """ # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. new_pos = pos + 4 - float_bytes = buffer[pos:new_pos] + float_bytes = buffer[pos:new_pos].tobytes() # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. @@ -329,10 +351,20 @@ def _DoubleDecoder(): local_unpack = struct.unpack def InnerDecode(buffer, pos): + """Decode serialized double to a double and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the decoded double value and new position + in the serialized data. + """ # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. new_pos = pos + 8 - double_bytes = buffer[pos:new_pos] + double_bytes = buffer[pos:new_pos].tobytes() # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it @@ -355,6 +387,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): + """Decode serialized packed enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -365,6 +409,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: @@ -372,8 +417,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + # pylint: enable=protected-access if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. @@ -386,18 +433,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( - (tag_bytes, buffer[pos:new_pos])) + (tag_bytes, buffer[pos:new_pos].tobytes())) + # pylint: enable=protected-access # Predict that the next tag is another copy of the same repeated # field. pos = new_pos + tag_len @@ -409,10 +470,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') + # pylint: disable=protected-access if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: @@ -421,7 +495,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + # pylint: enable=protected-access return pos return DecodeField @@ -458,20 +533,33 @@ BoolDecoder = _ModifiedDecoder( wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) -def StringDecoder(field_number, is_repeated, is_packed, key, new_default): +def StringDecoder(field_number, is_repeated, is_packed, key, new_default, + is_strict_utf8=False): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint local_unicode = six.text_type - def _ConvertToUnicode(byte_str): + def _ConvertToUnicode(memview): + """Convert byte to unicode.""" + byte_str = memview.tobytes() try: - return local_unicode(byte_str, 'utf-8') + value = local_unicode(byte_str, 'utf-8') except UnicodeDecodeError as e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise + if is_strict_utf8 and six.PY2: + if _SURROGATE_PATTERN.search(value): + reason = ('String field %s contains invalid UTF-8 data when parsing' + 'a protocol buffer: surrogates not allowed. Use' + 'the bytes type if you intend to send raw bytes.') % ( + key.full_name) + raise message.DecodeError(reason) + + return value + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -523,7 +611,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(buffer[pos:new_pos]) + value.append(buffer[pos:new_pos].tobytes()) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -536,7 +624,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = buffer[pos:new_pos] + field_dict[key] = buffer[pos:new_pos].tobytes() return new_pos return DecodeField @@ -665,6 +753,18 @@ def MessageSetItemDecoder(descriptor): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + """Decode serialized message set to its value and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ message_set_item_start = pos type_id = -1 message_start = -1 @@ -695,6 +795,7 @@ def MessageSetItemDecoder(descriptor): raise _DecodeError('MessageSet item missing message.') extension = message.Extensions._FindExtensionByNumber(type_id) + # pylint: disable=protected-access if extension is not None: value = field_dict.get(extension) if value is None: @@ -707,8 +808,9 @@ def MessageSetItemDecoder(descriptor): else: if not message._unknown_fields: message._unknown_fields = [] - message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, - buffer[message_set_item_start:pos])) + message._unknown_fields.append( + (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) + # pylint: enable=protected-access return pos @@ -767,7 +869,7 @@ def _SkipVarint(buffer, pos, end): # Previously ord(buffer[pos]) raised IndexError when pos is out of range. # With this code, ord(b'') raises TypeError. Both are handled in # python_message.py to generate a 'Truncated message' error. - while ord(buffer[pos:pos+1]) & 0x80: + while ord(buffer[pos:pos+1].tobytes()) & 0x80: pos += 1 pos += 1 if pos > end: @@ -782,6 +884,13 @@ def _SkipFixed64(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + +def _DecodeFixed64(buffer, pos): + """Decode a fixed64.""" + new_pos = pos + 8 + return (struct.unpack('".', + '5:1 : \'}\': Expected ">".', text_format.Parse, malformed, message, @@ -981,7 +1167,8 @@ class Proto2Tests(TextFormatBase): with self.assertRaises(text_format.ParseError) as e: text_format.Parse(text, message) self.assertEqual(str(e.exception), - '1:27 : Expected identifier or number, got "bb".') + '1:27 : \'optional_nested_message { "bb": 1 }\': ' + 'Expected identifier or number, got "bb".') def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() @@ -998,7 +1185,8 @@ class Proto2Tests(TextFormatBase): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' six.assertRaisesRegex(self, text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + (r'1:23 : \'optional_nested_enum: 100\': ' + r'Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value with number 100.'), text_format.Parse, text, message) @@ -1448,6 +1636,26 @@ class TokenizerTest(unittest.TestCase): self.assertEqual(0, text_format._ConsumeUint64(tokenizer)) self.assertTrue(tokenizer.AtEnd()) + def testConsumeOctalIntegers(self): + """Test support for C style octal integers.""" + text = '00 -00 04 0755 -010 007 -0033 08 -09 01' + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(4, tokenizer.ConsumeInteger()) + self.assertEqual(0o755, tokenizer.ConsumeInteger()) + self.assertEqual(-0o10, tokenizer.ConsumeInteger()) + self.assertEqual(7, tokenizer.ConsumeInteger()) + self.assertEqual(-0o033, tokenizer.ConsumeInteger()) + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() # 08 + tokenizer.NextToken() + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() # -09 + tokenizer.NextToken() + self.assertEqual(1, tokenizer.ConsumeInteger()) + self.assertTrue(tokenizer.AtEnd()) + def testConsumeByteString(self): text = '"string1\'' tokenizer = text_format.Tokenizer(text.splitlines()) @@ -1556,6 +1764,12 @@ class TokenizerTest(unittest.TestCase): tokenizer.ConsumeCommentOrTrailingComment()) self.assertTrue(tokenizer.AtEnd()) + def testHugeString(self): + # With pathologic backtracking, fails with Forge OOM. + text = '"' + 'a' * (10 * 1024 * 1024) + '"' + tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False) + tokenizer.ConsumeString() + # Tests for pretty printer functionality. @_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2)) diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 4a76cd4e..0807e7f7 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -185,6 +185,14 @@ class UnicodeValueChecker(object): 'encoding. Non-UTF-8 strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) + else: + try: + proposed_value.encode('utf8') + except UnicodeEncodeError: + raise ValueError('%.1024r isn\'t a valid unicode string and ' + 'can\'t be encoded in UTF-8.'% + (proposed_value)) + return proposed_value def DefaultValue(self): diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 8b7de2e7..fceadf71 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks from google.protobuf.internal import type_checkers +from google.protobuf import descriptor BaseTestCase = testing_refleaks.BaseTestCase -# CheckUnknownField() cannot be used by the C++ implementation because -# some protect members are called. It is not a behavior difference -# for python and C++ implementation. -def SkipCheckUnknownFieldIfCppImplementation(func): - return unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'Addtional test for pure python involved protect members')(func) - - class UnknownFieldsTest(BaseTestCase): def setUp(self): @@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase): # stdout. self.assertTrue(data == self.all_fields_data) - def expectSerializeProto3(self, preserve): + def testSerializeProto3(self): + # Verify proto3 unknown fields behavior. message = unittest_proto3_arena_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) - if preserve: - self.assertEqual(self.all_fields_data, message.SerializeToString()) - else: - self.assertEqual(0, len(message.SerializeToString())) - - def testSerializeProto3(self): - # Verify that proto3 unknown fields behavior. - default_preserve = (api_implementation - .GetPythonProto3PreserveUnknownsDefault()) - self.expectSerializeProto3(default_preserve) - api_implementation.SetPythonProto3PreserveUnknownsDefault( - not default_preserve) - self.expectSerializeProto3(not default_preserve) - api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve) + self.assertEqual(self.all_fields_data, message.SerializeToString()) def testByteSize(self): self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) @@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - # CheckUnknownField() is an additional Pure Python check which checks + # InternalCheckUnknownField() is an additional Pure Python check which checks # a detail of unknown fields. It cannot be used by the C++ # implementation because some protect members are called. # The test is added for historical reasons. It is not necessary as # serialized string is checked. - - def CheckUnknownField(self, name, expected_value): + # TODO(jieluo): Remove message._unknown_fields. + def InternalCheckUnknownField(self, name, expected_value): + if api_implementation.Type() == 'cpp': + return field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) @@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase): for tag_bytes, value in self.empty_message._unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] - decoder(value, 0, len(value), self.all_fields, result_dict) + decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) self.assertEqual(expected_value, result_dict[field_descriptor]) - @SkipCheckUnknownFieldIfCppImplementation + def CheckUnknownField(self, name, unknown_fields, expected_value): + field_descriptor = self.descriptor.fields_by_name[name] + expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ + field_descriptor.type] + for unknown_field in unknown_fields: + if unknown_field.field_number == field_descriptor.number: + self.assertEqual(expected_type, unknown_field.wire_type) + if expected_type == 3: + # Check group + self.assertEqual(expected_value[0], + unknown_field.data[0].field_number) + self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) + self.assertEqual(expected_value[2], unknown_field.data[0].data) + continue + if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: + self.assertIn(unknown_field.data, expected_value) + else: + self.assertEqual(expected_value, unknown_field.data) + def testCheckUnknownFieldValue(self): + unknown_fields = self.empty_message.UnknownFields() # Test enum. self.CheckUnknownField('optional_nested_enum', + unknown_fields, self.all_fields.optional_nested_enum) + self.InternalCheckUnknownField('optional_nested_enum', + self.all_fields.optional_nested_enum) + # Test repeated enum. self.CheckUnknownField('repeated_nested_enum', + unknown_fields, self.all_fields.repeated_nested_enum) + self.InternalCheckUnknownField('repeated_nested_enum', + self.all_fields.repeated_nested_enum) # Test varint. self.CheckUnknownField('optional_int32', + unknown_fields, self.all_fields.optional_int32) + self.InternalCheckUnknownField('optional_int32', + self.all_fields.optional_int32) + # Test fixed32. self.CheckUnknownField('optional_fixed32', + unknown_fields, self.all_fields.optional_fixed32) + self.InternalCheckUnknownField('optional_fixed32', + self.all_fields.optional_fixed32) # Test fixed64. self.CheckUnknownField('optional_fixed64', + unknown_fields, self.all_fields.optional_fixed64) + self.InternalCheckUnknownField('optional_fixed64', + self.all_fields.optional_fixed64) # Test lengthd elimited. self.CheckUnknownField('optional_string', - self.all_fields.optional_string) + unknown_fields, + self.all_fields.optional_string.encode('utf-8')) + self.InternalCheckUnknownField('optional_string', + self.all_fields.optional_string) # Test group. self.CheckUnknownField('optionalgroup', - self.all_fields.optionalgroup) + unknown_fields, + (17, 0, 117)) + self.InternalCheckUnknownField('optionalgroup', + self.all_fields.optionalgroup) + + self.assertEqual(97, len(unknown_fields)) def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() @@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase): message.optional_int64 = 3 message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() + unknown_fields = destination.UnknownFields() + self.assertEqual(0, len(unknown_fields)) destination.ParseFromString(message.SerializeToString()) - + # ParseFromString clears the message thus unknown fields is invalid. + with self.assertRaises(ValueError) as context: + len(unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + unknown_fields = destination.UnknownFields() + self.assertEqual(2, len(unknown_fields)) destination.MergeFrom(source) + self.assertEqual(4, len(unknown_fields)) # Check that the fields where correctly merged, even stored in the unknown # fields set. message.ParseFromString(destination.SerializeToString()) @@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.assertEqual(message.optional_int64, 3) def testClear(self): + unknown_fields = self.empty_message.UnknownFields() self.empty_message.Clear() # All cleared, even unknown fields. self.assertEqual(self.empty_message.SerializeToString(), b'') + with self.assertRaises(ValueError) as context: + len(unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + + def testSubUnknownFields(self): + message = unittest_pb2.TestAllTypes() + message.optionalgroup.a = 123 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + sub_unknown_fields = destination.UnknownFields()[0].data + self.assertEqual(1, len(sub_unknown_fields)) + self.assertEqual(sub_unknown_fields[0].data, 123) + destination.Clear() + with self.assertRaises(ValueError) as context: + len(sub_unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + with self.assertRaises(ValueError) as context: + # pylint: disable=pointless-statement + sub_unknown_fields[0] + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + message.Clear() + message.optional_uint32 = 456 + nested_message = unittest_pb2.NestedTestAllTypes() + nested_message.payload.optional_nested_message.ParseFromString( + message.SerializeToString()) + unknown_fields = ( + nested_message.payload.optional_nested_message.UnknownFields()) + self.assertEqual(unknown_fields[0].data, 456) + nested_message.ClearField('payload') + self.assertEqual(unknown_fields[0].data, 456) + unknown_fields = ( + nested_message.payload.optional_nested_message.UnknownFields()) + self.assertEqual(0, len(unknown_fields)) + + def testUnknownField(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 123 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_field = destination.UnknownFields()[0] + destination.Clear() + with self.assertRaises(ValueError) as context: + unknown_field.data # pylint: disable=pointless-statement + self.assertIn('The parent message might be cleared.', + str(context.exception)) def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() @@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase): def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] - wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] - field_tag = encoder.TagBytes(field_descriptor.number, wire_type) - result_dict = {} - for tag_bytes, value in self.missing_message._unknown_fields: - if tag_bytes == field_tag: - decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ - tag_bytes][0] - decoder(value, 0, len(value), self.message, result_dict) - self.assertEqual(expected_value, result_dict[field_descriptor]) + unknown_fields = self.missing_message.UnknownFields() + for field in unknown_fields: + if field.field_number == field_descriptor.number: + if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: + self.assertIn(field.data, expected_value) + else: + self.assertEqual(expected_value, field.data) def testUnknownParseMismatchEnumValue(self): just_string = missing_enum_values_pb2.JustString() @@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase): def testUnknownPackedEnumValue(self): self.assertEqual([], self.missing_message.packed_nested_enum) - @SkipCheckUnknownFieldIfCppImplementation def testCheckUnknownFieldValueForEnum(self): self.CheckUnknownField('optional_nested_enum', self.message.optional_nested_enum) diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py index 58c94a47..ce1db7d7 100644 --- a/python/google/protobuf/json_format.py +++ b/python/google/protobuf/json_format.py @@ -482,7 +482,7 @@ class _Parser(object): ('Message type "{0}" has no field named "{1}".\n' ' Available Fields(except extensions): {2}').format( message_descriptor.full_name, name, - message_descriptor.fields)) + [f.json_name for f in message_descriptor.fields])) if name in names: raise ParseError('Message type "{0}" should not have multiple ' '"{1}" fields.'.format( diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index eeb0d576..eca2e0a9 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -268,6 +268,10 @@ class Message(object): def ClearExtension(self, extension_handle): raise NotImplementedError + def UnknownFields(self): + """Returns the UnknownFieldSet.""" + raise NotImplementedError + def DiscardUnknownFields(self): raise NotImplementedError diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py index e4fb065e..f3ab0a55 100644 --- a/python/google/protobuf/message_factory.py +++ b/python/google/protobuf/message_factory.py @@ -39,9 +39,18 @@ my_proto_instance = message_classes['some.proto.package.MessageName']() __author__ = 'matthewtoia@google.com (Matt Toia)' +from google.protobuf.internal import api_implementation from google.protobuf import descriptor_pool from google.protobuf import message -from google.protobuf import reflection + +if api_implementation.Type() == 'cpp': + from google.protobuf.pyext import cpp_message as message_impl +else: + from google.protobuf.internal import python_message as message_impl + + +# The type of all Message classes. +_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType class MessageFactory(object): @@ -70,11 +79,11 @@ class MessageFactory(object): descriptor_name = descriptor.name if str is bytes: # PY2 descriptor_name = descriptor.name.encode('ascii', 'ignore') - result_class = reflection.GeneratedProtocolMessageType( + result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE( descriptor_name, (message.Message,), {'DESCRIPTOR': descriptor, '__module__': None}) - # If module not set, it wrongly points to the reflection.py module. + # If module not set, it wrongly points to message_factory module. self._classes[descriptor] = result_class for field in descriptor.fields: if field.message_type: diff --git a/python/google/protobuf/proto_api.h b/python/google/protobuf/proto_api.h index 5c076d23..47edf0ea 100644 --- a/python/google/protobuf/proto_api.h +++ b/python/google/protobuf/proto_api.h @@ -42,16 +42,15 @@ // Then use the methods of the returned class: // py_proto_api->GetMessagePointer(...); -#ifndef PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__ -#define PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__ +#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ +#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ #include +#include + namespace google { namespace protobuf { - -class Message; - namespace python { // Note on the implementation: @@ -89,4 +88,4 @@ inline const char* PyProtoAPICapsuleName() { } // namespace protobuf } // namespace google -#endif // PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__ +#endif // GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__ diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 19a1c38a..3cb16b74 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -32,8 +32,8 @@ #include #include -#include #include +#include #include #include @@ -44,6 +44,7 @@ #include #include #include +#include #if PY_MAJOR_VERSION >= 3 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize @@ -54,10 +55,12 @@ #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." #endif - #define PyString_AsStringAndSize(ob, charpp, sizep) \ - (PyUnicode_Check(ob)? \ - ((*(charpp) = const_cast(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \ - PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif namespace google { @@ -70,7 +73,7 @@ namespace python { // released. // This is enough to support the "is" operator on live objects. // All descriptors are stored here. -hash_map interned_descriptors; +std::unordered_map* interned_descriptors; PyObject* PyString_FromCppString(const string& str) { return PyString_FromStringAndSize(str.c_str(), str.size()); @@ -119,8 +122,10 @@ bool _CalledFromGeneratedFile(int stacklevel) { PyErr_Clear(); return false; } - if ((filename_size < 3) || (strcmp(&filename[filename_size - 3], ".py") != 0)) { - // Cython's stack does not have .py file name and is not at global module scope. + if ((filename_size < 3) || + (strcmp(&filename[filename_size - 3], ".py") != 0)) { + // Cython's stack does not have .py file name and is not at global module + // scope. return true; } if (filename_size < 7) { @@ -131,7 +136,7 @@ bool _CalledFromGeneratedFile(int stacklevel) { // Filename is not ending with _pb2. return false; } - + if (frame->f_globals != frame->f_locals) { // Not at global module scope return false; @@ -197,7 +202,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { // First search in the cache. PyDescriptorPool* caching_pool = GetDescriptorPool_FromPool( GetFileDescriptor(descriptor)->pool()); - hash_map* descriptor_options = + std::unordered_map* descriptor_options = caching_pool->descriptor_options; if (descriptor_options->find(descriptor) != descriptor_options->end()) { PyObject *value = (*descriptor_options)[descriptor]; @@ -232,7 +237,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { if (value == NULL) { return NULL; } - if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) { + if (!PyObject_TypeCheck(value.get(), CMessage_Type)) { PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s", message_type->full_name().c_str(), Py_TYPE(value.get())->tp_name); @@ -275,7 +280,7 @@ static PyObject* CopyToPythonProto(const DescriptorClass *descriptor, const Descriptor* self_descriptor = DescriptorProtoClass::default_instance().GetDescriptor(); CMessage* message = reinterpret_cast(target); - if (!PyObject_TypeCheck(target, &CMessage_Type) || + if (!PyObject_TypeCheck(target, CMessage_Type) || message->message->GetDescriptor() != self_descriptor) { PyErr_Format(PyExc_TypeError, "Not a %s message", self_descriptor->full_name().c_str()); @@ -332,9 +337,9 @@ PyObject* NewInternedDescriptor(PyTypeObject* type, } // See if the object is in the map of interned descriptors - hash_map::iterator it = - interned_descriptors.find(descriptor); - if (it != interned_descriptors.end()) { + std::unordered_map::iterator it = + interned_descriptors->find(descriptor); + if (it != interned_descriptors->end()) { GOOGLE_DCHECK(Py_TYPE(it->second) == type); Py_INCREF(it->second); return it->second; @@ -348,7 +353,7 @@ PyObject* NewInternedDescriptor(PyTypeObject* type, py_descriptor->descriptor = descriptor; // and cache it. - interned_descriptors.insert( + interned_descriptors->insert( std::make_pair(descriptor, reinterpret_cast(py_descriptor))); // Ensures that the DescriptorPool stays alive. @@ -370,7 +375,7 @@ PyObject* NewInternedDescriptor(PyTypeObject* type, static void Dealloc(PyBaseDescriptor* self) { // Remove from interned dictionary - interned_descriptors.erase(self->descriptor); + interned_descriptors->erase(self->descriptor); Py_CLEAR(self->pool); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -758,6 +763,11 @@ static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) { static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) { PyObject *result; + if (_GetDescriptor(self)->is_repeated()) { + return PyList_New(0); + } + + switch (_GetDescriptor(self)->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: { int32 value = _GetDescriptor(self)->default_value_int32(); @@ -805,6 +815,10 @@ static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) { result = PyInt_FromLong(value->number()); break; } + case FieldDescriptor::CPPTYPE_MESSAGE: { + Py_RETURN_NONE; + break; + } default: PyErr_Format(PyExc_NotImplementedError, "default value for %s", _GetDescriptor(self)->full_name().c_str()); @@ -1919,6 +1933,9 @@ bool InitDescriptor() { if (!InitDescriptorMappingTypes()) return false; + // Initialize globals defined in this file. + interned_descriptors = new std::unordered_map; + return true; } diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h index f081df84..c4dde9e7 100644 --- a/python/google/protobuf/pyext/descriptor.h +++ b/python/google/protobuf/pyext/descriptor.h @@ -100,6 +100,6 @@ bool InitDescriptor(); } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc index 0153664f..d5b5dc68 100644 --- a/python/google/protobuf/pyext/descriptor_containers.cc +++ b/python/google/protobuf/pyext/descriptor_containers.cc @@ -33,7 +33,7 @@ // // They avoid the allocation of a full dictionary or a full list: they simply // store a pointer to the parent descriptor, use the C++ Descriptor methods (see -// google/protobuf/descriptor.h) to retrieve other descriptors, and create +// net/proto2/public/descriptor.h) to retrieve other descriptors, and create // Python objects on the fly. // // The containers fully conform to abc.Mapping and abc.Sequence, and behave just @@ -64,10 +64,12 @@ #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." #endif - #define PyString_AsStringAndSize(ob, charpp, sizep) \ - (PyUnicode_Check(ob)? \ - ((*(charpp) = const_cast(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \ - PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif namespace google { diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h index 83de07b6..4e05c58e 100644 --- a/python/google/protobuf/pyext/descriptor_containers.h +++ b/python/google/protobuf/pyext/descriptor_containers.h @@ -104,6 +104,6 @@ PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor); } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__ diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc index daa40cc7..0514b35c 100644 --- a/python/google/protobuf/pyext/descriptor_database.cc +++ b/python/google/protobuf/pyext/descriptor_database.cc @@ -70,7 +70,7 @@ static bool GetFileDescriptorProto(PyObject* py_descriptor, const Descriptor* filedescriptor_descriptor = FileDescriptorProto::default_instance().GetDescriptor(); CMessage* message = reinterpret_cast(py_descriptor); - if (PyObject_TypeCheck(py_descriptor, &CMessage_Type) && + if (PyObject_TypeCheck(py_descriptor, CMessage_Type) && message->message->GetDescriptor() == filedescriptor_descriptor) { // Fast path: Just use the pointer. FileDescriptorProto* file_proto = @@ -143,6 +143,43 @@ bool PyDescriptorDatabase::FindFileContainingExtension( return GetFileDescriptorProto(py_descriptor.get(), output); } +// Finds the tag numbers used by all known extensions of +// containing_type, and appends them to output in an undefined +// order. +// Python DescriptorDatabases are not required to implement this method. +bool PyDescriptorDatabase::FindAllExtensionNumbers( + const string& containing_type, std::vector* output) { + ScopedPyObjectPtr py_method( + PyObject_GetAttrString(py_database_, "FindAllExtensionNumbers")); + if (py_method == NULL) { + // This method is not implemented, returns without error. + PyErr_Clear(); + return false; + } + ScopedPyObjectPtr py_list( + PyObject_CallFunction(py_method.get(), "s#", containing_type.c_str(), + containing_type.size())); + if (py_list == NULL) { + PyErr_Print(); + return false; + } + Py_ssize_t size = PyList_Size(py_list.get()); + int64 item_value; + for (Py_ssize_t i = 0 ; i < size; ++i) { + ScopedPyObjectPtr item(PySequence_GetItem(py_list.get(), i)); + item_value = PyLong_AsLong(item.get()); + if (item_value < 0) { + GOOGLE_LOG(ERROR) + << "FindAllExtensionNumbers method did not return " + << "valid extension numbers."; + PyErr_Print(); + return false; + } + output->push_back(item_value); + } + return true; +} + } // namespace python } // namespace protobuf } // namespace google diff --git a/python/google/protobuf/pyext/descriptor_database.h b/python/google/protobuf/pyext/descriptor_database.h index fc71c4bc..daf25e0b 100644 --- a/python/google/protobuf/pyext/descriptor_database.h +++ b/python/google/protobuf/pyext/descriptor_database.h @@ -63,6 +63,13 @@ class PyDescriptorDatabase : public DescriptorDatabase { int field_number, FileDescriptorProto* output); + // Finds the tag numbers used by all known extensions of + // containing_type, and appends them to output in an undefined + // order. + // Python objects are not required to implement this method. + bool FindAllExtensionNumbers(const string& containing_type, + std::vector* output); + private: // The python object that implements the database. The reference is owned. PyObject* py_database_; @@ -70,6 +77,6 @@ class PyDescriptorDatabase : public DescriptorDatabase { } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__ diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc index 962accc6..d0038b10 100644 --- a/python/google/protobuf/pyext/descriptor_pool.cc +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -30,6 +30,8 @@ // Implements the DescriptorPool, which collects all descriptors. +#include + #include #include @@ -46,10 +48,12 @@ #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." #endif - #define PyString_AsStringAndSize(ob, charpp, sizep) \ - (PyUnicode_Check(ob)? \ - ((*(charpp) = const_cast(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \ - PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif namespace google { @@ -58,7 +62,8 @@ namespace python { // A map to cache Python Pools per C++ pointer. // Pointers are not owned here, and belong to the PyDescriptorPool. -static hash_map descriptor_pool_map; +static std::unordered_map* + descriptor_pool_map; namespace cdescriptor_pool { @@ -74,8 +79,7 @@ static PyDescriptorPool* _CreateDescriptorPool() { cpool->underlay = NULL; cpool->database = NULL; - cpool->descriptor_options = - new hash_map(); + cpool->descriptor_options = new std::unordered_map(); cpool->py_message_factory = message_factory::NewMessageFactory( &PyMessageFactory_Type, cpool); @@ -101,7 +105,7 @@ static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay( cpool->pool = new DescriptorPool(underlay); cpool->underlay = underlay; - if (!descriptor_pool_map.insert( + if (!descriptor_pool_map->insert( std::make_pair(cpool->pool, cpool)).second) { // Should never happen -- would indicate an internal error / bug. PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); @@ -124,7 +128,7 @@ static PyDescriptorPool* PyDescriptorPool_NewWithDatabase( cpool->pool = new DescriptorPool(); } - if (!descriptor_pool_map.insert(std::make_pair(cpool->pool, cpool)).second) { + if (!descriptor_pool_map->insert(std::make_pair(cpool->pool, cpool)).second) { // Should never happen -- would indicate an internal error / bug. PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered"); return NULL; @@ -151,9 +155,9 @@ static PyObject* New(PyTypeObject* type, static void Dealloc(PyObject* pself) { PyDescriptorPool* self = reinterpret_cast(pself); - descriptor_pool_map.erase(self->pool); + descriptor_pool_map->erase(self->pool); Py_CLEAR(self->py_message_factory); - for (hash_map::iterator it = + for (std::unordered_map::iterator it = self->descriptor_options->begin(); it != self->descriptor_options->end(); ++it) { Py_DECREF(it->second); @@ -180,6 +184,7 @@ static PyObject* FindMessageByName(PyObject* self, PyObject* arg) { return NULL; } + return PyMessageDescriptor_FromDescriptor(message_descriptor); } @@ -218,6 +223,7 @@ PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) { return NULL; } + return PyFieldDescriptor_FromDescriptor(field_descriptor); } @@ -239,6 +245,7 @@ PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { return NULL; } + return PyFieldDescriptor_FromDescriptor(field_descriptor); } @@ -260,6 +267,7 @@ PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) { return NULL; } + return PyEnumDescriptor_FromDescriptor(enum_descriptor); } @@ -281,6 +289,7 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { return NULL; } + return PyOneofDescriptor_FromDescriptor(oneof_descriptor); } @@ -303,6 +312,7 @@ static PyObject* FindServiceByName(PyObject* self, PyObject* arg) { return NULL; } + return PyServiceDescriptor_FromDescriptor(service_descriptor); } @@ -321,6 +331,7 @@ static PyObject* FindMethodByName(PyObject* self, PyObject* arg) { return NULL; } + return PyMethodDescriptor_FromDescriptor(method_descriptor); } @@ -339,6 +350,7 @@ static PyObject* FindFileContainingSymbol(PyObject* self, PyObject* arg) { return NULL; } + return PyFileDescriptor_FromDescriptor(file_descriptor); } @@ -362,6 +374,7 @@ static PyObject* FindExtensionByNumber(PyObject* self, PyObject* args) { return NULL; } + return PyFieldDescriptor_FromDescriptor(extension_descriptor); } @@ -668,13 +681,17 @@ bool InitDescriptorPool() { // The Pool of messages declared in Python libraries. // generated_pool() contains all messages already linked in C++ libraries, and // is used as underlay. + descriptor_pool_map = + new std::unordered_map; python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay( DescriptorPool::generated_pool()); if (python_generated_pool == NULL) { + delete descriptor_pool_map; return false; } + // Register this pool to be found for C++-generated descriptors. - descriptor_pool_map.insert( + descriptor_pool_map->insert( std::make_pair(DescriptorPool::generated_pool(), python_generated_pool)); @@ -695,9 +712,9 @@ PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) { pool == DescriptorPool::generated_pool()) { return python_generated_pool; } - hash_map::iterator it = - descriptor_pool_map.find(pool); - if (it == descriptor_pool_map.end()) { + std::unordered_map::iterator it = + descriptor_pool_map->find(pool); + if (it == descriptor_pool_map->end()) { PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool"); return NULL; } diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h index 53ee53dc..8289daea 100644 --- a/python/google/protobuf/pyext/descriptor_pool.h +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -33,7 +33,7 @@ #include -#include +#include #include namespace google { @@ -77,7 +77,7 @@ typedef struct PyDescriptorPool { // Cache the options for any kind of descriptor. // Descriptor pointers are owned by the DescriptorPool above. // Python objects are owned by the map. - hash_map* descriptor_options; + std::unordered_map* descriptor_options; } PyDescriptorPool; @@ -140,6 +140,6 @@ bool InitDescriptorPool(); } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index 174c5470..b73368eb 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -51,10 +51,12 @@ #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." #endif - #define PyString_AsStringAndSize(ob, charpp, sizep) \ - (PyUnicode_Check(ob)? \ - ((*(charpp) = const_cast(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \ - PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif namespace google { @@ -63,40 +65,25 @@ namespace python { namespace extension_dict { -PyObject* len(ExtensionDict* self) { -#if PY_MAJOR_VERSION >= 3 - return PyLong_FromLong(PyDict_Size(self->values)); -#else - return PyInt_FromLong(PyDict_Size(self->values)); -#endif -} - PyObject* subscript(ExtensionDict* self, PyObject* key) { const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); if (descriptor == NULL) { return NULL; } - if (!CheckFieldBelongsToMessage(descriptor, self->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { return NULL; } if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { - return cmessage::InternalGetScalar(self->message, descriptor); + return cmessage::InternalGetScalar(self->parent->message, descriptor); } - PyObject* value = PyDict_GetItem(self->values, key); - if (value != NULL) { - Py_INCREF(value); - return value; - } - - if (self->parent == NULL) { - // We are in "detached" state. Don't allow further modifications. - // TODO(amauryfa): Support adding non-scalars to a detached extension dict. - // This probably requires to store the type of the main message. - PyErr_SetObject(PyExc_KeyError, key); - return NULL; + CMessage::CompositeFieldsMap::iterator iterator = + self->parent->composite_fields->find(descriptor); + if (iterator != self->parent->composite_fields->end()) { + Py_INCREF(iterator->second); + return iterator->second; } if (descriptor->label() != FieldDescriptor::LABEL_REPEATED && @@ -107,7 +94,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (sub_message == NULL) { return NULL; } - PyDict_SetItem(self->values, key, sub_message); + Py_INCREF(sub_message); + (*self->parent->composite_fields)[descriptor] = sub_message; return sub_message; } @@ -136,7 +124,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (py_container == NULL) { return NULL; } - PyDict_SetItem(self->values, key, py_container); + Py_INCREF(py_container); + (*self->parent->composite_fields)[descriptor] = py_container; return py_container; } else { PyObject* py_container = repeated_scalar_container::NewContainer( @@ -144,7 +133,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { if (py_container == NULL) { return NULL; } - PyDict_SetItem(self->values, key, py_container); + Py_INCREF(py_container); + (*self->parent->composite_fields)[descriptor] = py_container; return py_container; } } @@ -157,7 +147,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { if (descriptor == NULL) { return -1; } - if (!CheckFieldBelongsToMessage(descriptor, self->message)) { + if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) { return -1; } @@ -167,14 +157,10 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { "type"); return -1; } - if (self->parent) { - cmessage::AssureWritable(self->parent); - if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { - return -1; - } + cmessage::AssureWritable(self->parent); + if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) { + return -1; } - // TODO(tibell): We shouldn't write scalars to the cache. - PyDict_SetItem(self->values, key, value); return 0; } @@ -232,22 +218,36 @@ ExtensionDict* NewExtensionDict(CMessage *parent) { return NULL; } - self->parent = parent; // Store a borrowed reference. - self->message = parent->message; - self->owner = parent->owner; - self->values = PyDict_New(); + Py_INCREF(parent); + self->parent = parent; return self; } void dealloc(ExtensionDict* self) { - Py_CLEAR(self->values); - self->owner.reset(); + Py_CLEAR(self->parent); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } +static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) { + // Only equality comparisons are implemented. + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + bool equals = false; + if (PyObject_TypeCheck(other, &ExtensionDict_Type)) { + equals = self->parent == reinterpret_cast(other)->parent;; + } + if (equals ^ (opid == Py_EQ)) { + Py_RETURN_FALSE; + } else { + Py_RETURN_TRUE; + } +} + static PyMappingMethods MpMethods = { - (lenfunc)len, /* mp_length */ - (binaryfunc)subscript, /* mp_subscript */ + (lenfunc)NULL, /* mp_length */ + (binaryfunc)subscript, /* mp_subscript */ (objobjargproc)ass_subscript,/* mp_ass_subscript */ }; @@ -286,7 +286,7 @@ PyTypeObject ExtensionDict_Type = { "An extension dict", // tp_doc 0, // tp_traverse 0, // tp_clear - 0, // tp_richcompare + (richcmpfunc)extension_dict::RichCompare, // tp_richcompare 0, // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h index 0de2c4ee..d800d479 100644 --- a/python/google/protobuf/pyext/extension_dict.h +++ b/python/google/protobuf/pyext/extension_dict.h @@ -37,6 +37,7 @@ #include #include +#include #include @@ -51,23 +52,8 @@ namespace python { typedef struct ExtensionDict { PyObject_HEAD; - // This is the top-level C++ Message object that owns the whole - // proto tree. Every Python container class holds a - // reference to it in order to keep it alive as long as there's a - // Python object that references any part of the tree. - CMessage::OwnerRef owner; - - // Weak reference to parent message. Used to make sure - // the parent is writable when an extension field is modified. + // Strong, owned reference to the parent message. Never NULL. CMessage* parent; - - // Pointer to the C++ Message that this ExtensionDict extends. - // Not owned by us. - Message* message; - - // A dict of child messages, indexed by Extension descriptors. - // Similar to CMessage::composite_fields. - PyObject* values; } ExtensionDict; extern PyTypeObject ExtensionDict_Type; @@ -80,6 +66,6 @@ ExtensionDict* NewExtensionDict(CMessage *parent); } // namespace extension_dict } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__ diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc index 6d7ee285..3eec49c7 100644 --- a/python/google/protobuf/pyext/map_container.cc +++ b/python/google/protobuf/pyext/map_container.cc @@ -68,12 +68,14 @@ class MapReflectionFriend { static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key); static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v); static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v); + static PyObject* ScalarMapToStr(PyObject* _self); + static PyObject* MessageMapToStr(PyObject* _self); }; struct MapIterator { PyObject_HEAD; - std::unique_ptr<::google::protobuf::MapIterator> iter; + std::unique_ptr<::proto2::MapIterator> iter; // A pointer back to the container, so we can notice changes to the version. // We own a ref on this. @@ -199,26 +201,26 @@ static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor, // This is only used for ScalarMap, so we don't need to handle the // CPPTYPE_MESSAGE case. PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor, - MapValueRef* value) { + const MapValueRef& value) { switch (field_descriptor->cpp_type()) { case FieldDescriptor::CPPTYPE_INT32: - return PyInt_FromLong(value->GetInt32Value()); + return PyInt_FromLong(value.GetInt32Value()); case FieldDescriptor::CPPTYPE_INT64: - return PyLong_FromLongLong(value->GetInt64Value()); + return PyLong_FromLongLong(value.GetInt64Value()); case FieldDescriptor::CPPTYPE_UINT32: - return PyInt_FromSize_t(value->GetUInt32Value()); + return PyInt_FromSize_t(value.GetUInt32Value()); case FieldDescriptor::CPPTYPE_UINT64: - return PyLong_FromUnsignedLongLong(value->GetUInt64Value()); + return PyLong_FromUnsignedLongLong(value.GetUInt64Value()); case FieldDescriptor::CPPTYPE_FLOAT: - return PyFloat_FromDouble(value->GetFloatValue()); + return PyFloat_FromDouble(value.GetFloatValue()); case FieldDescriptor::CPPTYPE_DOUBLE: - return PyFloat_FromDouble(value->GetDoubleValue()); + return PyFloat_FromDouble(value.GetDoubleValue()); case FieldDescriptor::CPPTYPE_BOOL: - return PyBool_FromLong(value->GetBoolValue()); + return PyBool_FromLong(value.GetBoolValue()); case FieldDescriptor::CPPTYPE_STRING: - return ToStringObject(field_descriptor, value->GetStringValue()); + return ToStringObject(field_descriptor, value.GetStringValue()); case FieldDescriptor::CPPTYPE_ENUM: - return PyInt_FromLong(value->GetEnumValue()); + return PyInt_FromLong(value.GetEnumValue()); default: PyErr_Format( PyExc_SystemError, "Couldn't convert type %d to value", @@ -312,7 +314,7 @@ static MapContainer* GetMap(PyObject* obj) { Py_ssize_t MapReflectionFriend::Length(PyObject* _self) { MapContainer* self = GetMap(_self); - const google::protobuf::Message* message = self->message; + const proto2::Message* message = self->message; return message->GetReflection()->MapSize(*message, self->parent_field_descriptor); } @@ -421,7 +423,7 @@ int MapContainer::Release() { // ScalarMap /////////////////////////////////////////////////////////////////// PyObject *NewScalarMapContainer( - CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { + CMessage* parent, const proto2::FieldDescriptor* parent_field_descriptor) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; } @@ -472,7 +474,7 @@ PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self, self->version++; } - return MapValueRefToPython(self->value_field_descriptor, &value); + return MapValueRefToPython(self->value_field_descriptor, value); } int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key, @@ -535,10 +537,47 @@ static PyObject* ScalarMapGet(PyObject* self, PyObject* args) { } } +PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) { + ScopedPyObjectPtr dict(PyDict_New()); + if (dict == NULL) { + return NULL; + } + ScopedPyObjectPtr key; + ScopedPyObjectPtr value; + + MapContainer* self = GetMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + for (proto2::MapIterator it = reflection->MapBegin( + message, self->parent_field_descriptor); + it != reflection->MapEnd(message, self->parent_field_descriptor); + ++it) { + key.reset(MapKeyToPython(self->key_field_descriptor, + it.GetKey())); + if (key == NULL) { + return NULL; + } + value.reset(MapValueRefToPython(self->value_field_descriptor, + it.GetValueRef())); + if (value == NULL) { + return NULL; + } + if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) { + return NULL; + } + } + return PyObject_Repr(dict.get()); +} + static void ScalarMapDealloc(PyObject* _self) { MapContainer* self = GetMap(_self); self->owner.reset(); - Py_TYPE(_self)->tp_free(_self); + PyTypeObject *type = Py_TYPE(_self); + type->tp_free(_self); + if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) { + // With Python3, the Map class is not static, and must be managed. + Py_DECREF(type); + } } static PyMethodDef ScalarMapMethods[] = { @@ -570,6 +609,7 @@ PyTypeObject *ScalarMapContainer_Type; {Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem}, {Py_tp_methods, (void *)ScalarMapMethods}, {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr}, {0, 0}, }; @@ -597,7 +637,7 @@ PyTypeObject *ScalarMapContainer_Type; 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + MapReflectionFriend::ScalarMapToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence &ScalarMapMappingMethods, // tp_as_mapping @@ -634,7 +674,8 @@ static MessageMapContainer* GetMessageMap(PyObject* obj) { return reinterpret_cast(obj); } -static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { +static PyObject* GetCMessage(MessageMapContainer* self, Message* message, + bool insert_message_dict) { // Get or create the CMessage object corresponding to this message. ScopedPyObjectPtr key(PyLong_FromVoidPtr(message)); PyObject* ret = PyDict_GetItem(self->message_dict, key.get()); @@ -649,10 +690,11 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { cmsg->owner = self->owner; cmsg->message = message; cmsg->parent = self->parent; - - if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { - Py_DECREF(ret); - return NULL; + if (insert_message_dict) { + if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) { + Py_DECREF(ret); + return NULL; + } } } else { Py_INCREF(ret); @@ -662,7 +704,7 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) { } PyObject* NewMessageMapContainer( - CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor, + CMessage* parent, const proto2::FieldDescriptor* parent_field_descriptor, CMessageClass* message_class) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; @@ -781,7 +823,41 @@ PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self, self->version++; } - return GetCMessage(self, value.MutableMessageValue()); + return GetCMessage(self, value.MutableMessageValue(), true); +} + +PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) { + ScopedPyObjectPtr dict(PyDict_New()); + if (dict == NULL) { + return NULL; + } + ScopedPyObjectPtr key; + ScopedPyObjectPtr value; + + MessageMapContainer* self = GetMessageMap(_self); + Message* message = self->GetMutableMessage(); + const Reflection* reflection = message->GetReflection(); + for (proto2::MapIterator it = reflection->MapBegin( + message, self->parent_field_descriptor); + it != reflection->MapEnd(message, self->parent_field_descriptor); + ++it) { + key.reset(MapKeyToPython(self->key_field_descriptor, + it.GetKey())); + if (key == NULL) { + return NULL; + } + // Do not insert the cmessage to self->message_dict because + // the returned CMessage will not escape this function. + value.reset(GetCMessage( + self, it.MutableValueRef()->MutableMessageValue(), false)); + if (value == NULL) { + return NULL; + } + if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) { + return NULL; + } + } + return PyObject_Repr(dict.get()); } PyObject* MessageMapGet(PyObject* self, PyObject* args) { @@ -813,7 +889,12 @@ static void MessageMapDealloc(PyObject* _self) { self->owner.reset(); Py_DECREF(self->message_dict); Py_DECREF(self->message_class); - Py_TYPE(_self)->tp_free(_self); + PyTypeObject *type = Py_TYPE(_self); + type->tp_free(_self); + if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) { + // With Python3, the Map class is not static, and must be managed. + Py_DECREF(type); + } } static PyMethodDef MessageMapMethods[] = { @@ -847,6 +928,7 @@ PyTypeObject *MessageMapContainer_Type; {Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem}, {Py_tp_methods, (void *)MessageMapMethods}, {Py_tp_iter, (void *)MapReflectionFriend::GetIterator}, + {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr}, {0, 0} }; @@ -874,7 +956,7 @@ PyTypeObject *MessageMapContainer_Type; 0, // tp_getattr 0, // tp_setattr 0, // tp_compare - 0, // tp_repr + MapReflectionFriend::MessageMapToStr, // tp_repr 0, // tp_as_number 0, // tp_as_sequence &MessageMapMappingMethods, // tp_as_mapping @@ -929,7 +1011,7 @@ PyObject* MapReflectionFriend::GetIterator(PyObject *_self) { Message* message = self->GetMutableMessage(); const Reflection* reflection = message->GetReflection(); - iter->iter.reset(new ::google::protobuf::MapIterator( + iter->iter.reset(new ::proto2::MapIterator( reflection->MapBegin(message, self->parent_field_descriptor))); } @@ -1027,17 +1109,15 @@ bool InitMapContainers() { return false; } - if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) { - return false; - } - Py_INCREF(mutable_mapping.get()); #if PY_MAJOR_VERSION >= 3 - PyObject* bases = PyTuple_New(1); - PyTuple_SET_ITEM(bases, 0, mutable_mapping.get()); + ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get())); + if (bases == NULL) { + return false; + } ScalarMapContainer_Type = reinterpret_cast( - PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases)); + PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get())); #else _ScalarMapContainer_Type.tp_base = reinterpret_cast(mutable_mapping.get()); @@ -1055,7 +1135,7 @@ bool InitMapContainers() { #if PY_MAJOR_VERSION >= 3 MessageMapContainer_Type = reinterpret_cast( - PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases)); + PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get())); #else Py_INCREF(mutable_mapping.get()); _MessageMapContainer_Type.tp_base = diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h index 111fafbf..7e77b027 100644 --- a/python/google/protobuf/pyext/map_container.h +++ b/python/google/protobuf/pyext/map_container.h @@ -120,6 +120,6 @@ extern PyObject* NewMessageMapContainer( } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index b2984509..5d0e37fa 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -45,12 +45,11 @@ #ifndef Py_TYPE #define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) #endif -#include #include #include #include #include -#include +#include #include #include #include @@ -58,12 +57,16 @@ #include #include #include -#include -#include +#include #include #include +#include +#include +#include #include #include +#include +#include #if PY_MAJOR_VERSION >= 3 #define PyInt_AsLong PyLong_AsLong @@ -72,16 +75,19 @@ #define PyString_Check PyUnicode_Check #define PyString_FromString PyUnicode_FromString #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_FromFormat PyUnicode_FromFormat #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." #else #define PyString_AsString(ob) \ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) - #define PyString_AsStringAndSize(ob, charpp, sizep) \ - (PyUnicode_Check(ob)? \ - ((*(charpp) = const_cast(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \ - PyBytes_AsStringAndSize(ob, (charpp), (sizep))) - #endif +#define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob) ? ((*(charpp) = const_cast( \ + PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \ + ? -1 \ + : 0) \ + : PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif #endif namespace google { @@ -99,44 +105,27 @@ namespace message_meta { static int InsertEmptyWeakref(PyTypeObject* base); namespace { -// Copied oveer from internal 'google/protobuf/stubs/strutil.h'. -inline void UpperString(string * s) { +// Copied over from internal 'google/protobuf/stubs/strutil.h'. +inline void LowerString(string * s) { string::iterator end = s->end(); for (string::iterator i = s->begin(); i != end; ++i) { - // toupper() changes based on locale. We don't want this! - if ('a' <= *i && *i <= 'z') *i += 'A' - 'a'; + // tolower() changes based on locale. We don't want this! + if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A'; } } } -// Add the number of a field descriptor to the containing message class. -// Equivalent to: -// _cls._FIELD_NUMBER = -static bool AddFieldNumberToClass( - PyObject* cls, const FieldDescriptor* field_descriptor) { - string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; - UpperString(&constant_name); - ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( - constant_name.c_str(), constant_name.size())); - if (attr_name == NULL) { - return false; - } - ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); - if (number == NULL) { - return false; - } - if (PyObject_SetAttr(cls, attr_name.get(), number.get()) == -1) { - return false; - } - return true; -} - - // Finalize the creation of the Message class. static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // For each field set: cls._FIELD_NUMBER = for (int i = 0; i < descriptor->field_count(); ++i) { - if (!AddFieldNumberToClass(cls, descriptor->field(i))) { + const FieldDescriptor* field_descriptor = descriptor->field(i); + ScopedPyObjectPtr property(NewFieldProperty(field_descriptor)); + if (property == NULL) { + return -1; + } + if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(), + property.get()) < 0) { return -1; } } @@ -182,7 +171,7 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { // .extensions_by_name[name] // which was defined previously. for (int i = 0; i < descriptor->extension_count(); ++i) { - const google::protobuf::FieldDescriptor* field = descriptor->extension(i); + const proto2::FieldDescriptor* field = descriptor->extension(i); ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field)); if (extension_field == NULL) { return -1; @@ -193,11 +182,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) { cls, field->name().c_str(), extension_field.get()) == -1) { return -1; } - - // For each extension set cls._FIELD_NUMBER = . - if (!AddFieldNumberToClass(cls, field)) { - return -1; - } } return 0; @@ -265,10 +249,10 @@ static PyObject* New(PyTypeObject* type, PyObject* well_known_class = PyDict_GetItemString( WKT_classes, message_descriptor->full_name().c_str()); if (well_known_class == NULL) { - new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type, + new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type, PythonMessage_class, dict)); } else { - new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type, + new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type, PythonMessage_class, well_known_class, dict)); } @@ -285,7 +269,7 @@ static PyObject* New(PyTypeObject* type, // Insert the empty weakref into the base classes. if (InsertEmptyWeakref( reinterpret_cast(PythonMessage_class)) < 0 || - InsertEmptyWeakref(&CMessage_Type) < 0) { + InsertEmptyWeakref(CMessage_Type) < 0) { return NULL; } @@ -353,6 +337,13 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) { // The _extensions_by_name dictionary is built on every access. // TODO(amauryfa): Migrate all users to pool.FindAllExtensions() static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { + if (self->message_descriptor == NULL) { + // This is the base Message object, simply raise AttributeError. + PyErr_SetString(PyExc_AttributeError, + "Base Message class has no DESCRIPTOR"); + return NULL; + } + const PyDescriptorPool* pool = self->py_message_factory->pool; std::vector extensions; @@ -376,6 +367,13 @@ static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) { // The _extensions_by_number dictionary is built on every access. // TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber() static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) { + if (self->message_descriptor == NULL) { + // This is the base Message object, simply raise AttributeError. + PyErr_SetString(PyExc_AttributeError, + "Base Message class has no DESCRIPTOR"); + return NULL; + } + const PyDescriptorPool* pool = self->py_message_factory->pool; std::vector extensions; @@ -405,9 +403,51 @@ static PyGetSetDef Getters[] = { {NULL} }; +// Compute some class attributes on the fly: +// - All the _FIELD_NUMBER attributes, for all fields and nested extensions. +// Returns a new reference, or NULL with an exception set. +static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) { + char* attr; + Py_ssize_t attr_size; + static const char kSuffix[] = "_FIELD_NUMBER"; + if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 && + strings::EndsWith(StringPiece(attr, attr_size), kSuffix)) { + string field_name(attr, attr_size - sizeof(kSuffix) + 1); + LowerString(&field_name); + + // Try to find a field with the given name, without the suffix. + const FieldDescriptor* field = + self->message_descriptor->FindFieldByLowercaseName(field_name); + if (!field) { + // Search nested extensions as well. + field = + self->message_descriptor->FindExtensionByLowercaseName(field_name); + } + if (field) { + return PyInt_FromLong(field->number()); + } + } + PyErr_SetObject(PyExc_AttributeError, name); + return NULL; +} + +static PyObject* GetAttr(CMessageClass* self, PyObject* name) { + PyObject* result = CMessageClass_Type->tp_base->tp_getattro( + reinterpret_cast(self), name); + if (result != NULL) { + return result; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; + } + + PyErr_Clear(); + return GetClassAttribute(self, name); +} + } // namespace message_meta -PyTypeObject CMessageClass_Type = { +static PyTypeObject _CMessageClass_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) FULL_MODULE_NAME ".MessageMeta", // tp_name sizeof(CMessageClass), // tp_basicsize @@ -424,7 +464,7 @@ PyTypeObject CMessageClass_Type = { 0, // tp_hash 0, // tp_call 0, // tp_str - 0, // tp_getattro + (getattrofunc)message_meta::GetAttr, // tp_getattro 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags @@ -447,9 +487,10 @@ PyTypeObject CMessageClass_Type = { 0, // tp_alloc message_meta::New, // tp_new }; +PyTypeObject* CMessageClass_Type = &_CMessageClass_Type; static CMessageClass* CheckMessageClass(PyTypeObject* cls) { - if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + if (!PyObject_TypeCheck(cls, CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name); return NULL; } @@ -486,11 +527,21 @@ struct ChildVisitor { return 0; } + // Returns 0 on success, -1 on failure. + int VisitMapContainer(MapContainer* container) { + return 0; + } + // Returns 0 on success, -1 on failure. int VisitCMessage(CMessage* cmessage, const FieldDescriptor* field_descriptor) { return 0; } + + // Returns 0 on success, -1 on failure. + int VisitUnknownFieldSet(PyUnknownFields* unknown_field_set) { + return 0; + } }; // Apply a function to a composite field. Does nothing if child is of @@ -538,34 +589,19 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // Visit normal fields. if (self->composite_fields) { - // Never use self->message in this function, it may be already freed. - const Descriptor* message_descriptor = - GetMessageDescriptor(Py_TYPE(self)); - while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { - Py_ssize_t key_str_size; - char *key_str_data; - if (PyString_AsStringAndSize(key, &key_str_data, &key_str_size) != 0) - return -1; - const string key_str(key_str_data, key_str_size); - const FieldDescriptor* descriptor = - message_descriptor->FindFieldByName(key_str); - if (descriptor != NULL) { - if (VisitCompositeField(descriptor, field, visitor) == -1) - return -1; - } + for (CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->begin(); + it != self->composite_fields->end(); it++) { + const FieldDescriptor* descriptor = it->first; + PyObject* field = it->second; + if (VisitCompositeField(descriptor, field, visitor) == -1) return -1; } } - // Visit extension fields. - if (self->extensions != NULL) { - pos = 0; - while (PyDict_Next(self->extensions->values, &pos, &key, &field)) { - const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); - if (descriptor == NULL) - return -1; - if (VisitCompositeField(descriptor, field, visitor) == -1) - return -1; - } + if (self->unknown_field_set) { + PyUnknownFields* unknown_field_set = + reinterpret_cast(self->unknown_field_set); + visitor.VisitUnknownFieldSet(unknown_field_set); } return 0; @@ -577,8 +613,12 @@ PyObject* EncodeError_class; PyObject* DecodeError_class; PyObject* PickleError_class; -/* Is 64bit */ +// Format an error message for unexpected types. +// Always return with an exception set. void FormatTypeError(PyObject* arg, char* expected_types) { + // This function is often called with an exception set. + // Clear it to call PyObject_Repr() in good conditions. + PyErr_Clear(); PyObject* repr = PyObject_Repr(arg); if (repr) { PyErr_Format(PyExc_TypeError, @@ -859,7 +899,7 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, namespace cmessage { PyMessageFactory* GetFactoryForMessage(CMessage* message) { - GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type)); + GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type)); return reinterpret_cast(Py_TYPE(message))->py_message_factory; } @@ -883,22 +923,20 @@ static int MaybeReleaseOverlappingOneofField( // Non-message fields don't need to be released. return 0; } - const char* field_name = existing_field->name().c_str(); - PyObject* child_message = cmessage->composite_fields ? - PyDict_GetItemString(cmessage->composite_fields, field_name) : NULL; - if (child_message == NULL) { - // No python reference to this field so no need to release. - return 0; - } - - if (InternalReleaseFieldByDescriptor( - cmessage, existing_field, child_message) < 0) { - return -1; + if (cmessage->composite_fields) { + CMessage::CompositeFieldsMap::iterator iterator = + cmessage->composite_fields->find(existing_field); + if (iterator != cmessage->composite_fields->end()) { + if (InternalReleaseFieldByDescriptor(cmessage, existing_field, + iterator->second) < 0) { + return -1; + } + Py_DECREF(iterator->second); + cmessage->composite_fields->erase(iterator); + } } - return PyDict_DelItemString(cmessage->composite_fields, field_name); -#else - return 0; #endif + return 0; } // --------------------------------------------------------------------- @@ -937,10 +975,49 @@ struct FixupMessageReference : public ChildVisitor { return 0; } + int VisitUnknownFieldSet(PyUnknownFields* unknown_field_set) { + const Reflection* reflection = message_->GetReflection(); + unknown_field_set->fields = &reflection->GetUnknownFields(*message_); + return 0; + } + private: Message* message_; }; +// After a Merge, visit every sub-message that was read-only, and +// eventually update their pointer if the Merge operation modified them. +struct FixupMessageAfterMerge : public FixupMessageReference { + explicit FixupMessageAfterMerge(CMessage* parent) : + FixupMessageReference(parent->message), + parent_cmessage(parent), message(parent->message) {} + + int VisitCMessage(CMessage* cmessage, + const FieldDescriptor* field_descriptor) { + if (cmessage->read_only == false) { + return 0; + } + if (message->GetReflection()->HasField(*message, field_descriptor)) { + Message* mutable_message = GetMutableMessage( + parent_cmessage, field_descriptor); + if (mutable_message == NULL) { + return -1; + } + cmessage->message = mutable_message; + cmessage->read_only = false; + if (ForEachCompositeField( + cmessage, FixupMessageAfterMerge(cmessage)) == -1) { + return -1; + } + } + return 0; + } + + private: + CMessage* parent_cmessage; + Message* message; +}; + int AssureWritable(CMessage* self) { if (self == NULL || !self->read_only) { return 0; @@ -974,10 +1051,8 @@ int AssureWritable(CMessage* self) { // When a CMessage is made writable its Message pointer is updated // to point to a new mutable Message. When that happens we need to // update any references to the old, read-only CMessage. There are - // four places such references occur: RepeatedScalarContainer, - // RepeatedCompositeContainer, MapContainer, and ExtensionDict. - if (self->extensions != NULL) - self->extensions->message = self->message; + // three places such references occur: RepeatedScalarContainer, + // RepeatedCompositeContainer, and MapContainer. if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1) return -1; @@ -986,27 +1061,6 @@ int AssureWritable(CMessage* self) { // --- Globals: -// Retrieve a C++ FieldDescriptor for a message attribute. -// The C++ message must be valid. -// TODO(amauryfa): This function should stay internal, because exception -// handling is not consistent. -static const FieldDescriptor* GetFieldDescriptor( - CMessage* self, PyObject* name) { - const Descriptor *message_descriptor = self->message->GetDescriptor(); - char* field_name; - Py_ssize_t size; - if (PyString_AsStringAndSize(name, &field_name, &size) < 0) { - return NULL; - } - const FieldDescriptor *field_descriptor = - message_descriptor->FindFieldByName(string(field_name, size)); - if (field_descriptor == NULL) { - // Note: No exception is set! - return NULL; - } - return field_descriptor; -} - // Retrieve a C++ FieldDescriptor for an extension handle. const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { ScopedPyObjectPtr cdescriptor; @@ -1038,7 +1092,7 @@ static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor, const EnumValueDescriptor* enum_value_descriptor = enum_descriptor->FindValueByName(string(enum_label, size)); if (enum_value_descriptor == NULL) { - PyErr_SetString(PyExc_ValueError, "unknown enum label"); + PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label); return NULL; } return PyInt_FromLong(enum_value_descriptor->number()); @@ -1164,19 +1218,24 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { PyErr_SetString(PyExc_ValueError, "Field name must be a string"); return -1; } - const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); - if (descriptor == NULL) { + ScopedPyObjectPtr property( + PyObject_GetAttr(reinterpret_cast(Py_TYPE(self)), name)); + if (property == NULL || + !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) { PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.", self->message->GetDescriptor()->name().c_str(), PyString_AsString(name)); return -1; } + const FieldDescriptor* descriptor = + reinterpret_cast(property.get()) + ->field_descriptor; if (value == Py_None) { // field=None is the same as no field at all. continue; } if (descriptor->is_map()) { - ScopedPyObjectPtr map(GetAttr(reinterpret_cast(self), name)); + ScopedPyObjectPtr map(GetFieldValue(self, descriptor)); const FieldDescriptor* value_descriptor = descriptor->message_type()->FindFieldByName("value"); if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { @@ -1204,8 +1263,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { } } } else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - ScopedPyObjectPtr container( - GetAttr(reinterpret_cast(self), name)); + ScopedPyObjectPtr container(GetFieldValue(self, descriptor)); if (container == NULL) { return -1; } @@ -1272,8 +1330,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { } } } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - ScopedPyObjectPtr message( - GetAttr(reinterpret_cast(self), name)); + ScopedPyObjectPtr message(GetFieldValue(self, descriptor)); if (message == NULL) { return -1; } @@ -1297,9 +1354,9 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) { if (new_val == NULL) { return -1; } + value = new_val.get(); } - if (SetAttr(reinterpret_cast(self), name, - (new_val.get() == NULL) ? value : new_val.get()) < 0) { + if (SetFieldValue(self, descriptor, value) < 0) { return -1; } } @@ -1322,10 +1379,11 @@ CMessage* NewEmptyMessage(CMessageClass* type) { self->parent = NULL; self->parent_field_descriptor = NULL; self->read_only = false; - self->extensions = NULL; self->composite_fields = NULL; + self->unknown_field_set = NULL; + return self; } @@ -1408,12 +1466,20 @@ static void Dealloc(CMessage* self) { } // Null out all weak references from children to this message. GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences())); - if (self->extensions) { - self->extensions->parent = NULL; - } - Py_CLEAR(self->extensions); - Py_CLEAR(self->composite_fields); + if (self->composite_fields) { + for (CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->begin(); + it != self->composite_fields->end(); it++) { + Py_DECREF(it->second); + } + delete self->composite_fields; + } + if (self->unknown_field_set) { + unknown_fields::Clear( + reinterpret_cast(self->unknown_field_set)); + Py_CLEAR(self->unknown_field_set); + } self->owner.~ThreadUnsafeSharedPtr(); Py_TYPE(self)->tp_free(reinterpret_cast(self)); } @@ -1564,13 +1630,16 @@ PyObject* ClearExtension(CMessage* self, PyObject* extension) { if (descriptor == NULL) { return NULL; } - if (self->extensions != NULL) { - PyObject* value = PyDict_GetItem(self->extensions->values, extension); - if (value != NULL) { - if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) { + if (self->composite_fields != NULL) { + CMessage::CompositeFieldsMap::iterator iterator = + self->composite_fields->find(descriptor); + if (iterator != self->composite_fields->end()) { + if (InternalReleaseFieldByDescriptor(self, descriptor, + iterator->second) < 0) { return NULL; } - PyDict_DelItem(self->extensions->values, extension); + Py_DECREF(iterator->second); + self->composite_fields->erase(iterator); } } return ClearFieldByDescriptor(self, descriptor); @@ -1770,14 +1839,16 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { arg = arg_in_oneof.get(); } - // Release the field if it exists in the dict of composite fields. if (self->composite_fields) { - PyObject* value = PyDict_GetItem(self->composite_fields, arg); - if (value != NULL) { - if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) { + CMessage::CompositeFieldsMap::iterator iterator = + self->composite_fields->find(field_descriptor); + if (iterator != self->composite_fields->end()) { + if (InternalReleaseFieldByDescriptor(self, field_descriptor, + iterator->second) < 0) { return NULL; } - PyDict_DelItem(self->composite_fields, arg); + Py_DECREF(iterator->second); + self->composite_fields->erase(iterator); } } return ClearFieldByDescriptor(self, field_descriptor); @@ -1787,9 +1858,18 @@ PyObject* Clear(CMessage* self) { AssureWritable(self); if (ForEachCompositeField(self, ReleaseChild(self)) == -1) return NULL; - Py_CLEAR(self->extensions); if (self->composite_fields) { - PyDict_Clear(self->composite_fields); + for (CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->begin(); + it != self->composite_fields->end(); it++) { + Py_DECREF(it->second); + } + self->composite_fields->clear(); + } + if (self->unknown_field_set) { + unknown_fields::Clear( + reinterpret_cast(self->unknown_field_set)); + Py_CLEAR(self->unknown_field_set); } self->message->Clear(); Py_RETURN_NONE; @@ -1946,7 +2026,7 @@ static PyObject* ToStr(CMessage* self) { PyObject* MergeFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + if (!PyObject_TypeCheck(arg, CMessage_Type)) { PyErr_Format(PyExc_TypeError, "Parameter to MergeFrom() must be instance of same class: " "expected %s got %s.", @@ -1967,18 +2047,19 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) { } AssureWritable(self); - // TODO(tibell): Message::MergeFrom might turn some child Messages - // into mutable messages, invalidating the message field in the - // corresponding CMessages. We should run a FixupMessageReferences - // pass here. - self->message->MergeFrom(*other_message->message); + // Child message might be lazily created before MergeFrom. Make sure they + // are mutable at this point if child messages are really created. + if (ForEachCompositeField(self, FixupMessageAfterMerge(self)) == -1) { + return NULL; + } + Py_RETURN_NONE; } static PyObject* CopyFrom(CMessage* self, PyObject* arg) { CMessage* other_message; - if (!PyObject_TypeCheck(arg, &CMessage_Type)) { + if (!PyObject_TypeCheck(arg, CMessage_Type)) { PyErr_Format(PyExc_TypeError, "Parameter to CopyFrom() must be instance of same class: " "expected %s got %s.", @@ -2050,6 +2131,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } AssureWritable(self); + io::CodedInputStream input( reinterpret_cast(data), data_length); if (allow_oversize_protos) { @@ -2058,6 +2140,12 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { PyMessageFactory* factory = GetFactoryForMessage(self); input.SetExtensionRegistry(factory->pool->pool, factory->message_factory); bool success = self->message->MergePartialFromCodedStream(&input); + // Child message might be lazily created before MergeFrom. Make sure they + // are mutable at this point if child messages are really created. + if (ForEachCompositeField(self, FixupMessageAfterMerge(self)) == -1) { + return NULL; + } + if (success) { if (!input.ConsumedEntireMessage()) { // TODO(jieluo): Raise error and return NULL instead. @@ -2088,7 +2176,7 @@ PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) { if (descriptor == NULL) { return NULL; } - if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) { + if (!PyObject_TypeCheck(cls, CMessageClass_Type)) { PyErr_Format(PyExc_TypeError, "Expected a message class, got %s", cls->ob_type->tp_name); return NULL; @@ -2192,23 +2280,15 @@ static PyObject* ListFields(CMessage* self) { PyTuple_SET_ITEM(t.get(), 1, extension); } else { // Normal field - const string& field_name = fields[i]->name(); - ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize( - field_name.c_str(), field_name.length())); - if (py_field_name == NULL) { - PyErr_SetString(PyExc_ValueError, "bad string"); - return NULL; - } ScopedPyObjectPtr field_descriptor( PyFieldDescriptor_FromDescriptor(fields[i])); if (field_descriptor == NULL) { return NULL; } - PyObject* field_value = - GetAttr(reinterpret_cast(self), py_field_name.get()); + PyObject* field_value = GetFieldValue(self, fields[i]); if (field_value == NULL) { - PyErr_SetObject(PyExc_ValueError, py_field_name.get()); + PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str()); return NULL; } PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release()); @@ -2261,10 +2341,10 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { } bool equals = true; // If other is not a message, it cannot be equal. - if (!PyObject_TypeCheck(other, &CMessage_Type)) { + if (!PyObject_TypeCheck(other, CMessage_Type)) { equals = false; } - const google::protobuf::Message* other_message = + const proto2::Message* other_message = reinterpret_cast(other)->message; // If messages don't have the same descriptors, they are not equal. if (equals && @@ -2272,11 +2352,12 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { equals = false; } // Check the message contents. - if (equals && !google::protobuf::util::MessageDifferencer::Equals( + if (equals && !proto2::util::MessageDifferencer::Equals( *self->message, *reinterpret_cast(other)->message)) { equals = false; } + if (equals ^ (opid == Py_EQ)) { Py_RETURN_FALSE; } else { @@ -2498,7 +2579,7 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) { if (clone == NULL) { return NULL; } - if (!PyObject_TypeCheck(clone, &CMessage_Type)) { + if (!PyObject_TypeCheck(clone, CMessage_Type)) { Py_DECREF(clone); return NULL; } @@ -2592,26 +2673,29 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, } static PyObject* GetExtensionDict(CMessage* self, void *closure) { - if (self->extensions) { - Py_INCREF(self->extensions); - return reinterpret_cast(self->extensions); - } - // If there are extension_ranges, the message is "extendable". Allocate a // dictionary to store the extension fields. const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self)); - if (descriptor->extension_range_count() > 0) { - ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); - if (extension_dict == NULL) { - return NULL; - } - self->extensions = extension_dict; - Py_INCREF(self->extensions); - return reinterpret_cast(self->extensions); + if (!descriptor->extension_range_count()) { + PyErr_SetNone(PyExc_AttributeError); + return NULL; + } + if (!self->composite_fields) { + self->composite_fields = new CMessage::CompositeFieldsMap(); } + if (!self->composite_fields) { + return NULL; + } + ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self); + return reinterpret_cast(extension_dict); +} - PyErr_SetNone(PyExc_AttributeError); - return NULL; +static PyObject* UnknownFieldSet(CMessage* self) { + if (self->unknown_field_set == NULL) { + self->unknown_field_set = unknown_fields::NewPyUnknownFields(self); + } + Py_INCREF(self->unknown_field_set); + return self->unknown_field_set; } static PyObject* GetExtensionsByName(CMessage *self, void *closure) { @@ -2682,6 +2766,8 @@ static PyMethodDef Methods[] = { "Serializes the message to a string, only for initialized messages." }, { "SetInParent", (PyCFunction)SetInParent, METH_NOARGS, "Sets the has bit of the given field in its parent message." }, + { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS, + "Parse unknown field set"}, { "WhichOneof", (PyCFunction)WhichOneof, METH_O, "Returns the name of the field set inside a oneof, " "or None if no field is set." }, @@ -2693,30 +2779,53 @@ static PyMethodDef Methods[] = { { NULL, NULL} }; -static bool SetCompositeField( - CMessage* self, PyObject* name, PyObject* value) { +static bool SetCompositeField(CMessage* self, const FieldDescriptor* field, + PyObject* value) { if (self->composite_fields == NULL) { - self->composite_fields = PyDict_New(); - if (self->composite_fields == NULL) { - return false; - } + self->composite_fields = new CMessage::CompositeFieldsMap(); } - return PyDict_SetItem(self->composite_fields, name, value) == 0; + Py_INCREF(value); + Py_XDECREF((*self->composite_fields)[field]); + (*self->composite_fields)[field] = value; + return true; } PyObject* GetAttr(PyObject* pself, PyObject* name) { CMessage* self = reinterpret_cast(pself); - PyObject* value = self->composite_fields ? - PyDict_GetItem(self->composite_fields, name) : NULL; - if (value != NULL) { - Py_INCREF(value); - return value; + PyObject* result = PyObject_GenericGetAttr( + reinterpret_cast(self), name); + if (result != NULL) { + return result; + } + if (!PyErr_ExceptionMatches(PyExc_AttributeError)) { + return NULL; } - const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); - if (field_descriptor == NULL) { - return CMessage_Type.tp_base->tp_getattro( - reinterpret_cast(self), name); + PyErr_Clear(); + return message_meta::GetClassAttribute( + CheckMessageClass(Py_TYPE(self)), name); +} + +PyObject* GetFieldValue(CMessage* self, + const FieldDescriptor* field_descriptor) { + if (self->composite_fields) { + CMessage::CompositeFieldsMap::iterator it = + self->composite_fields->find(field_descriptor); + if (it != self->composite_fields->end()) { + PyObject* value = it->second; + Py_INCREF(value); + return value; + } + } + + const Descriptor* message_descriptor = + (reinterpret_cast(Py_TYPE(self)))->message_descriptor; + if (self->message->GetDescriptor() != field_descriptor->containing_type()) { + PyErr_Format(PyExc_TypeError, + "descriptor to field '%s' doesn't apply to '%s' object", + field_descriptor->full_name().c_str(), + Py_TYPE(self)->tp_name); + return NULL; } if (field_descriptor->is_map()) { @@ -2737,7 +2846,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { if (py_container == NULL) { return NULL; } - if (!SetCompositeField(self, name, py_container)) { + if (!SetCompositeField(self, field_descriptor, py_container)) { Py_DECREF(py_container); return NULL; } @@ -2761,7 +2870,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { if (py_container == NULL) { return NULL; } - if (!SetCompositeField(self, name, py_container)) { + if (!SetCompositeField(self, field_descriptor, py_container)) { Py_DECREF(py_container); return NULL; } @@ -2773,7 +2882,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { if (sub_message == NULL) { return NULL; } - if (!SetCompositeField(self, name, sub_message)) { + if (!SetCompositeField(self, field_descriptor, sub_message)) { Py_DECREF(sub_message); return NULL; } @@ -2783,44 +2892,35 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) { return InternalGetScalar(self->message, field_descriptor); } -int SetAttr(PyObject* pself, PyObject* name, PyObject* value) { - CMessage* self = reinterpret_cast(pself); - if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) { - PyErr_SetString(PyExc_TypeError, "Can't set composite field"); +int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor, + PyObject* value) { + if (self->message->GetDescriptor() != field_descriptor->containing_type()) { + PyErr_Format(PyExc_TypeError, + "descriptor to field '%s' doesn't apply to '%s' object", + field_descriptor->full_name().c_str(), + Py_TYPE(self)->tp_name); return -1; - } - - const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); - if (field_descriptor != NULL) { + } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to repeated " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format(PyExc_AttributeError, + "Assignment not allowed to " + "field \"%s\" in protocol message object.", + field_descriptor->name().c_str()); + return -1; + } else { AssureWritable(self); - if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { - PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated " - "field \"%s\" in protocol message object.", - field_descriptor->name().c_str()); - return -1; - } else { - if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { - PyErr_Format(PyExc_AttributeError, "Assignment not allowed to " - "field \"%s\" in protocol message object.", - field_descriptor->name().c_str()); - return -1; - } else { - return InternalSetScalar(self, field_descriptor, value); - } - } + return InternalSetScalar(self, field_descriptor, value); } - - PyErr_Format(PyExc_AttributeError, - "Assignment not allowed " - "(no field \"%s\" in protocol message object).", - PyString_AsString(name)); - return -1; } - } // namespace cmessage -PyTypeObject CMessage_Type = { - PyVarObject_HEAD_INIT(&CMessageClass_Type, 0) +static CMessageClass _CMessage_Type = { { { + PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0) FULL_MODULE_NAME ".CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize @@ -2837,9 +2937,10 @@ PyTypeObject CMessage_Type = { 0, // tp_call (reprfunc)cmessage::ToStr, // tp_str cmessage::GetAttr, // tp_getattro - cmessage::SetAttr, // tp_setattro + 0, // tp_setattro 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags + Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE + | Py_TPFLAGS_HAVE_VERSION_TAG, // tp_flags "A ProtocolMessage", // tp_doc 0, // tp_traverse 0, // tp_clear @@ -2858,7 +2959,8 @@ PyTypeObject CMessage_Type = { (initproc)cmessage::Init, // tp_init 0, // tp_alloc cmessage::New, // tp_new -}; +} } }; +PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type; // --- Exposing the C proto living inside Python proto to C code: @@ -2884,7 +2986,7 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { } const Message* PyMessage_GetMessagePointer(PyObject* msg) { - if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + if (!PyObject_TypeCheck(msg, CMessage_Type)) { PyErr_SetString(PyExc_TypeError, "Not a Message instance"); return NULL; } @@ -2893,15 +2995,14 @@ const Message* PyMessage_GetMessagePointer(PyObject* msg) { } Message* PyMessage_GetMutableMessagePointer(PyObject* msg) { - if (!PyObject_TypeCheck(msg, &CMessage_Type)) { + if (!PyObject_TypeCheck(msg, CMessage_Type)) { PyErr_SetString(PyExc_TypeError, "Not a Message instance"); return NULL; } + CMessage* cmsg = reinterpret_cast(msg); - if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) || - (cmsg->extensions != NULL && - PyDict_Size(cmsg->extensions->values) != 0)) { + if (cmsg->composite_fields && !cmsg->composite_fields->empty()) { // There is currently no way of accurately syncing arbitrary changes to // the underlying C++ message back to the CMessage (e.g. removed repeated // composite containers). We only allow direct mutation of the underlying @@ -2945,22 +3046,29 @@ bool InitProto2MessageModule(PyObject *m) { // Initialize constants defined in this file. InitGlobals(); - CMessageClass_Type.tp_base = &PyType_Type; - if (PyType_Ready(&CMessageClass_Type) < 0) { + CMessageClass_Type->tp_base = &PyType_Type; + if (PyType_Ready(CMessageClass_Type) < 0) { return false; } PyModule_AddObject(m, "MessageMeta", - reinterpret_cast(&CMessageClass_Type)); + reinterpret_cast(CMessageClass_Type)); - if (PyType_Ready(&CMessage_Type) < 0) { + if (PyType_Ready(CMessage_Type) < 0) { + return false; + } + if (PyType_Ready(CFieldProperty_Type) < 0) { return false; } // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. - PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); + PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None); + // Invalidate any cached data for the CMessage type. + // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG, + // after we have modified CMessage_Type.tp_dict. + PyType_Modified(CMessage_Type); - PyModule_AddObject(m, "Message", reinterpret_cast(&CMessage_Type)); + PyModule_AddObject(m, "Message", reinterpret_cast(CMessage_Type)); // Initialize Repeated container types. { @@ -3003,6 +3111,22 @@ bool InitProto2MessageModule(PyObject *m) { } } + if (PyType_Ready(&PyUnknownFields_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "UnknownFieldSet", + reinterpret_cast( + &PyUnknownFields_Type)); + + if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) { + return false; + } + + PyModule_AddObject(m, "UnknownField", + reinterpret_cast( + &PyUnknownFieldRef_Type)); + // Initialize Map container types. if (!InitMapContainers()) { return false; diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index d754e62a..e729e448 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -38,6 +38,7 @@ #include #include +#include #include #include @@ -96,26 +97,25 @@ typedef struct CMessage { // made writable, at which point this field is set to false. bool read_only; - // A reference to a Python dictionary containing CMessage, + // A mapping indexed by field, containing CMessage, // RepeatedCompositeContainer, and RepeatedScalarContainer // objects. Used as a cache to make sure we don't have to make a // Python wrapper for the C++ Message objects on every access, or // deal with the synchronization nightmare that could create. - PyObject* composite_fields; + // Also cache extension fields. + // The FieldDescriptor is owned by the message's pool; PyObject references + // are owned. + typedef __gnu_cxx::hash_map + CompositeFieldsMap; + CompositeFieldsMap* composite_fields; - // A reference to the dictionary containing the message's extensions. - // Similar to composite_fields, acting as a cache, but also contains the - // required extension dict logic. - ExtensionDict* extensions; + // A reference to PyUnknownFields. + PyObject* unknown_field_set; // Implements the "weakref" protocol for this object. PyObject* weakreflist; } CMessage; -extern PyTypeObject CMessageClass_Type; -extern PyTypeObject CMessage_Type; - - // The (meta) type of all Messages classes. // It allows us to cache some C++ pointers in the class object itself, they are // faster to extract than from the type's dictionary. @@ -142,6 +142,8 @@ struct CMessageClass { } }; +extern PyTypeObject* CMessageClass_Type; +extern PyTypeObject* CMessage_Type; namespace cmessage { @@ -235,15 +237,13 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg); // has been registered with the same field number on this class. PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle); -// Retrieves an attribute named 'name' from 'self', which is interpreted as a -// CMessage. Returns the attribute value on success, or null on failure. -// -// Returns a new reference. -PyObject* GetAttr(PyObject* self, PyObject* name); - -// Set the value of the attribute named 'name', for 'self', which is interpreted -// as a CMessage, to the value 'value'. Returns -1 on failure. -int SetAttr(PyObject* self, PyObject* name, PyObject* value); +// Get a field from a message. +PyObject* GetFieldValue(CMessage* self, + const FieldDescriptor* field_descriptor); +// Sets the value of a scalar field in a message. +// On error, return -1 with an extension set. +int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor, + PyObject* value); PyObject* FindInitializationErrors(CMessage* self); @@ -357,6 +357,6 @@ extern template bool CheckAndGetInteger(PyObject*, uint64*); } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__ diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc index bacc76a6..efaa2617 100644 --- a/python/google/protobuf/pyext/message_factory.cc +++ b/python/google/protobuf/pyext/message_factory.cc @@ -28,6 +28,8 @@ // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#include + #include #include @@ -137,7 +139,7 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, // This is the same implementation as MessageFactory.GetPrototype(). // Do not create a MessageClass that already exists. - hash_map::iterator it = + std::unordered_map::iterator it = self->classes_by_descriptor->find(descriptor); if (it != self->classes_by_descriptor->end()) { Py_INCREF(it->second); @@ -158,7 +160,7 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self, return NULL; } ScopedPyObjectPtr message_class(PyObject_CallObject( - reinterpret_cast(&CMessageClass_Type), args.get())); + reinterpret_cast(CMessageClass_Type), args.get())); if (message_class == NULL) { return NULL; } diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h index 36092f7e..06444b0a 100644 --- a/python/google/protobuf/pyext/message_factory.h +++ b/python/google/protobuf/pyext/message_factory.h @@ -33,7 +33,7 @@ #include -#include +#include #include #include @@ -66,7 +66,8 @@ struct PyMessageFactory { // // Descriptor pointers stored here are owned by the DescriptorPool above. // Python references to classes are owned by this PyDescriptorPool. - typedef hash_map ClassesByMessageMap; + typedef std::unordered_map + ClassesByMessageMap; ClassesByMessageMap* classes_by_descriptor; }; @@ -98,6 +99,6 @@ bool InitMessageFactory(); } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__ diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc index 8c933866..29d56702 100644 --- a/python/google/protobuf/pyext/message_module.cc +++ b/python/google/protobuf/pyext/message_module.cc @@ -31,47 +31,24 @@ #include #include -#include +#include #include namespace { // C++ API. Clients get at this via proto_api.h -struct ApiImplementation : google::protobuf::python::PyProto_API { - const google::protobuf::Message* - GetMessagePointer(PyObject* msg) const override { - return google::protobuf::python::PyMessage_GetMessagePointer(msg); +struct ApiImplementation : proto2::python::PyProto_API { + const proto2::Message* GetMessagePointer(PyObject* msg) const override { + return proto2::python::PyMessage_GetMessagePointer(msg); } - google::protobuf::Message* - GetMutableMessagePointer(PyObject* msg) const override { - return google::protobuf::python::PyMessage_GetMutableMessagePointer(msg); + proto2::Message* GetMutableMessagePointer(PyObject* msg) const override { + return proto2::python::PyMessage_GetMutableMessagePointer(msg); } }; } // namespace -static PyObject* GetPythonProto3PreserveUnknownsDefault( - PyObject* /*m*/, PyObject* /*args*/) { - if (google::protobuf::internal::GetProto3PreserveUnknownsDefault()) { - Py_RETURN_TRUE; - } else { - Py_RETURN_FALSE; - } -} - -static PyObject* SetPythonProto3PreserveUnknownsDefault( - PyObject* /*m*/, PyObject* arg) { - if (!arg || !PyBool_Check(arg)) { - PyErr_SetString( - PyExc_TypeError, - "Argument to SetPythonProto3PreserveUnknownsDefault must be boolean"); - return NULL; - } - google::protobuf::internal::SetProto3PreserveUnknownsDefault(PyObject_IsTrue(arg)); - Py_RETURN_NONE; -} - static const char module_docstring[] = "python-proto2 is a module that can be used to enhance proto2 Python API\n" "performance.\n" @@ -81,16 +58,9 @@ static const char module_docstring[] = static PyMethodDef ModuleMethods[] = { {"SetAllowOversizeProtos", - (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos, + (PyCFunction)proto2::python::cmessage::SetAllowOversizeProtos, METH_O, "Enable/disable oversize proto parsing."}, // DO NOT USE: For migration and testing only. - {"GetPythonProto3PreserveUnknownsDefault", - (PyCFunction)GetPythonProto3PreserveUnknownsDefault, - METH_NOARGS, "Get Proto3 preserve unknowns default."}, - // DO NOT USE: For migration and testing only. - {"SetPythonProto3PreserveUnknownsDefault", - (PyCFunction)SetPythonProto3PreserveUnknownsDefault, - METH_O, "Enable/disable proto3 unknowns preservation."}, { NULL, NULL} }; @@ -113,35 +83,32 @@ static struct PyModuleDef _module = { #define INITFUNC_ERRORVAL #endif -extern "C" { - PyMODINIT_FUNC INITFUNC(void) { - PyObject* m; +PyMODINIT_FUNC INITFUNC() { + PyObject* m; #if PY_MAJOR_VERSION >= 3 - m = PyModule_Create(&_module); + m = PyModule_Create(&_module); #else - m = Py_InitModule3("_message", ModuleMethods, - module_docstring); + m = Py_InitModule3("_message", ModuleMethods, module_docstring); #endif - if (m == NULL) { - return INITFUNC_ERRORVAL; - } + if (m == NULL) { + return INITFUNC_ERRORVAL; + } - if (!google::protobuf::python::InitProto2MessageModule(m)) { - Py_DECREF(m); - return INITFUNC_ERRORVAL; - } - - // Adds the C++ API - if (PyObject* api = - PyCapsule_New(new ApiImplementation(), - google::protobuf::python::PyProtoAPICapsuleName(), NULL)) { - PyModule_AddObject(m, "proto_API", api); - } else { - return INITFUNC_ERRORVAL; - } + if (!proto2::python::InitProto2MessageModule(m)) { + Py_DECREF(m); + return INITFUNC_ERRORVAL; + } + + // Adds the C++ API + if (PyObject* api = + PyCapsule_New(new ApiImplementation(), + proto2::python::PyProtoAPICapsuleName(), NULL)) { + PyModule_AddObject(m, "proto_API", api); + } else { + return INITFUNC_ERRORVAL; + } #if PY_MAJOR_VERSION >= 3 - return m; + return m; #endif - } } diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 5874d5de..d6bc3d7b 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -61,9 +61,9 @@ namespace repeated_composite_container { // TODO(tibell): We might also want to check: // GOOGLE_CHECK_NOTNULL((self)->owner.get()); -#define GOOGLE_CHECK_ATTACHED(self) \ - do { \ - GOOGLE_CHECK_NOTNULL((self)->message); \ +#define GOOGLE_CHECK_ATTACHED(self) \ + do { \ + GOOGLE_CHECK_NOTNULL((self)->message); \ GOOGLE_CHECK_NOTNULL((self)->parent_field_descriptor); \ } while (0); @@ -152,6 +152,8 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self, cmsg->message = sub_message; cmsg->parent = self->parent; if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) { + message->GetReflection()->RemoveLast( + message, self->parent_field_descriptor); Py_DECREF(cmsg); return NULL; } @@ -210,7 +212,7 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) { } ScopedPyObjectPtr next; while ((next.reset(PyIter_Next(iter.get()))) != NULL) { - if (!PyObject_TypeCheck(next.get(), &CMessage_Type)) { + if (!PyObject_TypeCheck(next.get(), CMessage_Type)) { PyErr_SetString(PyExc_TypeError, "Not a cmessage"); return NULL; } @@ -487,9 +489,9 @@ static PyObject* Pop(PyObject* pself, PyObject* args) { void ReleaseLastTo(CMessage* parent, const FieldDescriptor* field, CMessage* target) { - GOOGLE_CHECK_NOTNULL(parent); - GOOGLE_CHECK_NOTNULL(field); - GOOGLE_CHECK_NOTNULL(target); + GOOGLE_CHECK(parent != nullptr); + GOOGLE_CHECK(field != nullptr); + GOOGLE_CHECK(target != nullptr); CMessage::OwnerRef released_message( parent->message->GetReflection()->ReleaseLast(parent->message, field)); diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h index e5e946aa..464699aa 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.h +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -161,6 +161,6 @@ void ReleaseLastTo(CMessage* parent, } // namespace repeated_composite_container } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index de3b6e14..cdb64269 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -663,6 +663,10 @@ static PyObject* ToStr(PyObject* pself) { return PyObject_Repr(list.get()); } +static PyObject* MergeFrom(PyObject* pself, PyObject* arg) { + return Extend(reinterpret_cast(pself), arg); +} + // The private constructor of RepeatedScalarContainer objects. PyObject *NewContainer( CMessage* parent, const FieldDescriptor* parent_field_descriptor) { @@ -776,6 +780,8 @@ static PyMethodDef Methods[] = { "Removes an object from the repeated container." }, { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS, "Sorts the repeated container."}, + { "MergeFrom", (PyCFunction)MergeFrom, METH_O, + "Merges a repeated container into the current container." }, { NULL, NULL } }; diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h index 559dec98..4dcecbac 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.h +++ b/python/google/protobuf/pyext/repeated_scalar_container.h @@ -104,6 +104,6 @@ void SetOwner(RepeatedScalarContainer* self, } // namespace repeated_scalar_container } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__ diff --git a/python/google/protobuf/pyext/safe_numerics.h b/python/google/protobuf/pyext/safe_numerics.h index 639ba2c8..60112cfa 100644 --- a/python/google/protobuf/pyext/safe_numerics.h +++ b/python/google/protobuf/pyext/safe_numerics.h @@ -159,6 +159,6 @@ inline Dest checked_numeric_cast(Source source) { } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__ diff --git a/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h b/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h index ad804b5f..79fa9e3d 100644 --- a/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h +++ b/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h @@ -99,6 +99,6 @@ class ThreadUnsafeSharedPtr { } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_CPP_THREAD_UNSAFE_SHARED_PTR_H__ diff --git a/python/google/protobuf/python_protobuf.h b/python/google/protobuf/python_protobuf.h index beb6e460..8db1ffb7 100644 --- a/python/google/protobuf/python_protobuf.h +++ b/python/google/protobuf/python_protobuf.h @@ -52,6 +52,6 @@ Message* MutableCProtoInsidePyProto(PyObject* msg); } // namespace python } // namespace protobuf - } // namespace google + #endif // GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__ diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index f4ce8caf..81e18859 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -48,25 +48,23 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -from google.protobuf.internal import api_implementation -from google.protobuf import message - - -if api_implementation.Type() == 'cpp': - from google.protobuf.pyext import cpp_message as message_impl -else: - from google.protobuf.internal import python_message as message_impl +from google.protobuf import message_factory +from google.protobuf import symbol_database # The type of all Message classes. # Part of the public interface, but normally only used by message factories. -GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType +GeneratedProtocolMessageType = message_factory._GENERATED_PROTOCOL_MESSAGE_TYPE MESSAGE_CLASS_CACHE = {} +# Deprecated. Please NEVER use reflection.ParseMessage(). def ParseMessage(descriptor, byte_str): """Generate a new Message instance from this Descriptor and a byte string. + DEPRECATED: ParseMessage is deprecated because it is using MakeClass(). + Please use MessageFactory.GetPrototype() instead. + Args: descriptor: Protobuf Descriptor object byte_str: Serialized protocol buffer byte string @@ -80,42 +78,18 @@ def ParseMessage(descriptor, byte_str): return new_msg +# Deprecated. Please NEVER use reflection.MakeClass(). def MakeClass(descriptor): """Construct a class object for a protobuf described by descriptor. - Composite descriptors are handled by defining the new class as a member of the - parent class, recursing as deep as necessary. - This is the dynamic equivalent to: - - class Parent(message.Message): - __metaclass__ = GeneratedProtocolMessageType - DESCRIPTOR = descriptor - class Child(message.Message): - __metaclass__ = GeneratedProtocolMessageType - DESCRIPTOR = descriptor.nested_types[0] - - Sample usage: - file_descriptor = descriptor_pb2.FileDescriptorProto() - file_descriptor.ParseFromString(proto2_string) - msg_descriptor = descriptor.MakeDescriptor(file_descriptor.message_type[0]) - msg_class = reflection.MakeClass(msg_descriptor) - msg = msg_class() + DEPRECATED: use MessageFactory.GetPrototype() instead. Args: descriptor: A descriptor.Descriptor object describing the protobuf. Returns: The Message class object described by the descriptor. """ - if descriptor in MESSAGE_CLASS_CACHE: - return MESSAGE_CLASS_CACHE[descriptor] - - attributes = {} - for name, nested_type in descriptor.nested_types_by_name.items(): - attributes[name] = MakeClass(nested_type) - - attributes[GeneratedProtocolMessageType._DESCRIPTOR_KEY] = descriptor - - result = GeneratedProtocolMessageType( - str(descriptor.name), (message.Message,), attributes) - MESSAGE_CLASS_CACHE[descriptor] = result - return result + # Original implementation leads to duplicate message classes, which won't play + # well with extensions. Message factory info is also missing. + # Redirect to message_factory. + return symbol_database.Default().GetPrototype(descriptor) diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py index 98995638..39898765 100644 --- a/python/google/protobuf/text_encoding.py +++ b/python/google/protobuf/text_encoding.py @@ -33,59 +33,70 @@ import re import six -# Lookup table for utf8 -_cescape_utf8_to_str = [chr(i) for i in range(0, 256)] -_cescape_utf8_to_str[9] = r'\t' # optional escape -_cescape_utf8_to_str[10] = r'\n' # optional escape -_cescape_utf8_to_str[13] = r'\r' # optional escape -_cescape_utf8_to_str[39] = r"\'" # optional escape - -_cescape_utf8_to_str[34] = r'\"' # necessary escape -_cescape_utf8_to_str[92] = r'\\' # necessary escape +_cescape_chr_to_symbol_map = {} +_cescape_chr_to_symbol_map[9] = r'\t' # optional escape +_cescape_chr_to_symbol_map[10] = r'\n' # optional escape +_cescape_chr_to_symbol_map[13] = r'\r' # optional escape +_cescape_chr_to_symbol_map[34] = r'\"' # necessary escape +_cescape_chr_to_symbol_map[39] = r"\'" # optional escape +_cescape_chr_to_symbol_map[92] = r'\\' # necessary escape + +# Lookup table for unicode +_cescape_unicode_to_str = [chr(i) for i in range(0, 256)] +for byte, string in _cescape_chr_to_symbol_map.items(): + _cescape_unicode_to_str[byte] = string # Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32) _cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] + [chr(i) for i in range(32, 127)] + [r'\%03o' % i for i in range(127, 256)]) -_cescape_byte_to_str[9] = r'\t' # optional escape -_cescape_byte_to_str[10] = r'\n' # optional escape -_cescape_byte_to_str[13] = r'\r' # optional escape -_cescape_byte_to_str[39] = r"\'" # optional escape - -_cescape_byte_to_str[34] = r'\"' # necessary escape -_cescape_byte_to_str[92] = r'\\' # necessary escape +for byte, string in _cescape_chr_to_symbol_map.items(): + _cescape_byte_to_str[byte] = string +del byte, string def CEscape(text, as_utf8): - """Escape a bytes string for use in an ascii protocol buffer. - - text.encode('string_escape') does not seem to satisfy our needs as it - encodes unprintable characters using two-digit hex escapes whereas our - C++ unescaping function allows hex escapes to be any length. So, - "\0011".encode('string_escape') ends up being "\\x011", which will be - decoded in C++ as a single-character string with char code 0x11. + # type: (...) -> str + """Escape a bytes string for use in an text protocol buffer. Args: - text: A byte string to be escaped - as_utf8: Specifies if result should be returned in UTF-8 encoding + text: A byte string to be escaped. + as_utf8: Specifies if result may contain non-ASCII characters. + In Python 3 this allows unescaped non-ASCII Unicode characters. + In Python 2 the return value will be valid UTF-8 rather than only ASCII. Returns: - Escaped string + Escaped string (str). """ - # PY3 hack: make Ord work for str and bytes: - # //platforms/networking/data uses unicode here, hence basestring. - Ord = ord if isinstance(text, six.string_types) else lambda x: x + # Python's text.encode() 'string_escape' or 'unicode_escape' codecs do not + # satisfy our needs; they encodes unprintable characters using two-digit hex + # escapes whereas our C++ unescaping function allows hex escapes to be any + # length. So, "\0011".encode('string_escape') ends up being "\\x011", which + # will be decoded in C++ as a single-character string with char code 0x11. + if six.PY3: + text_is_unicode = isinstance(text, str) + if as_utf8 and text_is_unicode: + # We're already unicode, no processing beyond control char escapes. + return text.translate(_cescape_chr_to_symbol_map) + ord_ = ord if text_is_unicode else lambda x: x # bytes iterate as ints. + else: + ord_ = ord # PY2 if as_utf8: - return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text) - return ''.join(_cescape_byte_to_str[Ord(c)] for c in text) + return ''.join(_cescape_unicode_to_str[ord_(c)] for c in text) + return ''.join(_cescape_byte_to_str[ord_(c)] for c in text) _CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])') -_cescape_highbit_to_str = ([chr(i) for i in range(0, 127)] + - [r'\%03o' % i for i in range(127, 256)]) def CUnescape(text): - """Unescape a text string with C-style escape sequences to UTF-8 bytes.""" + # type: (str) -> bytes + """Unescape a text string with C-style escape sequences to UTF-8 bytes. + + Args: + text: The data to parse in a str. + Returns: + A byte string. + """ def ReplaceHex(m): # Only replace the match if the number of leading back slashes is odd. i.e. @@ -98,10 +109,9 @@ def CUnescape(text): # allow single-digit hex escapes (like '\xf'). result = _CUNESCAPE_HEX.sub(ReplaceHex, text) - if str is bytes: # PY2 + if six.PY2: return result.decode('string_escape') - result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result) - return (result.encode('ascii') # Make it bytes to allow decode. + return (result.encode('utf-8') # PY3: Make it bytes to allow decode. .decode('unicode_escape') # Make it bytes again to return the proper type. .encode('raw_unicode_escape')) diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index 2cbd21bc..5dd41830 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -55,15 +55,15 @@ from google.protobuf.internal import type_checkers from google.protobuf import descriptor from google.protobuf import text_encoding -__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue', - 'Merge'] +__all__ = ['MessageToString', 'Parse', 'PrintMessage', 'PrintField', + 'PrintFieldValue', 'Merge', 'MessageToBytes'] _INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), type_checkers.Int32ValueChecker(), type_checkers.Uint64ValueChecker(), type_checkers.Int64ValueChecker()) -_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) -_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) +_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE) +_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE) _FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT, descriptor.FieldDescriptor.CPPTYPE_DOUBLE]) _QUOTES = frozenset(("'", '"')) @@ -121,6 +121,7 @@ class TextWriter(object): def MessageToString(message, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, @@ -128,6 +129,7 @@ def MessageToString(message, descriptor_pool=None, indent=0, message_formatter=None): + # type: (...) -> str """Convert protobuf message to text format. Floating point values can be formatted compactly with 15 digits of @@ -137,8 +139,11 @@ def MessageToString(message, Args: message: The protocol buffers message. - as_utf8: Produce text output in UTF8 format. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than only ASCII. as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. use_index_order: If True, fields of a proto message will be printed using @@ -159,7 +164,8 @@ def MessageToString(message, A string of the text formatted protocol buffer message. """ out = TextWriter(as_utf8) - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, use_field_number, descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -170,6 +176,16 @@ def MessageToString(message, return result +def MessageToBytes(message, **kwargs): + # type: (...) -> bytes + """Convert protobuf message to encoded text format. See MessageToString.""" + text = MessageToString(message, **kwargs) + if isinstance(text, bytes): + return text + codec = 'utf-8' if kwargs.get('as_utf8') else 'ascii' + return text.encode(codec) + + def _IsMapEntry(field): return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and field.message_type.has_options and @@ -181,13 +197,15 @@ def PrintMessage(message, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, use_field_number=False, descriptor_pool=None, message_formatter=None): - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, use_field_number, descriptor_pool, message_formatter) printer.PrintMessage(message) @@ -199,12 +217,14 @@ def PrintField(field, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, message_formatter=None): """Print a single field name/value pair.""" - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, message_formatter) printer.PrintField(field, value) @@ -215,12 +235,14 @@ def PrintFieldValue(field, indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, message_formatter=None): """Print a single field value (not including name).""" - printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets, + printer = _Printer(out, indent, as_utf8, as_one_line, + use_short_repeated_primitives, pointy_brackets, use_index_order, float_format, message_formatter) printer.PrintFieldValue(field, value) @@ -258,6 +280,7 @@ class _Printer(object): indent=0, as_utf8=False, as_one_line=False, + use_short_repeated_primitives=False, pointy_brackets=False, use_index_order=False, float_format=None, @@ -274,8 +297,11 @@ class _Printer(object): Args: out: To record the text format result. indent: The indent level for pretty print. - as_utf8: Produce text output in UTF8 format. + as_utf8: Return unescaped Unicode for non-ASCII characters. + In Python 3 actual Unicode characters may appear as is in strings. + In Python 2 the return value will be valid UTF-8 rather than ASCII. as_one_line: Don't introduce newlines between fields. + use_short_repeated_primitives: Use short repeated format for primitives. pointy_brackets: If True, use angle brackets instead of curly braces for nesting. use_index_order: If True, print fields of a proto message using the order @@ -294,6 +320,7 @@ class _Printer(object): self.indent = indent self.as_utf8 = as_utf8 self.as_one_line = as_one_line + self.use_short_repeated_primitives = use_short_repeated_primitives self.pointy_brackets = pointy_brackets self.use_index_order = use_index_order self.float_format = float_format @@ -351,13 +378,18 @@ class _Printer(object): entry_submsg = value.GetEntryClass()(key=key, value=value[key]) self.PrintField(field, entry_submsg) elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED: - for element in value: - self.PrintField(field, element) + if (self.use_short_repeated_primitives + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE + and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING): + self._PrintShortRepeatedPrimitivesValue(field, value) + else: + for element in value: + self.PrintField(field, element) else: self.PrintField(field, value) - def PrintField(self, field, value): - """Print a single field name/value pair.""" + def _PrintFieldName(self, field): + """Print field name.""" out = self.out out.write(' ' * self.indent) if self.use_field_number: @@ -383,11 +415,22 @@ class _Printer(object): # don't include it. out.write(': ') + def PrintField(self, field, value): + """Print a single field name/value pair.""" + self._PrintFieldName(field) self.PrintFieldValue(field, value) - if self.as_one_line: - out.write(' ') - else: - out.write('\n') + self.out.write(' ' if self.as_one_line else '\n') + + def _PrintShortRepeatedPrimitivesValue(self, field, value): + # Note: this is called only when value has at least one element. + self._PrintFieldName(field) + self.out.write('[') + for i in xrange(len(value) - 1): + self.PrintFieldValue(field, value[i]) + self.out.write(', ') + self.PrintFieldValue(field, value[-1]) + self.out.write(']') + self.out.write(' ' if self.as_one_line else '\n') def _PrintMessageFieldValue(self, value): if self.pointy_brackets: @@ -428,12 +471,12 @@ class _Printer(object): out.write(str(value)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: out.write('\"') - if isinstance(value, six.text_type): + if isinstance(value, six.text_type) and (six.PY2 or not self.as_utf8): out_value = value.encode('utf-8') else: out_value = value if field.type == descriptor.FieldDescriptor.TYPE_BYTES: - # We need to escape non-UTF8 chars in TYPE_BYTES field. + # We always need to escape all binary data in TYPE_BYTES fields. out_as_utf8 = False else: out_as_utf8 = self.as_utf8 @@ -487,12 +530,7 @@ def Parse(text, Raises: ParseError: On text parsing problems. """ - if not isinstance(text, str): - if six.PY3: - text = text.decode('utf-8') - else: - text = text.encode('utf-8') - return ParseLines(text.split('\n'), + return ParseLines(text.split(b'\n' if isinstance(text, bytes) else u'\n'), message, allow_unknown_extension, allow_field_number, @@ -523,13 +561,8 @@ def Merge(text, Raises: ParseError: On text parsing problems. """ - if not isinstance(text, str): - if six.PY3: - text = text.decode('utf-8') - else: - text = text.encode('utf-8') return MergeLines( - text.split('\n'), + text.split(b'\n' if isinstance(text, bytes) else u'\n'), message, allow_unknown_extension, allow_field_number, @@ -570,6 +603,9 @@ def MergeLines(lines, descriptor_pool=None): """Parses a text representation of a protocol message into a message. + Like ParseLines(), but allows repeated values for a non-repeated field, and + uses the last one. + Args: lines: An iterable of lines of a message's text representation. message: A protocol buffer message to merge into. @@ -601,22 +637,12 @@ class _Parser(object): self.allow_field_number = allow_field_number self.descriptor_pool = descriptor_pool - def ParseFromString(self, text, message): - """Parses a text representation of a protocol message into a message.""" - if not isinstance(text, str): - text = text.decode('utf-8') - return self.ParseLines(text.split('\n'), message) - def ParseLines(self, lines, message): """Parses a text representation of a protocol message into a message.""" self._allow_multiple_scalars = False self._ParseOrMerge(lines, message) return message - def MergeFromString(self, text, message): - """Merges a text representation of a protocol message into a message.""" - return self._MergeLines(text.split('\n'), message) - def MergeLines(self, lines, message): """Merges a text representation of a protocol message into a message.""" self._allow_multiple_scalars = True @@ -633,7 +659,14 @@ class _Parser(object): Raises: ParseError: On text parsing problems. """ - tokenizer = Tokenizer(lines) + # Tokenize expects native str lines. + if six.PY2: + str_lines = (line if isinstance(line, str) else line.encode('utf-8') + for line in lines) + else: + str_lines = (line if isinstance(line, str) else line.decode('utf-8') + for line in lines) + tokenizer = Tokenizer(str_lines) while not tokenizer.AtEnd(): self._MergeField(tokenizer, message) @@ -1019,7 +1052,9 @@ class Tokenizer(object): r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number ] + [ # quoted str for each quote mark - r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES + # Avoid backtracking! https://stackoverflow.com/a/844267 + r'{qt}[^{qt}\n\\]*((\\.)+[^{qt}\n\\]*)*({qt}|\\?$)'.format(qt=mark) + for mark in _QUOTES ])) _IDENTIFIER = re.compile(r'[^\d\W]\w*') @@ -1316,7 +1351,8 @@ class Tokenizer(object): def ParseError(self, message): """Creates and *returns* a ParseError for the current token.""" - return ParseError(message, self._line + 1, self._column + 1) + return ParseError('\'' + self._current_line + '\': ' + message, + self._line + 1, self._column + 1) def _StringParseError(self, e): return self.ParseError('Couldn\'t parse string: ' + str(e)) @@ -1490,6 +1526,12 @@ def _ParseAbstractInteger(text, is_long=False): ValueError: Thrown Iff the text is not a valid integer. """ # Do the actual parsing. Exception handling is propagated to caller. + orig_text = text + c_octal_match = re.match(r'(-?)0(\d+)$', text) + if c_octal_match: + # Python 3 no longer supports 0755 octal syntax without the 'o', so + # we always use the '0o' prefix for multi-digit numbers starting with 0. + text = c_octal_match.group(1) + '0o' + c_octal_match.group(2) try: # We force 32-bit values to int and 64-bit values to long to make # alternate implementations where the distinction is more significant @@ -1499,7 +1541,7 @@ def _ParseAbstractInteger(text, is_long=False): else: return int(text, 0) except ValueError: - raise ValueError('Couldn\'t parse integer: %s' % text) + raise ValueError('Couldn\'t parse integer: %s' % orig_text) def ParseFloat(text): -- cgit v1.2.3