aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorFeng Xiao <xfxyjwf@gmail.com>2018-08-08 17:00:41 -0700
committerFeng Xiao <xfxyjwf@gmail.com>2018-08-08 17:00:41 -0700
commit6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3 (patch)
treee575738adf52d24b883cca5e8928a5ded31caba1 /python
parente7746f487cb9cca685ffb1b3d7dccc5554b618a4 (diff)
downloadprotobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.tar.gz
protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.tar.bz2
protobuf-6bbe197e9c1b6fc38cbdc45e3bf83fa7ced792a3.zip
Down-integrate from google3.
Diffstat (limited to 'python')
-rwxr-xr-xpython/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py9
-rw-r--r--python/google/protobuf/descriptor_database.py11
-rw-r--r--python/google/protobuf/descriptor_pool.py196
-rwxr-xr-xpython/google/protobuf/internal/__init__.py30
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py26
-rwxr-xr-xpython/google/protobuf/internal/containers.py127
-rwxr-xr-xpython/google/protobuf/internal/decoder.py192
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py8
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py63
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py11
-rw-r--r--python/google/protobuf/internal/factory_test1.proto14
-rw-r--r--python/google/protobuf/internal/message_factory_test.py4
-rwxr-xr-xpython/google/protobuf/internal/message_test.py262
-rw-r--r--python/google/protobuf/internal/no_package.proto30
-rwxr-xr-xpython/google/protobuf/internal/python_message.py150
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py53
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py298
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py8
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py165
-rw-r--r--python/google/protobuf/json_format.py2
-rwxr-xr-xpython/google/protobuf/message.py4
-rw-r--r--python/google/protobuf/message_factory.py15
-rw-r--r--python/google/protobuf/proto_api.h11
-rw-r--r--python/google/protobuf/pyext/descriptor.cc51
-rw-r--r--python/google/protobuf/pyext/descriptor.h2
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.cc12
-rw-r--r--python/google/protobuf/pyext/descriptor_containers.h2
-rw-r--r--python/google/protobuf/pyext/descriptor_database.cc39
-rw-r--r--python/google/protobuf/pyext/descriptor_database.h9
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.cc47
-rw-r--r--python/google/protobuf/pyext/descriptor_pool.h6
-rw-r--r--python/google/protobuf/pyext/extension_dict.cc92
-rw-r--r--python/google/protobuf/pyext/extension_dict.h20
-rw-r--r--python/google/protobuf/pyext/map_container.cc148
-rw-r--r--python/google/protobuf/pyext/map_container.h2
-rw-r--r--python/google/protobuf/pyext/message.cc614
-rw-r--r--python/google/protobuf/pyext/message.h40
-rw-r--r--python/google/protobuf/pyext/message_factory.cc6
-rw-r--r--python/google/protobuf/pyext/message_factory.h7
-rw-r--r--python/google/protobuf/pyext/message_module.cc89
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.cc16
-rw-r--r--python/google/protobuf/pyext/repeated_composite_container.h2
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.cc6
-rw-r--r--python/google/protobuf/pyext/repeated_scalar_container.h2
-rw-r--r--python/google/protobuf/pyext/safe_numerics.h2
-rw-r--r--python/google/protobuf/pyext/thread_unsafe_shared_ptr.h2
-rw-r--r--python/google/protobuf/python_protobuf.h2
-rwxr-xr-xpython/google/protobuf/reflection.py52
-rw-r--r--python/google/protobuf/text_encoding.py84
-rwxr-xr-xpython/google/protobuf/text_format.py134
50 files changed, 2328 insertions, 849 deletions
diff --git a/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py b/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py
index 8267cd2c..bc53e256 100755
--- a/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py
+++ b/python/compatibility_tests/v2.5.0/tests/google/protobuf/internal/text_format_test.py
@@ -410,7 +410,8 @@ class TextFormatTest(unittest.TestCase):
text = 'optional_nested_enum: BARR'
self.assertRaisesWithMessage(
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
+ ('1:23 : \'optional_nested_enum: BARR\': '
+ 'Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value named BARR.'),
text_format.Merge, text, message)
@@ -418,7 +419,8 @@ class TextFormatTest(unittest.TestCase):
text = 'optional_nested_enum: 100'
self.assertRaisesWithMessage(
text_format.ParseError,
- ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
+ ('1:23 : \'optional_nested_enum: 100\': '
+ 'Enum type "protobuf_unittest.TestAllTypes.NestedEnum" '
'has no value with number 100.'),
text_format.Merge, text, message)
@@ -427,7 +429,8 @@ class TextFormatTest(unittest.TestCase):
text = 'optional_int32: bork'
self.assertRaisesWithMessage(
text_format.ParseError,
- ('1:17 : Couldn\'t parse integer: bork'),
+ ('1:17 : \'optional_int32: bork\': '
+ 'Couldn\'t parse integer: bork'),
text_format.Merge, text, message)
def testMergeStringFieldUnescape(self):
diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py
index 8b7715cd..a7616cbc 100644
--- a/python/google/protobuf/descriptor_database.py
+++ b/python/google/protobuf/descriptor_database.py
@@ -76,6 +76,9 @@ class DescriptorDatabase(object):
self._AddSymbol(name, file_desc_proto)
for enum in file_desc_proto.enum_type:
self._AddSymbol(('.'.join((package, enum.name))), file_desc_proto)
+ for enum_value in enum.value:
+ self._file_desc_protos_by_symbol[
+ '.'.join((package, enum_value.name))] = file_desc_proto
for extension in file_desc_proto.extension:
self._AddSymbol(('.'.join((package, extension.name))), file_desc_proto)
for service in file_desc_proto.service:
@@ -133,6 +136,14 @@ class DescriptorDatabase(object):
top_level, _, _ = symbol.rpartition('.')
return self._file_desc_protos_by_symbol[top_level]
+ def FindFileContainingExtension(self, extendee_name, extension_number):
+ # TODO(jieluo): implement this API.
+ return None
+
+ def FindAllExtensionNumbers(self, extendee_name):
+ # TODO(jieluo): implement this API.
+ return []
+
def _AddSymbol(self, name, file_desc_proto):
if name in self._file_desc_protos_by_symbol:
warn_msg = ('Conflict register for file "' + file_desc_proto.name +
diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py
index 8983f76f..42f7bcb5 100644
--- a/python/google/protobuf/descriptor_pool.py
+++ b/python/google/protobuf/descriptor_pool.py
@@ -131,33 +131,46 @@ class DescriptorPool(object):
# TODO(jieluo): Remove _file_desc_by_toplevel_extension after
# maybe year 2020 for compatibility issue (with 3.4.1 only).
self._file_desc_by_toplevel_extension = {}
+ self._top_enum_values = {}
# We store extensions in two two-level mappings: The first key is the
# descriptor of the message being extended, the second key is the extension
# full name or its tag number.
self._extensions_by_name = collections.defaultdict(dict)
self._extensions_by_number = collections.defaultdict(dict)
- def _CheckConflictRegister(self, desc):
+ def _CheckConflictRegister(self, desc, desc_name, file_name):
"""Check if the descriptor name conflicts with another of the same name.
Args:
- desc: Descriptor of a message, enum, service or extension.
+ desc: Descriptor of a message, enum, service, extension or enum value.
+ desc_name: the full name of desc.
+ file_name: The file name of descriptor.
"""
- desc_name = desc.full_name
for register, descriptor_type in [
(self._descriptors, descriptor.Descriptor),
(self._enum_descriptors, descriptor.EnumDescriptor),
(self._service_descriptors, descriptor.ServiceDescriptor),
- (self._toplevel_extensions, descriptor.FieldDescriptor)]:
+ (self._toplevel_extensions, descriptor.FieldDescriptor),
+ (self._top_enum_values, descriptor.EnumValueDescriptor)]:
if desc_name in register:
- file_name = register[desc_name].file.name
+ old_desc = register[desc_name]
+ if isinstance(old_desc, descriptor.EnumValueDescriptor):
+ old_file = old_desc.type.file.name
+ else:
+ old_file = old_desc.file.name
+
if not isinstance(desc, descriptor_type) or (
- file_name != desc.file.name):
- warn_msg = ('Conflict register for file "' + desc.file.name +
+ old_file != file_name):
+ warn_msg = ('Conflict register for file "' + file_name +
'": ' + desc_name +
' is already defined in file "' +
- file_name + '"')
+ old_file + '"')
+ if isinstance(desc, descriptor.EnumValueDescriptor):
+ warn_msg += ('\nNote: enum values appear as '
+ 'siblings of the enum type instead of '
+ 'children of it.')
warnings.warn(warn_msg, RuntimeWarning)
+
return
def Add(self, file_desc_proto):
@@ -196,7 +209,7 @@ class DescriptorPool(object):
if not isinstance(desc, descriptor.Descriptor):
raise TypeError('Expected instance of descriptor.Descriptor.')
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._descriptors[desc.full_name] = desc
self._AddFileDescriptor(desc.file)
@@ -213,8 +226,26 @@ class DescriptorPool(object):
if not isinstance(enum_desc, descriptor.EnumDescriptor):
raise TypeError('Expected instance of descriptor.EnumDescriptor.')
- self._CheckConflictRegister(enum_desc)
+ file_name = enum_desc.file.name
+ self._CheckConflictRegister(enum_desc, enum_desc.full_name, file_name)
self._enum_descriptors[enum_desc.full_name] = enum_desc
+
+ # Top enum values need to be indexed.
+ # Count the number of dots to see whether the enum is toplevel or nested
+ # in a message. We cannot use enum_desc.containing_type at this stage.
+ if enum_desc.file.package:
+ top_level = (enum_desc.full_name.count('.')
+ - enum_desc.file.package.count('.') == 1)
+ else:
+ top_level = enum_desc.full_name.count('.') == 0
+ if top_level:
+ file_name = enum_desc.file.name
+ package = enum_desc.file.package
+ for enum_value in enum_desc.values:
+ full_name = _NormalizeFullyQualifiedName(
+ '.'.join((package, enum_value.name)))
+ self._CheckConflictRegister(enum_value, full_name, file_name)
+ self._top_enum_values[full_name] = enum_value
self._AddFileDescriptor(enum_desc.file)
def AddServiceDescriptor(self, service_desc):
@@ -227,7 +258,8 @@ class DescriptorPool(object):
if not isinstance(service_desc, descriptor.ServiceDescriptor):
raise TypeError('Expected instance of descriptor.ServiceDescriptor.')
- self._CheckConflictRegister(service_desc)
+ self._CheckConflictRegister(service_desc, service_desc.full_name,
+ service_desc.file.name)
self._service_descriptors[service_desc.full_name] = service_desc
def AddExtensionDescriptor(self, extension):
@@ -247,7 +279,6 @@ class DescriptorPool(object):
raise TypeError('Expected an extension descriptor.')
if extension.extension_scope is None:
- self._CheckConflictRegister(extension)
self._toplevel_extensions[extension.full_name] = extension
try:
@@ -349,6 +380,30 @@ class DescriptorPool(object):
symbol = _NormalizeFullyQualifiedName(symbol)
try:
+ return self._InternalFindFileContainingSymbol(symbol)
+ except KeyError:
+ pass
+
+ try:
+ # Try fallback database. Build and find again if possible.
+ self._FindFileContainingSymbolInDb(symbol)
+ return self._InternalFindFileContainingSymbol(symbol)
+ except KeyError:
+ raise KeyError('Cannot find a file containing %s' % symbol)
+
+ def _InternalFindFileContainingSymbol(self, symbol):
+ """Gets the already built FileDescriptor containing the specified symbol.
+
+ Args:
+ symbol: The name of the symbol to search for.
+
+ Returns:
+ A FileDescriptor that contains the specified symbol.
+
+ Raises:
+ KeyError: if the file cannot be found in the pool.
+ """
+ try:
return self._descriptors[symbol].file
except KeyError:
pass
@@ -364,7 +419,7 @@ class DescriptorPool(object):
pass
try:
- return self._FindFileContainingSymbolInDb(symbol)
+ return self._top_enum_values[symbol].type.file
except KeyError:
pass
@@ -373,13 +428,15 @@ class DescriptorPool(object):
except KeyError:
pass
- # Try nested extensions inside a message.
- message_name, _, extension_name = symbol.rpartition('.')
+ # Try fields, enum values and nested extensions inside a message.
+ top_name, _, sub_name = symbol.rpartition('.')
try:
- message = self.FindMessageTypeByName(message_name)
- assert message.extensions_by_name[extension_name]
+ message = self.FindMessageTypeByName(top_name)
+ assert (sub_name in message.extensions_by_name or
+ sub_name in message.fields_by_name or
+ sub_name in message.enum_values_by_name)
return message.file
- except KeyError:
+ except (KeyError, AssertionError):
raise KeyError('Cannot find a file containing %s' % symbol)
def FindMessageTypeByName(self, full_name):
@@ -499,7 +556,11 @@ class DescriptorPool(object):
KeyError: when no extension with the given number is known for the
specified message.
"""
- return self._extensions_by_number[message_descriptor][number]
+ try:
+ return self._extensions_by_number[message_descriptor][number]
+ except KeyError:
+ self._TryLoadExtensionFromDB(message_descriptor, number)
+ return self._extensions_by_number[message_descriptor][number]
def FindAllExtensions(self, message_descriptor):
"""Gets all the known extension of a given message.
@@ -513,8 +574,57 @@ class DescriptorPool(object):
Returns:
A list of FieldDescriptor describing the extensions.
"""
+ # Fallback to descriptor db if FindAllExtensionNumbers is provided.
+ if self._descriptor_db and hasattr(
+ self._descriptor_db, 'FindAllExtensionNumbers'):
+ full_name = message_descriptor.full_name
+ all_numbers = self._descriptor_db.FindAllExtensionNumbers(full_name)
+ for number in all_numbers:
+ if number in self._extensions_by_number[message_descriptor]:
+ continue
+ self._TryLoadExtensionFromDB(message_descriptor, number)
+
return list(self._extensions_by_number[message_descriptor].values())
+ def _TryLoadExtensionFromDB(self, message_descriptor, number):
+ """Try to Load extensions from decriptor db.
+
+ Args:
+ message_descriptor: descriptor of the extended message.
+ number: the extension number that needs to be loaded.
+ """
+ if not self._descriptor_db:
+ return
+ # Only supported when FindFileContainingExtension is provided.
+ if not hasattr(
+ self._descriptor_db, 'FindFileContainingExtension'):
+ return
+
+ full_name = message_descriptor.full_name
+ file_proto = self._descriptor_db.FindFileContainingExtension(
+ full_name, number)
+
+ if file_proto is None:
+ return
+
+ try:
+ file_desc = self._ConvertFileProtoToFileDescriptor(file_proto)
+ for extension in file_desc.extensions_by_name.values():
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+ for message_type in file_desc.message_types_by_name.values():
+ for extension in message_type.extensions:
+ self._extensions_by_number[extension.containing_type][
+ extension.number] = extension
+ self._extensions_by_name[extension.containing_type][
+ extension.full_name] = extension
+ except:
+ warn_msg = ('Unable to load proto file %s for extension number %d.' %
+ (file_proto.name, number))
+ warnings.warn(warn_msg, RuntimeWarning)
+
def FindServiceByName(self, full_name):
"""Loads the named service descriptor from the pool.
@@ -532,6 +642,23 @@ class DescriptorPool(object):
self._FindFileContainingSymbolInDb(full_name)
return self._service_descriptors[full_name]
+ def FindMethodByName(self, full_name):
+ """Loads the named service method descriptor from the pool.
+
+ Args:
+ full_name: The full name of the method descriptor to load.
+
+ Returns:
+ The method descriptor for the service method.
+
+ Raises:
+ KeyError: if the method cannot be found in the pool.
+ """
+ full_name = _NormalizeFullyQualifiedName(full_name)
+ service_name, _, method_name = full_name.rpartition('.')
+ service_descriptor = self.FindServiceByName(service_name)
+ return service_descriptor.methods_by_name[method_name]
+
def _FindFileContainingSymbolInDb(self, symbol):
"""Finds the file in descriptor DB containing the specified symbol.
@@ -567,7 +694,6 @@ class DescriptorPool(object):
Returns:
A FileDescriptor matching the passed in proto.
"""
-
if file_proto.name not in self._file_descriptors:
built_deps = list(self._GetDeps(file_proto.dependency))
direct_deps = [self.FindFileByName(n) for n in file_proto.dependency]
@@ -604,7 +730,7 @@ class DescriptorPool(object):
for enum_type in file_proto.enum_type:
file_descriptor.enum_types_by_name[enum_type.name] = (
self._ConvertEnumDescriptor(enum_type, file_proto.package,
- file_descriptor, None, scope))
+ file_descriptor, None, scope, True))
for index, extension_proto in enumerate(file_proto.extension):
extension_desc = self._MakeFieldDescriptor(
@@ -616,6 +742,8 @@ class DescriptorPool(object):
file_descriptor.package, scope)
file_descriptor.extensions_by_name[extension_desc.name] = (
extension_desc)
+ self._file_desc_by_toplevel_extension[extension_desc.full_name] = (
+ file_descriptor)
for desc_proto in file_proto.message_type:
self._SetAllFieldTypes(file_proto.package, desc_proto, scope)
@@ -673,7 +801,8 @@ class DescriptorPool(object):
nested, desc_name, file_desc, scope, syntax)
for nested in desc_proto.nested_type]
enums = [
- self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope)
+ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None,
+ scope, False)
for enum in desc_proto.enum_type]
fields = [self._MakeFieldDescriptor(field, desc_name, index, file_desc)
for index, field in enumerate(desc_proto.field)]
@@ -718,12 +847,12 @@ class DescriptorPool(object):
fields[field_index].containing_oneof = oneofs[oneof_index]
scope[_PrefixWithDot(desc_name)] = desc
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._descriptors[desc_name] = desc
return desc
def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None,
- containing_type=None, scope=None):
+ containing_type=None, scope=None, top_level=False):
"""Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf.
Args:
@@ -732,6 +861,8 @@ class DescriptorPool(object):
file_desc: The file containing the enum descriptor.
containing_type: The type containing this enum.
scope: Scope containing available types.
+ top_level: If True, the enum is a top level symbol. If False, the enum
+ is defined inside a message.
Returns:
The added descriptor
@@ -757,8 +888,17 @@ class DescriptorPool(object):
containing_type=containing_type,
options=_OptionsOrNone(enum_proto))
scope['.%s' % enum_name] = desc
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._enum_descriptors[enum_name] = desc
+
+ # Add top level enum values.
+ if top_level:
+ for value in values:
+ full_name = _NormalizeFullyQualifiedName(
+ '.'.join((package, value.name)))
+ self._CheckConflictRegister(value, full_name, file_name)
+ self._top_enum_values[full_name] = value
+
return desc
def _MakeFieldDescriptor(self, field_proto, message_name, index,
@@ -885,6 +1025,8 @@ class DescriptorPool(object):
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = text_encoding.CUnescape(
field_proto.default_value)
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
+ field_desc.default_value = None
else:
# All other types are of the "int" type.
field_desc.default_value = int(field_proto.default_value)
@@ -901,6 +1043,8 @@ class DescriptorPool(object):
field_desc.default_value = field_desc.enum_type.values[0].number
elif field_proto.type == descriptor.FieldDescriptor.TYPE_BYTES:
field_desc.default_value = b''
+ elif field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE:
+ field_desc.default_value = None
else:
# All other types are of the "int" type.
field_desc.default_value = 0
@@ -954,7 +1098,7 @@ class DescriptorPool(object):
methods=methods,
options=_OptionsOrNone(service_proto),
file=file_desc)
- self._CheckConflictRegister(desc)
+ self._CheckConflictRegister(desc, desc.full_name, desc.file.name)
self._service_descriptors[service_name] = desc
return desc
diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py
index e69de29b..7d2e571a 100755
--- a/python/google/protobuf/internal/__init__.py
+++ b/python/google/protobuf/internal/__init__.py
@@ -0,0 +1,30 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index ab9e7812..23cc2c0a 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -145,29 +145,3 @@ def Version():
# For internal use only
def IsPythonDefaultSerializationDeterministic():
return _python_deterministic_proto_serialization
-
-# DO NOT USE: For migration and testing only. Will be removed when Proto3
-# defaults to preserve unknowns.
-if _implementation_type == 'cpp':
- try:
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
-
- def GetPythonProto3PreserveUnknownsDefault():
- return _message.GetPythonProto3PreserveUnknownsDefault()
-
- def SetPythonProto3PreserveUnknownsDefault(preserve):
- _message.SetPythonProto3PreserveUnknownsDefault(preserve)
- except ImportError:
- # Unrecognized cpp implementation. Skipping the unknown fields APIs.
- pass
-else:
- _python_proto3_preserve_unknowns_default = True
-
- def GetPythonProto3PreserveUnknownsDefault():
- return _python_proto3_preserve_unknowns_default
-
- def SetPythonProto3PreserveUnknownsDefault(preserve):
- global _python_proto3_preserve_unknowns_default
- _python_proto3_preserve_unknowns_default = preserve
-
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index c6a3692a..182cac99 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -628,3 +628,130 @@ class MessageMap(MutableMapping):
def GetEntryClass(self):
return self._entry_descriptor._concrete_class
+
+
+class _UnknownField(object):
+
+ """A parsed unknown field."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_field_number', '_wire_type', '_data']
+
+ def __init__(self, field_number, wire_type, data):
+ self._field_number = field_number
+ self._wire_type = wire_type
+ self._data = data
+ return
+
+ def __lt__(self, other):
+ # pylint: disable=protected-access
+ return self._field_number < other._field_number
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ # pylint: disable=protected-access
+ return (self._field_number == other._field_number and
+ self._wire_type == other._wire_type and
+ self._data == other._data)
+
+
+class UnknownFieldRef(object):
+
+ def __init__(self, parent, index):
+ self._parent = parent
+ self._index = index
+ return
+
+ def _check_valid(self):
+ if not self._parent:
+ raise ValueError('UnknownField does not exist. '
+ 'The parent message might be cleared.')
+ if self._index >= len(self._parent):
+ raise ValueError('UnknownField does not exist. '
+ 'The parent message might be cleared.')
+
+ @property
+ def field_number(self):
+ self._check_valid()
+ # pylint: disable=protected-access
+ return self._parent._internal_get(self._index)._field_number
+
+ @property
+ def wire_type(self):
+ self._check_valid()
+ # pylint: disable=protected-access
+ return self._parent._internal_get(self._index)._wire_type
+
+ @property
+ def data(self):
+ self._check_valid()
+ # pylint: disable=protected-access
+ return self._parent._internal_get(self._index)._data
+
+
+class UnknownFieldSet(object):
+
+ """UnknownField container"""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_values']
+
+ def __init__(self):
+ self._values = []
+
+ def __getitem__(self, index):
+ if self._values is None:
+ raise ValueError('UnknownFields does not exist. '
+ 'The parent message might be cleared.')
+ size = len(self._values)
+ if index < 0:
+ index += size
+ if index < 0 or index >= size:
+ raise IndexError('index %d out of range'.index)
+
+ return UnknownFieldRef(self, index)
+
+ def _internal_get(self, index):
+ return self._values[index]
+
+ def __len__(self):
+ if self._values is None:
+ raise ValueError('UnknownFields does not exist. '
+ 'The parent message might be cleared.')
+ return len(self._values)
+
+ def _add(self, field_number, wire_type, data):
+ unknown_field = _UnknownField(field_number, wire_type, data)
+ self._values.append(unknown_field)
+ return unknown_field
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield UnknownFieldRef(self, i)
+
+ def _extend(self, other):
+ if other is None:
+ return
+ # pylint: disable=protected-access
+ self._values.extend(other._values)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ # Sort unknown fields because their order shouldn't
+ # affect equality test.
+ values = list(self._values)
+ if other is None:
+ return not values
+ values.sort()
+ # pylint: disable=protected-access
+ other_values = sorted(other._values)
+ return values == other_values
+
+ def _clear(self):
+ for value in self._values:
+ # pylint: disable=protected-access
+ if isinstance(value._data, UnknownFieldSet):
+ value._data._clear() # pylint: disable=protected-access
+ self._values = None
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 52b64915..938f6293 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -86,7 +86,11 @@ import six
if six.PY3:
long = int
+else:
+ import re # pylint: disable=g-import-not-at-top
+ _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
+from google.protobuf.internal import containers
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
@@ -167,7 +171,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
def ReadTag(buffer, pos):
- """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
+ """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
We return the raw bytes of the tag rather than decoding them. The raw
bytes can then be used to look up the proper decoder. This effectively allows
@@ -175,13 +179,21 @@ def ReadTag(buffer, pos):
for work that is done in C (searching for a byte string in a hash table).
In a low-level language it would be much cheaper to decode the varint and
use that, but not in Python.
- """
+ Args:
+ buffer: memoryview object of the encoded bytes
+ pos: int of the current position to start from
+
+ Returns:
+ Tuple[bytes, int] of the tag data and new position.
+ """
start = pos
while six.indexbytes(buffer, pos) & 0x80:
pos += 1
pos += 1
- return (six.binary_type(buffer[start:pos]), pos)
+
+ tag_bytes = buffer[start:pos].tobytes()
+ return tag_bytes, pos
# --------------------------------------------------------------------
@@ -295,10 +307,20 @@ def _FloatDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
+ """Decode serialized float to a float and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes
+ pos: int, position in the memory view to start at.
+
+ Returns:
+ Tuple[float, int] of the deserialized float value and new position
+ in the serialized data.
+ """
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
new_pos = pos + 4
- float_bytes = buffer[pos:new_pos]
+ float_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
@@ -329,10 +351,20 @@ def _DoubleDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
+ """Decode serialized double to a double and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+
+ Returns:
+ Tuple[float, int] of the decoded double value and new position
+ in the serialized data.
+ """
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
new_pos = pos + 8
- double_bytes = buffer[pos:new_pos]
+ double_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
@@ -355,6 +387,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
if is_packed:
local_DecodeVarint = _DecodeVarint
def DecodePackedField(buffer, pos, end, message, field_dict):
+ """Decode serialized packed enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -365,6 +409,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
while pos < endpoint:
value_start_pos = pos
(element, pos) = _DecodeSignedVarint32(buffer, pos)
+ # pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
@@ -372,8 +417,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
message._unknown_fields = []
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
+
message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
+ (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+ # pylint: enable=protected-access
if pos > endpoint:
if element in enum_type.values_by_number:
del value[-1] # Discard corrupt value.
@@ -386,18 +433,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ """Decode serialized repeated enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(element, new_pos) = _DecodeSignedVarint32(buffer, pos)
+ # pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
if not message._unknown_fields:
message._unknown_fields = []
message._unknown_fields.append(
- (tag_bytes, buffer[pos:new_pos]))
+ (tag_bytes, buffer[pos:new_pos].tobytes()))
+ # pylint: enable=protected-access
# Predict that the next tag is another copy of the same repeated
# field.
pos = new_pos + tag_len
@@ -409,10 +470,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
+ """Decode serialized repeated enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value_start_pos = pos
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
+ # pylint: disable=protected-access
if enum_value in enum_type.values_by_number:
field_dict[key] = enum_value
else:
@@ -421,7 +495,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
+ (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+ # pylint: enable=protected-access
return pos
return DecodeField
@@ -458,20 +533,33 @@ BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
-def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
+def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
+ is_strict_utf8=False):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
local_unicode = six.text_type
- def _ConvertToUnicode(byte_str):
+ def _ConvertToUnicode(memview):
+ """Convert byte to unicode."""
+ byte_str = memview.tobytes()
try:
- return local_unicode(byte_str, 'utf-8')
+ value = local_unicode(byte_str, 'utf-8')
except UnicodeDecodeError as e:
# add more information to the error message and re-raise it.
e.reason = '%s in field: %s' % (e, key.full_name)
raise
+ if is_strict_utf8 and six.PY2:
+ if _SURROGATE_PATTERN.search(value):
+ reason = ('String field %s contains invalid UTF-8 data when parsing'
+ 'a protocol buffer: surrogates not allowed. Use'
+ 'the bytes type if you intend to send raw bytes.') % (
+ key.full_name)
+ raise message.DecodeError(reason)
+
+ return value
+
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
@@ -523,7 +611,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- value.append(buffer[pos:new_pos])
+ value.append(buffer[pos:new_pos].tobytes())
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -536,7 +624,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- field_dict[key] = buffer[pos:new_pos]
+ field_dict[key] = buffer[pos:new_pos].tobytes()
return new_pos
return DecodeField
@@ -665,6 +753,18 @@ def MessageSetItemDecoder(descriptor):
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
+ """Decode serialized message set to its value and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
message_set_item_start = pos
type_id = -1
message_start = -1
@@ -695,6 +795,7 @@ def MessageSetItemDecoder(descriptor):
raise _DecodeError('MessageSet item missing message.')
extension = message.Extensions._FindExtensionByNumber(type_id)
+ # pylint: disable=protected-access
if extension is not None:
value = field_dict.get(extension)
if value is None:
@@ -707,8 +808,9 @@ def MessageSetItemDecoder(descriptor):
else:
if not message._unknown_fields:
message._unknown_fields = []
- message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
- buffer[message_set_item_start:pos]))
+ message._unknown_fields.append(
+ (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
+ # pylint: enable=protected-access
return pos
@@ -767,7 +869,7 @@ def _SkipVarint(buffer, pos, end):
# Previously ord(buffer[pos]) raised IndexError when pos is out of range.
# With this code, ord(b'') raises TypeError. Both are handled in
# python_message.py to generate a 'Truncated message' error.
- while ord(buffer[pos:pos+1]) & 0x80:
+ while ord(buffer[pos:pos+1].tobytes()) & 0x80:
pos += 1
pos += 1
if pos > end:
@@ -782,6 +884,13 @@ def _SkipFixed64(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
+def _DecodeFixed64(buffer, pos):
+ """Decode a fixed64."""
+ new_pos = pos + 8
+ return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
+
+
def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""
@@ -791,6 +900,7 @@ def _SkipLengthDelimited(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
@@ -801,11 +911,53 @@ def _SkipGroup(buffer, pos, end):
return pos
pos = new_pos
+
+def _DecodeGroup(buffer, pos):
+ """Decode group. Returns the UnknownFieldSet and new position."""
+
+ unknown_field_set = containers.UnknownFieldSet()
+ while 1:
+ (tag_bytes, pos) = ReadTag(buffer, pos)
+ (tag, _) = _DecodeVarint(tag_bytes, 0)
+ field_number, wire_type = wire_format.UnpackTag(tag)
+ if wire_type == wire_format.WIRETYPE_END_GROUP:
+ break
+ (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+ # pylint: disable=protected-access
+ unknown_field_set._add(field_number, wire_type, data)
+
+ return (unknown_field_set, pos)
+
+
+def _DecodeUnknownField(buffer, pos, wire_type):
+ """Decode a unknown field. Returns the UnknownField and new position."""
+
+ if wire_type == wire_format.WIRETYPE_VARINT:
+ (data, pos) = _DecodeVarint(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_FIXED64:
+ (data, pos) = _DecodeFixed64(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_FIXED32:
+ (data, pos) = _DecodeFixed32(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+ (size, pos) = _DecodeVarint(buffer, pos)
+ data = buffer[pos:pos+size]
+ pos += size
+ elif wire_type == wire_format.WIRETYPE_START_GROUP:
+ (data, pos) = _DecodeGroup(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_END_GROUP:
+ return (0, -1)
+ else:
+ raise _DecodeError('Wrong wire type in tag.')
+
+ return (data, pos)
+
+
def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
return -1
+
def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""
@@ -814,6 +966,14 @@ def _SkipFixed32(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
+def _DecodeFixed32(buffer, pos):
+ """Decode a fixed32."""
+
+ new_pos = pos + 4
+ return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
+
+
def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index f97477b3..97e5315a 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -43,6 +43,7 @@ import warnings
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
+from google.protobuf.internal import no_package_pb2
from google.protobuf import descriptor_database
@@ -52,7 +53,10 @@ class DescriptorDatabaseTest(unittest.TestCase):
db = descriptor_database.DescriptorDatabase()
file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
+ no_package_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto)
+ db.Add(file_desc_proto2)
self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
@@ -76,6 +80,10 @@ class DescriptorDatabaseTest(unittest.TestCase):
# Can find enum value.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0'))
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.FACTORY_2_VALUE_0'))
+ self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
+ '.NO_PACKAGE_VALUE_0'))
# Can find top level extension.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.another_field'))
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 2cbf7813..1b72b0b9 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -36,7 +36,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)'
import copy
import os
-import sys
import warnings
try:
@@ -55,6 +54,7 @@ from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf.internal import file_options_test_pb2
from google.protobuf.internal import more_messages_pb2
+from google.protobuf.internal import no_package_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
@@ -120,7 +120,6 @@ class DescriptorPoolTestBase(object):
self.assertIsInstance(file_desc5, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/unittest.proto',
file_desc5.name)
-
# Tests the generated pool.
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@@ -129,6 +128,32 @@ class DescriptorPoolTestBase(object):
assert descriptor_pool.Default().FindFileContainingSymbol(
'protobuf_unittest.TestService')
+ # Can find field.
+ file_desc6 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+ self.assertIsInstance(file_desc6, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/internal/factory_test1.proto',
+ file_desc6.name)
+
+ # Can find top level Enum value.
+ file_desc7 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.FACTORY_1_VALUE_0')
+ self.assertIsInstance(file_desc7, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/internal/factory_test1.proto',
+ file_desc7.name)
+
+ # Can find nested Enum value.
+ file_desc8 = self.pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes.FOO')
+ self.assertIsInstance(file_desc8, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/unittest.proto',
+ file_desc8.name)
+
+ # TODO(jieluo): Add tests for no package when b/13860351 is fixed.
+
+ self.assertRaises(KeyError, self.pool.FindFileContainingSymbol,
+ 'google.protobuf.python.internal.Factory1Message.none_field')
+
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')
@@ -217,11 +242,10 @@ class DescriptorPoolTestBase(object):
def testFindTypeErrors(self):
self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')
+ self.assertRaises(KeyError, self.pool.FindMethodByName, '')
# TODO(jieluo): Fix python to raise correct errors.
if api_implementation.Type() == 'cpp':
- self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
- self.assertRaises(KeyError, self.pool.FindMethodByName, '')
error_type = TypeError
else:
error_type = AttributeError
@@ -231,6 +255,7 @@ class DescriptorPoolTestBase(object):
self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
self.assertRaises(error_type, self.pool.FindOneofByName, 0)
self.assertRaises(error_type, self.pool.FindServiceByName, 0)
+ self.assertRaises(error_type, self.pool.FindMethodByName, 0)
self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
if api_implementation.Type() == 'python':
error_type = KeyError
@@ -275,11 +300,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindEnumTypeByName('Does not exist')
def testFindFieldByName(self):
- if isinstance(self, SecondaryDescriptorFromDescriptorDB):
- if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix cpp extension to find field correctly
- # when descriptor pool is using an underlying database.
- return
field = self.pool.FindFieldByName(
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertEqual(field.name, 'list_value')
@@ -290,11 +310,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindFieldByName('Does not exist')
def testFindOneofByName(self):
- if isinstance(self, SecondaryDescriptorFromDescriptorDB):
- if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix cpp extension to find oneof correctly
- # when descriptor pool is using an underlying database.
- return
oneof = self.pool.FindOneofByName(
'google.protobuf.python.internal.Factory2Message.oneof_field')
self.assertEqual(oneof.name, 'oneof_field')
@@ -302,11 +317,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindOneofByName('Does not exist')
def testFindExtensionByName(self):
- if isinstance(self, SecondaryDescriptorFromDescriptorDB):
- if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix cpp extension to find extension correctly
- # when descriptor pool is using an underlying database.
- return
# An extension defined in a message.
extension = self.pool.FindExtensionByName(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@@ -382,6 +392,11 @@ class DescriptorPoolTestBase(object):
with self.assertRaises(KeyError):
self.pool.FindServiceByName('Does not exist')
+ method = self.pool.FindMethodByName('protobuf_unittest.TestService.Foo')
+ self.assertIs(method.containing_service, service)
+ with self.assertRaises(KeyError):
+ self.pool.FindMethodByName('protobuf_unittest.TestService.Doesnotexist')
+
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@@ -601,6 +616,8 @@ class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
unittest_import_pb2.DESCRIPTOR.serialized_pb))
self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb))
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ no_package_pb2.DESCRIPTOR.serialized_pb))
class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
@@ -620,6 +637,8 @@ class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
unittest_import_pb2.DESCRIPTOR.serialized_pb))
db.Add(descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb))
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ no_package_pb2.DESCRIPTOR.serialized_pb))
self.pool = descriptor_pool.DescriptorPool(descriptor_db=db)
@@ -746,11 +765,7 @@ class MessageField(object):
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(field_type_desc, field_desc.message_type)
test.assertEqual(file_desc, field_desc.file)
- # TODO(jieluo): Fix python and cpp extension diff for message field
- # default value.
- if api_implementation.Type() == 'cpp':
- test.assertRaises(
- NotImplementedError, getattr, field_desc, 'default_value')
+ test.assertEqual(field_desc.default_value, None)
class StringField(object):
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 02a43d15..af6bece1 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -452,6 +452,17 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual('attribute is not writable: has_options',
str(e.exception))
+ def testDefault(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ field = message_descriptor.fields_by_name['repeated_int32']
+ self.assertEqual(field.default_value, [])
+ field = message_descriptor.fields_by_name['repeated_nested_message']
+ self.assertEqual(field.default_value, [])
+ field = message_descriptor.fields_by_name['optionalgroup']
+ self.assertEqual(field.default_value, None)
+ field = message_descriptor.fields_by_name['optional_nested_message']
+ self.assertEqual(field.default_value, None)
+
class NewDescriptorTest(DescriptorTest):
"""Redo the same tests as above, but with a separate DescriptorPool."""
diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto
index d2fbbeec..f5bd0383 100644
--- a/python/google/protobuf/internal/factory_test1.proto
+++ b/python/google/protobuf/internal/factory_test1.proto
@@ -56,3 +56,17 @@ message Factory1Message {
extensions 1000 to max;
}
+
+message Factory1MethodRequest {
+ optional string argument = 1;
+}
+
+message Factory1MethodResponse {
+ optional string result = 1;
+}
+
+service Factory1Service {
+ // Dummy method for this dummy service.
+ rpc Factory1Method(Factory1MethodRequest) returns (Factory1MethodResponse) {
+ }
+}
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index 6df52ed2..b97e3f65 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -142,10 +142,8 @@ class MessageFactoryTest(unittest.TestCase):
self.assertEqual('test2', msg1.Extensions[ext2])
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(12321))
+ self.assertRaises(TypeError, len, msg1.Extensions)
if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix len to return the correct value.
- # self.assertEqual(2, len(msg1.Extensions))
- self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByName, 0)
self.assertRaises(TypeError,
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 61a56a67..1a865398 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -1,4 +1,5 @@
#! /usr/bin/env python
+# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -49,6 +50,7 @@ import copy
import math
import operator
import pickle
+import pydoc
import six
import sys
import warnings
@@ -72,6 +74,7 @@ from google.protobuf import message_factory
from google.protobuf import text_format
from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
+from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
@@ -415,6 +418,37 @@ class MessageTest(BaseTestCase):
empty.ParseFromString(populated.SerializeToString())
self.assertEqual(str(empty), '')
+ def testMergeFromRepeatedField(self, message_module):
+ msg = message_module.TestAllTypes()
+ msg.repeated_int32.append(1)
+ msg.repeated_int32.append(3)
+ msg.repeated_nested_message.add(bb=1)
+ msg.repeated_nested_message.add(bb=2)
+ other_msg = message_module.TestAllTypes()
+ other_msg.repeated_nested_message.add(bb=3)
+ other_msg.repeated_nested_message.add(bb=4)
+ other_msg.repeated_int32.append(5)
+ other_msg.repeated_int32.append(7)
+
+ msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
+ self.assertEqual(4, len(msg.repeated_int32))
+
+ msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
+ self.assertEqual([1, 2, 3, 4],
+ [m.bb for m in msg.repeated_nested_message])
+
+ def testAddWrongRepeatedNestedField(self, message_module):
+ msg = message_module.TestAllTypes()
+ try:
+ msg.repeated_nested_message.add('wrong')
+ except TypeError:
+ pass
+ try:
+ msg.repeated_nested_message.add(value_field='wrong')
+ except ValueError:
+ pass
+ self.assertEqual(len(msg.repeated_nested_message), 0)
+
def testRepeatedNestedFieldIteration(self, message_module):
msg = message_module.TestAllTypes()
msg.repeated_nested_message.add(bb=1)
@@ -645,6 +679,82 @@ class MessageTest(BaseTestCase):
m.payload.repeated_int32.extend([])
self.assertTrue(m.HasField('payload'))
+ def testMergeFrom(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, m1.optional_nested_message.bb)
+ m2.optional_nested_message.bb = 1
+ # Make sure cmessage pointing to a mutable message after merge instead of
+ # the lazily created message.
+ m1.MergeFrom(m2)
+ self.assertEqual(1, m1.optional_nested_message.bb)
+
+ # Test more nested sub message.
+ msg1 = message_module.NestedTestAllTypes()
+ msg2 = message_module.NestedTestAllTypes()
+ self.assertEqual(0, msg1.child.payload.optional_nested_message.bb)
+ msg2.child.payload.optional_nested_message.bb = 1
+ msg1.MergeFrom(msg2)
+ self.assertEqual(1, msg1.child.payload.optional_nested_message.bb)
+
+ # Test repeated field.
+ self.assertEqual(msg1.payload.repeated_nested_message,
+ msg1.payload.repeated_nested_message)
+ msg2.payload.repeated_nested_message.add().bb = 1
+ msg1.MergeFrom(msg2)
+ self.assertEqual(1, len(msg1.payload.repeated_nested_message))
+ self.assertEqual(1, msg1.payload.repeated_nested_message[0].bb)
+
+ def testMergeFromString(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, m1.optional_nested_message.bb)
+ m2.optional_nested_message.bb = 1
+ # Make sure cmessage pointing to a mutable message after merge instead of
+ # the lazily created message.
+ m1.MergeFromString(m2.SerializeToString())
+ self.assertEqual(1, m1.optional_nested_message.bb)
+
+ @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2')
+ def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module):
+ m2 = message_module.TestAllTypes()
+ m2.optional_string = 'scalar string'
+ m2.repeated_string.append('repeated string')
+ m2.optional_bytes = b'scalar bytes'
+ m2.repeated_bytes.append(b'repeated bytes')
+
+ serialized = m2.SerializeToString()
+ memview = memoryview(serialized)
+ m1 = message_module.TestAllTypes.FromString(memview)
+
+ self.assertEqual(m1.optional_bytes, b'scalar bytes')
+ self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
+ self.assertEqual(m1.optional_string, 'scalar string')
+ self.assertEqual(m1.repeated_string, ['repeated string'])
+ # Make sure that the memoryview was correctly converted to bytes, and
+ # that a sub-sliced memoryview is not being used.
+ self.assertIsInstance(m1.optional_bytes, bytes)
+ self.assertIsInstance(m1.repeated_bytes[0], bytes)
+ self.assertIsInstance(m1.optional_string, six.text_type)
+ self.assertIsInstance(m1.repeated_string[0], six.text_type)
+
+ @unittest.skipIf(six.PY3, 'memoryview is supported by py3')
+ def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module):
+ memview = memoryview(b'')
+ with self.assertRaises(TypeError):
+ message_module.TestAllTypes.FromString(memview)
+
+ def testMergeFromEmpty(self, message_module):
+ m1 = message_module.TestAllTypes()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, m1.optional_nested_message.bb)
+ self.assertFalse(m1.HasField('optional_nested_message'))
+ # Make sure the sub message is still immutable after merge from empty.
+ m1.MergeFromString(b'') # field state should not change
+ self.assertFalse(m1.HasField('optional_nested_message'))
+
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@@ -1067,14 +1177,8 @@ class MessageTest(BaseTestCase):
with self.assertRaises(AttributeError):
m.repeated_int32 = []
m.repeated_int32.append(1)
- if api_implementation.Type() == 'cpp':
- # For test coverage: cpp has a different path if composite
- # field is in cache
- with self.assertRaises(TypeError):
- m.repeated_int32 = []
- else:
- with self.assertRaises(AttributeError):
- m.repeated_int32 = []
+ with self.assertRaises(AttributeError):
+ m.repeated_int32 = []
# Class to test proto2-only features (required, extensions, etc.)
@@ -1169,6 +1273,21 @@ class Proto2Test(BaseTestCase):
msg = unittest_pb2.TestAllTypes()
self.assertRaises(AttributeError, getattr, msg, 'Extensions')
+ def testMergeFromExtensions(self):
+ msg1 = more_extensions_pb2.TopLevelMessage()
+ msg2 = more_extensions_pb2.TopLevelMessage()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, msg1.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension])
+ self.assertFalse(msg1.HasField('submessage'))
+ msg2.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension] = 123
+ # Make sure cmessage and extensions pointing to a mutable message
+ # after merge instead of the lazily created message.
+ msg1.MergeFrom(msg2)
+ self.assertEqual(123, msg1.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension])
+
def testGoldenExtensions(self):
golden_data = test_util.GoldenFileData('golden_message')
golden_message = unittest_pb2.TestAllExtensions()
@@ -1316,6 +1435,15 @@ class Proto2Test(BaseTestCase):
unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
+ def test_documentation(self):
+ # Also used by the interactive help() function.
+ doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
+ self.assertIn('class TestAllTypes', doc)
+ self.assertIn('SerializePartialToString', doc)
+ self.assertIn('repeated_float', doc)
+ base = unittest_pb2.TestAllTypes.__bases__[0]
+ self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
+
# Class to test proto3-only features/behavior (updated field presence & enums)
class Proto3Test(BaseTestCase):
@@ -1539,10 +1667,8 @@ class Proto3Test(BaseTestCase):
self.assertEqual(True, msg2.map_bool_bool[True])
self.assertEqual(2, msg2.map_int32_enum[888])
self.assertEqual(456, msg2.map_int32_enum[123])
- # TODO(jieluo): Add cpp extension support.
- if api_implementation.Type() == 'python':
- self.assertEqual('{-123: -456}',
- str(msg2.map_int32_int32))
+ self.assertEqual('{-123: -456}',
+ str(msg2.map_int32_int32))
def testMapEntryAlwaysSerialized(self):
msg = map_unittest_pb2.TestMap()
@@ -1603,11 +1729,10 @@ class Proto3Test(BaseTestCase):
self.assertIn(123, msg2.map_int32_foreign_message)
self.assertIn(-456, msg2.map_int32_foreign_message)
self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ msg2.map_int32_foreign_message[123].c = 1
# TODO(jieluo): Fix text format for message map.
- # TODO(jieluo): Add cpp extension support.
- if api_implementation.Type() == 'python':
- self.assertEqual(15,
- len(str(msg2.map_int32_foreign_message)))
+ self.assertIn(str(msg2.map_int32_foreign_message),
+ ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
def testNestedMessageMapItemDelete(self):
msg = map_unittest_pb2.TestMap()
@@ -1721,6 +1846,15 @@ class Proto3Test(BaseTestCase):
self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
+ # Test when cpp extension cache a map.
+ m1 = map_unittest_pb2.TestMap()
+ m2 = map_unittest_pb2.TestMap()
+ self.assertEqual(m1.map_int32_foreign_message,
+ m1.map_int32_foreign_message)
+ m2.map_int32_foreign_message[123].c = 10
+ m1.MergeFrom(m2)
+ self.assertEqual(10, m2.map_int32_foreign_message[123].c)
+
def testMergeFromBadType(self):
msg = map_unittest_pb2.TestMap()
with self.assertRaisesRegexp(
@@ -1972,7 +2106,7 @@ class Proto3Test(BaseTestCase):
def testMapValidAfterFieldCleared(self):
# Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
- # ScalarMapContainer::Release()
+ # MapContainer::Release()
msg = map_unittest_pb2.TestMap()
int32_map = msg.map_int32_int32
@@ -1988,7 +2122,7 @@ class Proto3Test(BaseTestCase):
def testMessageMapValidAfterFieldCleared(self):
# Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
- # ScalarMapContainer::Release()
+ # MapContainer::Release()
msg = map_unittest_pb2.TestMap()
int32_foreign_message = msg.map_int32_foreign_message
@@ -1998,6 +2132,24 @@ class Proto3Test(BaseTestCase):
self.assertEqual(b'', msg.SerializeToString())
self.assertTrue(2 in int32_foreign_message.keys())
+ def testMessageMapItemValidAfterTopMessageCleared(self):
+ # Message map item needs to work even if it is cleared.
+ # For the C++ implementation this tests the correctness of
+ # MapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_all_types[2].optional_string = 'bar'
+
+ if api_implementation.Type() == 'cpp':
+ # Need to keep the map reference because of b/27942626.
+ # TODO(jieluo): Remove it.
+ unused_map = msg.map_int32_all_types # pylint: disable=unused-variable
+ msg_value = msg.map_int32_all_types[2]
+ msg.Clear()
+
+ # Reset to trigger sync between repeated field and map in c++.
+ msg.map_int32_all_types[3].optional_string = 'foo'
+ self.assertEqual(msg_value.optional_string, 'bar')
+
def testMapIterInvalidatedByClearField(self):
# Map iterator is invalidated when field is cleared.
# But this case does need to not crash the interpreter.
@@ -2058,6 +2210,80 @@ class Proto3Test(BaseTestCase):
msg.map_string_foreign_message['foo'].c = 5
self.assertEqual(0, len(msg.FindInitializationErrors()))
+ def testStrictUtf8Check(self):
+ # Test u'\ud801' is rejected at parser in both python2 and python3.
+ serialized = (b'r\x03\xed\xa0\x81')
+ msg = unittest_proto3_arena_pb2.TestAllTypes()
+ with self.assertRaises(Exception) as context:
+ msg.MergeFromString(serialized)
+ if api_implementation.Type() == 'python':
+ self.assertIn('optional_string', str(context.exception))
+ else:
+ self.assertIn('Error parsing message', str(context.exception))
+
+ # Test optional_string=u'😍' is accepted.
+ serialized = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'😍').SerializeToString()
+ msg2 = unittest_proto3_arena_pb2.TestAllTypes()
+ msg2.MergeFromString(serialized)
+ self.assertEqual(msg2.optional_string, u'😍')
+
+ msg = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud001')
+ self.assertEqual(msg.optional_string, u'\ud001')
+
+ @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
+ def testSurrogatesInPython3(self):
+ # Surrogates like U+D83D is an invalid unicode character, it is
+ # supported by Python2 only because in some builds, unicode strings
+ # use 2-bytes code units. Since Python 3.3, we don't have this problem.
+ #
+ # Surrogates are utf16 code units, in a unicode string they are invalid
+ # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
+ # Python3 reject such cases at setters and parsers. Python2 accpect it
+ # to keep same features with the language itself. 'Unpaired pairs'
+ # like u'\ud801' are rejected at parsers when strict utf8 check is enabled
+ # in proto3 to keep same behavior with c extension.
+
+ # Surrogates are rejected at setters in Python3.
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\udc01')
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=b'\xed\xa0\x81')
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801')
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\ud801')
+
+ @unittest.skipIf(six.PY3, 'Surrogates are rejected at setters in Python3')
+ def testSurrogatesInPython2(self):
+ # Test optional_string=u'\ud801\udc01'.
+ # surrogate pair is acceptable in python2.
+ msg = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\udc01')
+ # TODO(jieluo): Change pure python to have same behavior with c extension.
+ # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
+ # equal, some are not equal.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(msg.optional_string, u'\ud801\udc01')
+ else:
+ self.assertEqual(msg.optional_string, u'\U00010401')
+ serialized = msg.SerializeToString()
+ msg2 = unittest_proto3_arena_pb2.TestAllTypes()
+ msg2.MergeFromString(serialized)
+ self.assertEqual(msg2.optional_string, u'\U00010401')
+
+ # Python2 does not reject surrogates at setters.
+ msg = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=b'\xed\xa0\x81')
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801')
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\ud801')
class ValidTypeNamesTest(BaseTestCase):
diff --git a/python/google/protobuf/internal/no_package.proto b/python/google/protobuf/internal/no_package.proto
index 3546dcc3..49eda959 100644
--- a/python/google/protobuf/internal/no_package.proto
+++ b/python/google/protobuf/internal/no_package.proto
@@ -1,3 +1,33 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
syntax = "proto2";
enum NoPackageEnum {
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 975e3b4d..ab5d160f 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -56,6 +56,7 @@ import sys
import weakref
import six
+from six.moves import range
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import api_implementation
@@ -124,6 +125,21 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ # If a concrete class already exists for this descriptor, don't try to
+ # create another. Doing so will break any messages that already exist with
+ # the existing class.
+ #
+ # The C++ implementation appears to have its own internal `PyMessageFactory`
+ # to achieve similar results.
+ #
+ # This most commonly happens in `text_format.py` when using descriptors from
+ # a custom pool; it calls symbol_database.Global().getPrototype() on a
+ # descriptor which already has an existing concrete class.
+ new_class = getattr(descriptor, '_concrete_class', None)
+ if new_class:
+ return new_class
+
if descriptor.full_name in well_known_types.WKTBASES:
bases += (well_known_types.WKTBASES[descriptor.full_name],)
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
@@ -151,6 +167,16 @@ class GeneratedProtocolMessageType(type):
type.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ # If this is an _existing_ class looked up via `_concrete_class` in the
+ # __new__ method above, then we don't need to re-initialize anything.
+ existing_class = getattr(descriptor, '_concrete_class', None)
+ if existing_class:
+ assert existing_class is cls, (
+ 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
+ % (descriptor.full_name))
+ return
+
cls._decoders_by_tag = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
@@ -245,6 +271,7 @@ def _AddSlots(message_descriptor, dictionary):
'_cached_byte_size_dirty',
'_fields',
'_unknown_fields',
+ '_unknown_field_set',
'_is_present_in_parent',
'_listener',
'_listener_for_children',
@@ -271,6 +298,13 @@ def _IsMessageMapField(field):
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+def _IsStrictUtf8Check(field):
+ if field.containing_type.syntax != 'proto3':
+ return False
+ enforce_utf8 = True
+ return enforce_utf8
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
is_packable = (is_repeated and
@@ -322,10 +356,16 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_decoder = decoder.MapDecoder(
field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
is_message_map)
+ elif decode_type == _FieldDescriptor.TYPE_STRING:
+ is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
+ field_decoder = decoder.StringDecoder(
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor,
+ is_strict_utf8_check)
else:
field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor)
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
@@ -422,6 +462,9 @@ def _DefaultValueConstructorForField(field):
# _concrete_class may not yet be initialized.
message_type = field.message_type
def MakeSubMessageDefault(message):
+ assert getattr(message_type, '_concrete_class', None), (
+ 'Uninitialized concrete class found for field %r (message type %r)'
+ % (field.full_name, message_type.full_name))
result = message_type._concrete_class()
result._SetListener(
_OneofListener(message, field)
@@ -477,6 +520,9 @@ def _AddInitMethod(message_descriptor, cls):
# _unknown_fields is () when empty for efficiency, and will be turned into
# a list if fields are added.
self._unknown_fields = ()
+ # _unknown_field_set is None when empty for efficiency, and will be
+ # turned into UnknownFieldSet struct if fields are added.
+ self._unknown_field_set = None # pylint: disable=protected-access
self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self)
@@ -584,6 +630,14 @@ def _AddPropertiesForField(field, cls):
_AddPropertiesForNonRepeatedScalarField(field, cls)
+class _FieldProperty(property):
+ __slots__ = ('DESCRIPTOR',)
+
+ def __init__(self, descriptor, getter, setter, doc):
+ property.__init__(self, getter, setter, doc=doc)
+ self.DESCRIPTOR = descriptor
+
+
def _AddPropertiesForRepeatedField(field, cls):
"""Adds a public property for a "repeated" protocol message field. Clients
can use this property to get the value of the field, which will be either a
@@ -625,7 +679,7 @@ def _AddPropertiesForRepeatedField(field, cls):
'"%s" in protocol message object.' % proto_field_name)
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedScalarField(field, cls):
@@ -681,7 +735,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
# Add a property to encapsulate the getter/setter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
@@ -725,7 +779,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
# Add a property to encapsulate the getter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForExtensions(descriptor, cls):
@@ -949,13 +1003,8 @@ def _AddEqualsMethod(message_descriptor, cls):
if not self.ListFields() == other.ListFields():
return False
- # Sort unknown fields because their order shouldn't affect equality test.
- unknown_fields = list(self._unknown_fields)
- unknown_fields.sort()
- other_unknown_fields = list(other._unknown_fields)
- other_unknown_fields.sort()
-
- return unknown_fields == other_unknown_fields
+ # pylint: disable=protected-access
+ return self._unknown_field_set == other._unknown_field_set
cls.__eq__ = __eq__
@@ -1078,6 +1127,13 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
def _AddMergeFromStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def MergeFromString(self, serialized):
+ if isinstance(serialized, memoryview) and six.PY2:
+ raise TypeError(
+ 'memoryview not supported in Python 2 with the pure Python proto '
+ 'implementation: this is to maintain compatibility with the C++ '
+ 'implementation')
+
+ serialized = memoryview(serialized)
length = len(serialized)
try:
if self._InternalParse(serialized, 0, length) != length:
@@ -1095,26 +1151,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
- is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
+ """Create a message from serialized bytes.
+
+ Args:
+ self: Message, instance of the proto message object.
+ buffer: memoryview of the serialized data.
+ pos: int, position to start in the serialized data.
+ end: int, end position of the serialized data.
+
+ Returns:
+ Message object.
+ """
+ # Guard against internal misuse, since this function is called internally
+ # quite extensively, and its easy to accidentally pass bytes.
+ assert isinstance(buffer, memoryview)
self._Modified()
field_dict = self._fields
- unknown_field_list = self._unknown_fields
+ # pylint: disable=protected-access
+ unknown_field_set = self._unknown_field_set
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
if field_decoder is None:
- value_start_pos = new_pos
- new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
+ if not self._unknown_fields: # pylint: disable=protected-access
+ self._unknown_fields = [] # pylint: disable=protected-access
+ if unknown_field_set is None:
+ # pylint: disable=protected-access
+ self._unknown_field_set = containers.UnknownFieldSet()
+ # pylint: disable=protected-access
+ unknown_field_set = self._unknown_field_set
+ # pylint: disable=protected-access
+ (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
+ field_number, wire_type = wire_format.UnpackTag(tag)
+ # TODO(jieluo): remove old_pos.
+ old_pos = new_pos
+ (data, new_pos) = decoder._DecodeUnknownField(
+ buffer, new_pos, wire_type) # pylint: disable=protected-access
if new_pos == -1:
return pos
- if (not is_proto3 or
- api_implementation.GetPythonProto3PreserveUnknownsDefault()):
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append(
- (tag_bytes, buffer[value_start_pos:new_pos]))
+ # pylint: disable=protected-access
+ unknown_field_set._add(field_number, wire_type, data)
+ # TODO(jieluo): remove _unknown_fields.
+ new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
+ if new_pos == -1:
+ return pos
+ self._unknown_fields.append(
+ (tag_bytes, buffer[old_pos:new_pos].tobytes()))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -1259,6 +1343,10 @@ def _AddMergeFromMethod(cls):
if not self._unknown_fields:
self._unknown_fields = []
self._unknown_fields.extend(msg._unknown_fields)
+ # pylint: disable=protected-access
+ if self._unknown_field_set is None:
+ self._unknown_field_set = containers.UnknownFieldSet()
+ self._unknown_field_set._extend(msg._unknown_field_set)
cls.MergeFrom = MergeFrom
@@ -1291,12 +1379,25 @@ def _Clear(self):
# Clear fields.
self._fields = {}
self._unknown_fields = ()
+ # pylint: disable=protected-access
+ if self._unknown_field_set is not None:
+ self._unknown_field_set._clear()
+ self._unknown_field_set = None
+
self._oneofs = {}
self._Modified()
+def _UnknownFields(self):
+ if self._unknown_field_set is None: # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ self._unknown_field_set = containers.UnknownFieldSet()
+ return self._unknown_field_set # pylint: disable=protected-access
+
+
def _DiscardUnknownFields(self):
self._unknown_fields = []
+ self._unknown_field_set = None # pylint: disable=protected-access
for field, value in self.ListFields():
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
@@ -1335,6 +1436,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddReduceMethod(cls)
# Adds methods which do not depend on cls.
cls.Clear = _Clear
+ cls.UnknownFields = _UnknownFields
cls.DiscardUnknownFields = _DiscardUnknownFields
cls._SetListener = _SetListener
@@ -1471,6 +1573,10 @@ class _ExtensionDict(object):
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
result = extension_handle._default_constructor(self._extended_message)
elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ assert getattr(extension_handle.message_type, '_concrete_class', None), (
+ 'Uninitialized concrete class found for field %r (message type %r)'
+ % (extension_handle.full_name,
+ extension_handle.message_type.full_name))
result = extension_handle.message_type._concrete_class()
try:
result._SetListener(self._extended_message._listener_for_children)
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 0306ff46..31ceda24 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -64,6 +64,10 @@ from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import decoder
+if six.PY3:
+ long = int # pylint: disable=redefined-builtin,invalid-name
+
+
BaseTestCase = testing_refleaks.BaseTestCase
@@ -647,10 +651,7 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_int32', 1, int)
TestGetAndDeserialize('optional_int32', 1 << 30, int)
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
- try:
- integer_64 = long
- except NameError: # Python3
- integer_64 = int
+ integer_64 = long
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
# in an int.
@@ -1103,6 +1104,7 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(23, myproto_instance.foo_field)
self.assertTrue(myproto_instance.HasField('foo_field'))
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testDescriptorProtoSupport(self):
# Hand written descriptors/reflection are only supported by the pure-Python
# implementation of the API.
@@ -1141,7 +1143,8 @@ class ReflectionTest(BaseTestCase):
self.assertTrue('price' in desc.fields_by_name)
self.assertTrue('owners' in desc.fields_by_name)
- class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
+ class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
+ message.Message)):
DESCRIPTOR = desc
prius = CarMessage()
@@ -2435,7 +2438,7 @@ class SerializationTest(BaseTestCase):
first_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
- serialized = first_proto.SerializeToString()
+ serialized = memoryview(first_proto.SerializeToString())
for truncation_point in range(len(serialized) + 1):
try:
@@ -2857,6 +2860,38 @@ class SerializationTest(BaseTestCase):
self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
51)
+ def testFieldProperties(self):
+ cls = unittest_pb2.TestAllTypes
+ self.assertIs(cls.optional_int32.DESCRIPTOR,
+ cls.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
+ cls.optional_int32.DESCRIPTOR.number)
+ self.assertIs(cls.optional_nested_message.DESCRIPTOR,
+ cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
+ self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
+ cls.optional_nested_message.DESCRIPTOR.number)
+ self.assertIs(cls.repeated_int32.DESCRIPTOR,
+ cls.DESCRIPTOR.fields_by_name['repeated_int32'])
+ self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
+ cls.repeated_int32.DESCRIPTOR.number)
+
+ def testFieldDataDescriptor(self):
+ msg = unittest_pb2.TestAllTypes()
+ msg.optional_int32 = 42
+ self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
+ unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
+ self.assertEqual(msg.optional_int32, 25)
+ with self.assertRaises(AttributeError):
+ del msg.optional_int32
+ try:
+ unittest_pb2.ForeignMessage.c.__get__(msg)
+ except TypeError:
+ pass # The cpp implementation cannot mix fields from other messages.
+ # This test exercises a specific check that avoids a crash.
+ else:
+ pass # The python implementation allows fields from other messages.
+ # This is useless, but works.
+
def testInitKwargs(self):
proto = unittest_pb2.TestAllTypes(
optional_int32=1,
@@ -2963,6 +2998,7 @@ class ClassAPITest(BaseTestCase):
@unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation requires a call to MakeDescriptor()')
+ @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
def testMakeClassWithNestedDescriptor(self):
leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
containing_type=None, fields=[],
@@ -2980,10 +3016,7 @@ class ClassAPITest(BaseTestCase):
containing_type=None, fields=[],
nested_types=[child_desc, sibling_desc],
enum_types=[], extensions=[])
- message_class = reflection.MakeClass(parent_desc)
- self.assertIn('child', message_class.__dict__)
- self.assertIn('sibling', message_class.__dict__)
- self.assertIn('leaf', message_class.child.__dict__)
+ reflection.MakeClass(parent_desc)
def _GetSerializedFileDescriptor(self, name):
"""Get a serialized representation of a test FileDescriptorProto.
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 237a2d50..c68f42d2 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -33,20 +33,19 @@
"""Test for google.protobuf.text_format."""
-__author__ = 'kenton@google.com (Kenton Varda)'
-
-
+import io
import math
import re
-import six
import string
+import textwrap
+import six
+
+# pylint: disable=g-import-not-at-top
try:
- import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top
+ import unittest2 as unittest # PY26
except ImportError:
- import unittest # pylint: disable=g-import-not-at-top
-
-from google.protobuf.internal import _parameterized
+ import unittest
from google.protobuf import any_pb2
from google.protobuf import any_test_pb2
@@ -54,12 +53,13 @@ from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
-from google.protobuf.internal import api_implementation
from google.protobuf.internal import any_test_pb2 as test_extend_any
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import test_util
from google.protobuf import descriptor_pool
from google.protobuf import text_format
+from google.protobuf.internal import _parameterized
+# pylint: enable=g-import-not-at-top
# Low-level nuts-n-bolts tests.
@@ -100,8 +100,8 @@ class TextFormatBase(unittest.TestCase):
return text
-@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
-class TextFormatTest(TextFormatBase):
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatMessageToStringTests(TextFormatBase):
def testPrintExotic(self, message_module):
message = message_module.TestAllTypes()
@@ -154,6 +154,40 @@ class TextFormatTest(TextFormatBase):
'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
'repeated_string: "Google" repeated_string: "Zurich"')
+ def VerifyPrintShortFormatRepeatedFields(self, message_module, as_one_line):
+ message = message_module.TestAllTypes()
+ message.repeated_int32.append(1)
+ message.repeated_string.append('Google')
+ message.repeated_string.append('Hello,World')
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_FOO)
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
+ message.optional_nested_message.bb = 3
+ for i in (21, 32):
+ msg = message.repeated_nested_message.add()
+ msg.bb = i
+ expected_ascii = (
+ 'optional_nested_message {\n bb: 3\n}\n'
+ 'repeated_int32: [1]\n'
+ 'repeated_string: "Google"\n'
+ 'repeated_string: "Hello,World"\n'
+ 'repeated_nested_message {\n bb: 21\n}\n'
+ 'repeated_nested_message {\n bb: 32\n}\n'
+ 'repeated_foreign_enum: [FOREIGN_FOO, FOREIGN_BAR, FOREIGN_BAZ]\n')
+ if as_one_line:
+ expected_ascii = expected_ascii.replace('\n ', '').replace('\n', '')
+ actual_ascii = text_format.MessageToString(
+ message, use_short_repeated_primitives=True,
+ as_one_line=as_one_line)
+ self.CompareToGoldenText(actual_ascii, expected_ascii)
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(actual_ascii, parsed_message)
+ self.assertEqual(parsed_message, message)
+
+ def tesPrintShortFormatRepeatedFields(self, message_module, as_one_line):
+ self.VerifyPrintShortFormatRepeatedFields(message_module, False)
+ self.VerifyPrintShortFormatRepeatedFields(message_module, True)
+
def testPrintNestedNewLineInStringAsOneLine(self, message_module):
message = message_module.TestAllTypes()
message.optional_string = 'a\nnew\nline'
@@ -213,13 +247,18 @@ class TextFormatTest(TextFormatBase):
def testPrintRawUtf8String(self, message_module):
message = message_module.TestAllTypes()
- message.repeated_string.append(u'\u00fc\ua71f')
+ message.repeated_string.append(u'\u00fc\t\ua71f')
text = text_format.MessageToString(message, as_utf8=True)
- self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
+ golden_unicode = u'repeated_string: "\u00fc\\t\ua71f"\n'
+ golden_text = golden_unicode if six.PY3 else golden_unicode.encode('utf-8')
+ # MessageToString always returns a native str.
+ self.CompareToGoldenText(text, golden_text)
parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
- self.assertEqual(message, parsed_message,
- '\n%s != %s' % (message, parsed_message))
+ self.assertEqual(
+ message, parsed_message, '\n%s != %s (%s != %s)' %
+ (message, parsed_message, message.repeated_string[0],
+ parsed_message.repeated_string[0]))
def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
@@ -259,6 +298,36 @@ class TextFormatTest(TextFormatBase):
message.c = 123
self.assertEqual('c: 123\n', str(message))
+ def testMessageToStringUnicode(self, message_module):
+ golden_unicode = u'Á short desçription and a 🍌.'
+ golden_bytes = golden_unicode.encode('utf-8')
+ message = message_module.TestAllTypes()
+ message.optional_string = golden_unicode
+ message.optional_bytes = golden_bytes
+ text = text_format.MessageToString(message, as_utf8=True)
+ golden_message = textwrap.dedent(
+ 'optional_string: "Á short desçription and a 🍌."\n'
+ 'optional_bytes: '
+ r'"\303\201 short des\303\247ription and a \360\237\215\214."'
+ '\n')
+ self.CompareToGoldenText(text, golden_message)
+
+ def testMessageToStringASCII(self, message_module):
+ golden_unicode = u'Á short desçription and a 🍌.'
+ golden_bytes = golden_unicode.encode('utf-8')
+ message = message_module.TestAllTypes()
+ message.optional_string = golden_unicode
+ message.optional_bytes = golden_bytes
+ text = text_format.MessageToString(message, as_utf8=False) # ASCII
+ golden_message = (
+ 'optional_string: '
+ r'"\303\201 short des\303\247ription and a \360\237\215\214."'
+ '\n'
+ 'optional_bytes: '
+ r'"\303\201 short des\303\247ription and a \360\237\215\214."'
+ '\n')
+ self.CompareToGoldenText(text, golden_message)
+
def testPrintField(self, message_module):
message = message_module.TestAllTypes()
field = message.DESCRIPTOR.fields_by_name['optional_float']
@@ -289,6 +358,45 @@ class TextFormatTest(TextFormatBase):
self.assertEqual('0.0', out.getvalue())
out.close()
+
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatMessageToTextBytesTests(TextFormatBase):
+
+ def testMessageToBytes(self, message_module):
+ message = message_module.ForeignMessage()
+ message.c = 123
+ self.assertEqual(b'c: 123\n', text_format.MessageToBytes(message))
+
+ def testRawUtf8RoundTrip(self, message_module):
+ message = message_module.TestAllTypes()
+ message.repeated_string.append(u'\u00fc\t\ua71f')
+ utf8_text = text_format.MessageToBytes(message, as_utf8=True)
+ golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n'
+ self.CompareToGoldenText(utf8_text, golden_bytes)
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(utf8_text, parsed_message)
+ self.assertEqual(
+ message, parsed_message, '\n%s != %s (%s != %s)' %
+ (message, parsed_message, message.repeated_string[0],
+ parsed_message.repeated_string[0]))
+
+ def testEscapedUtf8ASCIIRoundTrip(self, message_module):
+ message = message_module.TestAllTypes()
+ message.repeated_string.append(u'\u00fc\t\ua71f')
+ ascii_text = text_format.MessageToBytes(message) # as_utf8=False default
+ golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n'
+ self.CompareToGoldenText(ascii_text, golden_bytes)
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(ascii_text, parsed_message)
+ self.assertEqual(
+ message, parsed_message, '\n%s != %s (%s != %s)' %
+ (message, parsed_message, message.repeated_string[0],
+ parsed_message.repeated_string[0]))
+
+
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatParserTests(TextFormatBase):
+
def testParseAllFields(self, message_module):
message = message_module.TestAllTypes()
test_util.SetAllFields(message)
@@ -318,14 +426,14 @@ class TextFormatTest(TextFormatBase):
if message_module is unittest_pb2:
test_util.ExpectAllFieldsSet(self, message)
- if six.PY2:
- msg2 = message_module.TestAllTypes()
- text = (u'optional_string: "café"')
- text_format.Merge(text, msg2)
- self.assertEqual(msg2.optional_string, u'café')
- msg2.Clear()
- text_format.Parse(text, msg2)
- self.assertEqual(msg2.optional_string, u'café')
+ msg2 = message_module.TestAllTypes()
+ text = (u'optional_string: "café"')
+ text_format.Merge(text, msg2)
+ self.assertEqual(msg2.optional_string, u'café')
+ msg2.Clear()
+ self.assertEqual(msg2.optional_string, u'')
+ text_format.Parse(text, msg2)
+ self.assertEqual(msg2.optional_string, u'café')
def testParseExotic(self, message_module):
message = message_module.TestAllTypes()
@@ -425,7 +533,8 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
six.assertRaisesRegex(self, text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ (r'1:23 : \'optional_nested_enum: BARR\': '
+ r'Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value named BARR.'), text_format.Parse,
text, message)
@@ -433,7 +542,8 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
six.assertRaisesRegex(self, text_format.ParseError,
- ('1:17 : Couldn\'t parse integer: bork'),
+ ('1:17 : \'optional_int32: bork\': '
+ 'Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
def testParseStringFieldUnescape(self, message_module):
@@ -457,6 +567,96 @@ class TextFormatTest(TextFormatBase):
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
+ def testParseOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m2 = message_module.TestAllTypes()
+ text_format.Parse(text_format.MessageToString(m), m2)
+ self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+
+ def testParseMultipleOneof(self, message_module):
+ m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
+ m2 = message_module.TestAllTypes()
+ with six.assertRaisesRegex(self, text_format.ParseError,
+ ' is specified along with field '):
+ text_format.Parse(m_string, m2)
+
+ # This example contains non-ASCII codepoint unicode data as literals
+ # which should come through as utf-8 for bytes, and as the unicode
+ # itself for string fields. It also demonstrates escaped binary data.
+ # The ur"" string prefix is unfortunately missing from Python 3
+ # so we resort to double escaping our \s so that they come through.
+ _UNICODE_SAMPLE = u"""
+ optional_bytes: 'Á short desçription'
+ optional_string: 'Á short desçription'
+ repeated_bytes: '\\303\\201 short des\\303\\247ription'
+ repeated_bytes: '\\x12\\x34\\x56\\x78\\x90\\xab\\xcd\\xef'
+ repeated_string: '\\xd0\\x9f\\xd1\\x80\\xd0\\xb8\\xd0\\xb2\\xd0\\xb5\\xd1\\x82'
+ """
+ _BYTES_SAMPLE = _UNICODE_SAMPLE.encode('utf-8')
+ _GOLDEN_UNICODE = u'Á short desçription'
+ _GOLDEN_BYTES = _GOLDEN_UNICODE.encode('utf-8')
+ _GOLDEN_BYTES_1 = b'\x12\x34\x56\x78\x90\xab\xcd\xef'
+ _GOLDEN_STR_0 = u'Привет'
+
+ def testParseUnicode(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.Parse(self._UNICODE_SAMPLE, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+ # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data.
+ self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1)
+ # repeated_string[0] contained \ escaped data representing the UTF-8
+ # representation of _GOLDEN_STR_0 - it needs to decode as such.
+ self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0)
+
+ def testParseBytes(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.Parse(self._BYTES_SAMPLE, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+ # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data.
+ self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1)
+ # repeated_string[0] contained \ escaped data representing the UTF-8
+ # representation of _GOLDEN_STR_0 - it needs to decode as such.
+ self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0)
+
+ def testFromBytesFile(self, message_module):
+ m = message_module.TestAllTypes()
+ f = io.BytesIO(self._BYTES_SAMPLE)
+ text_format.ParseLines(f, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+ def testFromUnicodeFile(self, message_module):
+ m = message_module.TestAllTypes()
+ f = io.StringIO(self._UNICODE_SAMPLE)
+ text_format.ParseLines(f, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+ def testFromBytesLines(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.ParseLines(self._BYTES_SAMPLE.split(b'\n'), m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+ def testFromUnicodeLines(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.ParseLines(self._UNICODE_SAMPLE.split(u'\n'), m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatMergeTests(TextFormatBase):
+
def testMergeDuplicateScalars(self, message_module):
message = message_module.TestAllTypes()
text = ('optional_int32: 42 ' 'optional_int32: 67')
@@ -472,26 +672,12 @@ class TextFormatTest(TextFormatBase):
self.assertTrue(r is message)
self.assertEqual(2, message.optional_nested_message.bb)
- def testParseOneof(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = message_module.TestAllTypes()
- text_format.Parse(text_format.MessageToString(m), m2)
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
-
def testMergeMultipleOneof(self, message_module):
m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
text_format.Merge(m_string, m2)
self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
- def testParseMultipleOneof(self, message_module):
- m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
- m2 = message_module.TestAllTypes()
- with self.assertRaisesRegexp(text_format.ParseError,
- ' is specified along with field '):
- text_format.Parse(m_string, m2)
-
# These are tests that aren't fundamentally specific to proto2, but are at
# the moment because of differences between the proto2 and proto3 test schemas.
@@ -938,7 +1124,7 @@ class Proto2Tests(TextFormatBase):
'}\n')
six.assertRaisesRegex(self,
text_format.ParseError,
- '5:1 : Expected ">".',
+ '5:1 : \'}\': Expected ">".',
text_format.Parse,
malformed,
message,
@@ -981,7 +1167,8 @@ class Proto2Tests(TextFormatBase):
with self.assertRaises(text_format.ParseError) as e:
text_format.Parse(text, message)
self.assertEqual(str(e.exception),
- '1:27 : Expected identifier or number, got "bb".')
+ '1:27 : \'optional_nested_message { "bb": 1 }\': '
+ 'Expected identifier or number, got "bb".')
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
@@ -998,7 +1185,8 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: 100'
six.assertRaisesRegex(self, text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ (r'1:23 : \'optional_nested_enum: 100\': '
+ r'Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value with number 100.'), text_format.Parse,
text, message)
@@ -1448,6 +1636,26 @@ class TokenizerTest(unittest.TestCase):
self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
+ def testConsumeOctalIntegers(self):
+ """Test support for C style octal integers."""
+ text = '00 -00 04 0755 -010 007 -0033 08 -09 01'
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(4, tokenizer.ConsumeInteger())
+ self.assertEqual(0o755, tokenizer.ConsumeInteger())
+ self.assertEqual(-0o10, tokenizer.ConsumeInteger())
+ self.assertEqual(7, tokenizer.ConsumeInteger())
+ self.assertEqual(-0o033, tokenizer.ConsumeInteger())
+ with self.assertRaises(text_format.ParseError):
+ tokenizer.ConsumeInteger() # 08
+ tokenizer.NextToken()
+ with self.assertRaises(text_format.ParseError):
+ tokenizer.ConsumeInteger() # -09
+ tokenizer.NextToken()
+ self.assertEqual(1, tokenizer.ConsumeInteger())
+ self.assertTrue(tokenizer.AtEnd())
+
def testConsumeByteString(self):
text = '"string1\''
tokenizer = text_format.Tokenizer(text.splitlines())
@@ -1556,6 +1764,12 @@ class TokenizerTest(unittest.TestCase):
tokenizer.ConsumeCommentOrTrailingComment())
self.assertTrue(tokenizer.AtEnd())
+ def testHugeString(self):
+ # With pathologic backtracking, fails with Forge OOM.
+ text = '"' + 'a' * (10 * 1024 * 1024) + '"'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ tokenizer.ConsumeString()
+
# Tests for pretty printer functionality.
@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 4a76cd4e..0807e7f7 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -185,6 +185,14 @@ class UnicodeValueChecker(object):
'encoding. Non-UTF-8 strings must be converted to '
'unicode objects before being added.' %
(proposed_value))
+ else:
+ try:
+ proposed_value.encode('utf8')
+ except UnicodeEncodeError:
+ raise ValueError('%.1024r isn\'t a valid unicode string and '
+ 'can\'t be encoded in UTF-8.'%
+ (proposed_value))
+
return proposed_value
def DefaultValue(self):
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 8b7de2e7..fceadf71 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
+from google.protobuf import descriptor
BaseTestCase = testing_refleaks.BaseTestCase
-# CheckUnknownField() cannot be used by the C++ implementation because
-# some protect members are called. It is not a behavior difference
-# for python and C++ implementation.
-def SkipCheckUnknownFieldIfCppImplementation(func):
- return unittest.skipIf(
- api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
- 'Addtional test for pure python involved protect members')(func)
-
-
class UnknownFieldsTest(BaseTestCase):
def setUp(self):
@@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase):
# stdout.
self.assertTrue(data == self.all_fields_data)
- def expectSerializeProto3(self, preserve):
+ def testSerializeProto3(self):
+ # Verify proto3 unknown fields behavior.
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
- if preserve:
- self.assertEqual(self.all_fields_data, message.SerializeToString())
- else:
- self.assertEqual(0, len(message.SerializeToString()))
-
- def testSerializeProto3(self):
- # Verify that proto3 unknown fields behavior.
- default_preserve = (api_implementation
- .GetPythonProto3PreserveUnknownsDefault())
- self.expectSerializeProto3(default_preserve)
- api_implementation.SetPythonProto3PreserveUnknownsDefault(
- not default_preserve)
- self.expectSerializeProto3(not default_preserve)
- api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve)
+ self.assertEqual(self.all_fields_data, message.SerializeToString())
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
@@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- # CheckUnknownField() is an additional Pure Python check which checks
+ # InternalCheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
-
- def CheckUnknownField(self, name, expected_value):
+ # TODO(jieluo): Remove message._unknown_fields.
+ def InternalCheckUnknownField(self, name, expected_value):
+ if api_implementation.Type() == 'cpp':
+ return
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
@@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
- decoder(value, 0, len(value), self.all_fields, result_dict)
+ decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
- @SkipCheckUnknownFieldIfCppImplementation
+ def CheckUnknownField(self, name, unknown_fields, expected_value):
+ field_descriptor = self.descriptor.fields_by_name[name]
+ expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
+ field_descriptor.type]
+ for unknown_field in unknown_fields:
+ if unknown_field.field_number == field_descriptor.number:
+ self.assertEqual(expected_type, unknown_field.wire_type)
+ if expected_type == 3:
+ # Check group
+ self.assertEqual(expected_value[0],
+ unknown_field.data[0].field_number)
+ self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
+ self.assertEqual(expected_value[2], unknown_field.data[0].data)
+ continue
+ if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ self.assertIn(unknown_field.data, expected_value)
+ else:
+ self.assertEqual(expected_value, unknown_field.data)
+
def testCheckUnknownFieldValue(self):
+ unknown_fields = self.empty_message.UnknownFields()
# Test enum.
self.CheckUnknownField('optional_nested_enum',
+ unknown_fields,
self.all_fields.optional_nested_enum)
+ self.InternalCheckUnknownField('optional_nested_enum',
+ self.all_fields.optional_nested_enum)
+
# Test repeated enum.
self.CheckUnknownField('repeated_nested_enum',
+ unknown_fields,
self.all_fields.repeated_nested_enum)
+ self.InternalCheckUnknownField('repeated_nested_enum',
+ self.all_fields.repeated_nested_enum)
# Test varint.
self.CheckUnknownField('optional_int32',
+ unknown_fields,
self.all_fields.optional_int32)
+ self.InternalCheckUnknownField('optional_int32',
+ self.all_fields.optional_int32)
+
# Test fixed32.
self.CheckUnknownField('optional_fixed32',
+ unknown_fields,
self.all_fields.optional_fixed32)
+ self.InternalCheckUnknownField('optional_fixed32',
+ self.all_fields.optional_fixed32)
# Test fixed64.
self.CheckUnknownField('optional_fixed64',
+ unknown_fields,
self.all_fields.optional_fixed64)
+ self.InternalCheckUnknownField('optional_fixed64',
+ self.all_fields.optional_fixed64)
# Test lengthd elimited.
self.CheckUnknownField('optional_string',
- self.all_fields.optional_string)
+ unknown_fields,
+ self.all_fields.optional_string.encode('utf-8'))
+ self.InternalCheckUnknownField('optional_string',
+ self.all_fields.optional_string)
# Test group.
self.CheckUnknownField('optionalgroup',
- self.all_fields.optionalgroup)
+ unknown_fields,
+ (17, 0, 117))
+ self.InternalCheckUnknownField('optionalgroup',
+ self.all_fields.optionalgroup)
+
+ self.assertEqual(97, len(unknown_fields))
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
@@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
message.optional_int64 = 3
message.optional_uint32 = 4
destination = unittest_pb2.TestEmptyMessage()
+ unknown_fields = destination.UnknownFields()
+ self.assertEqual(0, len(unknown_fields))
destination.ParseFromString(message.SerializeToString())
-
+ # ParseFromString clears the message thus unknown fields is invalid.
+ with self.assertRaises(ValueError) as context:
+ len(unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ unknown_fields = destination.UnknownFields()
+ self.assertEqual(2, len(unknown_fields))
destination.MergeFrom(source)
+ self.assertEqual(4, len(unknown_fields))
# Check that the fields where correctly merged, even stored in the unknown
# fields set.
message.ParseFromString(destination.SerializeToString())
@@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.assertEqual(message.optional_int64, 3)
def testClear(self):
+ unknown_fields = self.empty_message.UnknownFields()
self.empty_message.Clear()
# All cleared, even unknown fields.
self.assertEqual(self.empty_message.SerializeToString(), b'')
+ with self.assertRaises(ValueError) as context:
+ len(unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+
+ def testSubUnknownFields(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optionalgroup.a = 123
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ sub_unknown_fields = destination.UnknownFields()[0].data
+ self.assertEqual(1, len(sub_unknown_fields))
+ self.assertEqual(sub_unknown_fields[0].data, 123)
+ destination.Clear()
+ with self.assertRaises(ValueError) as context:
+ len(sub_unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ with self.assertRaises(ValueError) as context:
+ # pylint: disable=pointless-statement
+ sub_unknown_fields[0]
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ message.Clear()
+ message.optional_uint32 = 456
+ nested_message = unittest_pb2.NestedTestAllTypes()
+ nested_message.payload.optional_nested_message.ParseFromString(
+ message.SerializeToString())
+ unknown_fields = (
+ nested_message.payload.optional_nested_message.UnknownFields())
+ self.assertEqual(unknown_fields[0].data, 456)
+ nested_message.ClearField('payload')
+ self.assertEqual(unknown_fields[0].data, 456)
+ unknown_fields = (
+ nested_message.payload.optional_nested_message.UnknownFields())
+ self.assertEqual(0, len(unknown_fields))
+
+ def testUnknownField(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_int32 = 123
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ unknown_field = destination.UnknownFields()[0]
+ destination.Clear()
+ with self.assertRaises(ValueError) as context:
+ unknown_field.data # pylint: disable=pointless-statement
+ self.assertIn('The parent message might be cleared.',
+ str(context.exception))
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
@@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase):
def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
- wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
- field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
- result_dict = {}
- for tag_bytes, value in self.missing_message._unknown_fields:
- if tag_bytes == field_tag:
- decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
- tag_bytes][0]
- decoder(value, 0, len(value), self.message, result_dict)
- self.assertEqual(expected_value, result_dict[field_descriptor])
+ unknown_fields = self.missing_message.UnknownFields()
+ for field in unknown_fields:
+ if field.field_number == field_descriptor.number:
+ if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ self.assertIn(field.data, expected_value)
+ else:
+ self.assertEqual(expected_value, field.data)
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
@@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase):
def testUnknownPackedEnumValue(self):
self.assertEqual([], self.missing_message.packed_nested_enum)
- @SkipCheckUnknownFieldIfCppImplementation
def testCheckUnknownFieldValueForEnum(self):
self.CheckUnknownField('optional_nested_enum',
self.message.optional_nested_enum)
diff --git a/python/google/protobuf/json_format.py b/python/google/protobuf/json_format.py
index 58c94a47..ce1db7d7 100644
--- a/python/google/protobuf/json_format.py
+++ b/python/google/protobuf/json_format.py
@@ -482,7 +482,7 @@ class _Parser(object):
('Message type "{0}" has no field named "{1}".\n'
' Available Fields(except extensions): {2}').format(
message_descriptor.full_name, name,
- message_descriptor.fields))
+ [f.json_name for f in message_descriptor.fields]))
if name in names:
raise ParseError('Message type "{0}" should not have multiple '
'"{1}" fields.'.format(
diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py
index eeb0d576..eca2e0a9 100755
--- a/python/google/protobuf/message.py
+++ b/python/google/protobuf/message.py
@@ -268,6 +268,10 @@ class Message(object):
def ClearExtension(self, extension_handle):
raise NotImplementedError
+ def UnknownFields(self):
+ """Returns the UnknownFieldSet."""
+ raise NotImplementedError
+
def DiscardUnknownFields(self):
raise NotImplementedError
diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py
index e4fb065e..f3ab0a55 100644
--- a/python/google/protobuf/message_factory.py
+++ b/python/google/protobuf/message_factory.py
@@ -39,9 +39,18 @@ my_proto_instance = message_classes['some.proto.package.MessageName']()
__author__ = 'matthewtoia@google.com (Matt Toia)'
+from google.protobuf.internal import api_implementation
from google.protobuf import descriptor_pool
from google.protobuf import message
-from google.protobuf import reflection
+
+if api_implementation.Type() == 'cpp':
+ from google.protobuf.pyext import cpp_message as message_impl
+else:
+ from google.protobuf.internal import python_message as message_impl
+
+
+# The type of all Message classes.
+_GENERATED_PROTOCOL_MESSAGE_TYPE = message_impl.GeneratedProtocolMessageType
class MessageFactory(object):
@@ -70,11 +79,11 @@ class MessageFactory(object):
descriptor_name = descriptor.name
if str is bytes: # PY2
descriptor_name = descriptor.name.encode('ascii', 'ignore')
- result_class = reflection.GeneratedProtocolMessageType(
+ result_class = _GENERATED_PROTOCOL_MESSAGE_TYPE(
descriptor_name,
(message.Message,),
{'DESCRIPTOR': descriptor, '__module__': None})
- # If module not set, it wrongly points to the reflection.py module.
+ # If module not set, it wrongly points to message_factory module.
self._classes[descriptor] = result_class
for field in descriptor.fields:
if field.message_type:
diff --git a/python/google/protobuf/proto_api.h b/python/google/protobuf/proto_api.h
index 5c076d23..47edf0ea 100644
--- a/python/google/protobuf/proto_api.h
+++ b/python/google/protobuf/proto_api.h
@@ -42,16 +42,15 @@
// Then use the methods of the returned class:
// py_proto_api->GetMessagePointer(...);
-#ifndef PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
-#define PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
+#ifndef GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
+#define GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
#include <Python.h>
+#include <google/protobuf/message.h>
+
namespace google {
namespace protobuf {
-
-class Message;
-
namespace python {
// Note on the implementation:
@@ -89,4 +88,4 @@ inline const char* PyProtoAPICapsuleName() {
} // namespace protobuf
} // namespace google
-#endif // PYTHON_GOOGLE_PROTOBUF_PROTO_API_H__
+#endif // GOOGLE_PROTOBUF_PYTHON_PROTO_API_H__
diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc
index 19a1c38a..3cb16b74 100644
--- a/python/google/protobuf/pyext/descriptor.cc
+++ b/python/google/protobuf/pyext/descriptor.cc
@@ -32,8 +32,8 @@
#include <Python.h>
#include <frameobject.h>
-#include <google/protobuf/stubs/hash.h>
#include <string>
+#include <unordered_map>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/descriptor.pb.h>
@@ -44,6 +44,7 @@
#include <google/protobuf/pyext/message.h>
#include <google/protobuf/pyext/message_factory.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+#include <google/protobuf/stubs/hash.h>
#if PY_MAJOR_VERSION >= 3
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
@@ -54,10 +55,12 @@
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
- #define PyString_AsStringAndSize(ob, charpp, sizep) \
- (PyUnicode_Check(ob)? \
- ((*(charpp) = const_cast<char*>(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \
- PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
+ PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
+ ? -1 \
+ : 0) \
+ : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
@@ -70,7 +73,7 @@ namespace python {
// released.
// This is enough to support the "is" operator on live objects.
// All descriptors are stored here.
-hash_map<const void*, PyObject*> interned_descriptors;
+std::unordered_map<const void*, PyObject*>* interned_descriptors;
PyObject* PyString_FromCppString(const string& str) {
return PyString_FromStringAndSize(str.c_str(), str.size());
@@ -119,8 +122,10 @@ bool _CalledFromGeneratedFile(int stacklevel) {
PyErr_Clear();
return false;
}
- if ((filename_size < 3) || (strcmp(&filename[filename_size - 3], ".py") != 0)) {
- // Cython's stack does not have .py file name and is not at global module scope.
+ if ((filename_size < 3) ||
+ (strcmp(&filename[filename_size - 3], ".py") != 0)) {
+ // Cython's stack does not have .py file name and is not at global module
+ // scope.
return true;
}
if (filename_size < 7) {
@@ -131,7 +136,7 @@ bool _CalledFromGeneratedFile(int stacklevel) {
// Filename is not ending with _pb2.
return false;
}
-
+
if (frame->f_globals != frame->f_locals) {
// Not at global module scope
return false;
@@ -197,7 +202,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
// First search in the cache.
PyDescriptorPool* caching_pool = GetDescriptorPool_FromPool(
GetFileDescriptor(descriptor)->pool());
- hash_map<const void*, PyObject*>* descriptor_options =
+ std::unordered_map<const void*, PyObject*>* descriptor_options =
caching_pool->descriptor_options;
if (descriptor_options->find(descriptor) != descriptor_options->end()) {
PyObject *value = (*descriptor_options)[descriptor];
@@ -232,7 +237,7 @@ static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) {
if (value == NULL) {
return NULL;
}
- if (!PyObject_TypeCheck(value.get(), &CMessage_Type)) {
+ if (!PyObject_TypeCheck(value.get(), CMessage_Type)) {
PyErr_Format(PyExc_TypeError, "Invalid class for %s: %s",
message_type->full_name().c_str(),
Py_TYPE(value.get())->tp_name);
@@ -275,7 +280,7 @@ static PyObject* CopyToPythonProto(const DescriptorClass *descriptor,
const Descriptor* self_descriptor =
DescriptorProtoClass::default_instance().GetDescriptor();
CMessage* message = reinterpret_cast<CMessage*>(target);
- if (!PyObject_TypeCheck(target, &CMessage_Type) ||
+ if (!PyObject_TypeCheck(target, CMessage_Type) ||
message->message->GetDescriptor() != self_descriptor) {
PyErr_Format(PyExc_TypeError, "Not a %s message",
self_descriptor->full_name().c_str());
@@ -332,9 +337,9 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
}
// See if the object is in the map of interned descriptors
- hash_map<const void*, PyObject*>::iterator it =
- interned_descriptors.find(descriptor);
- if (it != interned_descriptors.end()) {
+ std::unordered_map<const void*, PyObject*>::iterator it =
+ interned_descriptors->find(descriptor);
+ if (it != interned_descriptors->end()) {
GOOGLE_DCHECK(Py_TYPE(it->second) == type);
Py_INCREF(it->second);
return it->second;
@@ -348,7 +353,7 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
py_descriptor->descriptor = descriptor;
// and cache it.
- interned_descriptors.insert(
+ interned_descriptors->insert(
std::make_pair(descriptor, reinterpret_cast<PyObject*>(py_descriptor)));
// Ensures that the DescriptorPool stays alive.
@@ -370,7 +375,7 @@ PyObject* NewInternedDescriptor(PyTypeObject* type,
static void Dealloc(PyBaseDescriptor* self) {
// Remove from interned dictionary
- interned_descriptors.erase(self->descriptor);
+ interned_descriptors->erase(self->descriptor);
Py_CLEAR(self->pool);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -758,6 +763,11 @@ static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) {
static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) {
PyObject *result;
+ if (_GetDescriptor(self)->is_repeated()) {
+ return PyList_New(0);
+ }
+
+
switch (_GetDescriptor(self)->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32: {
int32 value = _GetDescriptor(self)->default_value_int32();
@@ -805,6 +815,10 @@ static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) {
result = PyInt_FromLong(value->number());
break;
}
+ case FieldDescriptor::CPPTYPE_MESSAGE: {
+ Py_RETURN_NONE;
+ break;
+ }
default:
PyErr_Format(PyExc_NotImplementedError, "default value for %s",
_GetDescriptor(self)->full_name().c_str());
@@ -1919,6 +1933,9 @@ bool InitDescriptor() {
if (!InitDescriptorMappingTypes())
return false;
+ // Initialize globals defined in this file.
+ interned_descriptors = new std::unordered_map<const void*, PyObject*>;
+
return true;
}
diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h
index f081df84..c4dde9e7 100644
--- a/python/google/protobuf/pyext/descriptor.h
+++ b/python/google/protobuf/pyext/descriptor.h
@@ -100,6 +100,6 @@ bool InitDescriptor();
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__
diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc
index 0153664f..d5b5dc68 100644
--- a/python/google/protobuf/pyext/descriptor_containers.cc
+++ b/python/google/protobuf/pyext/descriptor_containers.cc
@@ -33,7 +33,7 @@
//
// They avoid the allocation of a full dictionary or a full list: they simply
// store a pointer to the parent descriptor, use the C++ Descriptor methods (see
-// google/protobuf/descriptor.h) to retrieve other descriptors, and create
+// net/proto2/public/descriptor.h) to retrieve other descriptors, and create
// Python objects on the fly.
//
// The containers fully conform to abc.Mapping and abc.Sequence, and behave just
@@ -64,10 +64,12 @@
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
- #define PyString_AsStringAndSize(ob, charpp, sizep) \
- (PyUnicode_Check(ob)? \
- ((*(charpp) = const_cast<char*>(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \
- PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
+ PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
+ ? -1 \
+ : 0) \
+ : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h
index 83de07b6..4e05c58e 100644
--- a/python/google/protobuf/pyext/descriptor_containers.h
+++ b/python/google/protobuf/pyext/descriptor_containers.h
@@ -104,6 +104,6 @@ PyObject* NewServiceMethodsByName(const ServiceDescriptor* descriptor);
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_CONTAINERS_H__
diff --git a/python/google/protobuf/pyext/descriptor_database.cc b/python/google/protobuf/pyext/descriptor_database.cc
index daa40cc7..0514b35c 100644
--- a/python/google/protobuf/pyext/descriptor_database.cc
+++ b/python/google/protobuf/pyext/descriptor_database.cc
@@ -70,7 +70,7 @@ static bool GetFileDescriptorProto(PyObject* py_descriptor,
const Descriptor* filedescriptor_descriptor =
FileDescriptorProto::default_instance().GetDescriptor();
CMessage* message = reinterpret_cast<CMessage*>(py_descriptor);
- if (PyObject_TypeCheck(py_descriptor, &CMessage_Type) &&
+ if (PyObject_TypeCheck(py_descriptor, CMessage_Type) &&
message->message->GetDescriptor() == filedescriptor_descriptor) {
// Fast path: Just use the pointer.
FileDescriptorProto* file_proto =
@@ -143,6 +143,43 @@ bool PyDescriptorDatabase::FindFileContainingExtension(
return GetFileDescriptorProto(py_descriptor.get(), output);
}
+// Finds the tag numbers used by all known extensions of
+// containing_type, and appends them to output in an undefined
+// order.
+// Python DescriptorDatabases are not required to implement this method.
+bool PyDescriptorDatabase::FindAllExtensionNumbers(
+ const string& containing_type, std::vector<int>* output) {
+ ScopedPyObjectPtr py_method(
+ PyObject_GetAttrString(py_database_, "FindAllExtensionNumbers"));
+ if (py_method == NULL) {
+ // This method is not implemented, returns without error.
+ PyErr_Clear();
+ return false;
+ }
+ ScopedPyObjectPtr py_list(
+ PyObject_CallFunction(py_method.get(), "s#", containing_type.c_str(),
+ containing_type.size()));
+ if (py_list == NULL) {
+ PyErr_Print();
+ return false;
+ }
+ Py_ssize_t size = PyList_Size(py_list.get());
+ int64 item_value;
+ for (Py_ssize_t i = 0 ; i < size; ++i) {
+ ScopedPyObjectPtr item(PySequence_GetItem(py_list.get(), i));
+ item_value = PyLong_AsLong(item.get());
+ if (item_value < 0) {
+ GOOGLE_LOG(ERROR)
+ << "FindAllExtensionNumbers method did not return "
+ << "valid extension numbers.";
+ PyErr_Print();
+ return false;
+ }
+ output->push_back(item_value);
+ }
+ return true;
+}
+
} // namespace python
} // namespace protobuf
} // namespace google
diff --git a/python/google/protobuf/pyext/descriptor_database.h b/python/google/protobuf/pyext/descriptor_database.h
index fc71c4bc..daf25e0b 100644
--- a/python/google/protobuf/pyext/descriptor_database.h
+++ b/python/google/protobuf/pyext/descriptor_database.h
@@ -63,6 +63,13 @@ class PyDescriptorDatabase : public DescriptorDatabase {
int field_number,
FileDescriptorProto* output);
+ // Finds the tag numbers used by all known extensions of
+ // containing_type, and appends them to output in an undefined
+ // order.
+ // Python objects are not required to implement this method.
+ bool FindAllExtensionNumbers(const string& containing_type,
+ std::vector<int>* output);
+
private:
// The python object that implements the database. The reference is owned.
PyObject* py_database_;
@@ -70,6 +77,6 @@ class PyDescriptorDatabase : public DescriptorDatabase {
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_DATABASE_H__
diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc
index 962accc6..d0038b10 100644
--- a/python/google/protobuf/pyext/descriptor_pool.cc
+++ b/python/google/protobuf/pyext/descriptor_pool.cc
@@ -30,6 +30,8 @@
// Implements the DescriptorPool, which collects all descriptors.
+#include <unordered_map>
+
#include <Python.h>
#include <google/protobuf/descriptor.pb.h>
@@ -46,10 +48,12 @@
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
- #define PyString_AsStringAndSize(ob, charpp, sizep) \
- (PyUnicode_Check(ob)? \
- ((*(charpp) = const_cast<char*>(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \
- PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
+ PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
+ ? -1 \
+ : 0) \
+ : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
@@ -58,7 +62,8 @@ namespace python {
// A map to cache Python Pools per C++ pointer.
// Pointers are not owned here, and belong to the PyDescriptorPool.
-static hash_map<const DescriptorPool*, PyDescriptorPool*> descriptor_pool_map;
+static std::unordered_map<const DescriptorPool*, PyDescriptorPool*>*
+ descriptor_pool_map;
namespace cdescriptor_pool {
@@ -74,8 +79,7 @@ static PyDescriptorPool* _CreateDescriptorPool() {
cpool->underlay = NULL;
cpool->database = NULL;
- cpool->descriptor_options =
- new hash_map<const void*, PyObject *>();
+ cpool->descriptor_options = new std::unordered_map<const void*, PyObject*>();
cpool->py_message_factory = message_factory::NewMessageFactory(
&PyMessageFactory_Type, cpool);
@@ -101,7 +105,7 @@ static PyDescriptorPool* PyDescriptorPool_NewWithUnderlay(
cpool->pool = new DescriptorPool(underlay);
cpool->underlay = underlay;
- if (!descriptor_pool_map.insert(
+ if (!descriptor_pool_map->insert(
std::make_pair(cpool->pool, cpool)).second) {
// Should never happen -- would indicate an internal error / bug.
PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
@@ -124,7 +128,7 @@ static PyDescriptorPool* PyDescriptorPool_NewWithDatabase(
cpool->pool = new DescriptorPool();
}
- if (!descriptor_pool_map.insert(std::make_pair(cpool->pool, cpool)).second) {
+ if (!descriptor_pool_map->insert(std::make_pair(cpool->pool, cpool)).second) {
// Should never happen -- would indicate an internal error / bug.
PyErr_SetString(PyExc_ValueError, "DescriptorPool already registered");
return NULL;
@@ -151,9 +155,9 @@ static PyObject* New(PyTypeObject* type,
static void Dealloc(PyObject* pself) {
PyDescriptorPool* self = reinterpret_cast<PyDescriptorPool*>(pself);
- descriptor_pool_map.erase(self->pool);
+ descriptor_pool_map->erase(self->pool);
Py_CLEAR(self->py_message_factory);
- for (hash_map<const void*, PyObject*>::iterator it =
+ for (std::unordered_map<const void*, PyObject*>::iterator it =
self->descriptor_options->begin();
it != self->descriptor_options->end(); ++it) {
Py_DECREF(it->second);
@@ -180,6 +184,7 @@ static PyObject* FindMessageByName(PyObject* self, PyObject* arg) {
return NULL;
}
+
return PyMessageDescriptor_FromDescriptor(message_descriptor);
}
@@ -218,6 +223,7 @@ PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) {
return NULL;
}
+
return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
@@ -239,6 +245,7 @@ PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) {
return NULL;
}
+
return PyFieldDescriptor_FromDescriptor(field_descriptor);
}
@@ -260,6 +267,7 @@ PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) {
return NULL;
}
+
return PyEnumDescriptor_FromDescriptor(enum_descriptor);
}
@@ -281,6 +289,7 @@ PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) {
return NULL;
}
+
return PyOneofDescriptor_FromDescriptor(oneof_descriptor);
}
@@ -303,6 +312,7 @@ static PyObject* FindServiceByName(PyObject* self, PyObject* arg) {
return NULL;
}
+
return PyServiceDescriptor_FromDescriptor(service_descriptor);
}
@@ -321,6 +331,7 @@ static PyObject* FindMethodByName(PyObject* self, PyObject* arg) {
return NULL;
}
+
return PyMethodDescriptor_FromDescriptor(method_descriptor);
}
@@ -339,6 +350,7 @@ static PyObject* FindFileContainingSymbol(PyObject* self, PyObject* arg) {
return NULL;
}
+
return PyFileDescriptor_FromDescriptor(file_descriptor);
}
@@ -362,6 +374,7 @@ static PyObject* FindExtensionByNumber(PyObject* self, PyObject* args) {
return NULL;
}
+
return PyFieldDescriptor_FromDescriptor(extension_descriptor);
}
@@ -668,13 +681,17 @@ bool InitDescriptorPool() {
// The Pool of messages declared in Python libraries.
// generated_pool() contains all messages already linked in C++ libraries, and
// is used as underlay.
+ descriptor_pool_map =
+ new std::unordered_map<const DescriptorPool*, PyDescriptorPool*>;
python_generated_pool = cdescriptor_pool::PyDescriptorPool_NewWithUnderlay(
DescriptorPool::generated_pool());
if (python_generated_pool == NULL) {
+ delete descriptor_pool_map;
return false;
}
+
// Register this pool to be found for C++-generated descriptors.
- descriptor_pool_map.insert(
+ descriptor_pool_map->insert(
std::make_pair(DescriptorPool::generated_pool(),
python_generated_pool));
@@ -695,9 +712,9 @@ PyDescriptorPool* GetDescriptorPool_FromPool(const DescriptorPool* pool) {
pool == DescriptorPool::generated_pool()) {
return python_generated_pool;
}
- hash_map<const DescriptorPool*, PyDescriptorPool*>::iterator it =
- descriptor_pool_map.find(pool);
- if (it == descriptor_pool_map.end()) {
+ std::unordered_map<const DescriptorPool*, PyDescriptorPool*>::iterator it =
+ descriptor_pool_map->find(pool);
+ if (it == descriptor_pool_map->end()) {
PyErr_SetString(PyExc_KeyError, "Unknown descriptor pool");
return NULL;
}
diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h
index 53ee53dc..8289daea 100644
--- a/python/google/protobuf/pyext/descriptor_pool.h
+++ b/python/google/protobuf/pyext/descriptor_pool.h
@@ -33,7 +33,7 @@
#include <Python.h>
-#include <google/protobuf/stubs/hash.h>
+#include <unordered_map>
#include <google/protobuf/descriptor.h>
namespace google {
@@ -77,7 +77,7 @@ typedef struct PyDescriptorPool {
// Cache the options for any kind of descriptor.
// Descriptor pointers are owned by the DescriptorPool above.
// Python objects are owned by the map.
- hash_map<const void*, PyObject*>* descriptor_options;
+ std::unordered_map<const void*, PyObject*>* descriptor_options;
} PyDescriptorPool;
@@ -140,6 +140,6 @@ bool InitDescriptorPool();
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__
diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc
index 174c5470..b73368eb 100644
--- a/python/google/protobuf/pyext/extension_dict.cc
+++ b/python/google/protobuf/pyext/extension_dict.cc
@@ -51,10 +51,12 @@
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#endif
- #define PyString_AsStringAndSize(ob, charpp, sizep) \
- (PyUnicode_Check(ob)? \
- ((*(charpp) = const_cast<char*>(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \
- PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
+ PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
+ ? -1 \
+ : 0) \
+ : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
#endif
namespace google {
@@ -63,40 +65,25 @@ namespace python {
namespace extension_dict {
-PyObject* len(ExtensionDict* self) {
-#if PY_MAJOR_VERSION >= 3
- return PyLong_FromLong(PyDict_Size(self->values));
-#else
- return PyInt_FromLong(PyDict_Size(self->values));
-#endif
-}
-
PyObject* subscript(ExtensionDict* self, PyObject* key) {
const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
if (descriptor == NULL) {
return NULL;
}
- if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
+ if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
return NULL;
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) {
- return cmessage::InternalGetScalar(self->message, descriptor);
+ return cmessage::InternalGetScalar(self->parent->message, descriptor);
}
- PyObject* value = PyDict_GetItem(self->values, key);
- if (value != NULL) {
- Py_INCREF(value);
- return value;
- }
-
- if (self->parent == NULL) {
- // We are in "detached" state. Don't allow further modifications.
- // TODO(amauryfa): Support adding non-scalars to a detached extension dict.
- // This probably requires to store the type of the main message.
- PyErr_SetObject(PyExc_KeyError, key);
- return NULL;
+ CMessage::CompositeFieldsMap::iterator iterator =
+ self->parent->composite_fields->find(descriptor);
+ if (iterator != self->parent->composite_fields->end()) {
+ Py_INCREF(iterator->second);
+ return iterator->second;
}
if (descriptor->label() != FieldDescriptor::LABEL_REPEATED &&
@@ -107,7 +94,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (sub_message == NULL) {
return NULL;
}
- PyDict_SetItem(self->values, key, sub_message);
+ Py_INCREF(sub_message);
+ (*self->parent->composite_fields)[descriptor] = sub_message;
return sub_message;
}
@@ -136,7 +124,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (py_container == NULL) {
return NULL;
}
- PyDict_SetItem(self->values, key, py_container);
+ Py_INCREF(py_container);
+ (*self->parent->composite_fields)[descriptor] = py_container;
return py_container;
} else {
PyObject* py_container = repeated_scalar_container::NewContainer(
@@ -144,7 +133,8 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) {
if (py_container == NULL) {
return NULL;
}
- PyDict_SetItem(self->values, key, py_container);
+ Py_INCREF(py_container);
+ (*self->parent->composite_fields)[descriptor] = py_container;
return py_container;
}
}
@@ -157,7 +147,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
if (descriptor == NULL) {
return -1;
}
- if (!CheckFieldBelongsToMessage(descriptor, self->message)) {
+ if (!CheckFieldBelongsToMessage(descriptor, self->parent->message)) {
return -1;
}
@@ -167,14 +157,10 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) {
"type");
return -1;
}
- if (self->parent) {
- cmessage::AssureWritable(self->parent);
- if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
- return -1;
- }
+ cmessage::AssureWritable(self->parent);
+ if (cmessage::InternalSetScalar(self->parent, descriptor, value) < 0) {
+ return -1;
}
- // TODO(tibell): We shouldn't write scalars to the cache.
- PyDict_SetItem(self->values, key, value);
return 0;
}
@@ -232,22 +218,36 @@ ExtensionDict* NewExtensionDict(CMessage *parent) {
return NULL;
}
- self->parent = parent; // Store a borrowed reference.
- self->message = parent->message;
- self->owner = parent->owner;
- self->values = PyDict_New();
+ Py_INCREF(parent);
+ self->parent = parent;
return self;
}
void dealloc(ExtensionDict* self) {
- Py_CLEAR(self->values);
- self->owner.reset();
+ Py_CLEAR(self->parent);
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
+static PyObject* RichCompare(ExtensionDict* self, PyObject* other, int opid) {
+ // Only equality comparisons are implemented.
+ if (opid != Py_EQ && opid != Py_NE) {
+ Py_INCREF(Py_NotImplemented);
+ return Py_NotImplemented;
+ }
+ bool equals = false;
+ if (PyObject_TypeCheck(other, &ExtensionDict_Type)) {
+ equals = self->parent == reinterpret_cast<ExtensionDict*>(other)->parent;;
+ }
+ if (equals ^ (opid == Py_EQ)) {
+ Py_RETURN_FALSE;
+ } else {
+ Py_RETURN_TRUE;
+ }
+}
+
static PyMappingMethods MpMethods = {
- (lenfunc)len, /* mp_length */
- (binaryfunc)subscript, /* mp_subscript */
+ (lenfunc)NULL, /* mp_length */
+ (binaryfunc)subscript, /* mp_subscript */
(objobjargproc)ass_subscript,/* mp_ass_subscript */
};
@@ -286,7 +286,7 @@ PyTypeObject ExtensionDict_Type = {
"An extension dict", // tp_doc
0, // tp_traverse
0, // tp_clear
- 0, // tp_richcompare
+ (richcmpfunc)extension_dict::RichCompare, // tp_richcompare
0, // tp_weaklistoffset
0, // tp_iter
0, // tp_iternext
diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h
index 0de2c4ee..d800d479 100644
--- a/python/google/protobuf/pyext/extension_dict.h
+++ b/python/google/protobuf/pyext/extension_dict.h
@@ -37,6 +37,7 @@
#include <Python.h>
#include <memory>
+#include <hash_map>
#include <google/protobuf/pyext/message.h>
@@ -51,23 +52,8 @@ namespace python {
typedef struct ExtensionDict {
PyObject_HEAD;
- // This is the top-level C++ Message object that owns the whole
- // proto tree. Every Python container class holds a
- // reference to it in order to keep it alive as long as there's a
- // Python object that references any part of the tree.
- CMessage::OwnerRef owner;
-
- // Weak reference to parent message. Used to make sure
- // the parent is writable when an extension field is modified.
+ // Strong, owned reference to the parent message. Never NULL.
CMessage* parent;
-
- // Pointer to the C++ Message that this ExtensionDict extends.
- // Not owned by us.
- Message* message;
-
- // A dict of child messages, indexed by Extension descriptors.
- // Similar to CMessage::composite_fields.
- PyObject* values;
} ExtensionDict;
extern PyTypeObject ExtensionDict_Type;
@@ -80,6 +66,6 @@ ExtensionDict* NewExtensionDict(CMessage *parent);
} // namespace extension_dict
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_EXTENSION_DICT_H__
diff --git a/python/google/protobuf/pyext/map_container.cc b/python/google/protobuf/pyext/map_container.cc
index 6d7ee285..3eec49c7 100644
--- a/python/google/protobuf/pyext/map_container.cc
+++ b/python/google/protobuf/pyext/map_container.cc
@@ -68,12 +68,14 @@ class MapReflectionFriend {
static PyObject* MessageMapGetItem(PyObject* _self, PyObject* key);
static int ScalarMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
static int MessageMapSetItem(PyObject* _self, PyObject* key, PyObject* v);
+ static PyObject* ScalarMapToStr(PyObject* _self);
+ static PyObject* MessageMapToStr(PyObject* _self);
};
struct MapIterator {
PyObject_HEAD;
- std::unique_ptr<::google::protobuf::MapIterator> iter;
+ std::unique_ptr<::proto2::MapIterator> iter;
// A pointer back to the container, so we can notice changes to the version.
// We own a ref on this.
@@ -199,26 +201,26 @@ static PyObject* MapKeyToPython(const FieldDescriptor* field_descriptor,
// This is only used for ScalarMap, so we don't need to handle the
// CPPTYPE_MESSAGE case.
PyObject* MapValueRefToPython(const FieldDescriptor* field_descriptor,
- MapValueRef* value) {
+ const MapValueRef& value) {
switch (field_descriptor->cpp_type()) {
case FieldDescriptor::CPPTYPE_INT32:
- return PyInt_FromLong(value->GetInt32Value());
+ return PyInt_FromLong(value.GetInt32Value());
case FieldDescriptor::CPPTYPE_INT64:
- return PyLong_FromLongLong(value->GetInt64Value());
+ return PyLong_FromLongLong(value.GetInt64Value());
case FieldDescriptor::CPPTYPE_UINT32:
- return PyInt_FromSize_t(value->GetUInt32Value());
+ return PyInt_FromSize_t(value.GetUInt32Value());
case FieldDescriptor::CPPTYPE_UINT64:
- return PyLong_FromUnsignedLongLong(value->GetUInt64Value());
+ return PyLong_FromUnsignedLongLong(value.GetUInt64Value());
case FieldDescriptor::CPPTYPE_FLOAT:
- return PyFloat_FromDouble(value->GetFloatValue());
+ return PyFloat_FromDouble(value.GetFloatValue());
case FieldDescriptor::CPPTYPE_DOUBLE:
- return PyFloat_FromDouble(value->GetDoubleValue());
+ return PyFloat_FromDouble(value.GetDoubleValue());
case FieldDescriptor::CPPTYPE_BOOL:
- return PyBool_FromLong(value->GetBoolValue());
+ return PyBool_FromLong(value.GetBoolValue());
case FieldDescriptor::CPPTYPE_STRING:
- return ToStringObject(field_descriptor, value->GetStringValue());
+ return ToStringObject(field_descriptor, value.GetStringValue());
case FieldDescriptor::CPPTYPE_ENUM:
- return PyInt_FromLong(value->GetEnumValue());
+ return PyInt_FromLong(value.GetEnumValue());
default:
PyErr_Format(
PyExc_SystemError, "Couldn't convert type %d to value",
@@ -312,7 +314,7 @@ static MapContainer* GetMap(PyObject* obj) {
Py_ssize_t MapReflectionFriend::Length(PyObject* _self) {
MapContainer* self = GetMap(_self);
- const google::protobuf::Message* message = self->message;
+ const proto2::Message* message = self->message;
return message->GetReflection()->MapSize(*message,
self->parent_field_descriptor);
}
@@ -421,7 +423,7 @@ int MapContainer::Release() {
// ScalarMap ///////////////////////////////////////////////////////////////////
PyObject *NewScalarMapContainer(
- CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) {
+ CMessage* parent, const proto2::FieldDescriptor* parent_field_descriptor) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return NULL;
}
@@ -472,7 +474,7 @@ PyObject* MapReflectionFriend::ScalarMapGetItem(PyObject* _self,
self->version++;
}
- return MapValueRefToPython(self->value_field_descriptor, &value);
+ return MapValueRefToPython(self->value_field_descriptor, value);
}
int MapReflectionFriend::ScalarMapSetItem(PyObject* _self, PyObject* key,
@@ -535,10 +537,47 @@ static PyObject* ScalarMapGet(PyObject* self, PyObject* args) {
}
}
+PyObject* MapReflectionFriend::ScalarMapToStr(PyObject* _self) {
+ ScopedPyObjectPtr dict(PyDict_New());
+ if (dict == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr key;
+ ScopedPyObjectPtr value;
+
+ MapContainer* self = GetMap(_self);
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ for (proto2::MapIterator it = reflection->MapBegin(
+ message, self->parent_field_descriptor);
+ it != reflection->MapEnd(message, self->parent_field_descriptor);
+ ++it) {
+ key.reset(MapKeyToPython(self->key_field_descriptor,
+ it.GetKey()));
+ if (key == NULL) {
+ return NULL;
+ }
+ value.reset(MapValueRefToPython(self->value_field_descriptor,
+ it.GetValueRef()));
+ if (value == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
+ return NULL;
+ }
+ }
+ return PyObject_Repr(dict.get());
+}
+
static void ScalarMapDealloc(PyObject* _self) {
MapContainer* self = GetMap(_self);
self->owner.reset();
- Py_TYPE(_self)->tp_free(_self);
+ PyTypeObject *type = Py_TYPE(_self);
+ type->tp_free(_self);
+ if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
+ // With Python3, the Map class is not static, and must be managed.
+ Py_DECREF(type);
+ }
}
static PyMethodDef ScalarMapMethods[] = {
@@ -570,6 +609,7 @@ PyTypeObject *ScalarMapContainer_Type;
{Py_mp_ass_subscript, (void *)MapReflectionFriend::ScalarMapSetItem},
{Py_tp_methods, (void *)ScalarMapMethods},
{Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
+ {Py_tp_repr, (void *)MapReflectionFriend::ScalarMapToStr},
{0, 0},
};
@@ -597,7 +637,7 @@ PyTypeObject *ScalarMapContainer_Type;
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ MapReflectionFriend::ScalarMapToStr, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
&ScalarMapMappingMethods, // tp_as_mapping
@@ -634,7 +674,8 @@ static MessageMapContainer* GetMessageMap(PyObject* obj) {
return reinterpret_cast<MessageMapContainer*>(obj);
}
-static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
+static PyObject* GetCMessage(MessageMapContainer* self, Message* message,
+ bool insert_message_dict) {
// Get or create the CMessage object corresponding to this message.
ScopedPyObjectPtr key(PyLong_FromVoidPtr(message));
PyObject* ret = PyDict_GetItem(self->message_dict, key.get());
@@ -649,10 +690,11 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
cmsg->owner = self->owner;
cmsg->message = message;
cmsg->parent = self->parent;
-
- if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
- Py_DECREF(ret);
- return NULL;
+ if (insert_message_dict) {
+ if (PyDict_SetItem(self->message_dict, key.get(), ret) < 0) {
+ Py_DECREF(ret);
+ return NULL;
+ }
}
} else {
Py_INCREF(ret);
@@ -662,7 +704,7 @@ static PyObject* GetCMessage(MessageMapContainer* self, Message* message) {
}
PyObject* NewMessageMapContainer(
- CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor,
+ CMessage* parent, const proto2::FieldDescriptor* parent_field_descriptor,
CMessageClass* message_class) {
if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) {
return NULL;
@@ -781,7 +823,41 @@ PyObject* MapReflectionFriend::MessageMapGetItem(PyObject* _self,
self->version++;
}
- return GetCMessage(self, value.MutableMessageValue());
+ return GetCMessage(self, value.MutableMessageValue(), true);
+}
+
+PyObject* MapReflectionFriend::MessageMapToStr(PyObject* _self) {
+ ScopedPyObjectPtr dict(PyDict_New());
+ if (dict == NULL) {
+ return NULL;
+ }
+ ScopedPyObjectPtr key;
+ ScopedPyObjectPtr value;
+
+ MessageMapContainer* self = GetMessageMap(_self);
+ Message* message = self->GetMutableMessage();
+ const Reflection* reflection = message->GetReflection();
+ for (proto2::MapIterator it = reflection->MapBegin(
+ message, self->parent_field_descriptor);
+ it != reflection->MapEnd(message, self->parent_field_descriptor);
+ ++it) {
+ key.reset(MapKeyToPython(self->key_field_descriptor,
+ it.GetKey()));
+ if (key == NULL) {
+ return NULL;
+ }
+ // Do not insert the cmessage to self->message_dict because
+ // the returned CMessage will not escape this function.
+ value.reset(GetCMessage(
+ self, it.MutableValueRef()->MutableMessageValue(), false));
+ if (value == NULL) {
+ return NULL;
+ }
+ if (PyDict_SetItem(dict.get(), key.get(), value.get()) < 0) {
+ return NULL;
+ }
+ }
+ return PyObject_Repr(dict.get());
}
PyObject* MessageMapGet(PyObject* self, PyObject* args) {
@@ -813,7 +889,12 @@ static void MessageMapDealloc(PyObject* _self) {
self->owner.reset();
Py_DECREF(self->message_dict);
Py_DECREF(self->message_class);
- Py_TYPE(_self)->tp_free(_self);
+ PyTypeObject *type = Py_TYPE(_self);
+ type->tp_free(_self);
+ if (type->tp_flags & Py_TPFLAGS_HEAPTYPE) {
+ // With Python3, the Map class is not static, and must be managed.
+ Py_DECREF(type);
+ }
}
static PyMethodDef MessageMapMethods[] = {
@@ -847,6 +928,7 @@ PyTypeObject *MessageMapContainer_Type;
{Py_mp_ass_subscript, (void *)MapReflectionFriend::MessageMapSetItem},
{Py_tp_methods, (void *)MessageMapMethods},
{Py_tp_iter, (void *)MapReflectionFriend::GetIterator},
+ {Py_tp_repr, (void *)MapReflectionFriend::MessageMapToStr},
{0, 0}
};
@@ -874,7 +956,7 @@ PyTypeObject *MessageMapContainer_Type;
0, // tp_getattr
0, // tp_setattr
0, // tp_compare
- 0, // tp_repr
+ MapReflectionFriend::MessageMapToStr, // tp_repr
0, // tp_as_number
0, // tp_as_sequence
&MessageMapMappingMethods, // tp_as_mapping
@@ -929,7 +1011,7 @@ PyObject* MapReflectionFriend::GetIterator(PyObject *_self) {
Message* message = self->GetMutableMessage();
const Reflection* reflection = message->GetReflection();
- iter->iter.reset(new ::google::protobuf::MapIterator(
+ iter->iter.reset(new ::proto2::MapIterator(
reflection->MapBegin(message, self->parent_field_descriptor)));
}
@@ -1027,17 +1109,15 @@ bool InitMapContainers() {
return false;
}
- if (!PyObject_TypeCheck(mutable_mapping.get(), &PyType_Type)) {
- return false;
- }
-
Py_INCREF(mutable_mapping.get());
#if PY_MAJOR_VERSION >= 3
- PyObject* bases = PyTuple_New(1);
- PyTuple_SET_ITEM(bases, 0, mutable_mapping.get());
+ ScopedPyObjectPtr bases(PyTuple_Pack(1, mutable_mapping.get()));
+ if (bases == NULL) {
+ return false;
+ }
ScalarMapContainer_Type = reinterpret_cast<PyTypeObject*>(
- PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases));
+ PyType_FromSpecWithBases(&ScalarMapContainer_Type_spec, bases.get()));
#else
_ScalarMapContainer_Type.tp_base =
reinterpret_cast<PyTypeObject*>(mutable_mapping.get());
@@ -1055,7 +1135,7 @@ bool InitMapContainers() {
#if PY_MAJOR_VERSION >= 3
MessageMapContainer_Type = reinterpret_cast<PyTypeObject*>(
- PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases));
+ PyType_FromSpecWithBases(&MessageMapContainer_Type_spec, bases.get()));
#else
Py_INCREF(mutable_mapping.get());
_MessageMapContainer_Type.tp_base =
diff --git a/python/google/protobuf/pyext/map_container.h b/python/google/protobuf/pyext/map_container.h
index 111fafbf..7e77b027 100644
--- a/python/google/protobuf/pyext/map_container.h
+++ b/python/google/protobuf/pyext/map_container.h
@@ -120,6 +120,6 @@ extern PyObject* NewMessageMapContainer(
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MAP_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc
index b2984509..5d0e37fa 100644
--- a/python/google/protobuf/pyext/message.cc
+++ b/python/google/protobuf/pyext/message.cc
@@ -45,12 +45,11 @@
#ifndef Py_TYPE
#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type)
#endif
-#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/stubs/logging.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
-#include <google/protobuf/util/message_differencer.h>
+#include <google/protobuf/descriptor.pb.h>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
@@ -58,12 +57,16 @@
#include <google/protobuf/pyext/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
#include <google/protobuf/pyext/extension_dict.h>
-#include <google/protobuf/pyext/repeated_composite_container.h>
-#include <google/protobuf/pyext/repeated_scalar_container.h>
+#include <google/protobuf/pyext/field.h>
#include <google/protobuf/pyext/map_container.h>
#include <google/protobuf/pyext/message_factory.h>
+#include <google/protobuf/pyext/repeated_composite_container.h>
+#include <google/protobuf/pyext/repeated_scalar_container.h>
+#include <google/protobuf/pyext/unknown_fields.h>
#include <google/protobuf/pyext/safe_numerics.h>
#include <google/protobuf/pyext/scoped_pyobject_ptr.h>
+#include <google/protobuf/util/message_differencer.h>
+#include <google/protobuf/stubs/strutil.h>
#if PY_MAJOR_VERSION >= 3
#define PyInt_AsLong PyLong_AsLong
@@ -72,16 +75,19 @@
#define PyString_Check PyUnicode_Check
#define PyString_FromString PyUnicode_FromString
#define PyString_FromStringAndSize PyUnicode_FromStringAndSize
+ #define PyString_FromFormat PyUnicode_FromFormat
#if PY_VERSION_HEX < 0x03030000
#error "Python 3.0 - 3.2 are not supported."
#else
#define PyString_AsString(ob) \
(PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob))
- #define PyString_AsStringAndSize(ob, charpp, sizep) \
- (PyUnicode_Check(ob)? \
- ((*(charpp) = const_cast<char*>(PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL? -1: 0): \
- PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
- #endif
+#define PyString_AsStringAndSize(ob, charpp, sizep) \
+ (PyUnicode_Check(ob) ? ((*(charpp) = const_cast<char*>( \
+ PyUnicode_AsUTF8AndSize(ob, (sizep)))) == NULL \
+ ? -1 \
+ : 0) \
+ : PyBytes_AsStringAndSize(ob, (charpp), (sizep)))
+#endif
#endif
namespace google {
@@ -99,44 +105,27 @@ namespace message_meta {
static int InsertEmptyWeakref(PyTypeObject* base);
namespace {
-// Copied oveer from internal 'google/protobuf/stubs/strutil.h'.
-inline void UpperString(string * s) {
+// Copied over from internal 'google/protobuf/stubs/strutil.h'.
+inline void LowerString(string * s) {
string::iterator end = s->end();
for (string::iterator i = s->begin(); i != end; ++i) {
- // toupper() changes based on locale. We don't want this!
- if ('a' <= *i && *i <= 'z') *i += 'A' - 'a';
+ // tolower() changes based on locale. We don't want this!
+ if ('A' <= *i && *i <= 'Z') *i += 'a' - 'A';
}
}
}
-// Add the number of a field descriptor to the containing message class.
-// Equivalent to:
-// _cls.<field>_FIELD_NUMBER = <number>
-static bool AddFieldNumberToClass(
- PyObject* cls, const FieldDescriptor* field_descriptor) {
- string constant_name = field_descriptor->name() + "_FIELD_NUMBER";
- UpperString(&constant_name);
- ScopedPyObjectPtr attr_name(PyString_FromStringAndSize(
- constant_name.c_str(), constant_name.size()));
- if (attr_name == NULL) {
- return false;
- }
- ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number()));
- if (number == NULL) {
- return false;
- }
- if (PyObject_SetAttr(cls, attr_name.get(), number.get()) == -1) {
- return false;
- }
- return true;
-}
-
-
// Finalize the creation of the Message class.
static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
// For each field set: cls.<field>_FIELD_NUMBER = <number>
for (int i = 0; i < descriptor->field_count(); ++i) {
- if (!AddFieldNumberToClass(cls, descriptor->field(i))) {
+ const FieldDescriptor* field_descriptor = descriptor->field(i);
+ ScopedPyObjectPtr property(NewFieldProperty(field_descriptor));
+ if (property == NULL) {
+ return -1;
+ }
+ if (PyObject_SetAttrString(cls, field_descriptor->name().c_str(),
+ property.get()) < 0) {
return -1;
}
}
@@ -182,7 +171,7 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
// <message descriptor>.extensions_by_name[name]
// which was defined previously.
for (int i = 0; i < descriptor->extension_count(); ++i) {
- const google::protobuf::FieldDescriptor* field = descriptor->extension(i);
+ const proto2::FieldDescriptor* field = descriptor->extension(i);
ScopedPyObjectPtr extension_field(PyFieldDescriptor_FromDescriptor(field));
if (extension_field == NULL) {
return -1;
@@ -193,11 +182,6 @@ static int AddDescriptors(PyObject* cls, const Descriptor* descriptor) {
cls, field->name().c_str(), extension_field.get()) == -1) {
return -1;
}
-
- // For each extension set cls.<extension name>_FIELD_NUMBER = <number>.
- if (!AddFieldNumberToClass(cls, field)) {
- return -1;
- }
}
return 0;
@@ -265,10 +249,10 @@ static PyObject* New(PyTypeObject* type,
PyObject* well_known_class = PyDict_GetItemString(
WKT_classes, message_descriptor->full_name().c_str());
if (well_known_class == NULL) {
- new_args.reset(Py_BuildValue("s(OO)O", name, &CMessage_Type,
+ new_args.reset(Py_BuildValue("s(OO)O", name, CMessage_Type,
PythonMessage_class, dict));
} else {
- new_args.reset(Py_BuildValue("s(OOO)O", name, &CMessage_Type,
+ new_args.reset(Py_BuildValue("s(OOO)O", name, CMessage_Type,
PythonMessage_class, well_known_class, dict));
}
@@ -285,7 +269,7 @@ static PyObject* New(PyTypeObject* type,
// Insert the empty weakref into the base classes.
if (InsertEmptyWeakref(
reinterpret_cast<PyTypeObject*>(PythonMessage_class)) < 0 ||
- InsertEmptyWeakref(&CMessage_Type) < 0) {
+ InsertEmptyWeakref(CMessage_Type) < 0) {
return NULL;
}
@@ -353,6 +337,13 @@ static int InsertEmptyWeakref(PyTypeObject *base_type) {
// The _extensions_by_name dictionary is built on every access.
// TODO(amauryfa): Migrate all users to pool.FindAllExtensions()
static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
+ if (self->message_descriptor == NULL) {
+ // This is the base Message object, simply raise AttributeError.
+ PyErr_SetString(PyExc_AttributeError,
+ "Base Message class has no DESCRIPTOR");
+ return NULL;
+ }
+
const PyDescriptorPool* pool = self->py_message_factory->pool;
std::vector<const FieldDescriptor*> extensions;
@@ -376,6 +367,13 @@ static PyObject* GetExtensionsByName(CMessageClass *self, void *closure) {
// The _extensions_by_number dictionary is built on every access.
// TODO(amauryfa): Migrate all users to pool.FindExtensionByNumber()
static PyObject* GetExtensionsByNumber(CMessageClass *self, void *closure) {
+ if (self->message_descriptor == NULL) {
+ // This is the base Message object, simply raise AttributeError.
+ PyErr_SetString(PyExc_AttributeError,
+ "Base Message class has no DESCRIPTOR");
+ return NULL;
+ }
+
const PyDescriptorPool* pool = self->py_message_factory->pool;
std::vector<const FieldDescriptor*> extensions;
@@ -405,9 +403,51 @@ static PyGetSetDef Getters[] = {
{NULL}
};
+// Compute some class attributes on the fly:
+// - All the _FIELD_NUMBER attributes, for all fields and nested extensions.
+// Returns a new reference, or NULL with an exception set.
+static PyObject* GetClassAttribute(CMessageClass *self, PyObject* name) {
+ char* attr;
+ Py_ssize_t attr_size;
+ static const char kSuffix[] = "_FIELD_NUMBER";
+ if (PyString_AsStringAndSize(name, &attr, &attr_size) >= 0 &&
+ strings::EndsWith(StringPiece(attr, attr_size), kSuffix)) {
+ string field_name(attr, attr_size - sizeof(kSuffix) + 1);
+ LowerString(&field_name);
+
+ // Try to find a field with the given name, without the suffix.
+ const FieldDescriptor* field =
+ self->message_descriptor->FindFieldByLowercaseName(field_name);
+ if (!field) {
+ // Search nested extensions as well.
+ field =
+ self->message_descriptor->FindExtensionByLowercaseName(field_name);
+ }
+ if (field) {
+ return PyInt_FromLong(field->number());
+ }
+ }
+ PyErr_SetObject(PyExc_AttributeError, name);
+ return NULL;
+}
+
+static PyObject* GetAttr(CMessageClass* self, PyObject* name) {
+ PyObject* result = CMessageClass_Type->tp_base->tp_getattro(
+ reinterpret_cast<PyObject*>(self), name);
+ if (result != NULL) {
+ return result;
+ }
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ return NULL;
+ }
+
+ PyErr_Clear();
+ return GetClassAttribute(self, name);
+}
+
} // namespace message_meta
-PyTypeObject CMessageClass_Type = {
+static PyTypeObject _CMessageClass_Type = {
PyVarObject_HEAD_INIT(&PyType_Type, 0)
FULL_MODULE_NAME ".MessageMeta", // tp_name
sizeof(CMessageClass), // tp_basicsize
@@ -424,7 +464,7 @@ PyTypeObject CMessageClass_Type = {
0, // tp_hash
0, // tp_call
0, // tp_str
- 0, // tp_getattro
+ (getattrofunc)message_meta::GetAttr, // tp_getattro
0, // tp_setattro
0, // tp_as_buffer
Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
@@ -447,9 +487,10 @@ PyTypeObject CMessageClass_Type = {
0, // tp_alloc
message_meta::New, // tp_new
};
+PyTypeObject* CMessageClass_Type = &_CMessageClass_Type;
static CMessageClass* CheckMessageClass(PyTypeObject* cls) {
- if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
PyErr_Format(PyExc_TypeError, "Class %s is not a Message", cls->tp_name);
return NULL;
}
@@ -487,10 +528,20 @@ struct ChildVisitor {
}
// Returns 0 on success, -1 on failure.
+ int VisitMapContainer(MapContainer* container) {
+ return 0;
+ }
+
+ // Returns 0 on success, -1 on failure.
int VisitCMessage(CMessage* cmessage,
const FieldDescriptor* field_descriptor) {
return 0;
}
+
+ // Returns 0 on success, -1 on failure.
+ int VisitUnknownFieldSet(PyUnknownFields* unknown_field_set) {
+ return 0;
+ }
};
// Apply a function to a composite field. Does nothing if child is of
@@ -538,34 +589,19 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) {
// Visit normal fields.
if (self->composite_fields) {
- // Never use self->message in this function, it may be already freed.
- const Descriptor* message_descriptor =
- GetMessageDescriptor(Py_TYPE(self));
- while (PyDict_Next(self->composite_fields, &pos, &key, &field)) {
- Py_ssize_t key_str_size;
- char *key_str_data;
- if (PyString_AsStringAndSize(key, &key_str_data, &key_str_size) != 0)
- return -1;
- const string key_str(key_str_data, key_str_size);
- const FieldDescriptor* descriptor =
- message_descriptor->FindFieldByName(key_str);
- if (descriptor != NULL) {
- if (VisitCompositeField(descriptor, field, visitor) == -1)
- return -1;
- }
+ for (CMessage::CompositeFieldsMap::iterator it =
+ self->composite_fields->begin();
+ it != self->composite_fields->end(); it++) {
+ const FieldDescriptor* descriptor = it->first;
+ PyObject* field = it->second;
+ if (VisitCompositeField(descriptor, field, visitor) == -1) return -1;
}
}
- // Visit extension fields.
- if (self->extensions != NULL) {
- pos = 0;
- while (PyDict_Next(self->extensions->values, &pos, &key, &field)) {
- const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key);
- if (descriptor == NULL)
- return -1;
- if (VisitCompositeField(descriptor, field, visitor) == -1)
- return -1;
- }
+ if (self->unknown_field_set) {
+ PyUnknownFields* unknown_field_set =
+ reinterpret_cast<PyUnknownFields*>(self->unknown_field_set);
+ visitor.VisitUnknownFieldSet(unknown_field_set);
}
return 0;
@@ -577,8 +613,12 @@ PyObject* EncodeError_class;
PyObject* DecodeError_class;
PyObject* PickleError_class;
-/* Is 64bit */
+// Format an error message for unexpected types.
+// Always return with an exception set.
void FormatTypeError(PyObject* arg, char* expected_types) {
+ // This function is often called with an exception set.
+ // Clear it to call PyObject_Repr() in good conditions.
+ PyErr_Clear();
PyObject* repr = PyObject_Repr(arg);
if (repr) {
PyErr_Format(PyExc_TypeError,
@@ -859,7 +899,7 @@ bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor,
namespace cmessage {
PyMessageFactory* GetFactoryForMessage(CMessage* message) {
- GOOGLE_DCHECK(PyObject_TypeCheck(message, &CMessage_Type));
+ GOOGLE_DCHECK(PyObject_TypeCheck(message, CMessage_Type));
return reinterpret_cast<CMessageClass*>(Py_TYPE(message))->py_message_factory;
}
@@ -883,22 +923,20 @@ static int MaybeReleaseOverlappingOneofField(
// Non-message fields don't need to be released.
return 0;
}
- const char* field_name = existing_field->name().c_str();
- PyObject* child_message = cmessage->composite_fields ?
- PyDict_GetItemString(cmessage->composite_fields, field_name) : NULL;
- if (child_message == NULL) {
- // No python reference to this field so no need to release.
- return 0;
- }
-
- if (InternalReleaseFieldByDescriptor(
- cmessage, existing_field, child_message) < 0) {
- return -1;
+ if (cmessage->composite_fields) {
+ CMessage::CompositeFieldsMap::iterator iterator =
+ cmessage->composite_fields->find(existing_field);
+ if (iterator != cmessage->composite_fields->end()) {
+ if (InternalReleaseFieldByDescriptor(cmessage, existing_field,
+ iterator->second) < 0) {
+ return -1;
+ }
+ Py_DECREF(iterator->second);
+ cmessage->composite_fields->erase(iterator);
+ }
}
- return PyDict_DelItemString(cmessage->composite_fields, field_name);
-#else
- return 0;
#endif
+ return 0;
}
// ---------------------------------------------------------------------
@@ -937,10 +975,49 @@ struct FixupMessageReference : public ChildVisitor {
return 0;
}
+ int VisitUnknownFieldSet(PyUnknownFields* unknown_field_set) {
+ const Reflection* reflection = message_->GetReflection();
+ unknown_field_set->fields = &reflection->GetUnknownFields(*message_);
+ return 0;
+ }
+
private:
Message* message_;
};
+// After a Merge, visit every sub-message that was read-only, and
+// eventually update their pointer if the Merge operation modified them.
+struct FixupMessageAfterMerge : public FixupMessageReference {
+ explicit FixupMessageAfterMerge(CMessage* parent) :
+ FixupMessageReference(parent->message),
+ parent_cmessage(parent), message(parent->message) {}
+
+ int VisitCMessage(CMessage* cmessage,
+ const FieldDescriptor* field_descriptor) {
+ if (cmessage->read_only == false) {
+ return 0;
+ }
+ if (message->GetReflection()->HasField(*message, field_descriptor)) {
+ Message* mutable_message = GetMutableMessage(
+ parent_cmessage, field_descriptor);
+ if (mutable_message == NULL) {
+ return -1;
+ }
+ cmessage->message = mutable_message;
+ cmessage->read_only = false;
+ if (ForEachCompositeField(
+ cmessage, FixupMessageAfterMerge(cmessage)) == -1) {
+ return -1;
+ }
+ }
+ return 0;
+ }
+
+ private:
+ CMessage* parent_cmessage;
+ Message* message;
+};
+
int AssureWritable(CMessage* self) {
if (self == NULL || !self->read_only) {
return 0;
@@ -974,10 +1051,8 @@ int AssureWritable(CMessage* self) {
// When a CMessage is made writable its Message pointer is updated
// to point to a new mutable Message. When that happens we need to
// update any references to the old, read-only CMessage. There are
- // four places such references occur: RepeatedScalarContainer,
- // RepeatedCompositeContainer, MapContainer, and ExtensionDict.
- if (self->extensions != NULL)
- self->extensions->message = self->message;
+ // three places such references occur: RepeatedScalarContainer,
+ // RepeatedCompositeContainer, and MapContainer.
if (ForEachCompositeField(self, FixupMessageReference(self->message)) == -1)
return -1;
@@ -986,27 +1061,6 @@ int AssureWritable(CMessage* self) {
// --- Globals:
-// Retrieve a C++ FieldDescriptor for a message attribute.
-// The C++ message must be valid.
-// TODO(amauryfa): This function should stay internal, because exception
-// handling is not consistent.
-static const FieldDescriptor* GetFieldDescriptor(
- CMessage* self, PyObject* name) {
- const Descriptor *message_descriptor = self->message->GetDescriptor();
- char* field_name;
- Py_ssize_t size;
- if (PyString_AsStringAndSize(name, &field_name, &size) < 0) {
- return NULL;
- }
- const FieldDescriptor *field_descriptor =
- message_descriptor->FindFieldByName(string(field_name, size));
- if (field_descriptor == NULL) {
- // Note: No exception is set!
- return NULL;
- }
- return field_descriptor;
-}
-
// Retrieve a C++ FieldDescriptor for an extension handle.
const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) {
ScopedPyObjectPtr cdescriptor;
@@ -1038,7 +1092,7 @@ static PyObject* GetIntegerEnumValue(const FieldDescriptor& descriptor,
const EnumValueDescriptor* enum_value_descriptor =
enum_descriptor->FindValueByName(string(enum_label, size));
if (enum_value_descriptor == NULL) {
- PyErr_SetString(PyExc_ValueError, "unknown enum label");
+ PyErr_Format(PyExc_ValueError, "unknown enum label \"%s\"", enum_label);
return NULL;
}
return PyInt_FromLong(enum_value_descriptor->number());
@@ -1164,19 +1218,24 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
PyErr_SetString(PyExc_ValueError, "Field name must be a string");
return -1;
}
- const FieldDescriptor* descriptor = GetFieldDescriptor(self, name);
- if (descriptor == NULL) {
+ ScopedPyObjectPtr property(
+ PyObject_GetAttr(reinterpret_cast<PyObject*>(Py_TYPE(self)), name));
+ if (property == NULL ||
+ !PyObject_TypeCheck(property.get(), CFieldProperty_Type)) {
PyErr_Format(PyExc_ValueError, "Protocol message %s has no \"%s\" field.",
self->message->GetDescriptor()->name().c_str(),
PyString_AsString(name));
return -1;
}
+ const FieldDescriptor* descriptor =
+ reinterpret_cast<PyMessageFieldProperty*>(property.get())
+ ->field_descriptor;
if (value == Py_None) {
// field=None is the same as no field at all.
continue;
}
if (descriptor->is_map()) {
- ScopedPyObjectPtr map(GetAttr(reinterpret_cast<PyObject*>(self), name));
+ ScopedPyObjectPtr map(GetFieldValue(self, descriptor));
const FieldDescriptor* value_descriptor =
descriptor->message_type()->FindFieldByName("value");
if (value_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
@@ -1204,8 +1263,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
}
}
} else if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
- ScopedPyObjectPtr container(
- GetAttr(reinterpret_cast<PyObject*>(self), name));
+ ScopedPyObjectPtr container(GetFieldValue(self, descriptor));
if (container == NULL) {
return -1;
}
@@ -1272,8 +1330,7 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
}
}
} else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- ScopedPyObjectPtr message(
- GetAttr(reinterpret_cast<PyObject*>(self), name));
+ ScopedPyObjectPtr message(GetFieldValue(self, descriptor));
if (message == NULL) {
return -1;
}
@@ -1297,9 +1354,9 @@ int InitAttributes(CMessage* self, PyObject* args, PyObject* kwargs) {
if (new_val == NULL) {
return -1;
}
+ value = new_val.get();
}
- if (SetAttr(reinterpret_cast<PyObject*>(self), name,
- (new_val.get() == NULL) ? value : new_val.get()) < 0) {
+ if (SetFieldValue(self, descriptor, value) < 0) {
return -1;
}
}
@@ -1322,10 +1379,11 @@ CMessage* NewEmptyMessage(CMessageClass* type) {
self->parent = NULL;
self->parent_field_descriptor = NULL;
self->read_only = false;
- self->extensions = NULL;
self->composite_fields = NULL;
+ self->unknown_field_set = NULL;
+
return self;
}
@@ -1408,12 +1466,20 @@ static void Dealloc(CMessage* self) {
}
// Null out all weak references from children to this message.
GOOGLE_CHECK_EQ(0, ForEachCompositeField(self, ClearWeakReferences()));
- if (self->extensions) {
- self->extensions->parent = NULL;
- }
- Py_CLEAR(self->extensions);
- Py_CLEAR(self->composite_fields);
+ if (self->composite_fields) {
+ for (CMessage::CompositeFieldsMap::iterator it =
+ self->composite_fields->begin();
+ it != self->composite_fields->end(); it++) {
+ Py_DECREF(it->second);
+ }
+ delete self->composite_fields;
+ }
+ if (self->unknown_field_set) {
+ unknown_fields::Clear(
+ reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
+ Py_CLEAR(self->unknown_field_set);
+ }
self->owner.~ThreadUnsafeSharedPtr<Message>();
Py_TYPE(self)->tp_free(reinterpret_cast<PyObject*>(self));
}
@@ -1564,13 +1630,16 @@ PyObject* ClearExtension(CMessage* self, PyObject* extension) {
if (descriptor == NULL) {
return NULL;
}
- if (self->extensions != NULL) {
- PyObject* value = PyDict_GetItem(self->extensions->values, extension);
- if (value != NULL) {
- if (InternalReleaseFieldByDescriptor(self, descriptor, value) < 0) {
+ if (self->composite_fields != NULL) {
+ CMessage::CompositeFieldsMap::iterator iterator =
+ self->composite_fields->find(descriptor);
+ if (iterator != self->composite_fields->end()) {
+ if (InternalReleaseFieldByDescriptor(self, descriptor,
+ iterator->second) < 0) {
return NULL;
}
- PyDict_DelItem(self->extensions->values, extension);
+ Py_DECREF(iterator->second);
+ self->composite_fields->erase(iterator);
}
}
return ClearFieldByDescriptor(self, descriptor);
@@ -1770,14 +1839,16 @@ PyObject* ClearField(CMessage* self, PyObject* arg) {
arg = arg_in_oneof.get();
}
- // Release the field if it exists in the dict of composite fields.
if (self->composite_fields) {
- PyObject* value = PyDict_GetItem(self->composite_fields, arg);
- if (value != NULL) {
- if (InternalReleaseFieldByDescriptor(self, field_descriptor, value) < 0) {
+ CMessage::CompositeFieldsMap::iterator iterator =
+ self->composite_fields->find(field_descriptor);
+ if (iterator != self->composite_fields->end()) {
+ if (InternalReleaseFieldByDescriptor(self, field_descriptor,
+ iterator->second) < 0) {
return NULL;
}
- PyDict_DelItem(self->composite_fields, arg);
+ Py_DECREF(iterator->second);
+ self->composite_fields->erase(iterator);
}
}
return ClearFieldByDescriptor(self, field_descriptor);
@@ -1787,9 +1858,18 @@ PyObject* Clear(CMessage* self) {
AssureWritable(self);
if (ForEachCompositeField(self, ReleaseChild(self)) == -1)
return NULL;
- Py_CLEAR(self->extensions);
if (self->composite_fields) {
- PyDict_Clear(self->composite_fields);
+ for (CMessage::CompositeFieldsMap::iterator it =
+ self->composite_fields->begin();
+ it != self->composite_fields->end(); it++) {
+ Py_DECREF(it->second);
+ }
+ self->composite_fields->clear();
+ }
+ if (self->unknown_field_set) {
+ unknown_fields::Clear(
+ reinterpret_cast<PyUnknownFields*>(self->unknown_field_set));
+ Py_CLEAR(self->unknown_field_set);
}
self->message->Clear();
Py_RETURN_NONE;
@@ -1946,7 +2026,7 @@ static PyObject* ToStr(CMessage* self) {
PyObject* MergeFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ if (!PyObject_TypeCheck(arg, CMessage_Type)) {
PyErr_Format(PyExc_TypeError,
"Parameter to MergeFrom() must be instance of same class: "
"expected %s got %s.",
@@ -1967,18 +2047,19 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg) {
}
AssureWritable(self);
- // TODO(tibell): Message::MergeFrom might turn some child Messages
- // into mutable messages, invalidating the message field in the
- // corresponding CMessages. We should run a FixupMessageReferences
- // pass here.
-
self->message->MergeFrom(*other_message->message);
+ // Child message might be lazily created before MergeFrom. Make sure they
+ // are mutable at this point if child messages are really created.
+ if (ForEachCompositeField(self, FixupMessageAfterMerge(self)) == -1) {
+ return NULL;
+ }
+
Py_RETURN_NONE;
}
static PyObject* CopyFrom(CMessage* self, PyObject* arg) {
CMessage* other_message;
- if (!PyObject_TypeCheck(arg, &CMessage_Type)) {
+ if (!PyObject_TypeCheck(arg, CMessage_Type)) {
PyErr_Format(PyExc_TypeError,
"Parameter to CopyFrom() must be instance of same class: "
"expected %s got %s.",
@@ -2050,6 +2131,7 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
}
AssureWritable(self);
+
io::CodedInputStream input(
reinterpret_cast<const uint8*>(data), data_length);
if (allow_oversize_protos) {
@@ -2058,6 +2140,12 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) {
PyMessageFactory* factory = GetFactoryForMessage(self);
input.SetExtensionRegistry(factory->pool->pool, factory->message_factory);
bool success = self->message->MergePartialFromCodedStream(&input);
+ // Child message might be lazily created before MergeFrom. Make sure they
+ // are mutable at this point if child messages are really created.
+ if (ForEachCompositeField(self, FixupMessageAfterMerge(self)) == -1) {
+ return NULL;
+ }
+
if (success) {
if (!input.ConsumedEntireMessage()) {
// TODO(jieluo): Raise error and return NULL instead.
@@ -2088,7 +2176,7 @@ PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle) {
if (descriptor == NULL) {
return NULL;
}
- if (!PyObject_TypeCheck(cls, &CMessageClass_Type)) {
+ if (!PyObject_TypeCheck(cls, CMessageClass_Type)) {
PyErr_Format(PyExc_TypeError, "Expected a message class, got %s",
cls->ob_type->tp_name);
return NULL;
@@ -2192,23 +2280,15 @@ static PyObject* ListFields(CMessage* self) {
PyTuple_SET_ITEM(t.get(), 1, extension);
} else {
// Normal field
- const string& field_name = fields[i]->name();
- ScopedPyObjectPtr py_field_name(PyString_FromStringAndSize(
- field_name.c_str(), field_name.length()));
- if (py_field_name == NULL) {
- PyErr_SetString(PyExc_ValueError, "bad string");
- return NULL;
- }
ScopedPyObjectPtr field_descriptor(
PyFieldDescriptor_FromDescriptor(fields[i]));
if (field_descriptor == NULL) {
return NULL;
}
- PyObject* field_value =
- GetAttr(reinterpret_cast<PyObject*>(self), py_field_name.get());
+ PyObject* field_value = GetFieldValue(self, fields[i]);
if (field_value == NULL) {
- PyErr_SetObject(PyExc_ValueError, py_field_name.get());
+ PyErr_SetString(PyExc_ValueError, fields[i]->name().c_str());
return NULL;
}
PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release());
@@ -2261,10 +2341,10 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
}
bool equals = true;
// If other is not a message, it cannot be equal.
- if (!PyObject_TypeCheck(other, &CMessage_Type)) {
+ if (!PyObject_TypeCheck(other, CMessage_Type)) {
equals = false;
}
- const google::protobuf::Message* other_message =
+ const proto2::Message* other_message =
reinterpret_cast<CMessage*>(other)->message;
// If messages don't have the same descriptors, they are not equal.
if (equals &&
@@ -2272,11 +2352,12 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) {
equals = false;
}
// Check the message contents.
- if (equals && !google::protobuf::util::MessageDifferencer::Equals(
+ if (equals && !proto2::util::MessageDifferencer::Equals(
*self->message,
*reinterpret_cast<CMessage*>(other)->message)) {
equals = false;
}
+
if (equals ^ (opid == Py_EQ)) {
Py_RETURN_FALSE;
} else {
@@ -2498,7 +2579,7 @@ PyObject* DeepCopy(CMessage* self, PyObject* arg) {
if (clone == NULL) {
return NULL;
}
- if (!PyObject_TypeCheck(clone, &CMessage_Type)) {
+ if (!PyObject_TypeCheck(clone, CMessage_Type)) {
Py_DECREF(clone);
return NULL;
}
@@ -2592,26 +2673,29 @@ PyObject* _CheckCalledFromGeneratedFile(PyObject* unused,
}
static PyObject* GetExtensionDict(CMessage* self, void *closure) {
- if (self->extensions) {
- Py_INCREF(self->extensions);
- return reinterpret_cast<PyObject*>(self->extensions);
- }
-
// If there are extension_ranges, the message is "extendable". Allocate a
// dictionary to store the extension fields.
const Descriptor* descriptor = GetMessageDescriptor(Py_TYPE(self));
- if (descriptor->extension_range_count() > 0) {
- ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
- if (extension_dict == NULL) {
- return NULL;
- }
- self->extensions = extension_dict;
- Py_INCREF(self->extensions);
- return reinterpret_cast<PyObject*>(self->extensions);
+ if (!descriptor->extension_range_count()) {
+ PyErr_SetNone(PyExc_AttributeError);
+ return NULL;
+ }
+ if (!self->composite_fields) {
+ self->composite_fields = new CMessage::CompositeFieldsMap();
}
+ if (!self->composite_fields) {
+ return NULL;
+ }
+ ExtensionDict* extension_dict = extension_dict::NewExtensionDict(self);
+ return reinterpret_cast<PyObject*>(extension_dict);
+}
- PyErr_SetNone(PyExc_AttributeError);
- return NULL;
+static PyObject* UnknownFieldSet(CMessage* self) {
+ if (self->unknown_field_set == NULL) {
+ self->unknown_field_set = unknown_fields::NewPyUnknownFields(self);
+ }
+ Py_INCREF(self->unknown_field_set);
+ return self->unknown_field_set;
}
static PyObject* GetExtensionsByName(CMessage *self, void *closure) {
@@ -2682,6 +2766,8 @@ static PyMethodDef Methods[] = {
"Serializes the message to a string, only for initialized messages." },
{ "SetInParent", (PyCFunction)SetInParent, METH_NOARGS,
"Sets the has bit of the given field in its parent message." },
+ { "UnknownFields", (PyCFunction)UnknownFieldSet, METH_NOARGS,
+ "Parse unknown field set"},
{ "WhichOneof", (PyCFunction)WhichOneof, METH_O,
"Returns the name of the field set inside a oneof, "
"or None if no field is set." },
@@ -2693,30 +2779,53 @@ static PyMethodDef Methods[] = {
{ NULL, NULL}
};
-static bool SetCompositeField(
- CMessage* self, PyObject* name, PyObject* value) {
+static bool SetCompositeField(CMessage* self, const FieldDescriptor* field,
+ PyObject* value) {
if (self->composite_fields == NULL) {
- self->composite_fields = PyDict_New();
- if (self->composite_fields == NULL) {
- return false;
- }
+ self->composite_fields = new CMessage::CompositeFieldsMap();
}
- return PyDict_SetItem(self->composite_fields, name, value) == 0;
+ Py_INCREF(value);
+ Py_XDECREF((*self->composite_fields)[field]);
+ (*self->composite_fields)[field] = value;
+ return true;
}
PyObject* GetAttr(PyObject* pself, PyObject* name) {
CMessage* self = reinterpret_cast<CMessage*>(pself);
- PyObject* value = self->composite_fields ?
- PyDict_GetItem(self->composite_fields, name) : NULL;
- if (value != NULL) {
- Py_INCREF(value);
- return value;
+ PyObject* result = PyObject_GenericGetAttr(
+ reinterpret_cast<PyObject*>(self), name);
+ if (result != NULL) {
+ return result;
+ }
+ if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
+ return NULL;
}
- const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name);
- if (field_descriptor == NULL) {
- return CMessage_Type.tp_base->tp_getattro(
- reinterpret_cast<PyObject*>(self), name);
+ PyErr_Clear();
+ return message_meta::GetClassAttribute(
+ CheckMessageClass(Py_TYPE(self)), name);
+}
+
+PyObject* GetFieldValue(CMessage* self,
+ const FieldDescriptor* field_descriptor) {
+ if (self->composite_fields) {
+ CMessage::CompositeFieldsMap::iterator it =
+ self->composite_fields->find(field_descriptor);
+ if (it != self->composite_fields->end()) {
+ PyObject* value = it->second;
+ Py_INCREF(value);
+ return value;
+ }
+ }
+
+ const Descriptor* message_descriptor =
+ (reinterpret_cast<CMessageClass*>(Py_TYPE(self)))->message_descriptor;
+ if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
+ PyErr_Format(PyExc_TypeError,
+ "descriptor to field '%s' doesn't apply to '%s' object",
+ field_descriptor->full_name().c_str(),
+ Py_TYPE(self)->tp_name);
+ return NULL;
}
if (field_descriptor->is_map()) {
@@ -2737,7 +2846,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) {
if (py_container == NULL) {
return NULL;
}
- if (!SetCompositeField(self, name, py_container)) {
+ if (!SetCompositeField(self, field_descriptor, py_container)) {
Py_DECREF(py_container);
return NULL;
}
@@ -2761,7 +2870,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) {
if (py_container == NULL) {
return NULL;
}
- if (!SetCompositeField(self, name, py_container)) {
+ if (!SetCompositeField(self, field_descriptor, py_container)) {
Py_DECREF(py_container);
return NULL;
}
@@ -2773,7 +2882,7 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) {
if (sub_message == NULL) {
return NULL;
}
- if (!SetCompositeField(self, name, sub_message)) {
+ if (!SetCompositeField(self, field_descriptor, sub_message)) {
Py_DECREF(sub_message);
return NULL;
}
@@ -2783,44 +2892,35 @@ PyObject* GetAttr(PyObject* pself, PyObject* name) {
return InternalGetScalar(self->message, field_descriptor);
}
-int SetAttr(PyObject* pself, PyObject* name, PyObject* value) {
- CMessage* self = reinterpret_cast<CMessage*>(pself);
- if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) {
- PyErr_SetString(PyExc_TypeError, "Can't set composite field");
+int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
+ PyObject* value) {
+ if (self->message->GetDescriptor() != field_descriptor->containing_type()) {
+ PyErr_Format(PyExc_TypeError,
+ "descriptor to field '%s' doesn't apply to '%s' object",
+ field_descriptor->full_name().c_str(),
+ Py_TYPE(self)->tp_name);
return -1;
- }
-
- const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name);
- if (field_descriptor != NULL) {
+ } else if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
+ PyErr_Format(PyExc_AttributeError,
+ "Assignment not allowed to repeated "
+ "field \"%s\" in protocol message object.",
+ field_descriptor->name().c_str());
+ return -1;
+ } else if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
+ PyErr_Format(PyExc_AttributeError,
+ "Assignment not allowed to "
+ "field \"%s\" in protocol message object.",
+ field_descriptor->name().c_str());
+ return -1;
+ } else {
AssureWritable(self);
- if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) {
- PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated "
- "field \"%s\" in protocol message object.",
- field_descriptor->name().c_str());
- return -1;
- } else {
- if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) {
- PyErr_Format(PyExc_AttributeError, "Assignment not allowed to "
- "field \"%s\" in protocol message object.",
- field_descriptor->name().c_str());
- return -1;
- } else {
- return InternalSetScalar(self, field_descriptor, value);
- }
- }
+ return InternalSetScalar(self, field_descriptor, value);
}
-
- PyErr_Format(PyExc_AttributeError,
- "Assignment not allowed "
- "(no field \"%s\" in protocol message object).",
- PyString_AsString(name));
- return -1;
}
-
} // namespace cmessage
-PyTypeObject CMessage_Type = {
- PyVarObject_HEAD_INIT(&CMessageClass_Type, 0)
+static CMessageClass _CMessage_Type = { { {
+ PyVarObject_HEAD_INIT(&_CMessageClass_Type, 0)
FULL_MODULE_NAME ".CMessage", // tp_name
sizeof(CMessage), // tp_basicsize
0, // tp_itemsize
@@ -2837,9 +2937,10 @@ PyTypeObject CMessage_Type = {
0, // tp_call
(reprfunc)cmessage::ToStr, // tp_str
cmessage::GetAttr, // tp_getattro
- cmessage::SetAttr, // tp_setattro
+ 0, // tp_setattro
0, // tp_as_buffer
- Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE, // tp_flags
+ Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE
+ | Py_TPFLAGS_HAVE_VERSION_TAG, // tp_flags
"A ProtocolMessage", // tp_doc
0, // tp_traverse
0, // tp_clear
@@ -2858,7 +2959,8 @@ PyTypeObject CMessage_Type = {
(initproc)cmessage::Init, // tp_init
0, // tp_alloc
cmessage::New, // tp_new
-};
+} } };
+PyTypeObject* CMessage_Type = &_CMessage_Type.super.ht_type;
// --- Exposing the C proto living inside Python proto to C code:
@@ -2884,7 +2986,7 @@ static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) {
}
const Message* PyMessage_GetMessagePointer(PyObject* msg) {
- if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ if (!PyObject_TypeCheck(msg, CMessage_Type)) {
PyErr_SetString(PyExc_TypeError, "Not a Message instance");
return NULL;
}
@@ -2893,15 +2995,14 @@ const Message* PyMessage_GetMessagePointer(PyObject* msg) {
}
Message* PyMessage_GetMutableMessagePointer(PyObject* msg) {
- if (!PyObject_TypeCheck(msg, &CMessage_Type)) {
+ if (!PyObject_TypeCheck(msg, CMessage_Type)) {
PyErr_SetString(PyExc_TypeError, "Not a Message instance");
return NULL;
}
+
CMessage* cmsg = reinterpret_cast<CMessage*>(msg);
- if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) ||
- (cmsg->extensions != NULL &&
- PyDict_Size(cmsg->extensions->values) != 0)) {
+ if (cmsg->composite_fields && !cmsg->composite_fields->empty()) {
// There is currently no way of accurately syncing arbitrary changes to
// the underlying C++ message back to the CMessage (e.g. removed repeated
// composite containers). We only allow direct mutation of the underlying
@@ -2945,22 +3046,29 @@ bool InitProto2MessageModule(PyObject *m) {
// Initialize constants defined in this file.
InitGlobals();
- CMessageClass_Type.tp_base = &PyType_Type;
- if (PyType_Ready(&CMessageClass_Type) < 0) {
+ CMessageClass_Type->tp_base = &PyType_Type;
+ if (PyType_Ready(CMessageClass_Type) < 0) {
return false;
}
PyModule_AddObject(m, "MessageMeta",
- reinterpret_cast<PyObject*>(&CMessageClass_Type));
+ reinterpret_cast<PyObject*>(CMessageClass_Type));
- if (PyType_Ready(&CMessage_Type) < 0) {
+ if (PyType_Ready(CMessage_Type) < 0) {
+ return false;
+ }
+ if (PyType_Ready(CFieldProperty_Type) < 0) {
return false;
}
// DESCRIPTOR is set on each protocol buffer message class elsewhere, but set
// it here as well to document that subclasses need to set it.
- PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None);
+ PyDict_SetItem(CMessage_Type->tp_dict, kDESCRIPTOR, Py_None);
+ // Invalidate any cached data for the CMessage type.
+ // This call is necessary to correctly support Py_TPFLAGS_HAVE_VERSION_TAG,
+ // after we have modified CMessage_Type.tp_dict.
+ PyType_Modified(CMessage_Type);
- PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(&CMessage_Type));
+ PyModule_AddObject(m, "Message", reinterpret_cast<PyObject*>(CMessage_Type));
// Initialize Repeated container types.
{
@@ -3003,6 +3111,22 @@ bool InitProto2MessageModule(PyObject *m) {
}
}
+ if (PyType_Ready(&PyUnknownFields_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "UnknownFieldSet",
+ reinterpret_cast<PyObject*>(
+ &PyUnknownFields_Type));
+
+ if (PyType_Ready(&PyUnknownFieldRef_Type) < 0) {
+ return false;
+ }
+
+ PyModule_AddObject(m, "UnknownField",
+ reinterpret_cast<PyObject*>(
+ &PyUnknownFieldRef_Type));
+
// Initialize Map container types.
if (!InitMapContainers()) {
return false;
diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h
index d754e62a..e729e448 100644
--- a/python/google/protobuf/pyext/message.h
+++ b/python/google/protobuf/pyext/message.h
@@ -38,6 +38,7 @@
#include <memory>
#include <string>
+#include <hash_map>
#include <google/protobuf/stubs/common.h>
#include <google/protobuf/pyext/thread_unsafe_shared_ptr.h>
@@ -96,26 +97,25 @@ typedef struct CMessage {
// made writable, at which point this field is set to false.
bool read_only;
- // A reference to a Python dictionary containing CMessage,
+ // A mapping indexed by field, containing CMessage,
// RepeatedCompositeContainer, and RepeatedScalarContainer
// objects. Used as a cache to make sure we don't have to make a
// Python wrapper for the C++ Message objects on every access, or
// deal with the synchronization nightmare that could create.
- PyObject* composite_fields;
+ // Also cache extension fields.
+ // The FieldDescriptor is owned by the message's pool; PyObject references
+ // are owned.
+ typedef __gnu_cxx::hash_map<const FieldDescriptor*, PyObject*>
+ CompositeFieldsMap;
+ CompositeFieldsMap* composite_fields;
- // A reference to the dictionary containing the message's extensions.
- // Similar to composite_fields, acting as a cache, but also contains the
- // required extension dict logic.
- ExtensionDict* extensions;
+ // A reference to PyUnknownFields.
+ PyObject* unknown_field_set;
// Implements the "weakref" protocol for this object.
PyObject* weakreflist;
} CMessage;
-extern PyTypeObject CMessageClass_Type;
-extern PyTypeObject CMessage_Type;
-
-
// The (meta) type of all Messages classes.
// It allows us to cache some C++ pointers in the class object itself, they are
// faster to extract than from the type's dictionary.
@@ -142,6 +142,8 @@ struct CMessageClass {
}
};
+extern PyTypeObject* CMessageClass_Type;
+extern PyTypeObject* CMessage_Type;
namespace cmessage {
@@ -235,15 +237,13 @@ PyObject* MergeFrom(CMessage* self, PyObject* arg);
// has been registered with the same field number on this class.
PyObject* RegisterExtension(PyObject* cls, PyObject* extension_handle);
-// Retrieves an attribute named 'name' from 'self', which is interpreted as a
-// CMessage. Returns the attribute value on success, or null on failure.
-//
-// Returns a new reference.
-PyObject* GetAttr(PyObject* self, PyObject* name);
-
-// Set the value of the attribute named 'name', for 'self', which is interpreted
-// as a CMessage, to the value 'value'. Returns -1 on failure.
-int SetAttr(PyObject* self, PyObject* name, PyObject* value);
+// Get a field from a message.
+PyObject* GetFieldValue(CMessage* self,
+ const FieldDescriptor* field_descriptor);
+// Sets the value of a scalar field in a message.
+// On error, return -1 with an extension set.
+int SetFieldValue(CMessage* self, const FieldDescriptor* field_descriptor,
+ PyObject* value);
PyObject* FindInitializationErrors(CMessage* self);
@@ -357,6 +357,6 @@ extern template bool CheckAndGetInteger<uint64>(PyObject*, uint64*);
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_H__
diff --git a/python/google/protobuf/pyext/message_factory.cc b/python/google/protobuf/pyext/message_factory.cc
index bacc76a6..efaa2617 100644
--- a/python/google/protobuf/pyext/message_factory.cc
+++ b/python/google/protobuf/pyext/message_factory.cc
@@ -28,6 +28,8 @@
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#include <unordered_map>
+
#include <Python.h>
#include <google/protobuf/dynamic_message.h>
@@ -137,7 +139,7 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
// This is the same implementation as MessageFactory.GetPrototype().
// Do not create a MessageClass that already exists.
- hash_map<const Descriptor*, CMessageClass*>::iterator it =
+ std::unordered_map<const Descriptor*, CMessageClass*>::iterator it =
self->classes_by_descriptor->find(descriptor);
if (it != self->classes_by_descriptor->end()) {
Py_INCREF(it->second);
@@ -158,7 +160,7 @@ CMessageClass* GetOrCreateMessageClass(PyMessageFactory* self,
return NULL;
}
ScopedPyObjectPtr message_class(PyObject_CallObject(
- reinterpret_cast<PyObject*>(&CMessageClass_Type), args.get()));
+ reinterpret_cast<PyObject*>(CMessageClass_Type), args.get()));
if (message_class == NULL) {
return NULL;
}
diff --git a/python/google/protobuf/pyext/message_factory.h b/python/google/protobuf/pyext/message_factory.h
index 36092f7e..06444b0a 100644
--- a/python/google/protobuf/pyext/message_factory.h
+++ b/python/google/protobuf/pyext/message_factory.h
@@ -33,7 +33,7 @@
#include <Python.h>
-#include <google/protobuf/stubs/hash.h>
+#include <unordered_map>
#include <google/protobuf/descriptor.h>
#include <google/protobuf/pyext/descriptor_pool.h>
@@ -66,7 +66,8 @@ struct PyMessageFactory {
//
// Descriptor pointers stored here are owned by the DescriptorPool above.
// Python references to classes are owned by this PyDescriptorPool.
- typedef hash_map<const Descriptor*, CMessageClass*> ClassesByMessageMap;
+ typedef std::unordered_map<const Descriptor*, CMessageClass*>
+ ClassesByMessageMap;
ClassesByMessageMap* classes_by_descriptor;
};
@@ -98,6 +99,6 @@ bool InitMessageFactory();
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_MESSAGE_FACTORY_H__
diff --git a/python/google/protobuf/pyext/message_module.cc b/python/google/protobuf/pyext/message_module.cc
index 8c933866..29d56702 100644
--- a/python/google/protobuf/pyext/message_module.cc
+++ b/python/google/protobuf/pyext/message_module.cc
@@ -31,47 +31,24 @@
#include <Python.h>
#include <google/protobuf/pyext/message.h>
-#include <google/protobuf/proto_api.h>
+#include <google/protobuf/python/proto_api.h>
#include <google/protobuf/message_lite.h>
namespace {
// C++ API. Clients get at this via proto_api.h
-struct ApiImplementation : google::protobuf::python::PyProto_API {
- const google::protobuf::Message*
- GetMessagePointer(PyObject* msg) const override {
- return google::protobuf::python::PyMessage_GetMessagePointer(msg);
+struct ApiImplementation : proto2::python::PyProto_API {
+ const proto2::Message* GetMessagePointer(PyObject* msg) const override {
+ return proto2::python::PyMessage_GetMessagePointer(msg);
}
- google::protobuf::Message*
- GetMutableMessagePointer(PyObject* msg) const override {
- return google::protobuf::python::PyMessage_GetMutableMessagePointer(msg);
+ proto2::Message* GetMutableMessagePointer(PyObject* msg) const override {
+ return proto2::python::PyMessage_GetMutableMessagePointer(msg);
}
};
} // namespace
-static PyObject* GetPythonProto3PreserveUnknownsDefault(
- PyObject* /*m*/, PyObject* /*args*/) {
- if (google::protobuf::internal::GetProto3PreserveUnknownsDefault()) {
- Py_RETURN_TRUE;
- } else {
- Py_RETURN_FALSE;
- }
-}
-
-static PyObject* SetPythonProto3PreserveUnknownsDefault(
- PyObject* /*m*/, PyObject* arg) {
- if (!arg || !PyBool_Check(arg)) {
- PyErr_SetString(
- PyExc_TypeError,
- "Argument to SetPythonProto3PreserveUnknownsDefault must be boolean");
- return NULL;
- }
- google::protobuf::internal::SetProto3PreserveUnknownsDefault(PyObject_IsTrue(arg));
- Py_RETURN_NONE;
-}
-
static const char module_docstring[] =
"python-proto2 is a module that can be used to enhance proto2 Python API\n"
"performance.\n"
@@ -81,16 +58,9 @@ static const char module_docstring[] =
static PyMethodDef ModuleMethods[] = {
{"SetAllowOversizeProtos",
- (PyCFunction)google::protobuf::python::cmessage::SetAllowOversizeProtos,
+ (PyCFunction)proto2::python::cmessage::SetAllowOversizeProtos,
METH_O, "Enable/disable oversize proto parsing."},
// DO NOT USE: For migration and testing only.
- {"GetPythonProto3PreserveUnknownsDefault",
- (PyCFunction)GetPythonProto3PreserveUnknownsDefault,
- METH_NOARGS, "Get Proto3 preserve unknowns default."},
- // DO NOT USE: For migration and testing only.
- {"SetPythonProto3PreserveUnknownsDefault",
- (PyCFunction)SetPythonProto3PreserveUnknownsDefault,
- METH_O, "Enable/disable proto3 unknowns preservation."},
{ NULL, NULL}
};
@@ -113,35 +83,32 @@ static struct PyModuleDef _module = {
#define INITFUNC_ERRORVAL
#endif
-extern "C" {
- PyMODINIT_FUNC INITFUNC(void) {
- PyObject* m;
+PyMODINIT_FUNC INITFUNC() {
+ PyObject* m;
#if PY_MAJOR_VERSION >= 3
- m = PyModule_Create(&_module);
+ m = PyModule_Create(&_module);
#else
- m = Py_InitModule3("_message", ModuleMethods,
- module_docstring);
+ m = Py_InitModule3("_message", ModuleMethods, module_docstring);
#endif
- if (m == NULL) {
- return INITFUNC_ERRORVAL;
- }
+ if (m == NULL) {
+ return INITFUNC_ERRORVAL;
+ }
- if (!google::protobuf::python::InitProto2MessageModule(m)) {
- Py_DECREF(m);
- return INITFUNC_ERRORVAL;
- }
-
- // Adds the C++ API
- if (PyObject* api =
- PyCapsule_New(new ApiImplementation(),
- google::protobuf::python::PyProtoAPICapsuleName(), NULL)) {
- PyModule_AddObject(m, "proto_API", api);
- } else {
- return INITFUNC_ERRORVAL;
- }
+ if (!proto2::python::InitProto2MessageModule(m)) {
+ Py_DECREF(m);
+ return INITFUNC_ERRORVAL;
+ }
+
+ // Adds the C++ API
+ if (PyObject* api =
+ PyCapsule_New(new ApiImplementation(),
+ proto2::python::PyProtoAPICapsuleName(), NULL)) {
+ PyModule_AddObject(m, "proto_API", api);
+ } else {
+ return INITFUNC_ERRORVAL;
+ }
#if PY_MAJOR_VERSION >= 3
- return m;
+ return m;
#endif
- }
}
diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc
index 5874d5de..d6bc3d7b 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.cc
+++ b/python/google/protobuf/pyext/repeated_composite_container.cc
@@ -61,9 +61,9 @@ namespace repeated_composite_container {
// TODO(tibell): We might also want to check:
// GOOGLE_CHECK_NOTNULL((self)->owner.get());
-#define GOOGLE_CHECK_ATTACHED(self) \
- do { \
- GOOGLE_CHECK_NOTNULL((self)->message); \
+#define GOOGLE_CHECK_ATTACHED(self) \
+ do { \
+ GOOGLE_CHECK_NOTNULL((self)->message); \
GOOGLE_CHECK_NOTNULL((self)->parent_field_descriptor); \
} while (0);
@@ -152,6 +152,8 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self,
cmsg->message = sub_message;
cmsg->parent = self->parent;
if (cmessage::InitAttributes(cmsg, args, kwargs) < 0) {
+ message->GetReflection()->RemoveLast(
+ message, self->parent_field_descriptor);
Py_DECREF(cmsg);
return NULL;
}
@@ -210,7 +212,7 @@ PyObject* Extend(RepeatedCompositeContainer* self, PyObject* value) {
}
ScopedPyObjectPtr next;
while ((next.reset(PyIter_Next(iter.get()))) != NULL) {
- if (!PyObject_TypeCheck(next.get(), &CMessage_Type)) {
+ if (!PyObject_TypeCheck(next.get(), CMessage_Type)) {
PyErr_SetString(PyExc_TypeError, "Not a cmessage");
return NULL;
}
@@ -487,9 +489,9 @@ static PyObject* Pop(PyObject* pself, PyObject* args) {
void ReleaseLastTo(CMessage* parent,
const FieldDescriptor* field,
CMessage* target) {
- GOOGLE_CHECK_NOTNULL(parent);
- GOOGLE_CHECK_NOTNULL(field);
- GOOGLE_CHECK_NOTNULL(target);
+ GOOGLE_CHECK(parent != nullptr);
+ GOOGLE_CHECK(field != nullptr);
+ GOOGLE_CHECK(target != nullptr);
CMessage::OwnerRef released_message(
parent->message->GetReflection()->ReleaseLast(parent->message, field));
diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h
index e5e946aa..464699aa 100644
--- a/python/google/protobuf/pyext/repeated_composite_container.h
+++ b/python/google/protobuf/pyext/repeated_composite_container.h
@@ -161,6 +161,6 @@ void ReleaseLastTo(CMessage* parent,
} // namespace repeated_composite_container
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_COMPOSITE_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc
index de3b6e14..cdb64269 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.cc
+++ b/python/google/protobuf/pyext/repeated_scalar_container.cc
@@ -663,6 +663,10 @@ static PyObject* ToStr(PyObject* pself) {
return PyObject_Repr(list.get());
}
+static PyObject* MergeFrom(PyObject* pself, PyObject* arg) {
+ return Extend(reinterpret_cast<RepeatedScalarContainer*>(pself), arg);
+}
+
// The private constructor of RepeatedScalarContainer objects.
PyObject *NewContainer(
CMessage* parent, const FieldDescriptor* parent_field_descriptor) {
@@ -776,6 +780,8 @@ static PyMethodDef Methods[] = {
"Removes an object from the repeated container." },
{ "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS,
"Sorts the repeated container."},
+ { "MergeFrom", (PyCFunction)MergeFrom, METH_O,
+ "Merges a repeated container into the current container." },
{ NULL, NULL }
};
diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h
index 559dec98..4dcecbac 100644
--- a/python/google/protobuf/pyext/repeated_scalar_container.h
+++ b/python/google/protobuf/pyext/repeated_scalar_container.h
@@ -104,6 +104,6 @@ void SetOwner(RepeatedScalarContainer* self,
} // namespace repeated_scalar_container
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_REPEATED_SCALAR_CONTAINER_H__
diff --git a/python/google/protobuf/pyext/safe_numerics.h b/python/google/protobuf/pyext/safe_numerics.h
index 639ba2c8..60112cfa 100644
--- a/python/google/protobuf/pyext/safe_numerics.h
+++ b/python/google/protobuf/pyext/safe_numerics.h
@@ -159,6 +159,6 @@ inline Dest checked_numeric_cast(Source source) {
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_SAFE_NUMERICS_H__
diff --git a/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h b/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h
index ad804b5f..79fa9e3d 100644
--- a/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h
+++ b/python/google/protobuf/pyext/thread_unsafe_shared_ptr.h
@@ -99,6 +99,6 @@ class ThreadUnsafeSharedPtr {
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_CPP_THREAD_UNSAFE_SHARED_PTR_H__
diff --git a/python/google/protobuf/python_protobuf.h b/python/google/protobuf/python_protobuf.h
index beb6e460..8db1ffb7 100644
--- a/python/google/protobuf/python_protobuf.h
+++ b/python/google/protobuf/python_protobuf.h
@@ -52,6 +52,6 @@ Message* MutableCProtoInsidePyProto(PyObject* msg);
} // namespace python
} // namespace protobuf
-
} // namespace google
+
#endif // GOOGLE_PROTOBUF_PYTHON_PYTHON_PROTOBUF_H__
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py
index f4ce8caf..81e18859 100755
--- a/python/google/protobuf/reflection.py
+++ b/python/google/protobuf/reflection.py
@@ -48,25 +48,23 @@ this file*.
__author__ = 'robinson@google.com (Will Robinson)'
-from google.protobuf.internal import api_implementation
-from google.protobuf import message
-
-
-if api_implementation.Type() == 'cpp':
- from google.protobuf.pyext import cpp_message as message_impl
-else:
- from google.protobuf.internal import python_message as message_impl
+from google.protobuf import message_factory
+from google.protobuf import symbol_database
# The type of all Message classes.
# Part of the public interface, but normally only used by message factories.
-GeneratedProtocolMessageType = message_impl.GeneratedProtocolMessageType
+GeneratedProtocolMessageType = message_factory._GENERATED_PROTOCOL_MESSAGE_TYPE
MESSAGE_CLASS_CACHE = {}
+# Deprecated. Please NEVER use reflection.ParseMessage().
def ParseMessage(descriptor, byte_str):
"""Generate a new Message instance from this Descriptor and a byte string.
+ DEPRECATED: ParseMessage is deprecated because it is using MakeClass().
+ Please use MessageFactory.GetPrototype() instead.
+
Args:
descriptor: Protobuf Descriptor object
byte_str: Serialized protocol buffer byte string
@@ -80,42 +78,18 @@ def ParseMessage(descriptor, byte_str):
return new_msg
+# Deprecated. Please NEVER use reflection.MakeClass().
def MakeClass(descriptor):
"""Construct a class object for a protobuf described by descriptor.
- Composite descriptors are handled by defining the new class as a member of the
- parent class, recursing as deep as necessary.
- This is the dynamic equivalent to:
-
- class Parent(message.Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = descriptor
- class Child(message.Message):
- __metaclass__ = GeneratedProtocolMessageType
- DESCRIPTOR = descriptor.nested_types[0]
-
- Sample usage:
- file_descriptor = descriptor_pb2.FileDescriptorProto()
- file_descriptor.ParseFromString(proto2_string)
- msg_descriptor = descriptor.MakeDescriptor(file_descriptor.message_type[0])
- msg_class = reflection.MakeClass(msg_descriptor)
- msg = msg_class()
+ DEPRECATED: use MessageFactory.GetPrototype() instead.
Args:
descriptor: A descriptor.Descriptor object describing the protobuf.
Returns:
The Message class object described by the descriptor.
"""
- if descriptor in MESSAGE_CLASS_CACHE:
- return MESSAGE_CLASS_CACHE[descriptor]
-
- attributes = {}
- for name, nested_type in descriptor.nested_types_by_name.items():
- attributes[name] = MakeClass(nested_type)
-
- attributes[GeneratedProtocolMessageType._DESCRIPTOR_KEY] = descriptor
-
- result = GeneratedProtocolMessageType(
- str(descriptor.name), (message.Message,), attributes)
- MESSAGE_CLASS_CACHE[descriptor] = result
- return result
+ # Original implementation leads to duplicate message classes, which won't play
+ # well with extensions. Message factory info is also missing.
+ # Redirect to message_factory.
+ return symbol_database.Default().GetPrototype(descriptor)
diff --git a/python/google/protobuf/text_encoding.py b/python/google/protobuf/text_encoding.py
index 98995638..39898765 100644
--- a/python/google/protobuf/text_encoding.py
+++ b/python/google/protobuf/text_encoding.py
@@ -33,59 +33,70 @@ import re
import six
-# Lookup table for utf8
-_cescape_utf8_to_str = [chr(i) for i in range(0, 256)]
-_cescape_utf8_to_str[9] = r'\t' # optional escape
-_cescape_utf8_to_str[10] = r'\n' # optional escape
-_cescape_utf8_to_str[13] = r'\r' # optional escape
-_cescape_utf8_to_str[39] = r"\'" # optional escape
-
-_cescape_utf8_to_str[34] = r'\"' # necessary escape
-_cescape_utf8_to_str[92] = r'\\' # necessary escape
+_cescape_chr_to_symbol_map = {}
+_cescape_chr_to_symbol_map[9] = r'\t' # optional escape
+_cescape_chr_to_symbol_map[10] = r'\n' # optional escape
+_cescape_chr_to_symbol_map[13] = r'\r' # optional escape
+_cescape_chr_to_symbol_map[34] = r'\"' # necessary escape
+_cescape_chr_to_symbol_map[39] = r"\'" # optional escape
+_cescape_chr_to_symbol_map[92] = r'\\' # necessary escape
+
+# Lookup table for unicode
+_cescape_unicode_to_str = [chr(i) for i in range(0, 256)]
+for byte, string in _cescape_chr_to_symbol_map.items():
+ _cescape_unicode_to_str[byte] = string
# Lookup table for non-utf8, with necessary escapes at (o >= 127 or o < 32)
_cescape_byte_to_str = ([r'\%03o' % i for i in range(0, 32)] +
[chr(i) for i in range(32, 127)] +
[r'\%03o' % i for i in range(127, 256)])
-_cescape_byte_to_str[9] = r'\t' # optional escape
-_cescape_byte_to_str[10] = r'\n' # optional escape
-_cescape_byte_to_str[13] = r'\r' # optional escape
-_cescape_byte_to_str[39] = r"\'" # optional escape
-
-_cescape_byte_to_str[34] = r'\"' # necessary escape
-_cescape_byte_to_str[92] = r'\\' # necessary escape
+for byte, string in _cescape_chr_to_symbol_map.items():
+ _cescape_byte_to_str[byte] = string
+del byte, string
def CEscape(text, as_utf8):
- """Escape a bytes string for use in an ascii protocol buffer.
-
- text.encode('string_escape') does not seem to satisfy our needs as it
- encodes unprintable characters using two-digit hex escapes whereas our
- C++ unescaping function allows hex escapes to be any length. So,
- "\0011".encode('string_escape') ends up being "\\x011", which will be
- decoded in C++ as a single-character string with char code 0x11.
+ # type: (...) -> str
+ """Escape a bytes string for use in an text protocol buffer.
Args:
- text: A byte string to be escaped
- as_utf8: Specifies if result should be returned in UTF-8 encoding
+ text: A byte string to be escaped.
+ as_utf8: Specifies if result may contain non-ASCII characters.
+ In Python 3 this allows unescaped non-ASCII Unicode characters.
+ In Python 2 the return value will be valid UTF-8 rather than only ASCII.
Returns:
- Escaped string
+ Escaped string (str).
"""
- # PY3 hack: make Ord work for str and bytes:
- # //platforms/networking/data uses unicode here, hence basestring.
- Ord = ord if isinstance(text, six.string_types) else lambda x: x
+ # Python's text.encode() 'string_escape' or 'unicode_escape' codecs do not
+ # satisfy our needs; they encodes unprintable characters using two-digit hex
+ # escapes whereas our C++ unescaping function allows hex escapes to be any
+ # length. So, "\0011".encode('string_escape') ends up being "\\x011", which
+ # will be decoded in C++ as a single-character string with char code 0x11.
+ if six.PY3:
+ text_is_unicode = isinstance(text, str)
+ if as_utf8 and text_is_unicode:
+ # We're already unicode, no processing beyond control char escapes.
+ return text.translate(_cescape_chr_to_symbol_map)
+ ord_ = ord if text_is_unicode else lambda x: x # bytes iterate as ints.
+ else:
+ ord_ = ord # PY2
if as_utf8:
- return ''.join(_cescape_utf8_to_str[Ord(c)] for c in text)
- return ''.join(_cescape_byte_to_str[Ord(c)] for c in text)
+ return ''.join(_cescape_unicode_to_str[ord_(c)] for c in text)
+ return ''.join(_cescape_byte_to_str[ord_(c)] for c in text)
_CUNESCAPE_HEX = re.compile(r'(\\+)x([0-9a-fA-F])(?![0-9a-fA-F])')
-_cescape_highbit_to_str = ([chr(i) for i in range(0, 127)] +
- [r'\%03o' % i for i in range(127, 256)])
def CUnescape(text):
- """Unescape a text string with C-style escape sequences to UTF-8 bytes."""
+ # type: (str) -> bytes
+ """Unescape a text string with C-style escape sequences to UTF-8 bytes.
+
+ Args:
+ text: The data to parse in a str.
+ Returns:
+ A byte string.
+ """
def ReplaceHex(m):
# Only replace the match if the number of leading back slashes is odd. i.e.
@@ -98,10 +109,9 @@ def CUnescape(text):
# allow single-digit hex escapes (like '\xf').
result = _CUNESCAPE_HEX.sub(ReplaceHex, text)
- if str is bytes: # PY2
+ if six.PY2:
return result.decode('string_escape')
- result = ''.join(_cescape_highbit_to_str[ord(c)] for c in result)
- return (result.encode('ascii') # Make it bytes to allow decode.
+ return (result.encode('utf-8') # PY3: Make it bytes to allow decode.
.decode('unicode_escape')
# Make it bytes again to return the proper type.
.encode('raw_unicode_escape'))
diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py
index 2cbd21bc..5dd41830 100755
--- a/python/google/protobuf/text_format.py
+++ b/python/google/protobuf/text_format.py
@@ -55,15 +55,15 @@ from google.protobuf.internal import type_checkers
from google.protobuf import descriptor
from google.protobuf import text_encoding
-__all__ = ['MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue',
- 'Merge']
+__all__ = ['MessageToString', 'Parse', 'PrintMessage', 'PrintField',
+ 'PrintFieldValue', 'Merge', 'MessageToBytes']
_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(),
type_checkers.Int32ValueChecker(),
type_checkers.Uint64ValueChecker(),
type_checkers.Int64ValueChecker())
-_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE)
-_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE)
+_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?$', re.IGNORECASE)
+_FLOAT_NAN = re.compile('nanf?$', re.IGNORECASE)
_FLOAT_TYPES = frozenset([descriptor.FieldDescriptor.CPPTYPE_FLOAT,
descriptor.FieldDescriptor.CPPTYPE_DOUBLE])
_QUOTES = frozenset(("'", '"'))
@@ -121,6 +121,7 @@ class TextWriter(object):
def MessageToString(message,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
@@ -128,6 +129,7 @@ def MessageToString(message,
descriptor_pool=None,
indent=0,
message_formatter=None):
+ # type: (...) -> str
"""Convert protobuf message to text format.
Floating point values can be formatted compactly with 15 digits of
@@ -137,8 +139,11 @@ def MessageToString(message,
Args:
message: The protocol buffers message.
- as_utf8: Produce text output in UTF8 format.
+ as_utf8: Return unescaped Unicode for non-ASCII characters.
+ In Python 3 actual Unicode characters may appear as is in strings.
+ In Python 2 the return value will be valid UTF-8 rather than only ASCII.
as_one_line: Don't introduce newlines between fields.
+ use_short_repeated_primitives: Use short repeated format for primitives.
pointy_brackets: If True, use angle brackets instead of curly braces for
nesting.
use_index_order: If True, fields of a proto message will be printed using
@@ -159,7 +164,8 @@ def MessageToString(message,
A string of the text formatted protocol buffer message.
"""
out = TextWriter(as_utf8)
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, use_field_number,
descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -170,6 +176,16 @@ def MessageToString(message,
return result
+def MessageToBytes(message, **kwargs):
+ # type: (...) -> bytes
+ """Convert protobuf message to encoded text format. See MessageToString."""
+ text = MessageToString(message, **kwargs)
+ if isinstance(text, bytes):
+ return text
+ codec = 'utf-8' if kwargs.get('as_utf8') else 'ascii'
+ return text.encode(codec)
+
+
def _IsMapEntry(field):
return (field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and
field.message_type.has_options and
@@ -181,13 +197,15 @@ def PrintMessage(message,
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
use_field_number=False,
descriptor_pool=None,
message_formatter=None):
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, use_field_number,
descriptor_pool, message_formatter)
printer.PrintMessage(message)
@@ -199,12 +217,14 @@ def PrintField(field,
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
message_formatter=None):
"""Print a single field name/value pair."""
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, message_formatter)
printer.PrintField(field, value)
@@ -215,12 +235,14 @@ def PrintFieldValue(field,
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
message_formatter=None):
"""Print a single field value (not including name)."""
- printer = _Printer(out, indent, as_utf8, as_one_line, pointy_brackets,
+ printer = _Printer(out, indent, as_utf8, as_one_line,
+ use_short_repeated_primitives, pointy_brackets,
use_index_order, float_format, message_formatter)
printer.PrintFieldValue(field, value)
@@ -258,6 +280,7 @@ class _Printer(object):
indent=0,
as_utf8=False,
as_one_line=False,
+ use_short_repeated_primitives=False,
pointy_brackets=False,
use_index_order=False,
float_format=None,
@@ -274,8 +297,11 @@ class _Printer(object):
Args:
out: To record the text format result.
indent: The indent level for pretty print.
- as_utf8: Produce text output in UTF8 format.
+ as_utf8: Return unescaped Unicode for non-ASCII characters.
+ In Python 3 actual Unicode characters may appear as is in strings.
+ In Python 2 the return value will be valid UTF-8 rather than ASCII.
as_one_line: Don't introduce newlines between fields.
+ use_short_repeated_primitives: Use short repeated format for primitives.
pointy_brackets: If True, use angle brackets instead of curly braces for
nesting.
use_index_order: If True, print fields of a proto message using the order
@@ -294,6 +320,7 @@ class _Printer(object):
self.indent = indent
self.as_utf8 = as_utf8
self.as_one_line = as_one_line
+ self.use_short_repeated_primitives = use_short_repeated_primitives
self.pointy_brackets = pointy_brackets
self.use_index_order = use_index_order
self.float_format = float_format
@@ -351,13 +378,18 @@ class _Printer(object):
entry_submsg = value.GetEntryClass()(key=key, value=value[key])
self.PrintField(field, entry_submsg)
elif field.label == descriptor.FieldDescriptor.LABEL_REPEATED:
- for element in value:
- self.PrintField(field, element)
+ if (self.use_short_repeated_primitives
+ and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE
+ and field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_STRING):
+ self._PrintShortRepeatedPrimitivesValue(field, value)
+ else:
+ for element in value:
+ self.PrintField(field, element)
else:
self.PrintField(field, value)
- def PrintField(self, field, value):
- """Print a single field name/value pair."""
+ def _PrintFieldName(self, field):
+ """Print field name."""
out = self.out
out.write(' ' * self.indent)
if self.use_field_number:
@@ -383,11 +415,22 @@ class _Printer(object):
# don't include it.
out.write(': ')
+ def PrintField(self, field, value):
+ """Print a single field name/value pair."""
+ self._PrintFieldName(field)
self.PrintFieldValue(field, value)
- if self.as_one_line:
- out.write(' ')
- else:
- out.write('\n')
+ self.out.write(' ' if self.as_one_line else '\n')
+
+ def _PrintShortRepeatedPrimitivesValue(self, field, value):
+ # Note: this is called only when value has at least one element.
+ self._PrintFieldName(field)
+ self.out.write('[')
+ for i in xrange(len(value) - 1):
+ self.PrintFieldValue(field, value[i])
+ self.out.write(', ')
+ self.PrintFieldValue(field, value[-1])
+ self.out.write(']')
+ self.out.write(' ' if self.as_one_line else '\n')
def _PrintMessageFieldValue(self, value):
if self.pointy_brackets:
@@ -428,12 +471,12 @@ class _Printer(object):
out.write(str(value))
elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING:
out.write('\"')
- if isinstance(value, six.text_type):
+ if isinstance(value, six.text_type) and (six.PY2 or not self.as_utf8):
out_value = value.encode('utf-8')
else:
out_value = value
if field.type == descriptor.FieldDescriptor.TYPE_BYTES:
- # We need to escape non-UTF8 chars in TYPE_BYTES field.
+ # We always need to escape all binary data in TYPE_BYTES fields.
out_as_utf8 = False
else:
out_as_utf8 = self.as_utf8
@@ -487,12 +530,7 @@ def Parse(text,
Raises:
ParseError: On text parsing problems.
"""
- if not isinstance(text, str):
- if six.PY3:
- text = text.decode('utf-8')
- else:
- text = text.encode('utf-8')
- return ParseLines(text.split('\n'),
+ return ParseLines(text.split(b'\n' if isinstance(text, bytes) else u'\n'),
message,
allow_unknown_extension,
allow_field_number,
@@ -523,13 +561,8 @@ def Merge(text,
Raises:
ParseError: On text parsing problems.
"""
- if not isinstance(text, str):
- if six.PY3:
- text = text.decode('utf-8')
- else:
- text = text.encode('utf-8')
return MergeLines(
- text.split('\n'),
+ text.split(b'\n' if isinstance(text, bytes) else u'\n'),
message,
allow_unknown_extension,
allow_field_number,
@@ -570,6 +603,9 @@ def MergeLines(lines,
descriptor_pool=None):
"""Parses a text representation of a protocol message into a message.
+ Like ParseLines(), but allows repeated values for a non-repeated field, and
+ uses the last one.
+
Args:
lines: An iterable of lines of a message's text representation.
message: A protocol buffer message to merge into.
@@ -601,22 +637,12 @@ class _Parser(object):
self.allow_field_number = allow_field_number
self.descriptor_pool = descriptor_pool
- def ParseFromString(self, text, message):
- """Parses a text representation of a protocol message into a message."""
- if not isinstance(text, str):
- text = text.decode('utf-8')
- return self.ParseLines(text.split('\n'), message)
-
def ParseLines(self, lines, message):
"""Parses a text representation of a protocol message into a message."""
self._allow_multiple_scalars = False
self._ParseOrMerge(lines, message)
return message
- def MergeFromString(self, text, message):
- """Merges a text representation of a protocol message into a message."""
- return self._MergeLines(text.split('\n'), message)
-
def MergeLines(self, lines, message):
"""Merges a text representation of a protocol message into a message."""
self._allow_multiple_scalars = True
@@ -633,7 +659,14 @@ class _Parser(object):
Raises:
ParseError: On text parsing problems.
"""
- tokenizer = Tokenizer(lines)
+ # Tokenize expects native str lines.
+ if six.PY2:
+ str_lines = (line if isinstance(line, str) else line.encode('utf-8')
+ for line in lines)
+ else:
+ str_lines = (line if isinstance(line, str) else line.decode('utf-8')
+ for line in lines)
+ tokenizer = Tokenizer(str_lines)
while not tokenizer.AtEnd():
self._MergeField(tokenizer, message)
@@ -1019,7 +1052,9 @@ class Tokenizer(object):
r'[a-zA-Z_][0-9a-zA-Z_+-]*', # an identifier
r'([0-9+-]|(\.[0-9]))[0-9a-zA-Z_.+-]*', # a number
] + [ # quoted str for each quote mark
- r'{qt}([^{qt}\n\\]|\\.)*({qt}|\\?$)'.format(qt=mark) for mark in _QUOTES
+ # Avoid backtracking! https://stackoverflow.com/a/844267
+ r'{qt}[^{qt}\n\\]*((\\.)+[^{qt}\n\\]*)*({qt}|\\?$)'.format(qt=mark)
+ for mark in _QUOTES
]))
_IDENTIFIER = re.compile(r'[^\d\W]\w*')
@@ -1316,7 +1351,8 @@ class Tokenizer(object):
def ParseError(self, message):
"""Creates and *returns* a ParseError for the current token."""
- return ParseError(message, self._line + 1, self._column + 1)
+ return ParseError('\'' + self._current_line + '\': ' + message,
+ self._line + 1, self._column + 1)
def _StringParseError(self, e):
return self.ParseError('Couldn\'t parse string: ' + str(e))
@@ -1490,6 +1526,12 @@ def _ParseAbstractInteger(text, is_long=False):
ValueError: Thrown Iff the text is not a valid integer.
"""
# Do the actual parsing. Exception handling is propagated to caller.
+ orig_text = text
+ c_octal_match = re.match(r'(-?)0(\d+)$', text)
+ if c_octal_match:
+ # Python 3 no longer supports 0755 octal syntax without the 'o', so
+ # we always use the '0o' prefix for multi-digit numbers starting with 0.
+ text = c_octal_match.group(1) + '0o' + c_octal_match.group(2)
try:
# We force 32-bit values to int and 64-bit values to long to make
# alternate implementations where the distinction is more significant
@@ -1499,7 +1541,7 @@ def _ParseAbstractInteger(text, is_long=False):
else:
return int(text, 0)
except ValueError:
- raise ValueError('Couldn\'t parse integer: %s' % text)
+ raise ValueError('Couldn\'t parse integer: %s' % orig_text)
def ParseFloat(text):