From ada65567852b96fdb4d070c0c3f86ca7b77824f9 Mon Sep 17 00:00:00 2001 From: Jisi Liu Date: Wed, 25 Feb 2015 16:39:11 -0800 Subject: Down integrate from Google internal. Change-Id: I34d301133eea9c6f3a822c47d1f91e136fd33145 --- python/README.txt | 5 +- python/google/protobuf/descriptor.py | 158 +- python/google/protobuf/descriptor_pool.py | 114 +- .../google/protobuf/internal/api_implementation.cc | 14 +- .../google/protobuf/internal/api_implementation.py | 38 +- .../internal/api_implementation_default_test.py | 63 - python/google/protobuf/internal/containers.py | 37 +- python/google/protobuf/internal/decoder.py | 3 - python/google/protobuf/internal/descriptor_test.py | 176 ++- python/google/protobuf/internal/message_test.py | 782 +++++++--- python/google/protobuf/internal/python_message.py | 81 +- python/google/protobuf/internal/reflection_test.py | 21 + python/google/protobuf/internal/test_util.py | 170 ++- .../google/protobuf/internal/text_format_test.py | 557 +++---- python/google/protobuf/internal/type_checkers.py | 8 +- .../protobuf/internal/unknown_fields_test.py | 120 +- python/google/protobuf/pyext/descriptor.cc | 1538 +++++++++++++++---- python/google/protobuf/pyext/descriptor.h | 122 +- .../google/protobuf/pyext/descriptor_containers.cc | 1564 ++++++++++++++++++++ .../google/protobuf/pyext/descriptor_containers.h | 95 ++ .../google/protobuf/pyext/descriptor_cpp2_test.py | 58 - python/google/protobuf/pyext/descriptor_pool.cc | 370 +++++ python/google/protobuf/pyext/descriptor_pool.h | 152 ++ python/google/protobuf/pyext/extension_dict.cc | 25 +- python/google/protobuf/pyext/extension_dict.h | 3 +- python/google/protobuf/pyext/message.cc | 1044 ++++++------- python/google/protobuf/pyext/message.h | 50 +- .../protobuf/pyext/message_factory_cpp2_test.py | 56 - .../pyext/reflection_cpp2_generated_test.py | 94 -- .../protobuf/pyext/repeated_composite_container.cc | 53 +- .../protobuf/pyext/repeated_composite_container.h | 8 +- .../protobuf/pyext/repeated_scalar_container.cc | 183 ++- .../protobuf/pyext/repeated_scalar_container.h | 4 +- python/google/protobuf/reflection.py | 18 +- python/google/protobuf/text_format.py | 5 + python/setup.py | 22 +- .../protobuf/compiler/python/python_generator.cc | 25 +- src/google/protobuf/testdata/golden_message_proto3 | Bin 0 -> 398 bytes src/google/protobuf/unittest_proto3_arena.proto | 54 +- 39 files changed, 5781 insertions(+), 2109 deletions(-) delete mode 100644 python/google/protobuf/internal/api_implementation_default_test.py create mode 100644 python/google/protobuf/pyext/descriptor_containers.cc create mode 100644 python/google/protobuf/pyext/descriptor_containers.h delete mode 100644 python/google/protobuf/pyext/descriptor_cpp2_test.py create mode 100644 python/google/protobuf/pyext/descriptor_pool.cc create mode 100644 python/google/protobuf/pyext/descriptor_pool.h delete mode 100644 python/google/protobuf/pyext/message_factory_cpp2_test.py delete mode 100755 python/google/protobuf/pyext/reflection_cpp2_generated_test.py create mode 100644 src/google/protobuf/testdata/golden_message_proto3 diff --git a/python/README.txt b/python/README.txt index 7d852458..04cb1767 100644 --- a/python/README.txt +++ b/python/README.txt @@ -48,8 +48,9 @@ Installation $ python setup.py build $ python setup.py google_test - If you want to test c++ implementation, run: - $ python setup.py test --cpp_implementation + If you want to build/test c++ implementation, run: + $ python setup.py build --cpp_implementation + $ python setup.py google_test --cpp_implementation If some tests fail, this library may not work correctly on your system. Continue at your own risk. diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index af571b7c..f7a58ca0 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -41,15 +41,13 @@ __author__ = 'robinson@google.com (Will Robinson)' from google.protobuf.internal import api_implementation +_USE_C_DESCRIPTORS = False if api_implementation.Type() == 'cpp': # Used by MakeDescriptor in cpp mode import os import uuid - - if api_implementation.Version() == 2: - from google.protobuf.pyext import _message - else: - from google.protobuf.internal import cpp_message + from google.protobuf.pyext import _message + _USE_C_DESCRIPTORS = getattr(_message, '_USE_C_DESCRIPTORS', False) class Error(Exception): @@ -60,12 +58,29 @@ class TypeTransformationError(Error): """Error transforming between python proto type and corresponding C++ type.""" +if _USE_C_DESCRIPTORS: + # This metaclass allows to override the behavior of code like + # isinstance(my_descriptor, FieldDescriptor) + # and make it return True when the descriptor is an instance of the extension + # type written in C++. + class DescriptorMetaclass(type): + def __instancecheck__(cls, obj): + if super(DescriptorMetaclass, cls).__instancecheck__(obj): + return True + if isinstance(obj, cls._C_DESCRIPTOR_CLASS): + return True + return False +else: + # The standard metaclass; nothing changes. + DescriptorMetaclass = type + + class DescriptorBase(object): """Descriptors base class. This class is the base of all descriptor classes. It provides common options - related functionaility. + related functionality. Attributes: has_options: True if the descriptor has non-default options. Usually it @@ -75,6 +90,12 @@ class DescriptorBase(object): avoid some bootstrapping issues. """ + __metaclass__ = DescriptorMetaclass + if _USE_C_DESCRIPTORS: + # The class, or tuple of classes, that are considered as "virtual + # subclasses" of this descriptor class. + _C_DESCRIPTOR_CLASS = () + def __init__(self, options, options_class_name): """Initialize the descriptor given its options message and the name of the class of the options message. The name of the class is required in case @@ -235,13 +256,25 @@ class Descriptor(_NestedDescriptorBase): file: (FileDescriptor) Reference to file descriptor. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.Descriptor + + def __new__(cls, name, full_name, filename, containing_type, fields, + nested_types, enum_types, extensions, options=None, + is_extendable=True, extension_ranges=None, oneofs=None, + file=None, serialized_start=None, serialized_end=None, + syntax=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.Message._GetMessageDescriptor(full_name) + # NOTE(tmarek): The file argument redefining a builtin is nothing we can # fix right now since we don't know how many clients already rely on the # name of the argument. def __init__(self, name, full_name, filename, containing_type, fields, nested_types, enum_types, extensions, options=None, is_extendable=True, extension_ranges=None, oneofs=None, - file=None, serialized_start=None, serialized_end=None): # pylint:disable=redefined-builtin + file=None, serialized_start=None, serialized_end=None, + syntax=None): # pylint:disable=redefined-builtin """Arguments to __init__() are as described in the description of Descriptor fields above. @@ -286,6 +319,7 @@ class Descriptor(_NestedDescriptorBase): self.oneofs_by_name = dict((o.name, o) for o in self.oneofs) for oneof in self.oneofs: oneof.containing_type = self + self.syntax = syntax or "proto2" def EnumValueName(self, enum, value): """Returns the string name of an enum value. @@ -452,6 +486,19 @@ class FieldDescriptor(DescriptorBase): FIRST_RESERVED_FIELD_NUMBER = 19000 LAST_RESERVED_FIELD_NUMBER = 19999 + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.FieldDescriptor + + def __new__(cls, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None, + has_default_value=True, containing_oneof=None): + _message.Message._CheckCalledFromGeneratedFile() + if is_extension: + return _message.Message._GetExtensionDescriptor(full_name) + else: + return _message.Message._GetFieldDescriptor(full_name) + def __init__(self, name, full_name, index, number, type, cpp_type, label, default_value, message_type, enum_type, containing_type, is_extension, extension_scope, options=None, @@ -481,20 +528,14 @@ class FieldDescriptor(DescriptorBase): self.containing_oneof = containing_oneof if api_implementation.Type() == 'cpp': if is_extension: - if api_implementation.Version() == 2: - # pylint: disable=protected-access - self._cdescriptor = ( - _message.Message._GetExtensionDescriptor(full_name)) - # pylint: enable=protected-access - else: - self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name) + # pylint: disable=protected-access + self._cdescriptor = ( + _message.Message._GetExtensionDescriptor(full_name)) + # pylint: enable=protected-access else: - if api_implementation.Version() == 2: - # pylint: disable=protected-access - self._cdescriptor = _message.Message._GetFieldDescriptor(full_name) - # pylint: enable=protected-access - else: - self._cdescriptor = cpp_message.GetFieldDescriptor(full_name) + # pylint: disable=protected-access + self._cdescriptor = _message.Message._GetFieldDescriptor(full_name) + # pylint: enable=protected-access else: self._cdescriptor = None @@ -544,6 +585,15 @@ class EnumDescriptor(_NestedDescriptorBase): None to use default enum options. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.EnumDescriptor + + def __new__(cls, name, full_name, filename, values, + containing_type=None, options=None, file=None, + serialized_start=None, serialized_end=None): + _message.Message._CheckCalledFromGeneratedFile() + return _message.Message._GetEnumDescriptor(full_name) + def __init__(self, name, full_name, filename, values, containing_type=None, options=None, file=None, serialized_start=None, serialized_end=None): @@ -588,6 +638,17 @@ class EnumValueDescriptor(DescriptorBase): None to use default enum value options options. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.EnumValueDescriptor + + def __new__(cls, name, index, number, type=None, options=None): + _message.Message._CheckCalledFromGeneratedFile() + # There is no way we can build a complete EnumValueDescriptor with the + # given parameters (the name of the Enum is not known, for example). + # Fortunately generated files just pass it to the EnumDescriptor() + # constructor, which will ignore it, so returning None is good enough. + return None + def __init__(self, name, index, number, type=None, options=None): """Arguments are as described in the attribute description above.""" super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions') @@ -611,6 +672,13 @@ class OneofDescriptor(object): oneof can contain. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.OneofDescriptor + + def __new__(cls, name, full_name, index, containing_type, fields): + _message.Message._CheckCalledFromGeneratedFile() + return _message.Message._GetOneofDescriptor(full_name) + def __init__(self, name, full_name, index, containing_type, fields): """Arguments are as described in the attribute description above.""" self.name = name @@ -704,6 +772,7 @@ class FileDescriptor(DescriptorBase): name: name of file, relative to root of source tree. package: name of the package + syntax: string indicating syntax of the file (can be "proto2" or "proto3") serialized_pb: (str) Byte string of serialized descriptor_pb2.FileDescriptorProto. dependencies: List of other FileDescriptors this FileDescriptor depends on. @@ -712,14 +781,31 @@ class FileDescriptor(DescriptorBase): extensions_by_name: Dict of extension names and their descriptors. """ + if _USE_C_DESCRIPTORS: + _C_DESCRIPTOR_CLASS = _message.FileDescriptor + + def __new__(cls, name, package, options=None, serialized_pb=None, + dependencies=None, syntax=None): + # FileDescriptor() is called from various places, not only from generated + # files, to register dynamic proto files and messages. + # TODO(amauryfa): Expose BuildFile() as a public function and make this + # constructor an implementation detail. + if serialized_pb: + # pylint: disable=protected-access2 + return _message.Message._BuildFile(serialized_pb) + # pylint: enable=protected-access + else: + return super(FileDescriptor, cls).__new__(cls) + def __init__(self, name, package, options=None, serialized_pb=None, - dependencies=None): + dependencies=None, syntax=None): """Constructor.""" super(FileDescriptor, self).__init__(options, 'FileOptions') self.message_types_by_name = {} self.name = name self.package = package + self.syntax = syntax or "proto2" self.serialized_pb = serialized_pb self.enum_types_by_name = {} @@ -728,12 +814,9 @@ class FileDescriptor(DescriptorBase): if (api_implementation.Type() == 'cpp' and self.serialized_pb is not None): - if api_implementation.Version() == 2: - # pylint: disable=protected-access - _message.Message._BuildFile(self.serialized_pb) - # pylint: enable=protected-access - else: - cpp_message.BuildFile(self.serialized_pb) + # pylint: disable=protected-access + _message.Message._BuildFile(self.serialized_pb) + # pylint: enable=protected-access def CopyToProto(self, proto): """Copies this to a descriptor_pb2.FileDescriptorProto. @@ -754,7 +837,8 @@ def _ParseOptions(message, string): return message -def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): +def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True, + syntax=None): """Make a protobuf Descriptor given a DescriptorProto protobuf. Handles nested descriptors. Note that this is limited to the scope of defining @@ -766,6 +850,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): package: Optional package name for the new message Descriptor (string). build_file_if_cpp: Update the C++ descriptor pool if api matches. Set to False on recursion, so no duplicates are created. + syntax: The syntax/semantics that should be used. Set to "proto3" to get + proto3 field presence semantics. Returns: A Descriptor for protobuf messages. """ @@ -791,12 +877,13 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): else: file_descriptor_proto.name = proto_name + '.proto' - if api_implementation.Version() == 2: - # pylint: disable=protected-access - _message.Message._BuildFile(file_descriptor_proto.SerializeToString()) - # pylint: enable=protected-access - else: - cpp_message.BuildFile(file_descriptor_proto.SerializeToString()) + # pylint: disable=protected-access + result = _message.Message._BuildFile( + file_descriptor_proto.SerializeToString()) + # pylint: enable=protected-access + + if _USE_C_DESCRIPTORS: + return result.message_types_by_name[desc_proto.name] full_message_name = [desc_proto.name] if package: full_message_name.insert(0, package) @@ -819,7 +906,8 @@ def MakeDescriptor(desc_proto, package='', build_file_if_cpp=True): # used by fields in the message, so no loops are possible here. nested_desc = MakeDescriptor(nested_proto, package='.'.join(full_message_name), - build_file_if_cpp=False) + build_file_if_cpp=False, + syntax=syntax) nested_types[full_name] = nested_desc fields = [] diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py index bcac513a..7e7701f8 100644 --- a/python/google/protobuf/descriptor_pool.py +++ b/python/google/protobuf/descriptor_pool.py @@ -64,6 +64,9 @@ from google.protobuf import descriptor_database from google.protobuf import text_encoding +_USE_C_DESCRIPTORS = descriptor._USE_C_DESCRIPTORS + + def _NormalizeFullyQualifiedName(name): """Remove leading period from fully-qualified type name. @@ -271,58 +274,81 @@ class DescriptorPool(object): file_descriptor = descriptor.FileDescriptor( name=file_proto.name, package=file_proto.package, + syntax=file_proto.syntax, options=file_proto.options, serialized_pb=file_proto.SerializeToString(), dependencies=direct_deps) - scope = {} - - # This loop extracts all the message and enum types from all the - # dependencoes of the file_proto. This is necessary to create the - # scope of available message types when defining the passed in - # file proto. - for dependency in built_deps: - scope.update(self._ExtractSymbols( - dependency.message_types_by_name.values())) - scope.update((_PrefixWithDot(enum.full_name), enum) - for enum in dependency.enum_types_by_name.values()) - - for message_type in file_proto.message_type: - message_desc = self._ConvertMessageDescriptor( - message_type, file_proto.package, file_descriptor, scope) - file_descriptor.message_types_by_name[message_desc.name] = message_desc - - for enum_type in file_proto.enum_type: - file_descriptor.enum_types_by_name[enum_type.name] = ( - self._ConvertEnumDescriptor(enum_type, file_proto.package, - file_descriptor, None, scope)) - - for index, extension_proto in enumerate(file_proto.extension): - extension_desc = self.MakeFieldDescriptor( - extension_proto, file_proto.package, index, is_extension=True) - extension_desc.containing_type = self._GetTypeFromScope( - file_descriptor.package, extension_proto.extendee, scope) - self.SetFieldType(extension_proto, extension_desc, - file_descriptor.package, scope) - file_descriptor.extensions_by_name[extension_desc.name] = extension_desc - - for desc_proto in file_proto.message_type: - self.SetAllFieldTypes(file_proto.package, desc_proto, scope) - - if file_proto.package: - desc_proto_prefix = _PrefixWithDot(file_proto.package) + if _USE_C_DESCRIPTORS: + # When using C++ descriptors, all objects defined in the file were added + # to the C++ database when the FileDescriptor was built above. + # Just add them to this descriptor pool. + def _AddMessageDescriptor(message_desc): + self._descriptors[message_desc.full_name] = message_desc + for nested in message_desc.nested_types: + _AddMessageDescriptor(nested) + for enum_type in message_desc.enum_types: + _AddEnumDescriptor(enum_type) + def _AddEnumDescriptor(enum_desc): + self._enum_descriptors[enum_desc.full_name] = enum_desc + for message_type in file_descriptor.message_types_by_name.values(): + _AddMessageDescriptor(message_type) + for enum_type in file_descriptor.enum_types_by_name.values(): + _AddEnumDescriptor(enum_type) else: - desc_proto_prefix = '' + scope = {} + + # This loop extracts all the message and enum types from all the + # dependencies of the file_proto. This is necessary to create the + # scope of available message types when defining the passed in + # file proto. + for dependency in built_deps: + scope.update(self._ExtractSymbols( + dependency.message_types_by_name.values())) + scope.update((_PrefixWithDot(enum.full_name), enum) + for enum in dependency.enum_types_by_name.values()) + + for message_type in file_proto.message_type: + message_desc = self._ConvertMessageDescriptor( + message_type, file_proto.package, file_descriptor, scope, + file_proto.syntax) + file_descriptor.message_types_by_name[message_desc.name] = ( + message_desc) + + for enum_type in file_proto.enum_type: + file_descriptor.enum_types_by_name[enum_type.name] = ( + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope)) + + for index, extension_proto in enumerate(file_proto.extension): + extension_desc = self.MakeFieldDescriptor( + extension_proto, file_proto.package, index, is_extension=True) + extension_desc.containing_type = self._GetTypeFromScope( + file_descriptor.package, extension_proto.extendee, scope) + self.SetFieldType(extension_proto, extension_desc, + file_descriptor.package, scope) + file_descriptor.extensions_by_name[extension_desc.name] = ( + extension_desc) + + for desc_proto in file_proto.message_type: + self.SetAllFieldTypes(file_proto.package, desc_proto, scope) + + if file_proto.package: + desc_proto_prefix = _PrefixWithDot(file_proto.package) + else: + desc_proto_prefix = '' + + for desc_proto in file_proto.message_type: + desc = self._GetTypeFromScope( + desc_proto_prefix, desc_proto.name, scope) + file_descriptor.message_types_by_name[desc_proto.name] = desc - for desc_proto in file_proto.message_type: - desc = self._GetTypeFromScope(desc_proto_prefix, desc_proto.name, scope) - file_descriptor.message_types_by_name[desc_proto.name] = desc self.Add(file_proto) self._file_descriptors[file_proto.name] = file_descriptor return self._file_descriptors[file_proto.name] def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, - scope=None): + scope=None, syntax=None): """Adds the proto to the pool in the specified package. Args: @@ -349,7 +375,8 @@ class DescriptorPool(object): scope = {} nested = [ - self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope) + self._ConvertMessageDescriptor( + nested, desc_name, file_desc, scope, syntax) for nested in desc_proto.nested_type] enums = [ self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) @@ -383,7 +410,8 @@ class DescriptorPool(object): extension_ranges=extension_ranges, file=file_desc, serialized_start=None, - serialized_end=None) + serialized_end=None, + syntax=syntax) for nested in desc.nested_types: nested.containing_type = desc for enum in desc.enum_types: diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc index 83db40b1..6db12e8d 100644 --- a/python/google/protobuf/internal/api_implementation.cc +++ b/python/google/protobuf/internal/api_implementation.cc @@ -50,10 +50,7 @@ namespace python { // and // PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 #ifdef PYTHON_PROTO2_CPP_IMPL_V1 -#if PY_MAJOR_VERSION >= 3 -#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3." -#endif -static int kImplVersion = 1; +#error "PYTHON_PROTO2_CPP_IMPL_V1 is no longer supported." #else #ifdef PYTHON_PROTO2_CPP_IMPL_V2 static int kImplVersion = 2; @@ -62,14 +59,7 @@ static int kImplVersion = 2; static int kImplVersion = 0; #else -// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. -// Python 2 still uses the Python version by default until some compatibility -// issues can be worked around. -#if PY_MAJOR_VERSION >= 3 -static int kImplVersion = 2; -#else -static int kImplVersion = 0; -#endif +static int kImplVersion = -1; // -1 means "Unspecified by compiler flags". #endif // PYTHON_PROTO2_PYTHON_IMPL #endif // PYTHON_PROTO2_CPP_IMPL_V2 diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index f7926c16..8ba4357c 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -40,14 +40,33 @@ try: # The compile-time constants in the _api_implementation module can be used to # switch to a certain implementation of the Python API at build time. _api_version = _api_implementation.api_version - del _api_implementation + _proto_extension_modules_exist_in_build = True except ImportError: - _api_version = 0 + _api_version = -1 # Unspecified by compiler flags. + _proto_extension_modules_exist_in_build = False + +if _api_version == 1: + raise ValueError('api_version=1 is no longer supported.') +if _api_version < 0: # Still unspecified? + try: + # The presence of this module in a build allows the proto implementation to + # be upgraded merely via build deps rather than a compiler flag or the + # runtime environment variable. + # pylint: disable=g-import-not-at-top + from google.protobuf import _use_fast_cpp_protos + # Work around a known issue in the classic bootstrap .par import hook. + if not _use_fast_cpp_protos: + raise ImportError('_use_fast_cpp_protos import succeeded but was None') + del _use_fast_cpp_protos + _api_version = 2 + except ImportError: + if _proto_extension_modules_exist_in_build: + if sys.version_info[0] >= 3: # Python 3 defaults to C++ impl v2. + _api_version = 2 + # TODO(b/17427486): Make Python 2 default to C++ impl v2. _default_implementation_type = ( - 'python' if _api_version == 0 else 'cpp') -_default_version_str = ( - '1' if _api_version <= 1 else '2') + 'python' if _api_version <= 0 else 'cpp') # This environment variable can be used to switch to a certain implementation # of the Python API, overriding the compile-time constants in the @@ -64,13 +83,12 @@ if _implementation_type != 'python': # _api_implementation module. Right now only 1 and 2 are valid values. Any other # value will be ignored. _implementation_version_str = os.getenv( - 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', - _default_version_str) + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', '2') -if _implementation_version_str not in ('1', '2'): +if _implementation_version_str != '2': raise ValueError( - "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + - _implementation_version_str + "' (supported versions: 1, 2)" + 'unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: "' + + _implementation_version_str + '" (supported versions: 2)' ) _implementation_version = int(_implementation_version_str) diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py deleted file mode 100644 index 78d5cf23..00000000 --- a/python/google/protobuf/internal/api_implementation_default_test.py +++ /dev/null @@ -1,63 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Test that the api_implementation defaults are what we expect.""" - -import os -import sys -# Clear environment implementation settings before the google3 imports. -os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None) -os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None) - -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation - - -class ApiImplementationDefaultTest(basetest.TestCase): - - if sys.version_info.major <= 2: - - def testThatPythonIsTheDefault(self): - """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" - self.assertEqual('python', api_implementation.Type()) - - else: - - def testThatCppApiV2IsTheDefault(self): - """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 20bfa857..d976f9e1 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -41,7 +41,6 @@ are: __author__ = 'petar@google.com (Petar Petrov)' - class BaseContainer(object): """Base container class.""" @@ -119,15 +118,23 @@ class RepeatedScalarFieldContainer(BaseContainer): self._message_listener.Modified() def extend(self, elem_seq): - """Extends by appending the given sequence. Similar to list.extend().""" - if not elem_seq: - return + """Extends by appending the given iterable. Similar to list.extend().""" - new_values = [] - for elem in elem_seq: - new_values.append(self._type_checker.CheckValue(elem)) - self._values.extend(new_values) - self._message_listener.Modified() + if elem_seq is None: + return + try: + elem_seq_iter = iter(elem_seq) + except TypeError: + if not elem_seq: + # silently ignore falsy inputs :-/. + # TODO(ptucker): Deprecate this behavior. b/18413862 + return + raise + + new_values = [self._type_checker.CheckValue(elem) for elem in elem_seq_iter] + if new_values: + self._values.extend(new_values) + self._message_listener.Modified() def MergeFrom(self, other): """Appends the contents of another repeated field of the same type to this @@ -141,6 +148,12 @@ class RepeatedScalarFieldContainer(BaseContainer): self._values.remove(elem) self._message_listener.Modified() + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + def __setitem__(self, key, value): """Sets the item on the specified position.""" if isinstance(key, slice): # PY3 @@ -245,6 +258,12 @@ class RepeatedCompositeFieldContainer(BaseContainer): self._values.remove(elem) self._message_listener.Modified() + def pop(self, key=-1): + """Removes and returns an item at a given index. Similar to list.pop().""" + value = self._values[key] + self.__delitem__(key) + return value + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index a4b90608..0f500606 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -621,9 +621,6 @@ def MessageDecoder(field_number, is_repeated, is_packed, key, new_default): if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: - value = field_dict.get(key) - if value is None: - value = field_dict.setdefault(key, new_default(message)) # Read length. (size, pos) = local_DecodeVarint(buffer, pos) new_pos = pos + size diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 3924f21a..50c4dbba 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -34,12 +34,16 @@ __author__ = 'robinson@google.com (Will Robinson)' +import sys + from google.apputils import basetest from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation from google.protobuf import descriptor +from google.protobuf import symbol_database from google.protobuf import text_format @@ -51,41 +55,28 @@ name: 'TestEmptyMessage' class DescriptorTest(basetest.TestCase): def setUp(self): - self.my_file = descriptor.FileDescriptor( + file_proto = descriptor_pb2.FileDescriptorProto( name='some/filename/some.proto', - package='protobuf_unittest' - ) - self.my_enum = descriptor.EnumDescriptor( - name='ForeignEnum', - full_name='protobuf_unittest.ForeignEnum', - filename=None, - file=self.my_file, - values=[ - descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), - descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), - descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6), - ]) - self.my_message = descriptor.Descriptor( - name='NestedMessage', - full_name='protobuf_unittest.TestAllTypes.NestedMessage', - filename=None, - file=self.my_file, - containing_type=None, - fields=[ - descriptor.FieldDescriptor( - name='bb', - full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', - index=0, number=1, - type=5, cpp_type=1, label=1, - has_default_value=False, default_value=0, - message_type=None, enum_type=None, containing_type=None, - is_extension=False, extension_scope=None), - ], - nested_types=[], - enum_types=[ - self.my_enum, - ], - extensions=[]) + package='protobuf_unittest') + message_proto = file_proto.message_type.add( + name='NestedMessage') + message_proto.field.add( + name='bb', + number=1, + type=descriptor_pb2.FieldDescriptorProto.TYPE_INT32, + label=descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL) + enum_proto = message_proto.enum_type.add( + name='ForeignEnum') + enum_proto.value.add(name='FOREIGN_FOO', number=4) + enum_proto.value.add(name='FOREIGN_BAR', number=5) + enum_proto.value.add(name='FOREIGN_BAZ', number=6) + + descriptor_pool = symbol_database.Default().pool + descriptor_pool.Add(file_proto) + self.my_file = descriptor_pool.FindFileByName(file_proto.name) + self.my_message = self.my_file.message_types_by_name[message_proto.name] + self.my_enum = self.my_message.enum_types_by_name[enum_proto.name] + self.my_method = descriptor.MethodDescriptor( name='Bar', full_name='protobuf_unittest.TestService.Bar', @@ -173,6 +164,11 @@ class DescriptorTest(basetest.TestCase): self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, method_options.Extensions[method_opt1]) + message_descriptor = ( + unittest_custom_options_pb2.DummyMessageContainingEnum.DESCRIPTOR) + self.assertTrue(file_descriptor.has_options) + self.assertFalse(message_descriptor.has_options) + def testDifferentCustomOptionTypes(self): kint32min = -2**31 kint64min = -2**63 @@ -394,6 +390,108 @@ class DescriptorTest(basetest.TestCase): self.assertEqual(self.my_file.name, 'some/filename/some.proto') self.assertEqual(self.my_file.package, 'protobuf_unittest') + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Immutability of descriptors is only enforced in v2 implementation') + def testImmutableCppDescriptor(self): + message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + with self.assertRaises(AttributeError): + message_descriptor.fields_by_name = None + with self.assertRaises(TypeError): + message_descriptor.fields_by_name['Another'] = None + with self.assertRaises(TypeError): + message_descriptor.fields.append(None) + + +class GeneratedDescriptorTest(basetest.TestCase): + """Tests for the properties of descriptors in generated code.""" + + def CheckMessageDescriptor(self, message_descriptor): + # Basic properties + self.assertEqual(message_descriptor.name, 'TestAllTypes') + self.assertEqual(message_descriptor.full_name, + 'protobuf_unittest.TestAllTypes') + # Test equality and hashability + self.assertEqual(message_descriptor, message_descriptor) + self.assertEqual(message_descriptor.fields[0].containing_type, + message_descriptor) + self.assertIn(message_descriptor, [message_descriptor]) + self.assertIn(message_descriptor, {message_descriptor: None}) + # Test field containers + self.CheckDescriptorSequence(message_descriptor.fields) + self.CheckDescriptorMapping(message_descriptor.fields_by_name) + self.CheckDescriptorMapping(message_descriptor.fields_by_number) + + def CheckFieldDescriptor(self, field_descriptor): + # Basic properties + self.assertEqual(field_descriptor.name, 'optional_int32') + self.assertEqual(field_descriptor.full_name, + 'protobuf_unittest.TestAllTypes.optional_int32') + self.assertEqual(field_descriptor.containing_type.name, 'TestAllTypes') + # Test equality and hashability + self.assertEqual(field_descriptor, field_descriptor) + self.assertEqual( + field_descriptor.containing_type.fields_by_name['optional_int32'], + field_descriptor) + self.assertIn(field_descriptor, [field_descriptor]) + self.assertIn(field_descriptor, {field_descriptor: None}) + + def CheckDescriptorSequence(self, sequence): + # Verifies that a property like 'messageDescriptor.fields' has all the + # properties of an immutable abc.Sequence. + self.assertGreater(len(sequence), 0) # Sized + self.assertEqual(len(sequence), len(list(sequence))) # Iterable + item = sequence[0] + self.assertEqual(item, sequence[0]) + self.assertIn(item, sequence) # Container + self.assertEqual(sequence.index(item), 0) + self.assertEqual(sequence.count(item), 1) + reversed_iterator = reversed(sequence) + self.assertEqual(list(reversed_iterator), list(sequence)[::-1]) + self.assertRaises(StopIteration, next, reversed_iterator) + + def CheckDescriptorMapping(self, mapping): + # Verifies that a property like 'messageDescriptor.fields' has all the + # properties of an immutable abc.Mapping. + self.assertGreater(len(mapping), 0) # Sized + self.assertEqual(len(mapping), len(list(mapping))) # Iterable + if sys.version_info.major >= 3: + key, item = next(iter(mapping.items())) + else: + key, item = mapping.items()[0] + self.assertIn(key, mapping) # Container + self.assertEqual(mapping.get(key), item) + # keys(), iterkeys() &co + item = (next(iter(mapping.keys())), next(iter(mapping.values()))) + self.assertEqual(item, next(iter(mapping.items()))) + if sys.version_info.major < 3: + def CheckItems(seq, iterator): + self.assertEqual(next(iterator), seq[0]) + self.assertEqual(list(iterator), seq[1:]) + CheckItems(mapping.keys(), mapping.iterkeys()) + CheckItems(mapping.values(), mapping.itervalues()) + CheckItems(mapping.items(), mapping.iteritems()) + + def testDescriptor(self): + message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.CheckMessageDescriptor(message_descriptor) + field_descriptor = message_descriptor.fields_by_name['optional_int32'] + self.CheckFieldDescriptor(field_descriptor) + + def testCppDescriptorContainer(self): + # Check that the collection is still valid even if the parent disappeared. + enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum'] + values = enum.values + del enum + self.assertEqual('FOO', values[0].name) + + def testCppDescriptorContainer_Iterator(self): + # Same test with the iterator + enum = unittest_pb2.TestAllTypes.DESCRIPTOR.enum_types_by_name['NestedEnum'] + values_iter = iter(enum.values) + del enum + self.assertEqual('FOO', next(values_iter).name) + class DescriptorCopyToProtoTest(basetest.TestCase): """Tests for CopyTo functions of Descriptor.""" @@ -588,10 +686,12 @@ class DescriptorCopyToProtoTest(basetest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - self._InternalTestCopyToProto( - unittest_pb2.TestService.DESCRIPTOR, - descriptor_pb2.ServiceDescriptorProto, - TEST_SERVICE_ASCII) + # TODO(rocking): enable this test after the proto descriptor change is + # checked in. + #self._InternalTestCopyToProto( + # unittest_pb2.TestService.DESCRIPTOR, + # descriptor_pb2.ServiceDescriptorProto, + # TEST_SERVICE_ASCII) class MakeDescriptorTest(basetest.TestCase): diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 42e2ad7e..7ab814cf 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -48,9 +48,12 @@ import math import operator import pickle import sys +import unittest from google.apputils import basetest +from google.apputils.pybase import parameterized from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import message @@ -69,88 +72,65 @@ def IsNegInf(val): return isinf(val) and (val < 0) +@parameterized.Parameters( + (unittest_pb2), + (unittest_proto3_arena_pb2)) class MessageTest(basetest.TestCase): - def testBadUtf8String(self): + def testBadUtf8String(self, message_module): if api_implementation.Type() != 'python': self.skipTest("Skipping testBadUtf8String, currently only the python " "api implementation raises UnicodeDecodeError when a " "string field contains bad utf-8.") bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') with self.assertRaises(UnicodeDecodeError) as context: - unittest_pb2.TestAllTypes.FromString(bad_utf8_data) - self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string', - str(context.exception)) - - def testGoldenMessage(self): - golden_data = test_util.GoldenFileData( - 'golden_message_oneof_implemented') - golden_message = unittest_pb2.TestAllTypes() - golden_message.ParseFromString(golden_data) - test_util.ExpectAllFieldsSet(self, golden_message) - self.assertEqual(golden_data, golden_message.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) + message_module.TestAllTypes.FromString(bad_utf8_data) + self.assertIn('TestAllTypes.optional_string', str(context.exception)) + + def testGoldenMessage(self, message_module): + # Proto3 doesn't have the "default_foo" members or foreign enums, + # and doesn't preserve unknown fields, so for proto3 we use a golden + # message that doesn't have these fields set. + if message_module is unittest_pb2: + golden_data = test_util.GoldenFileData( + 'golden_message_oneof_implemented') + else: + golden_data = test_util.GoldenFileData('golden_message_proto3') - def testGoldenExtensions(self): - golden_data = test_util.GoldenFileData('golden_message') - golden_message = unittest_pb2.TestAllExtensions() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(all_set) - self.assertEquals(all_set, golden_message) + if message_module is unittest_pb2: + test_util.ExpectAllFieldsSet(self, golden_message) self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) - def testGoldenPackedMessage(self): + def testGoldenPackedMessage(self, message_module): golden_data = test_util.GoldenFileData('golden_packed_fields_message') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestPackedTypes() + all_set = message_module.TestPackedTypes() test_util.SetAllPackedFields(all_set) - self.assertEquals(all_set, golden_message) + self.assertEqual(all_set, golden_message) self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) self.assertEqual(golden_data, golden_copy.SerializeToString()) - def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFileData('golden_packed_fields_message') - golden_message = unittest_pb2.TestPackedExtensions() - golden_message.ParseFromString(golden_data) - all_set = unittest_pb2.TestPackedExtensions() - test_util.SetAllPackedExtensions(all_set) - self.assertEquals(all_set, golden_message) - self.assertEqual(golden_data, all_set.SerializeToString()) - golden_copy = copy.deepcopy(golden_message) - self.assertEqual(golden_data, golden_copy.SerializeToString()) - - def testPickleSupport(self): + def testPickleSupport(self, message_module): golden_data = test_util.GoldenFileData('golden_message') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) pickled_message = pickle.dumps(golden_message) unpickled_message = pickle.loads(pickled_message) - self.assertEquals(unpickled_message, golden_message) - - - def testPickleIncompleteProto(self): - golden_message = unittest_pb2.TestRequired(a=1) - pickled_message = pickle.dumps(golden_message) - - unpickled_message = pickle.loads(pickled_message) - self.assertEquals(unpickled_message, golden_message) - self.assertEquals(unpickled_message.a, 1) - # This is still an incomplete proto - so serializing should fail - self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) + self.assertEqual(unpickled_message, golden_message) - def testPositiveInfinity(self): + def testPositiveInfinity(self, message_module): golden_data = (b'\x5D\x00\x00\x80\x7F' b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' b'\xCD\x02\x00\x00\x80\x7F' b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.optional_float)) self.assertTrue(IsPosInf(golden_message.optional_double)) @@ -158,12 +138,12 @@ class MessageTest(basetest.TestCase): self.assertTrue(IsPosInf(golden_message.repeated_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNegativeInfinity(self): + def testNegativeInfinity(self, message_module): golden_data = (b'\x5D\x00\x00\x80\xFF' b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' b'\xCD\x02\x00\x00\x80\xFF' b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.optional_float)) self.assertTrue(IsNegInf(golden_message.optional_double)) @@ -171,12 +151,12 @@ class MessageTest(basetest.TestCase): self.assertTrue(IsNegInf(golden_message.repeated_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNotANumber(self): + def testNotANumber(self, message_module): golden_data = (b'\x5D\x00\x00\xC0\x7F' b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' b'\xCD\x02\x00\x00\xC0\x7F' b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') - golden_message = unittest_pb2.TestAllTypes() + golden_message = message_module.TestAllTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.optional_float)) self.assertTrue(isnan(golden_message.optional_double)) @@ -188,47 +168,47 @@ class MessageTest(basetest.TestCase): # verify the serialized string can be converted into a correctly # behaving protocol buffer. serialized = golden_message.SerializeToString() - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.ParseFromString(serialized) self.assertTrue(isnan(message.optional_float)) self.assertTrue(isnan(message.optional_double)) self.assertTrue(isnan(message.repeated_float[0])) self.assertTrue(isnan(message.repeated_double[0])) - def testPositiveInfinityPacked(self): + def testPositiveInfinityPacked(self, message_module): golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) self.assertTrue(IsPosInf(golden_message.packed_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNegativeInfinityPacked(self): + def testNegativeInfinityPacked(self, message_module): golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) self.assertTrue(IsNegInf(golden_message.packed_double[0])) self.assertEqual(golden_data, golden_message.SerializeToString()) - def testNotANumberPacked(self): + def testNotANumberPacked(self, message_module): golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') - golden_message = unittest_pb2.TestPackedTypes() + golden_message = message_module.TestPackedTypes() golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) self.assertTrue(isnan(golden_message.packed_double[0])) serialized = golden_message.SerializeToString() - message = unittest_pb2.TestPackedTypes() + message = message_module.TestPackedTypes() message.ParseFromString(serialized) self.assertTrue(isnan(message.packed_float[0])) self.assertTrue(isnan(message.packed_double[0])) - def testExtremeFloatValues(self): - message = unittest_pb2.TestAllTypes() + def testExtremeFloatValues(self, message_module): + message = message_module.TestAllTypes() # Most positive exponent, no significand bits set. kMostPosExponentNoSigBits = math.pow(2, 127) @@ -272,8 +252,8 @@ class MessageTest(basetest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) - def testExtremeDoubleValues(self): - message = unittest_pb2.TestAllTypes() + def testExtremeDoubleValues(self, message_module): + message = message_module.TestAllTypes() # Most positive exponent, no significand bits set. kMostPosExponentNoSigBits = math.pow(2, 1023) @@ -317,43 +297,43 @@ class MessageTest(basetest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) - def testFloatPrinting(self): - message = unittest_pb2.TestAllTypes() + def testFloatPrinting(self, message_module): + message = message_module.TestAllTypes() message.optional_float = 2.0 self.assertEqual(str(message), 'optional_float: 2.0\n') - def testHighPrecisionFloatPrinting(self): - message = unittest_pb2.TestAllTypes() + def testHighPrecisionFloatPrinting(self, message_module): + message = message_module.TestAllTypes() message.optional_double = 0.12345678912345678 if sys.version_info.major >= 3: self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') else: self.assertEqual(str(message), 'optional_double: 0.123456789123\n') - def testUnknownFieldPrinting(self): - populated = unittest_pb2.TestAllTypes() + def testUnknownFieldPrinting(self, message_module): + populated = message_module.TestAllTypes() test_util.SetAllNonLazyFields(populated) - empty = unittest_pb2.TestEmptyMessage() + empty = message_module.TestEmptyMessage() empty.ParseFromString(populated.SerializeToString()) self.assertEqual(str(empty), '') - def testRepeatedNestedFieldIteration(self): - msg = unittest_pb2.TestAllTypes() + def testRepeatedNestedFieldIteration(self, message_module): + msg = message_module.TestAllTypes() msg.repeated_nested_message.add(bb=1) msg.repeated_nested_message.add(bb=2) msg.repeated_nested_message.add(bb=3) msg.repeated_nested_message.add(bb=4) - self.assertEquals([1, 2, 3, 4], - [m.bb for m in msg.repeated_nested_message]) - self.assertEquals([4, 3, 2, 1], - [m.bb for m in reversed(msg.repeated_nested_message)]) - self.assertEquals([4, 3, 2, 1], - [m.bb for m in msg.repeated_nested_message[::-1]]) + self.assertEqual([1, 2, 3, 4], + [m.bb for m in msg.repeated_nested_message]) + self.assertEqual([4, 3, 2, 1], + [m.bb for m in reversed(msg.repeated_nested_message)]) + self.assertEqual([4, 3, 2, 1], + [m.bb for m in msg.repeated_nested_message[::-1]]) - def testSortingRepeatedScalarFieldsDefaultComparator(self): + def testSortingRepeatedScalarFieldsDefaultComparator(self, message_module): """Check some different types with the default comparator.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() # TODO(mattp): would testing more scalar types strengthen test? message.repeated_int32.append(1) @@ -388,9 +368,9 @@ class MessageTest(basetest.TestCase): self.assertEqual(message.repeated_bytes[1], b'b') self.assertEqual(message.repeated_bytes[2], b'c') - def testSortingRepeatedScalarFieldsCustomComparator(self): + def testSortingRepeatedScalarFieldsCustomComparator(self, message_module): """Check some different types with custom comparator.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_int32.append(-3) message.repeated_int32.append(-2) @@ -408,9 +388,9 @@ class MessageTest(basetest.TestCase): self.assertEqual(message.repeated_string[1], 'bb') self.assertEqual(message.repeated_string[2], 'aaa') - def testSortingRepeatedCompositeFieldsCustomComparator(self): + def testSortingRepeatedCompositeFieldsCustomComparator(self, message_module): """Check passing a custom comparator to sort a repeated composite field.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_nested_message.add().bb = 1 message.repeated_nested_message.add().bb = 3 @@ -426,9 +406,9 @@ class MessageTest(basetest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) - def testRepeatedCompositeFieldSortArguments(self): + def testRepeatedCompositeFieldSortArguments(self, message_module): """Check sorting a repeated composite field using list.sort() arguments.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() get_bb = operator.attrgetter('bb') cmp_bb = lambda a, b: cmp(a.bb, b.bb) @@ -452,9 +432,9 @@ class MessageTest(basetest.TestCase): self.assertEqual([k.bb for k in message.repeated_nested_message], [6, 5, 4, 3, 2, 1]) - def testRepeatedScalarFieldSortArguments(self): + def testRepeatedScalarFieldSortArguments(self, message_module): """Check sorting a scalar field using list.sort() arguments.""" - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_int32.append(-3) message.repeated_int32.append(-2) @@ -484,9 +464,9 @@ class MessageTest(basetest.TestCase): message.repeated_string.sort(cmp=len_cmp, reverse=True) self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) - def testRepeatedFieldsComparable(self): - m1 = unittest_pb2.TestAllTypes() - m2 = unittest_pb2.TestAllTypes() + def testRepeatedFieldsComparable(self, message_module): + m1 = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() m1.repeated_int32.append(0) m1.repeated_int32.append(1) m1.repeated_int32.append(2) @@ -519,55 +499,6 @@ class MessageTest(basetest.TestCase): # TODO(anuraag): Implement extensiondict comparison in C++ and then add test - def testParsingMerge(self): - """Check the merge behavior when a required or optional field appears - multiple times in the input.""" - messages = [ - unittest_pb2.TestAllTypes(), - unittest_pb2.TestAllTypes(), - unittest_pb2.TestAllTypes() ] - messages[0].optional_int32 = 1 - messages[1].optional_int64 = 2 - messages[2].optional_int32 = 3 - messages[2].optional_string = 'hello' - - merged_message = unittest_pb2.TestAllTypes() - merged_message.optional_int32 = 3 - merged_message.optional_int64 = 2 - merged_message.optional_string = 'hello' - - generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() - generator.field1.extend(messages) - generator.field2.extend(messages) - generator.field3.extend(messages) - generator.ext1.extend(messages) - generator.ext2.extend(messages) - generator.group1.add().field1.MergeFrom(messages[0]) - generator.group1.add().field1.MergeFrom(messages[1]) - generator.group1.add().field1.MergeFrom(messages[2]) - generator.group2.add().field1.MergeFrom(messages[0]) - generator.group2.add().field1.MergeFrom(messages[1]) - generator.group2.add().field1.MergeFrom(messages[2]) - - data = generator.SerializeToString() - parsing_merge = unittest_pb2.TestParsingMerge() - parsing_merge.ParseFromString(data) - - # Required and optional fields should be merged. - self.assertEqual(parsing_merge.required_all_types, merged_message) - self.assertEqual(parsing_merge.optional_all_types, merged_message) - self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, - merged_message) - self.assertEqual(parsing_merge.Extensions[ - unittest_pb2.TestParsingMerge.optional_ext], - merged_message) - - # Repeated fields should not be merged. - self.assertEqual(len(parsing_merge.repeated_all_types), 3) - self.assertEqual(len(parsing_merge.repeatedgroup), 3) - self.assertEqual(len(parsing_merge.Extensions[ - unittest_pb2.TestParsingMerge.repeated_ext]), 3) - def ensureNestedMessageExists(self, msg, attribute): """Make sure that a nested message object exists. @@ -577,12 +508,28 @@ class MessageTest(basetest.TestCase): getattr(msg, attribute) self.assertFalse(msg.HasField(attribute)) - def testOneofGetCaseNonexistingField(self): - m = unittest_pb2.TestAllTypes() + def testOneofGetCaseNonexistingField(self, message_module): + m = message_module.TestAllTypes() self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') - def testOneofSemantics(self): - m = unittest_pb2.TestAllTypes() + def testOneofDefaultValues(self, message_module): + m = message_module.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + + # Oneof is set even when setting it to a default value. + m.oneof_uint32 = 0 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertFalse(m.HasField('oneof_string')) + + m.oneof_string = "" + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_string')) + self.assertFalse(m.HasField('oneof_uint32')) + + def testOneofSemantics(self, message_module): + m = message_module.TestAllTypes() self.assertIs(None, m.WhichOneof('oneof_field')) m.oneof_uint32 = 11 @@ -604,96 +551,569 @@ class MessageTest(basetest.TestCase): self.assertFalse(m.HasField('oneof_nested_message')) self.assertTrue(m.HasField('oneof_bytes')) - def testOneofCompositeFieldReadAccess(self): - m = unittest_pb2.TestAllTypes() + def testOneofCompositeFieldReadAccess(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 self.ensureNestedMessageExists(m, 'oneof_nested_message') self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) self.assertEqual(11, m.oneof_uint32) - def testOneofHasField(self): - m = unittest_pb2.TestAllTypes() - self.assertFalse(m.HasField('oneof_field')) + def testOneofWhichOneof(self, message_module): + m = message_module.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 - self.assertTrue(m.HasField('oneof_field')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + if message_module is unittest_pb2: + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' - self.assertTrue(m.HasField('oneof_field')) + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + m.ClearField('oneof_bytes') - self.assertFalse(m.HasField('oneof_field')) + self.assertIs(None, m.WhichOneof('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) - def testOneofClearField(self): - m = unittest_pb2.TestAllTypes() + def testOneofClearField(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 m.ClearField('oneof_field') - self.assertFalse(m.HasField('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) self.assertFalse(m.HasField('oneof_uint32')) self.assertIs(None, m.WhichOneof('oneof_field')) - def testOneofClearSetField(self): - m = unittest_pb2.TestAllTypes() + def testOneofClearSetField(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 m.ClearField('oneof_uint32') - self.assertFalse(m.HasField('oneof_field')) + if message_module is unittest_pb2: + self.assertFalse(m.HasField('oneof_field')) self.assertFalse(m.HasField('oneof_uint32')) self.assertIs(None, m.WhichOneof('oneof_field')) - def testOneofClearUnsetField(self): - m = unittest_pb2.TestAllTypes() + def testOneofClearUnsetField(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 self.ensureNestedMessageExists(m, 'oneof_nested_message') m.ClearField('oneof_nested_message') self.assertEqual(11, m.oneof_uint32) - self.assertTrue(m.HasField('oneof_field')) + if message_module is unittest_pb2: + self.assertTrue(m.HasField('oneof_field')) self.assertTrue(m.HasField('oneof_uint32')) self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) - def testOneofDeserialize(self): - m = unittest_pb2.TestAllTypes() + def testOneofDeserialize(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 - m2 = unittest_pb2.TestAllTypes() + m2 = message_module.TestAllTypes() m2.ParseFromString(m.SerializeToString()) self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - def testOneofCopyFrom(self): - m = unittest_pb2.TestAllTypes() + def testOneofCopyFrom(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 - m2 = unittest_pb2.TestAllTypes() + m2 = message_module.TestAllTypes() m2.CopyFrom(m) self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - def testOneofNestedMergeFrom(self): - m = unittest_pb2.NestedTestAllTypes() + def testOneofNestedMergeFrom(self, message_module): + m = message_module.NestedTestAllTypes() m.payload.oneof_uint32 = 11 - m2 = unittest_pb2.NestedTestAllTypes() + m2 = message_module.NestedTestAllTypes() m2.payload.oneof_bytes = b'bb' m2.child.payload.oneof_bytes = b'bb' m2.MergeFrom(m) self.assertEqual('oneof_uint32', m2.payload.WhichOneof('oneof_field')) self.assertEqual('oneof_bytes', m2.child.payload.WhichOneof('oneof_field')) - def testOneofClear(self): - m = unittest_pb2.TestAllTypes() + def testOneofMessageMergeFrom(self, message_module): + m = message_module.NestedTestAllTypes() + m.payload.oneof_nested_message.bb = 11 + m.child.payload.oneof_nested_message.bb = 12 + m2 = message_module.NestedTestAllTypes() + m2.payload.oneof_uint32 = 13 + m2.MergeFrom(m) + self.assertEqual('oneof_nested_message', + m2.payload.WhichOneof('oneof_field')) + self.assertEqual('oneof_nested_message', + m2.child.payload.WhichOneof('oneof_field')) + + def testOneofNestedMessageInit(self, message_module): + m = message_module.TestAllTypes( + oneof_nested_message=message_module.TestAllTypes.NestedMessage()) + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + + def testOneofClear(self, message_module): + m = message_module.TestAllTypes() m.oneof_uint32 = 11 m.Clear() self.assertIsNone(m.WhichOneof('oneof_field')) m.oneof_bytes = b'bb' - self.assertTrue(m.HasField('oneof_field')) + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + + def testAssignByteStringToUnicodeField(self, message_module): + """Assigning a byte string to a string field should result + in the value being converted to a Unicode string.""" + m = message_module.TestAllTypes() + m.optional_string = str('') + self.assertTrue(isinstance(m.optional_string, unicode)) +# TODO(haberman): why are these tests Google-internal only? - def testSortEmptyRepeatedCompositeContainer(self): + def testLongValuedSlice(self, message_module): + """It should be possible to use long-valued indicies in slices + + This didn't used to work in the v2 C++ implementation. + """ + m = message_module.TestAllTypes() + + # Repeated scalar + m.repeated_int32.append(1) + sl = m.repeated_int32[long(0):long(len(m.repeated_int32))] + self.assertEqual(len(m.repeated_int32), len(sl)) + + # Repeated composite + m.repeated_nested_message.add().bb = 3 + sl = m.repeated_nested_message[long(0):long(len(m.repeated_nested_message))] + self.assertEqual(len(m.repeated_nested_message), len(sl)) + + def testExtendShouldNotSwallowExceptions(self, message_module): + """This didn't use to work in the v2 C++ implementation.""" + m = message_module.TestAllTypes() + with self.assertRaises(NameError) as _: + m.repeated_int32.extend(a for i in range(10)) # pylint: disable=undefined-variable + with self.assertRaises(NameError) as _: + m.repeated_nested_enum.extend( + a for i in range(10)) # pylint: disable=undefined-variable + + FALSY_VALUES = [None, False, 0, 0.0, b'', u'', bytearray(), [], {}, set()] + + def testExtendInt32WithNothing(self, message_module): + """Test no-ops extending repeated int32 fields.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_int32) + + # TODO(ptucker): Deprecate this behavior. b/18413862 + for falsy_value in MessageTest.FALSY_VALUES: + m.repeated_int32.extend(falsy_value) + self.assertSequenceEqual([], m.repeated_int32) + + m.repeated_int32.extend([]) + self.assertSequenceEqual([], m.repeated_int32) + + def testExtendFloatWithNothing(self, message_module): + """Test no-ops extending repeated float fields.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_float) + + # TODO(ptucker): Deprecate this behavior. b/18413862 + for falsy_value in MessageTest.FALSY_VALUES: + m.repeated_float.extend(falsy_value) + self.assertSequenceEqual([], m.repeated_float) + + m.repeated_float.extend([]) + self.assertSequenceEqual([], m.repeated_float) + + def testExtendStringWithNothing(self, message_module): + """Test no-ops extending repeated string fields.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + + # TODO(ptucker): Deprecate this behavior. b/18413862 + for falsy_value in MessageTest.FALSY_VALUES: + m.repeated_string.extend(falsy_value) + self.assertSequenceEqual([], m.repeated_string) + + m.repeated_string.extend([]) + self.assertSequenceEqual([], m.repeated_string) + + def testExtendInt32WithPythonList(self, message_module): + """Test extending repeated int32 fields with python lists.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_int32) + m.repeated_int32.extend([0]) + self.assertSequenceEqual([0], m.repeated_int32) + m.repeated_int32.extend([1, 2]) + self.assertSequenceEqual([0, 1, 2], m.repeated_int32) + m.repeated_int32.extend([3, 4]) + self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) + + def testExtendFloatWithPythonList(self, message_module): + """Test extending repeated float fields with python lists.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_float) + m.repeated_float.extend([0.0]) + self.assertSequenceEqual([0.0], m.repeated_float) + m.repeated_float.extend([1.0, 2.0]) + self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) + m.repeated_float.extend([3.0, 4.0]) + self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) + + def testExtendStringWithPythonList(self, message_module): + """Test extending repeated string fields with python lists.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend(['']) + self.assertSequenceEqual([''], m.repeated_string) + m.repeated_string.extend(['11', '22']) + self.assertSequenceEqual(['', '11', '22'], m.repeated_string) + m.repeated_string.extend(['33', '44']) + self.assertSequenceEqual(['', '11', '22', '33', '44'], m.repeated_string) + + def testExtendStringWithString(self, message_module): + """Test extending repeated string fields with characters from a string.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend('abc') + self.assertSequenceEqual(['a', 'b', 'c'], m.repeated_string) + + class TestIterable(object): + """This iterable object mimics the behavior of numpy.array. + + __nonzero__ fails for length > 1, and returns bool(item[0]) for length == 1. + + """ + + def __init__(self, values=None): + self._list = values or [] + + def __nonzero__(self): + size = len(self._list) + if size == 0: + return False + if size == 1: + return bool(self._list[0]) + raise ValueError('Truth value is ambiguous.') + + def __len__(self): + return len(self._list) + + def __iter__(self): + return self._list.__iter__() + + def testExtendInt32WithIterable(self, message_module): + """Test extending repeated int32 fields with iterable.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([])) + self.assertSequenceEqual([], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([0])) + self.assertSequenceEqual([0], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([1, 2])) + self.assertSequenceEqual([0, 1, 2], m.repeated_int32) + m.repeated_int32.extend(MessageTest.TestIterable([3, 4])) + self.assertSequenceEqual([0, 1, 2, 3, 4], m.repeated_int32) + + def testExtendFloatWithIterable(self, message_module): + """Test extending repeated float fields with iterable.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([])) + self.assertSequenceEqual([], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([0.0])) + self.assertSequenceEqual([0.0], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([1.0, 2.0])) + self.assertSequenceEqual([0.0, 1.0, 2.0], m.repeated_float) + m.repeated_float.extend(MessageTest.TestIterable([3.0, 4.0])) + self.assertSequenceEqual([0.0, 1.0, 2.0, 3.0, 4.0], m.repeated_float) + + def testExtendStringWithIterable(self, message_module): + """Test extending repeated string fields with iterable.""" + m = message_module.TestAllTypes() + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable([])) + self.assertSequenceEqual([], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable([''])) + self.assertSequenceEqual([''], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable(['1', '2'])) + self.assertSequenceEqual(['', '1', '2'], m.repeated_string) + m.repeated_string.extend(MessageTest.TestIterable(['3', '4'])) + self.assertSequenceEqual(['', '1', '2', '3', '4'], m.repeated_string) + + def testPickleRepeatedScalarContainer(self, message_module): + # TODO(tibell): The pure-Python implementation support pickling of + # scalar containers in *some* cases. For now the cpp2 version + # throws an exception to avoid a segfault. Investigate if we + # want to support pickling of these fields. + # + # For more information see: https://b2.corp.google.com/u/0/issues/18677897 + if (api_implementation.Type() != 'cpp' or + api_implementation.Version() == 2): + return + m = message_module.TestAllTypes() + with self.assertRaises(pickle.PickleError) as _: + pickle.dumps(m.repeated_int32, pickle.HIGHEST_PROTOCOL) + + + def testSortEmptyRepeatedCompositeContainer(self, message_module): """Exercise a scenario that has led to segfaults in the past. """ - m = unittest_pb2.TestAllTypes() + m = message_module.TestAllTypes() m.repeated_nested_message.sort() - def testHasFieldOnRepeatedField(self): + def testHasFieldOnRepeatedField(self, message_module): """Using HasField on a repeated field should raise an exception. """ - m = unittest_pb2.TestAllTypes() + m = message_module.TestAllTypes() with self.assertRaises(ValueError) as _: m.HasField('repeated_int32') + def testRepeatedScalarFieldPop(self, message_module): + m = message_module.TestAllTypes() + with self.assertRaises(IndexError) as _: + m.repeated_int32.pop() + m.repeated_int32.extend(range(5)) + self.assertEqual(4, m.repeated_int32.pop()) + self.assertEqual(0, m.repeated_int32.pop(0)) + self.assertEqual(2, m.repeated_int32.pop(1)) + self.assertEqual([1, 3], m.repeated_int32) + + def testRepeatedCompositeFieldPop(self, message_module): + m = message_module.TestAllTypes() + with self.assertRaises(IndexError) as _: + m.repeated_nested_message.pop() + for i in range(5): + n = m.repeated_nested_message.add() + n.bb = i + self.assertEqual(4, m.repeated_nested_message.pop().bb) + self.assertEqual(0, m.repeated_nested_message.pop(0).bb) + self.assertEqual(2, m.repeated_nested_message.pop(1).bb) + self.assertEqual([1, 3], [n.bb for n in m.repeated_nested_message]) + + +# Class to test proto2-only features (required, extensions, etc.) +class Proto2Test(basetest.TestCase): + + def testFieldPresence(self): + message = unittest_pb2.TestAllTypes() + + self.assertFalse(message.HasField("optional_int32")) + self.assertFalse(message.HasField("optional_bool")) + self.assertFalse(message.HasField("optional_nested_message")) + + with self.assertRaises(ValueError): + message.HasField("field_doesnt_exist") + + with self.assertRaises(ValueError): + message.HasField("repeated_int32") + with self.assertRaises(ValueError): + message.HasField("repeated_nested_message") + + self.assertEqual(0, message.optional_int32) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + # Fields are set even when setting the values to default values. + message.optional_int32 = 0 + message.optional_bool = False + message.optional_nested_message.bb = 0 + self.assertTrue(message.HasField("optional_int32")) + self.assertTrue(message.HasField("optional_bool")) + self.assertTrue(message.HasField("optional_nested_message")) + + # Set the fields to non-default values. + message.optional_int32 = 5 + message.optional_bool = True + message.optional_nested_message.bb = 15 + + self.assertTrue(message.HasField("optional_int32")) + self.assertTrue(message.HasField("optional_bool")) + self.assertTrue(message.HasField("optional_nested_message")) + + # Clearing the fields unsets them and resets their value to default. + message.ClearField("optional_int32") + message.ClearField("optional_bool") + message.ClearField("optional_nested_message") + + self.assertFalse(message.HasField("optional_int32")) + self.assertFalse(message.HasField("optional_bool")) + self.assertFalse(message.HasField("optional_nested_message")) + self.assertEqual(0, message.optional_int32) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + # TODO(tibell): The C++ implementations actually allows assignment + # of unknown enum values to *scalar* fields (but not repeated + # fields). Once checked enum fields becomes the default in the + # Python implementation, the C++ implementation should follow suit. + def testAssignInvalidEnum(self): + """It should not be possible to assign an invalid enum number to an + enum field.""" + m = unittest_pb2.TestAllTypes() + + with self.assertRaises(ValueError) as _: + m.optional_nested_enum = 1234567 + self.assertRaises(ValueError, m.repeated_nested_enum.append, 1234567) + + def testGoldenExtensions(self): + golden_data = test_util.GoldenFileData('golden_message') + golden_message = unittest_pb2.TestAllExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(all_set) + self.assertEqual(all_set, golden_message) + self.assertEqual(golden_data, golden_message.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testGoldenPackedExtensions(self): + golden_data = test_util.GoldenFileData('golden_packed_fields_message') + golden_message = unittest_pb2.TestPackedExtensions() + golden_message.ParseFromString(golden_data) + all_set = unittest_pb2.TestPackedExtensions() + test_util.SetAllPackedExtensions(all_set) + self.assertEqual(all_set, golden_message) + self.assertEqual(golden_data, all_set.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEqual(unpickled_message, golden_message) + self.assertEqual(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) + + + # TODO(haberman): this isn't really a proto2-specific test except that this + # message has a required field in it. Should probably be factored out so + # that we can test the other parts with proto3. + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + +# Class to test proto3-only features/behavior (updated field presence & enums) +class Proto3Test(basetest.TestCase): + + def testFieldPresence(self): + message = unittest_proto3_arena_pb2.TestAllTypes() + + # We can't test presence of non-repeated, non-submessage fields. + with self.assertRaises(ValueError): + message.HasField("optional_int32") + with self.assertRaises(ValueError): + message.HasField("optional_float") + with self.assertRaises(ValueError): + message.HasField("optional_string") + with self.assertRaises(ValueError): + message.HasField("optional_bool") + + # But we can still test presence of submessage fields. + self.assertFalse(message.HasField("optional_nested_message")) + + # As with proto2, we can't test presence of fields that don't exist, or + # repeated fields. + with self.assertRaises(ValueError): + message.HasField("field_doesnt_exist") + + with self.assertRaises(ValueError): + message.HasField("repeated_int32") + with self.assertRaises(ValueError): + message.HasField("repeated_nested_message") + + # Fields should default to their type-specific default. + self.assertEqual(0, message.optional_int32) + self.assertEqual(0, message.optional_float) + self.assertEqual("", message.optional_string) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + # Setting a submessage should still return proper presence information. + message.optional_nested_message.bb = 0 + self.assertTrue(message.HasField("optional_nested_message")) + + # Set the fields to non-default values. + message.optional_int32 = 5 + message.optional_float = 1.1 + message.optional_string = "abc" + message.optional_bool = True + message.optional_nested_message.bb = 15 + + # Clearing the fields unsets them and resets their value to default. + message.ClearField("optional_int32") + message.ClearField("optional_float") + message.ClearField("optional_string") + message.ClearField("optional_bool") + message.ClearField("optional_nested_message") + + self.assertEqual(0, message.optional_int32) + self.assertEqual(0, message.optional_float) + self.assertEqual("", message.optional_string) + self.assertEqual(False, message.optional_bool) + self.assertEqual(0, message.optional_nested_message.bb) + + def testAssignUnknownEnum(self): + """Assigning an unknown enum value is allowed and preserves the value.""" + m = unittest_proto3_arena_pb2.TestAllTypes() + + m.optional_nested_enum = 1234567 + self.assertEqual(1234567, m.optional_nested_enum) + m.repeated_nested_enum.append(22334455) + self.assertEqual(22334455, m.repeated_nested_enum[0]) + # Assignment is a different code path than append for the C++ impl. + m.repeated_nested_enum[0] = 7654321 + self.assertEqual(7654321, m.repeated_nested_enum[0]) + serialized = m.SerializeToString() + + m2 = unittest_proto3_arena_pb2.TestAllTypes() + m2.ParseFromString(serialized) + self.assertEqual(1234567, m2.optional_nested_enum) + self.assertEqual(7654321, m2.repeated_nested_enum[0]) + class ValidTypeNamesTest(basetest.TestCase): diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 6fda6ae0..6ad0f90d 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -219,12 +219,20 @@ def _AttachFieldHelpers(cls, field_descriptor): def AddDecoder(wiretype, is_packed): tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) - cls._decoders_by_tag[tag_bytes] = ( - type_checkers.TYPE_TO_DECODER[field_descriptor.type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor), - field_descriptor if field_descriptor.containing_oneof is not None - else None) + decode_type = field_descriptor.type + if (decode_type == _FieldDescriptor.TYPE_ENUM and + type_checkers.SupportsOpenEnums(field_descriptor)): + decode_type = _FieldDescriptor.TYPE_INT32 + + oneof_descriptor = None + if field_descriptor.containing_oneof is not None: + oneof_descriptor = field_descriptor + + field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) + + cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], False) @@ -296,6 +304,8 @@ def _DefaultValueConstructorForField(field): def MakeSubMessageDefault(message): result = message_type._concrete_class() result._SetListener(message._listener_for_children) + if field.containing_oneof: + message._UpdateOneofState(field) return result return MakeSubMessageDefault @@ -476,6 +486,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): type_checker = type_checkers.GetTypeChecker(field) default_value = field.default_value valid_values = set() + is_proto3 = field.containing_type.syntax == "proto3" def getter(self): # TODO(protobuf-team): This may be broken since there may not be @@ -483,15 +494,24 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name + + clear_when_set_to_default = is_proto3 and not field.containing_oneof + def field_setter(self, new_value): # pylint: disable=protected-access - self._fields[field] = type_checker.CheckValue(new_value) + # Testing the value for truthiness captures all of the proto3 defaults + # (0, 0.0, enum 0, and False). + new_value = type_checker.CheckValue(new_value) + if clear_when_set_to_default and not new_value: + self._fields.pop(field, None) + else: + self._fields[field] = new_value # Check _cached_byte_size_dirty inline to improve performance, since scalar # setters are called frequently. if not self._cached_byte_size_dirty: self._Modified() - if field.containing_oneof is not None: + if field.containing_oneof: def setter(self, new_value): field_setter(self, new_value) self._UpdateOneofState(field) @@ -624,24 +644,35 @@ def _AddListFieldsMethod(message_descriptor, cls): cls.ListFields = ListFields +_Proto3HasError = 'Protocol message has no non-repeated submessage field "%s"' +_Proto2HasError = 'Protocol message has no non-repeated field "%s"' def _AddHasFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - singular_fields = {} + is_proto3 = (message_descriptor.syntax == "proto3") + error_msg = _Proto3HasError if is_proto3 else _Proto2HasError + + hassable_fields = {} for field in message_descriptor.fields: - if field.label != _FieldDescriptor.LABEL_REPEATED: - singular_fields[field.name] = field - # Fields inside oneofs are never repeated (enforced by the compiler). - for field in message_descriptor.oneofs: - singular_fields[field.name] = field + if field.label == _FieldDescriptor.LABEL_REPEATED: + continue + # For proto3, only submessages and fields inside a oneof have presence. + if (is_proto3 and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE and + not field.containing_oneof): + continue + hassable_fields[field.name] = field + + if not is_proto3: + # Fields inside oneofs are never repeated (enforced by the compiler). + for oneof in message_descriptor.oneofs: + hassable_fields[oneof.name] = oneof def HasField(self, field_name): try: - field = singular_fields[field_name] + field = hassable_fields[field_name] except KeyError: - raise ValueError( - 'Protocol message has no singular "%s" field.' % field_name) + raise ValueError(error_msg % field_name) if isinstance(field, descriptor_mod.OneofDescriptor): try: @@ -871,6 +902,7 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag + is_proto3 = message_descriptor.syntax == "proto3" def InternalParse(self, buffer, pos, end): self._Modified() @@ -884,9 +916,11 @@ def _AddMergeFromStringMethod(message_descriptor, cls): new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos - if not unknown_field_list: - unknown_field_list = self._unknown_fields = [] - unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) + if not is_proto3: + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append( + (tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -1008,6 +1042,8 @@ def _AddMergeFromMethod(cls): # Construct a new object to represent this field. field_value = field._default_constructor(self) fields[field] = field_value + if field.containing_oneof: + self._UpdateOneofState(field) field_value.MergeFrom(value) else: self._fields[field] = value @@ -1252,11 +1288,10 @@ class _ExtensionDict(object): # It's slightly wasteful to lookup the type checker each time, # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker( - extension_handle) + type_checker = type_checkers.GetTypeChecker(extension_handle) # pylint: disable=protected-access self._extended_message._fields[extension_handle] = ( - type_checker.CheckValue(value)) + type_checker.CheckValue(value)) self._extended_message._Modified() def _FindExtensionByName(self, name): diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 6b24b092..8ac76c63 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -1792,6 +1792,27 @@ class ReflectionTest(basetest.TestCase): # Just check the default value. self.assertEqual(57, msg.inner.value) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'CPPv2-specific test') + def testBadArguments(self): + # Some of these assertions used to segfault. + from google.protobuf.pyext import _message + self.assertRaises(TypeError, _message.Message._GetFieldDescriptor, 3) + self.assertRaises(TypeError, _message.Message._GetExtensionDescriptor, 42) + self.assertRaises(TypeError, + unittest_pb2.TestAllTypes().__getattribute__, 42) + + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'CPPv2-specific test') + def testRosyHack(self): + from google.protobuf.pyext import _message + from google3.gdata.rosy.proto import core_api2_pb2 + from google3.gdata.rosy.proto import core_pb2 + self.assertEqual(_message.Message, core_pb2.PageSelection.__base__) + self.assertEqual(_message.Message, core_api2_pb2.PageSelection.__base__) + # Since we had so many tests for protocol buffer equality, we broke these out # into separate TestCase classes. diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 787f4650..d84e3836 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -40,13 +40,19 @@ import os.path from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +# Tests whether the given TestAllTypes message is proto2 or not. +# This is used to gate several fields/features that only exist +# for the proto2 version of the message. +def IsProto2(message): + return message.DESCRIPTOR.syntax == "proto2" def SetAllNonLazyFields(message): """Sets every non-lazy field in the message to a unique value. Args: - message: A unittest_pb2.TestAllTypes instance. + message: A TestAllTypes instance. """ # @@ -77,7 +83,8 @@ def SetAllNonLazyFields(message): message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ - message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ + if IsProto2(message): + message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ message.optional_string_piece = u'124' message.optional_cord = u'125' @@ -110,7 +117,8 @@ def SetAllNonLazyFields(message): message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) - message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) + if IsProto2(message): + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) message.repeated_string_piece.append(u'224') message.repeated_cord.append(u'225') @@ -140,7 +148,8 @@ def SetAllNonLazyFields(message): message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) - message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) + if IsProto2(message): + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) message.repeated_string_piece.append(u'324') message.repeated_cord.append(u'325') @@ -149,28 +158,29 @@ def SetAllNonLazyFields(message): # Fields that have defaults. # - message.default_int32 = 401 - message.default_int64 = 402 - message.default_uint32 = 403 - message.default_uint64 = 404 - message.default_sint32 = 405 - message.default_sint64 = 406 - message.default_fixed32 = 407 - message.default_fixed64 = 408 - message.default_sfixed32 = 409 - message.default_sfixed64 = 410 - message.default_float = 411 - message.default_double = 412 - message.default_bool = False - message.default_string = '415' - message.default_bytes = b'416' - - message.default_nested_enum = unittest_pb2.TestAllTypes.FOO - message.default_foreign_enum = unittest_pb2.FOREIGN_FOO - message.default_import_enum = unittest_import_pb2.IMPORT_FOO - - message.default_string_piece = '424' - message.default_cord = '425' + if IsProto2(message): + message.default_int32 = 401 + message.default_int64 = 402 + message.default_uint32 = 403 + message.default_uint64 = 404 + message.default_sint32 = 405 + message.default_sint64 = 406 + message.default_fixed32 = 407 + message.default_fixed64 = 408 + message.default_sfixed32 = 409 + message.default_sfixed64 = 410 + message.default_float = 411 + message.default_double = 412 + message.default_bool = False + message.default_string = '415' + message.default_bytes = b'416' + + message.default_nested_enum = unittest_pb2.TestAllTypes.FOO + message.default_foreign_enum = unittest_pb2.FOREIGN_FOO + message.default_import_enum = unittest_import_pb2.IMPORT_FOO + + message.default_string_piece = '424' + message.default_cord = '425' message.oneof_uint32 = 601 message.oneof_nested_message.bb = 602 @@ -398,7 +408,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertTrue(message.HasField('optional_nested_enum')) test_case.assertTrue(message.HasField('optional_foreign_enum')) - test_case.assertTrue(message.HasField('optional_import_enum')) + if IsProto2(message): + test_case.assertTrue(message.HasField('optional_import_enum')) test_case.assertTrue(message.HasField('optional_string_piece')) test_case.assertTrue(message.HasField('optional_cord')) @@ -430,8 +441,9 @@ def ExpectAllFieldsSet(test_case, message): message.optional_nested_enum) test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, message.optional_foreign_enum) - test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, - message.optional_import_enum) + if IsProto2(message): + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.optional_import_enum) # ----------------------------------------------------------------- @@ -457,7 +469,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(2, len(message.repeated_import_message)) test_case.assertEqual(2, len(message.repeated_nested_enum)) test_case.assertEqual(2, len(message.repeated_foreign_enum)) - test_case.assertEqual(2, len(message.repeated_import_enum)) + if IsProto2(message): + test_case.assertEqual(2, len(message.repeated_import_enum)) test_case.assertEqual(2, len(message.repeated_string_piece)) test_case.assertEqual(2, len(message.repeated_cord)) @@ -488,8 +501,9 @@ def ExpectAllFieldsSet(test_case, message): message.repeated_nested_enum[0]) test_case.assertEqual(unittest_pb2.FOREIGN_BAR, message.repeated_foreign_enum[0]) - test_case.assertEqual(unittest_import_pb2.IMPORT_BAR, - message.repeated_import_enum[0]) + if IsProto2(message): + test_case.assertEqual(unittest_import_pb2.IMPORT_BAR, + message.repeated_import_enum[0]) test_case.assertEqual(301, message.repeated_int32[1]) test_case.assertEqual(302, message.repeated_int64[1]) @@ -517,53 +531,55 @@ def ExpectAllFieldsSet(test_case, message): message.repeated_nested_enum[1]) test_case.assertEqual(unittest_pb2.FOREIGN_BAZ, message.repeated_foreign_enum[1]) - test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, - message.repeated_import_enum[1]) + if IsProto2(message): + test_case.assertEqual(unittest_import_pb2.IMPORT_BAZ, + message.repeated_import_enum[1]) # ----------------------------------------------------------------- - test_case.assertTrue(message.HasField('default_int32')) - test_case.assertTrue(message.HasField('default_int64')) - test_case.assertTrue(message.HasField('default_uint32')) - test_case.assertTrue(message.HasField('default_uint64')) - test_case.assertTrue(message.HasField('default_sint32')) - test_case.assertTrue(message.HasField('default_sint64')) - test_case.assertTrue(message.HasField('default_fixed32')) - test_case.assertTrue(message.HasField('default_fixed64')) - test_case.assertTrue(message.HasField('default_sfixed32')) - test_case.assertTrue(message.HasField('default_sfixed64')) - test_case.assertTrue(message.HasField('default_float')) - test_case.assertTrue(message.HasField('default_double')) - test_case.assertTrue(message.HasField('default_bool')) - test_case.assertTrue(message.HasField('default_string')) - test_case.assertTrue(message.HasField('default_bytes')) - - test_case.assertTrue(message.HasField('default_nested_enum')) - test_case.assertTrue(message.HasField('default_foreign_enum')) - test_case.assertTrue(message.HasField('default_import_enum')) - - test_case.assertEqual(401, message.default_int32) - test_case.assertEqual(402, message.default_int64) - test_case.assertEqual(403, message.default_uint32) - test_case.assertEqual(404, message.default_uint64) - test_case.assertEqual(405, message.default_sint32) - test_case.assertEqual(406, message.default_sint64) - test_case.assertEqual(407, message.default_fixed32) - test_case.assertEqual(408, message.default_fixed64) - test_case.assertEqual(409, message.default_sfixed32) - test_case.assertEqual(410, message.default_sfixed64) - test_case.assertEqual(411, message.default_float) - test_case.assertEqual(412, message.default_double) - test_case.assertEqual(False, message.default_bool) - test_case.assertEqual('415', message.default_string) - test_case.assertEqual(b'416', message.default_bytes) - - test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, - message.default_nested_enum) - test_case.assertEqual(unittest_pb2.FOREIGN_FOO, - message.default_foreign_enum) - test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, - message.default_import_enum) + if IsProto2(message): + test_case.assertTrue(message.HasField('default_int32')) + test_case.assertTrue(message.HasField('default_int64')) + test_case.assertTrue(message.HasField('default_uint32')) + test_case.assertTrue(message.HasField('default_uint64')) + test_case.assertTrue(message.HasField('default_sint32')) + test_case.assertTrue(message.HasField('default_sint64')) + test_case.assertTrue(message.HasField('default_fixed32')) + test_case.assertTrue(message.HasField('default_fixed64')) + test_case.assertTrue(message.HasField('default_sfixed32')) + test_case.assertTrue(message.HasField('default_sfixed64')) + test_case.assertTrue(message.HasField('default_float')) + test_case.assertTrue(message.HasField('default_double')) + test_case.assertTrue(message.HasField('default_bool')) + test_case.assertTrue(message.HasField('default_string')) + test_case.assertTrue(message.HasField('default_bytes')) + + test_case.assertTrue(message.HasField('default_nested_enum')) + test_case.assertTrue(message.HasField('default_foreign_enum')) + test_case.assertTrue(message.HasField('default_import_enum')) + + test_case.assertEqual(401, message.default_int32) + test_case.assertEqual(402, message.default_int64) + test_case.assertEqual(403, message.default_uint32) + test_case.assertEqual(404, message.default_uint64) + test_case.assertEqual(405, message.default_sint32) + test_case.assertEqual(406, message.default_sint64) + test_case.assertEqual(407, message.default_fixed32) + test_case.assertEqual(408, message.default_fixed64) + test_case.assertEqual(409, message.default_sfixed32) + test_case.assertEqual(410, message.default_sfixed64) + test_case.assertEqual(411, message.default_float) + test_case.assertEqual(412, message.default_double) + test_case.assertEqual(False, message.default_bool) + test_case.assertEqual('415', message.default_string) + test_case.assertEqual(b'416', message.default_bytes) + + test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, + message.default_nested_enum) + test_case.assertEqual(unittest_pb2.FOREIGN_FOO, + message.default_foreign_enum) + test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, + message.default_import_enum) def GoldenFile(filename): @@ -594,7 +610,7 @@ def SetAllPackedFields(message): """Sets every field in the message to a unique value. Args: - message: A unittest_pb2.TestPackedTypes instance. + message: A TestPackedTypes instance. """ message.packed_int32.extend([601, 701]) message.packed_int64.extend([602, 702]) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 55e3c2c8..68ab9659 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -37,13 +37,17 @@ __author__ = 'kenton@google.com (Kenton Varda)' import re from google.apputils import basetest -from google.protobuf import text_format +from google.apputils.pybase import parameterized + +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util -from google.protobuf import unittest_pb2 -from google.protobuf import unittest_mset_pb2 +from google.protobuf import text_format -class TextFormatTest(basetest.TestCase): +# Base class with some common functionality. +class TextFormatBase(basetest.TestCase): def ReadGolden(self, golden_filename): with test_util.GoldenFile(golden_filename) as f: @@ -57,73 +61,24 @@ class TextFormatTest(basetest.TestCase): def CompareToGoldenText(self, text, golden_text): self.assertMultiLineEqual(text, golden_text) - def testPrintAllFields(self): - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_data_oneof_implemented.txt') - - def testPrintInIndexOrder(self): - message = unittest_pb2.TestFieldOrderings() - message.my_string = '115' - message.my_int = 101 - message.my_float = 111 - message.optional_nested_message.oo = 0 - message.optional_nested_message.bb = 1 - self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString( - message, use_index_order=True)), - 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n' - 'optional_nested_message {\n oo: 0\n bb: 1\n}\n') - self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString( - message)), - 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n' - 'optional_nested_message {\n bb: 1\n oo: 0\n}\n') - - def testPrintAllExtensions(self): - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_extensions_data.txt') - - def testPrintAllFieldsPointy(self): - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros( - text_format.MessageToString(message, pointy_brackets=True)), - 'text_format_unittest_data_pointy_oneof.txt') + def RemoveRedundantZeros(self, text): + # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove + # these zeros in order to match the golden file. + text = text.replace('e+0','e+').replace('e+0','e+') \ + .replace('e-0','e-').replace('e-0','e-') + # Floating point fields are printed with .0 suffix even if they are + # actualy integer numbers. + text = re.compile('\.0$', re.MULTILINE).sub('', text) + return text - def testPrintAllExtensionsPointy(self): - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString( - message, pointy_brackets=True)), - 'text_format_unittest_extensions_data_pointy.txt') - def testPrintMessageSet(self): - message = unittest_mset_pb2.TestMessageSetContainer() - ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension - ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension - message.message_set.Extensions[ext1].i = 23 - message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText( - text_format.MessageToString(message), - 'message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') +@parameterized.Parameters( + (unittest_pb2), + (unittest_proto3_arena_pb2)) +class TextFormatTest(TextFormatBase): - def testPrintExotic(self): - message = unittest_pb2.TestAllTypes() + def testPrintExotic(self, message_module): + message = message_module.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) @@ -142,61 +97,44 @@ class TextFormatTest(basetest.TestCase): ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' 'repeated_string: "\\303\\274\\352\\234\\237"\n') - def testPrintExoticUnicodeSubclass(self): + def testPrintExoticUnicodeSubclass(self, message_module): class UnicodeSub(unicode): pass - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) self.CompareToGoldenText( text_format.MessageToString(message), 'repeated_string: "\\303\\274\\352\\234\\237"\n') - def testPrintNestedMessageAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testPrintNestedMessageAsOneLine(self, message_module): + message = message_module.TestAllTypes() msg = message.repeated_nested_message.add() msg.bb = 42 self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_nested_message { bb: 42 }') - def testPrintRepeatedFieldsAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testPrintRepeatedFieldsAsOneLine(self, message_module): + message = message_module.TestAllTypes() message.repeated_int32.append(1) message.repeated_int32.append(1) message.repeated_int32.append(3) - message.repeated_string.append("Google") - message.repeated_string.append("Zurich") + message.repeated_string.append('Google') + message.repeated_string.append('Zurich') self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' 'repeated_string: "Google" repeated_string: "Zurich"') - def testPrintNestedNewLineInStringAsOneLine(self): - message = unittest_pb2.TestAllTypes() - message.optional_string = "a\nnew\nline" + def testPrintNestedNewLineInStringAsOneLine(self, message_module): + message = message_module.TestAllTypes() + message.optional_string = 'a\nnew\nline' self.CompareToGoldenText( text_format.MessageToString(message, as_one_line=True), 'optional_string: "a\\nnew\\nline"') - def testPrintMessageSetAsOneLine(self): - message = unittest_mset_pb2.TestMessageSetContainer() - ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension - ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension - message.message_set.Extensions[ext1].i = 23 - message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText( - text_format.MessageToString(message, as_one_line=True), - 'message_set {' - ' [protobuf_unittest.TestMessageSetExtension1] {' - ' i: 23' - ' }' - ' [protobuf_unittest.TestMessageSetExtension2] {' - ' str: \"foo\"' - ' }' - ' }') - - def testPrintExoticAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testPrintExoticAsOneLine(self, message_module): + message = message_module.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) @@ -216,8 +154,8 @@ class TextFormatTest(basetest.TestCase): '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' ' repeated_string: "\\303\\274\\352\\234\\237"') - def testRoundTripExoticAsOneLine(self): - message = unittest_pb2.TestAllTypes() + def testRoundTripExoticAsOneLine(self, message_module): + message = message_module.TestAllTypes() message.repeated_int64.append(-9223372036854775808) message.repeated_uint64.append(18446744073709551615) message.repeated_double.append(123.456) @@ -229,7 +167,7 @@ class TextFormatTest(basetest.TestCase): # Test as_utf8 = False. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=False) - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() r = text_format.Parse(wire_text, parsed_message) self.assertIs(r, parsed_message) self.assertEquals(message, parsed_message) @@ -237,25 +175,25 @@ class TextFormatTest(basetest.TestCase): # Test as_utf8 = True. wire_text = text_format.MessageToString( message, as_one_line=True, as_utf8=True) - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() r = text_format.Parse(wire_text, parsed_message) self.assertIs(r, parsed_message) self.assertEquals(message, parsed_message, '\n%s != %s' % (message, parsed_message)) - def testPrintRawUtf8String(self): - message = unittest_pb2.TestAllTypes() + def testPrintRawUtf8String(self, message_module): + message = message_module.TestAllTypes() message.repeated_string.append(u'\u00fc\ua71f') text = text_format.MessageToString(message, as_utf8=True) self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') - parsed_message = unittest_pb2.TestAllTypes() + parsed_message = message_module.TestAllTypes() text_format.Parse(text, parsed_message) self.assertEquals(message, parsed_message, '\n%s != %s' % (message, parsed_message)) - def testPrintFloatFormat(self): + def testPrintFloatFormat(self, message_module): # Check that float_format argument is passed to sub-message formatting. - message = unittest_pb2.NestedTestAllTypes() + message = message_module.NestedTestAllTypes() # We use 1.25 as it is a round number in binary. The proto 32-bit float # will not gain additional imprecise digits as a 64-bit Python float and # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit: @@ -285,85 +223,24 @@ class TextFormatTest(basetest.TestCase): self.RemoveRedundantZeros(text_message), 'payload {{ {} {} {} {} }}'.format(*formatted_fields)) - def testMessageToString(self): - message = unittest_pb2.ForeignMessage() + def testMessageToString(self, message_module): + message = message_module.ForeignMessage() message.c = 123 self.assertEqual('c: 123\n', str(message)) - def RemoveRedundantZeros(self, text): - # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove - # these zeros in order to match the golden file. - text = text.replace('e+0','e+').replace('e+0','e+') \ - .replace('e-0','e-').replace('e-0','e-') - # Floating point fields are printed with .0 suffix even if they are - # actualy integer numbers. - text = re.compile('\.0$', re.MULTILINE).sub('', text) - return text - - def testParseGolden(self): - golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) - parsed_message = unittest_pb2.TestAllTypes() - r = text_format.Parse(golden_text, parsed_message) - self.assertIs(r, parsed_message) - - message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.assertEquals(message, parsed_message) - - def testParseGoldenExtensions(self): - golden_text = '\n'.join(self.ReadGolden( - 'text_format_unittest_extensions_data.txt')) - parsed_message = unittest_pb2.TestAllExtensions() - text_format.Parse(golden_text, parsed_message) - - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - self.assertEquals(message, parsed_message) - - def testParseAllFields(self): - message = unittest_pb2.TestAllTypes() + def testParseAllFields(self, message_module): + message = message_module.TestAllTypes() test_util.SetAllFields(message) ascii_text = text_format.MessageToString(message) - parsed_message = unittest_pb2.TestAllTypes() - text_format.Parse(ascii_text, parsed_message) - self.assertEqual(message, parsed_message) - test_util.ExpectAllFieldsSet(self, message) - - def testParseAllExtensions(self): - message = unittest_pb2.TestAllExtensions() - test_util.SetAllExtensions(message) - ascii_text = text_format.MessageToString(message) - - parsed_message = unittest_pb2.TestAllExtensions() + parsed_message = message_module.TestAllTypes() text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) + if message_module is unittest_pb2: + test_util.ExpectAllFieldsSet(self, message) - def testParseMessageSet(self): - message = unittest_pb2.TestAllTypes() - text = ('repeated_uint64: 1\n' - 'repeated_uint64: 2\n') - text_format.Parse(text, message) - self.assertEqual(1, message.repeated_uint64[0]) - self.assertEqual(2, message.repeated_uint64[1]) - - message = unittest_mset_pb2.TestMessageSetContainer() - text = ('message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') - text_format.Parse(text, message) - ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension - ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension - self.assertEquals(23, message.message_set.Extensions[ext1].i) - self.assertEquals('foo', message.message_set.Extensions[ext2].str) - - def testParseExotic(self): - message = unittest_pb2.TestAllTypes() + def testParseExotic(self, message_module): + message = message_module.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' 'repeated_double: 123.456\n' @@ -388,8 +265,8 @@ class TextFormatTest(basetest.TestCase): self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) self.assertEqual(u'\u00fc', message.repeated_string[3]) - def testParseTrailingCommas(self): - message = unittest_pb2.TestAllTypes() + def testParseTrailingCommas(self, message_module): + message = message_module.TestAllTypes() text = ('repeated_int64: 100;\n' 'repeated_int64: 200;\n' 'repeated_int64: 300,\n' @@ -403,51 +280,37 @@ class TextFormatTest(basetest.TestCase): self.assertEqual(u'one', message.repeated_string[0]) self.assertEqual(u'two', message.repeated_string[1]) - def testParseEmptyText(self): - message = unittest_pb2.TestAllTypes() + def testParseEmptyText(self, message_module): + message = message_module.TestAllTypes() text = '' text_format.Parse(text, message) - self.assertEquals(unittest_pb2.TestAllTypes(), message) + self.assertEquals(message_module.TestAllTypes(), message) - def testParseInvalidUtf8(self): - message = unittest_pb2.TestAllTypes() + def testParseInvalidUtf8(self, message_module): + message = message_module.TestAllTypes() text = 'repeated_string: "\\xc3\\xc3"' self.assertRaises(text_format.ParseError, text_format.Parse, text, message) - def testParseSingleWord(self): - message = unittest_pb2.TestAllTypes() + def testParseSingleWord(self, message_module): + message = message_module.TestAllTypes() text = 'foo' - self.assertRaisesWithLiteralMatch( + self.assertRaisesRegexp( text_format.ParseError, - ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' - '"foo".'), + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"foo".'), text_format.Parse, text, message) - def testParseUnknownField(self): - message = unittest_pb2.TestAllTypes() + def testParseUnknownField(self, message_module): + message = message_module.TestAllTypes() text = 'unknown_field: 8\n' - self.assertRaisesWithLiteralMatch( + self.assertRaisesRegexp( text_format.ParseError, - ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' - '"unknown_field".'), + (r'1:1 : Message type "\w+.TestAllTypes" has no field named ' + r'"unknown_field".'), text_format.Parse, text, message) - def testParseBadExtension(self): - message = unittest_pb2.TestAllExtensions() - text = '[unknown_extension]: 8\n' - self.assertRaisesWithLiteralMatch( - text_format.ParseError, - '1:2 : Extension "unknown_extension" not registered.', - text_format.Parse, text, message) - message = unittest_pb2.TestAllTypes() - self.assertRaisesWithLiteralMatch( - text_format.ParseError, - ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' - 'extensions.'), - text_format.Parse, text, message) - - def testParseGroupNotClosed(self): - message = unittest_pb2.TestAllTypes() + def testParseGroupNotClosed(self, message_module): + message = message_module.TestAllTypes() text = 'RepeatedGroup: <' self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected ">".', @@ -458,46 +321,46 @@ class TextFormatTest(basetest.TestCase): text_format.ParseError, '1:16 : Expected "}".', text_format.Parse, text, message) - def testParseEmptyGroup(self): - message = unittest_pb2.TestAllTypes() + def testParseEmptyGroup(self, message_module): + message = message_module.TestAllTypes() text = 'OptionalGroup: {}' text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) message.Clear() - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() text = 'OptionalGroup: <>' text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) - def testParseBadEnumValue(self): - message = unittest_pb2.TestAllTypes() + def testParseBadEnumValue(self, message_module): + message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' - self.assertRaisesWithLiteralMatch( + self.assertRaisesRegexp( text_format.ParseError, - ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' - 'has no value named BARR.'), + (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + r'has no value named BARR.'), text_format.Parse, text, message) - message = unittest_pb2.TestAllTypes() + message = message_module.TestAllTypes() text = 'optional_nested_enum: 100' - self.assertRaisesWithLiteralMatch( + self.assertRaisesRegexp( text_format.ParseError, - ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' - 'has no value with number 100.'), + (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + r'has no value with number 100.'), text_format.Parse, text, message) - def testParseBadIntValue(self): - message = unittest_pb2.TestAllTypes() + def testParseBadIntValue(self, message_module): + message = message_module.TestAllTypes() text = 'optional_int32: bork' self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:17 : Couldn\'t parse integer: bork'), text_format.Parse, text, message) - def testParseStringFieldUnescape(self): - message = unittest_pb2.TestAllTypes() + def testParseStringFieldUnescape(self, message_module): + message = message_module.TestAllTypes() text = r'''repeated_string: "\xf\x62" repeated_string: "\\xf\\x62" repeated_string: "\\\xf\\\x62" @@ -516,40 +379,205 @@ class TextFormatTest(basetest.TestCase): message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5]) - def testMergeDuplicateScalars(self): - message = unittest_pb2.TestAllTypes() + def testMergeDuplicateScalars(self, message_module): + message = message_module.TestAllTypes() text = ('optional_int32: 42 ' 'optional_int32: 67') r = text_format.Merge(text, message) self.assertIs(r, message) self.assertEqual(67, message.optional_int32) - def testParseDuplicateScalars(self): - message = unittest_pb2.TestAllTypes() - text = ('optional_int32: 42 ' - 'optional_int32: 67') - self.assertRaisesWithLiteralMatch( - text_format.ParseError, - ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' - 'have multiple "optional_int32" fields.'), - text_format.Parse, text, message) - - def testMergeDuplicateNestedMessageScalars(self): - message = unittest_pb2.TestAllTypes() + def testMergeDuplicateNestedMessageScalars(self, message_module): + message = message_module.TestAllTypes() text = ('optional_nested_message { bb: 1 } ' 'optional_nested_message { bb: 2 }') r = text_format.Merge(text, message) self.assertTrue(r is message) self.assertEqual(2, message.optional_nested_message.bb) - def testParseDuplicateNestedMessageScalars(self): + def testParseOneof(self, message_module): + m = message_module.TestAllTypes() + m.oneof_uint32 = 11 + m2 = message_module.TestAllTypes() + text_format.Parse(text_format.MessageToString(m), m2) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + +# These are tests that aren't fundamentally specific to proto2, but are at +# the moment because of differences between the proto2 and proto3 test schemas. +# Ideally the schemas would be made more similar so these tests could pass. +class OnlyWorksWithProto2RightNowTests(TextFormatBase): + + def testParseGolden(self): + golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.Parse(golden_text, parsed_message) + self.assertIs(r, parsed_message) + message = unittest_pb2.TestAllTypes() - text = ('optional_nested_message { bb: 1 } ' - 'optional_nested_message { bb: 2 }') + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testPrintAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data_oneof_implemented.txt') + + def testPrintAllFieldsPointy(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + + def testPrintInIndexOrder(self): + message = unittest_pb2.TestFieldOrderings() + message.my_string = '115' + message.my_int = 101 + message.my_float = 111 + message.optional_nested_message.oo = 0 + message.optional_nested_message.bb = 1 + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message, use_index_order=True)), + 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n' + 'optional_nested_message {\n oo: 0\n bb: 1\n}\n') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message)), + 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n' + 'optional_nested_message {\n bb: 1\n oo: 0\n}\n') + + def testMergeLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.MergeLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) + + def testParseLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.ParseLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + +# Tests of proto2-only features (MessageSet and extensions). +class Proto2Tests(TextFormatBase): + + def testPrintMessageSet(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText( + text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + + def testPrintMessageSetAsOneLine(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'message_set {' + ' [protobuf_unittest.TestMessageSetExtension1] {' + ' i: 23' + ' }' + ' [protobuf_unittest.TestMessageSetExtension2] {' + ' str: \"foo\"' + ' }' + ' }') + + def testParseMessageSet(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_uint64: 1\n' + 'repeated_uint64: 2\n') + text_format.Parse(text, message) + self.assertEqual(1, message.repeated_uint64[0]) + self.assertEqual(2, message.repeated_uint64[1]) + + message = unittest_mset_pb2.TestMessageSetContainer() + text = ('message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + text_format.Parse(text, message) + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + self.assertEquals(23, message.message_set.Extensions[ext1].i) + self.assertEquals('foo', message.message_set.Extensions[ext2].str) + + def testPrintAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintAllExtensionsPointy(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), + 'text_format_unittest_extensions_data_pointy.txt') + + def testParseGoldenExtensions(self): + golden_text = '\n'.join(self.ReadGolden( + 'text_format_unittest_extensions_data.txt')) + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Parse(golden_text, parsed_message) + + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.assertEquals(message, parsed_message) + + def testParseAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + ascii_text = text_format.MessageToString(message) + + parsed_message = unittest_pb2.TestAllExtensions() + text_format.Parse(ascii_text, parsed_message) + self.assertEqual(message, parsed_message) + + def testParseBadExtension(self): + message = unittest_pb2.TestAllExtensions() + text = '[unknown_extension]: 8\n' self.assertRaisesWithLiteralMatch( text_format.ParseError, - ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' - 'should not have multiple "bb" fields.'), + '1:2 : Extension "unknown_extension" not registered.', + text_format.Parse, text, message) + message = unittest_pb2.TestAllTypes() + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' + 'extensions.'), text_format.Parse, text, message) def testMergeDuplicateExtensionScalars(self): @@ -572,32 +600,25 @@ class TextFormatTest(basetest.TestCase): '"protobuf_unittest.optional_int32_extension" extensions.'), text_format.Parse, text, message) - def testParseLinesGolden(self): - opened = self.ReadGolden('text_format_unittest_data.txt') - parsed_message = unittest_pb2.TestAllTypes() - r = text_format.ParseLines(opened, parsed_message) - self.assertIs(r, parsed_message) - + def testParseDuplicateNestedMessageScalars(self): message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.assertEquals(message, parsed_message) - - def testMergeLinesGolden(self): - opened = self.ReadGolden('text_format_unittest_data.txt') - parsed_message = unittest_pb2.TestAllTypes() - r = text_format.MergeLines(opened, parsed_message) - self.assertIs(r, parsed_message) + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), + text_format.Parse, text, message) + def testParseDuplicateScalars(self): message = unittest_pb2.TestAllTypes() - test_util.SetAllFields(message) - self.assertEqual(message, parsed_message) - - def testParseOneof(self): - m = unittest_pb2.TestAllTypes() - m.oneof_uint32 = 11 - m2 = unittest_pb2.TestAllTypes() - text_format.Parse(text_format.MessageToString(m), m2) - self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + text = ('optional_int32: 42 ' + 'optional_int32: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), + text_format.Parse, text, message) class TokenizerTest(basetest.TestCase): diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 118725da..76c056c4 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -59,6 +59,8 @@ from google.protobuf import descriptor _FieldDescriptor = descriptor.FieldDescriptor +def SupportsOpenEnums(field_descriptor): + return field_descriptor.containing_type.syntax == "proto3" def GetTypeChecker(field): """Returns a type checker for a message field of the specified types. @@ -74,7 +76,11 @@ def GetTypeChecker(field): field.type == _FieldDescriptor.TYPE_STRING): return UnicodeValueChecker() if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: - return EnumValueChecker(field.enum_type) + if SupportsOpenEnums(field): + # When open enums are supported, any int32 can be assigned. + return _VALUE_CHECKERS[_FieldDescriptor.CPPTYPE_INT32] + else: + return EnumValueChecker(field.enum_type) return _VALUE_CHECKERS[field.cpp_type] diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index a4dc1f7c..59f9ae4c 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -38,6 +38,7 @@ __author__ = 'bohdank@google.com (Bohdan Koval)' from google.apputils import basetest from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 +from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation from google.protobuf.internal import encoder from google.protobuf.internal import missing_enum_values_pb2 @@ -45,10 +46,81 @@ from google.protobuf.internal import test_util from google.protobuf.internal import type_checkers +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.all_fields = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.all_fields) + self.all_fields_data = self.all_fields.SerializeToString() + self.empty_message = unittest_pb2.TestEmptyMessage() + self.empty_message.ParseFromString(self.all_fields_data) + + def testSerialize(self): + data = self.empty_message.SerializeToString() + + # Don't use assertEqual because we don't want to dump raw binary data to + # stdout. + self.assertTrue(data == self.all_fields_data) + + def testSerializeProto3(self): + # Verify that proto3 doesn't preserve unknown fields. + message = unittest_proto3_arena_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(0, len(message.SerializeToString())) + + def testByteSize(self): + self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + + def testListFields(self): + # Make sure ListFields doesn't return unknown fields. + self.assertEqual(0, len(self.empty_message.ListFields())) + + def testSerializeMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an unknown extension. + item = raw.item.add() + item.type_id = 1545009 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Verify that the unknown extension is serialized unchanged + reserialized = proto.SerializeToString() + new_raw = unittest_mset_pb2.RawMessageSet() + new_raw.MergeFromString(reserialized) + self.assertEqual(raw, new_raw) + + # C++ implementation for proto2 does not currently take into account unknown + # fields when checking equality. + # + # TODO(haberman): fix this. + @basetest.unittest.skipIf( + api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, + 'C++ implementation does not expose unknown fields to Python') + def testEquals(self): + message = unittest_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message, message) + + self.all_fields.ClearField('optional_string') + message.ParseFromString(self.all_fields.SerializeToString()) + self.assertNotEqual(self.empty_message, message) + + @basetest.unittest.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 'C++ implementation does not expose unknown fields to Python') -class UnknownFieldsTest(basetest.TestCase): +class UnknownFieldsAccessorsTest(basetest.TestCase): def setUp(self): self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -98,13 +170,6 @@ class UnknownFieldsTest(basetest.TestCase): value = self.GetField('optionalgroup') self.assertEqual(self.all_fields.optionalgroup, value) - def testSerialize(self): - data = self.empty_message.SerializeToString() - - # Don't use assertEqual because we don't want to dump raw binary data to - # stdout. - self.assertTrue(data == self.all_fields_data) - def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() message.CopyFrom(self.empty_message) @@ -132,51 +197,12 @@ class UnknownFieldsTest(basetest.TestCase): self.empty_message.Clear() self.assertEqual(0, len(self.empty_message._unknown_fields)) - def testByteSize(self): - self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) - def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() message.ParseFromString(self.all_fields_data) self.assertEqual(self.empty_message._unknown_fields, message._unknown_fields) - def testListFields(self): - # Make sure ListFields doesn't return unknown fields. - self.assertEqual(0, len(self.empty_message.ListFields())) - - def testSerializeMessageSetWireFormatUnknownExtension(self): - # Create a message using the message set wire format with an unknown - # message. - raw = unittest_mset_pb2.RawMessageSet() - - # Add an unknown extension. - item = raw.item.add() - item.type_id = 1545009 - message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.i = 12345 - item.message = message1.SerializeToString() - - serialized = raw.SerializeToString() - - # Parse message using the message set wire format. - proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) - - # Verify that the unknown extension is serialized unchanged - reserialized = proto.SerializeToString() - new_raw = unittest_mset_pb2.RawMessageSet() - new_raw.MergeFromString(reserialized) - self.assertEqual(raw, new_raw) - - def testEquals(self): - message = unittest_pb2.TestEmptyMessage() - message.ParseFromString(self.all_fields_data) - self.assertEqual(self.empty_message, message) - - self.all_fields.ClearField('optional_string') - message.ParseFromString(self.all_fields.SerializeToString()) - self.assertNotEqual(self.empty_message, message) @basetest.unittest.skipIf( diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index 7343c0b7..d9b90ddb 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -31,10 +31,15 @@ // Author: petar@google.com (Petar Petrov) #include +#include #include +#include #include +#include #include +#include +#include #include #include @@ -42,64 +47,928 @@ #if PY_MAJOR_VERSION >= 3 #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_Check PyUnicode_Check + #define PyString_InternFromString PyUnicode_InternFromString #define PyInt_FromLong PyLong_FromLong + #define PyInt_FromSize_t PyLong_FromSize_t #if PY_VERSION_HEX < 0x03030000 #error "Python 3.0 - 3.2 are not supported." - #else - #define PyString_AsString(ob) \ - (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif namespace google { namespace protobuf { namespace python { +PyObject* PyString_FromCppString(const string& str) { + return PyString_FromStringAndSize(str.c_str(), str.size()); +} + +// Check that the calling Python code is the global scope of a _pb2.py module. +// This function is used to support the current code generated by the proto +// compiler, which creates descriptors, then update some properties. +// For example: +// message_descriptor = Descriptor( +// name='Message', +// fields = [FieldDescriptor(name='field')] +// message_descriptor.fields[0].containing_type = message_descriptor +// +// This code is still executed, but the descriptors now have no other storage +// than the (const) C++ pointer, and are immutable. +// So we let this code pass, by simply ignoring the new value. +// +// From user code, descriptors still look immutable. +// +// TODO(amauryfa): Change the proto2 compiler to remove the assignments, and +// remove this hack. +bool _CalledFromGeneratedFile(int stacklevel) { + PyThreadState *state = PyThreadState_GET(); + if (state == NULL) { + return false; + } + PyFrameObject* frame = state->frame; + if (frame == NULL) { + return false; + } + while (stacklevel-- > 0) { + frame = frame->f_back; + if (frame == NULL) { + return false; + } + } + if (frame->f_globals != frame->f_locals) { + // Not at global module scope + return false; + } + + if (frame->f_code->co_filename == NULL) { + return false; + } + char* filename; + Py_ssize_t filename_size; + if (PyString_AsStringAndSize(frame->f_code->co_filename, + &filename, &filename_size) < 0) { + // filename is not a string. + PyErr_Clear(); + return false; + } + if (filename_size < 7) { + // filename is too short. + return false; + } + if (strcmp(&filename[filename_size - 7], "_pb2.py") != 0) { + // Filename is not ending with _pb2. + return false; + } + return true; +} + +// If the calling code is not a _pb2.py file, raise AttributeError. +// To be used in attribute setters. +static int CheckCalledFromGeneratedFile(const char* attr_name) { + if (_CalledFromGeneratedFile(0)) { + return 0; + } + PyErr_Format(PyExc_AttributeError, + "attribute is not writable: %s", attr_name); + return -1; +} + + +#ifndef PyVarObject_HEAD_INIT +#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, +#endif +#ifndef Py_TYPE +#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) +#endif + + +// Helper functions for descriptor objects. + +// Converts options into a Python protobuf, and cache the result. +// +// This is a bit tricky because options can contain extension fields defined in +// the same proto file. In this case the options parsed from the serialized_pb +// have unkown fields, and we need to parse them again. +// +// Always returns a new reference. +template +static PyObject* GetOrBuildOptions(const DescriptorClass *descriptor) { + hash_map* descriptor_options = + GetDescriptorPool()->descriptor_options; + // First search in the cache. + if (descriptor_options->find(descriptor) != descriptor_options->end()) { + PyObject *value = (*descriptor_options)[descriptor]; + Py_INCREF(value); + return value; + } + + // Build the Options object: get its Python class, and make a copy of the C++ + // read-only instance. + const Message& options(descriptor->options()); + const Descriptor *message_type = options.GetDescriptor(); + PyObject* message_class(cdescriptor_pool::GetMessageClass( + GetDescriptorPool(), message_type)); + if (message_class == NULL) { + PyErr_Format(PyExc_TypeError, "Could not retrieve class for Options: %s", + message_type->full_name().c_str()); + return NULL; + } + ScopedPyObjectPtr value(PyEval_CallObject(message_class, NULL)); + if (value == NULL) { + return NULL; + } + CMessage* cmsg = reinterpret_cast(value.get()); + + const Reflection* reflection = options.GetReflection(); + const UnknownFieldSet& unknown_fields(reflection->GetUnknownFields(options)); + if (unknown_fields.empty()) { + cmsg->message->CopyFrom(options); + } else { + // Reparse options string! XXX call cmessage::MergeFromString + string serialized; + options.SerializeToString(&serialized); + io::CodedInputStream input( + reinterpret_cast(serialized.c_str()), serialized.size()); + input.SetExtensionRegistry(GetDescriptorPool()->pool, + cmessage::GetMessageFactory()); + bool success = cmsg->message->MergePartialFromCodedStream(&input); + if (!success) { + PyErr_Format(PyExc_ValueError, "Error parsing Options message"); + return NULL; + } + } + + // Cache the result. + Py_INCREF(value); + (*GetDescriptorPool()->descriptor_options)[descriptor] = value.get(); + + return value.release(); +} + +// Copy the C++ descriptor to a Python message. +// The Python message is an instance of descriptor_pb2.DescriptorProto +// or similar. +template +static PyObject* CopyToPythonProto(const DescriptorClass *descriptor, + PyObject *target) { + const Descriptor* self_descriptor = + DescriptorProtoClass::default_instance().GetDescriptor(); + CMessage* message = reinterpret_cast(target); + if (!PyObject_TypeCheck(target, &CMessage_Type) || + message->message->GetDescriptor() != self_descriptor) { + PyErr_Format(PyExc_TypeError, "Not a %s message", + self_descriptor->full_name().c_str()); + return NULL; + } + cmessage::AssureWritable(message); + DescriptorProtoClass* descriptor_message = + static_cast(message->message); + descriptor->CopyTo(descriptor_message); + Py_RETURN_NONE; +} + +// All Descriptors classes share the same memory layout. +typedef struct PyBaseDescriptor { + PyObject_HEAD + + // Pointer to the C++ proto2 descriptor. + // Like all descriptors, it is owned by the global DescriptorPool. + const void* descriptor; +} PyBaseDescriptor; + + +// FileDescriptor structure "inherits" from the base descriptor. +typedef struct PyFileDescriptor { + PyBaseDescriptor base; + + // The cached version of serialized pb. Either NULL, or a Bytes string. + // We own the reference. + PyObject *serialized_pb; +} PyFileDescriptor; + + +namespace descriptor { + +// Creates or retrieve a Python descriptor of the specified type. +// Objects are interned: the same descriptor will return the same object if it +// was kept alive. +// Always return a new reference. +PyObject* NewInternedDescriptor(PyTypeObject* type, const void* descriptor) { + if (descriptor == NULL) { + PyErr_BadInternalCall(); + return NULL; + } + + // See if the object is in the map of interned descriptors + hash_map::iterator it = + GetDescriptorPool()->interned_descriptors->find(descriptor); + if (it != GetDescriptorPool()->interned_descriptors->end()) { + GOOGLE_DCHECK(Py_TYPE(it->second) == type); + Py_INCREF(it->second); + return it->second; + } + // Create a new descriptor object + PyBaseDescriptor* py_descriptor = PyObject_New( + PyBaseDescriptor, type); + if (py_descriptor == NULL) { + return NULL; + } + py_descriptor->descriptor = descriptor; + // and cache it. + GetDescriptorPool()->interned_descriptors->insert( + std::make_pair(descriptor, reinterpret_cast(py_descriptor))); + + return reinterpret_cast(py_descriptor); +} + +static void Dealloc(PyBaseDescriptor* self) { + // Remove from interned dictionary + GetDescriptorPool()->interned_descriptors->erase(self->descriptor); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PyGetSetDef Getters[] = { + {NULL} +}; + +PyTypeObject PyBaseDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "google.protobuf.internal." + "_message.DescriptorBase", // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + "Descriptors base class", // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + Getters, // tp_getset +}; + +} // namespace descriptor + +const void* PyDescriptor_AsVoidPtr(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &descriptor::PyBaseDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a BaseDescriptor"); + return NULL; + } + return reinterpret_cast(obj)->descriptor; +} + +namespace message_descriptor { + +// Unchecked accessor to the C++ pointer. +static const Descriptor* _GetDescriptor(PyBaseDescriptor* self) { + return reinterpret_cast(self->descriptor); +} + +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_New(_GetDescriptor(self)->file()); +} + +static PyObject* GetConcreteClass(PyBaseDescriptor* self, void *closure) { + PyObject* concrete_class(cdescriptor_pool::GetMessageClass( + GetDescriptorPool(), _GetDescriptor(self))); + Py_XINCREF(concrete_class); + return concrete_class; +} + +static int SetConcreteClass(PyBaseDescriptor *self, PyObject *value, + void *closure) { + // This attribute is also set from reflection.py. Check that it's actually a + // no-op. + if (value != cdescriptor_pool::GetMessageClass( + GetDescriptorPool(), _GetDescriptor(self))) { + PyErr_SetString(PyExc_AttributeError, "Cannot change _concrete_class"); + } + return 0; +} + +static PyObject* GetFieldsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsByName(_GetDescriptor(self)); +} + +static PyObject* GetFieldsByNumber(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsByNumber(_GetDescriptor(self)); +} + +static PyObject* GetFieldsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageFieldsSeq(_GetDescriptor(self)); +} + +static PyObject* GetNestedTypesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageNestedTypesByName(_GetDescriptor(self)); +} + +static PyObject* GetNestedTypesSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageNestedTypesSeq(_GetDescriptor(self)); +} + +static PyObject* GetExtensionsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageExtensionsByName(_GetDescriptor(self)); +} + +static PyObject* GetExtensions(PyBaseDescriptor* self, void *closure) { + return NewMessageExtensionsSeq(_GetDescriptor(self)); +} + +static PyObject* GetEnumsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumsSeq(_GetDescriptor(self)); +} + +static PyObject* GetEnumTypesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumsByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumValuesByName(PyBaseDescriptor* self, void *closure) { + return NewMessageEnumValuesByName(_GetDescriptor(self)); +} + +static PyObject* GetOneofsByName(PyBaseDescriptor* self, void *closure) { + return NewMessageOneofsByName(_GetDescriptor(self)); +} + +static PyObject* GetOneofsSeq(PyBaseDescriptor* self, void *closure) { + return NewMessageOneofsSeq(_GetDescriptor(self)); +} + +static PyObject* IsExtendable(PyBaseDescriptor *self, void *closure) { + if (_GetDescriptor(self)->extension_range_count() > 0) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_New(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const MessageOptions& options(_GetDescriptor(self)->options()); + if (&options != &MessageOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto(_GetDescriptor(self), target); +} + +static PyObject* EnumValueName(PyBaseDescriptor *self, PyObject *args) { + const char *enum_name; + int number; + if (!PyArg_ParseTuple(args, "si", &enum_name, &number)) + return NULL; + const EnumDescriptor *enum_type = + _GetDescriptor(self)->FindEnumTypeByName(enum_name); + if (enum_type == NULL) { + PyErr_SetString(PyExc_KeyError, enum_name); + return NULL; + } + const EnumValueDescriptor *enum_value = + enum_type->FindValueByNumber(number); + if (enum_value == NULL) { + PyErr_Format(PyExc_KeyError, "%d", number); + return NULL; + } + return PyString_FromCppString(enum_value->name()); +} + +static PyObject* GetSyntax(PyBaseDescriptor *self, void *closure) { + return PyString_InternFromString( + FileDescriptor::SyntaxName(_GetDescriptor(self)->file()->syntax())); +} + +static PyGetSetDef Getters[] = { + { C("name"), (getter)GetName, NULL, "Last name", NULL}, + { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, + { C("_concrete_class"), (getter)GetConcreteClass, (setter)SetConcreteClass, "concrete class", NULL}, + { C("file"), (getter)GetFile, NULL, "File descriptor", NULL}, + + { C("fields"), (getter)GetFieldsSeq, NULL, "Fields sequence", NULL}, + { C("fields_by_name"), (getter)GetFieldsByName, NULL, "Fields by name", NULL}, + { C("fields_by_number"), (getter)GetFieldsByNumber, NULL, "Fields by number", NULL}, + { C("nested_types"), (getter)GetNestedTypesSeq, NULL, "Nested types sequence", NULL}, + { C("nested_types_by_name"), (getter)GetNestedTypesByName, NULL, "Nested types by name", NULL}, + { C("extensions"), (getter)GetExtensions, NULL, "Extensions Sequence", NULL}, + { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL}, + { C("enum_types"), (getter)GetEnumsSeq, NULL, "Enum sequence", NULL}, + { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enum types by name", NULL}, + { C("enum_values_by_name"), (getter)GetEnumValuesByName, NULL, "Enum values by name", NULL}, + { C("oneofs_by_name"), (getter)GetOneofsByName, NULL, "Oneofs by name", NULL}, + { C("oneofs"), (getter)GetOneofsSeq, NULL, "Oneofs by name", NULL}, + { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL}, + { C("is_extendable"), (getter)IsExtendable, (setter)NULL, NULL, NULL}, + { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, + { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + { C("syntax"), (getter)GetSyntax, (setter)NULL, "Syntax", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + { "EnumValueName", (PyCFunction)EnumValueName, METH_VARARGS, }, + {NULL} +}; + +} // namespace message_descriptor + +PyTypeObject PyMessageDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_message.MessageDescriptor"), // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Message Descriptor"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + message_descriptor::Methods, // tp_methods + 0, // tp_members + message_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; + +PyObject* PyMessageDescriptor_New( + const Descriptor* message_descriptor) { + return descriptor::NewInternedDescriptor( + &PyMessageDescriptor_Type, message_descriptor); +} + +const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyMessageDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a MessageDescriptor"); + return NULL; + } + return reinterpret_cast( + reinterpret_cast(obj)->descriptor); +} + +namespace field_descriptor { + +// Unchecked accessor to the C++ pointer. +static const FieldDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast(self->descriptor); +} + +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} + +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetType(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->type()); +} + +static PyObject* GetCppType(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->cpp_type()); +} + +static PyObject* GetLabel(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->label()); +} + +static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->number()); +} + +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); +} + +static PyObject* GetID(PyBaseDescriptor *self, void *closure) { + return PyLong_FromVoidPtr(self); +} + +static PyObject* IsExtension(PyBaseDescriptor *self, void *closure) { + return PyBool_FromLong(_GetDescriptor(self)->is_extension()); +} + +static PyObject* HasDefaultValue(PyBaseDescriptor *self, void *closure) { + return PyBool_FromLong(_GetDescriptor(self)->has_default_value()); +} + +static PyObject* GetDefaultValue(PyBaseDescriptor *self, void *closure) { + PyObject *result; + + switch (_GetDescriptor(self)->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: { + int32 value = _GetDescriptor(self)->default_value_int32(); + result = PyInt_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_INT64: { + int64 value = _GetDescriptor(self)->default_value_int64(); + result = PyLong_FromLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT32: { + uint32 value = _GetDescriptor(self)->default_value_uint32(); + result = PyInt_FromSize_t(value); + break; + } + case FieldDescriptor::CPPTYPE_UINT64: { + uint64 value = _GetDescriptor(self)->default_value_uint64(); + result = PyLong_FromUnsignedLongLong(value); + break; + } + case FieldDescriptor::CPPTYPE_FLOAT: { + float value = _GetDescriptor(self)->default_value_float(); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_DOUBLE: { + double value = _GetDescriptor(self)->default_value_double(); + result = PyFloat_FromDouble(value); + break; + } + case FieldDescriptor::CPPTYPE_BOOL: { + bool value = _GetDescriptor(self)->default_value_bool(); + result = PyBool_FromLong(value); + break; + } + case FieldDescriptor::CPPTYPE_STRING: { + string value = _GetDescriptor(self)->default_value_string(); + result = ToStringObject(_GetDescriptor(self), value); + break; + } + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* value = + _GetDescriptor(self)->default_value_enum(); + result = PyInt_FromLong(value->number()); + break; + } + default: + PyErr_Format(PyExc_NotImplementedError, "default value for %s", + _GetDescriptor(self)->full_name().c_str()); + return NULL; + } + return result; +} + +static PyObject* GetCDescriptor(PyObject *self, void *closure) { + Py_INCREF(self); + return self; +} + +static PyObject *GetEnumType(PyBaseDescriptor *self, void *closure) { + const EnumDescriptor* enum_type = _GetDescriptor(self)->enum_type(); + if (enum_type) { + return PyEnumDescriptor_New(enum_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetEnumType(PyBaseDescriptor *self, PyObject *value, void *closure) { + return CheckCalledFromGeneratedFile("enum_type"); +} + +static PyObject *GetMessageType(PyBaseDescriptor *self, void *closure) { + const Descriptor* message_type = _GetDescriptor(self)->message_type(); + if (message_type) { + return PyMessageDescriptor_New(message_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetMessageType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("message_type"); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_New(containing_type); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); +} + +static PyObject* GetExtensionScope(PyBaseDescriptor *self, void *closure) { + const Descriptor* extension_scope = + _GetDescriptor(self)->extension_scope(); + if (extension_scope) { + return PyMessageDescriptor_New(extension_scope); + } else { + Py_RETURN_NONE; + } +} + +static PyObject* GetContainingOneof(PyBaseDescriptor *self, void *closure) { + const OneofDescriptor* containing_oneof = + _GetDescriptor(self)->containing_oneof(); + if (containing_oneof) { + return PyOneofDescriptor_New(containing_oneof); + } else { + Py_RETURN_NONE; + } +} + +static int SetContainingOneof(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_oneof"); +} + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const FieldOptions& options(_GetDescriptor(self)->options()); + if (&options != &FieldOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} + + +static PyGetSetDef Getters[] = { + { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, + { C("name"), (getter)GetName, NULL, "Unqualified name", NULL}, + { C("type"), (getter)GetType, NULL, "C++ Type", NULL}, + { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL}, + { C("label"), (getter)GetLabel, NULL, "Label", NULL}, + { C("number"), (getter)GetNumber, NULL, "Number", NULL}, + { C("index"), (getter)GetIndex, NULL, "Index", NULL}, + { C("default_value"), (getter)GetDefaultValue, NULL, "Default Value", NULL}, + { C("has_default_value"), (getter)HasDefaultValue, NULL, NULL, NULL}, + { C("is_extension"), (getter)IsExtension, NULL, "ID", NULL}, + { C("id"), (getter)GetID, NULL, "ID", NULL}, + { C("_cdescriptor"), (getter)GetCDescriptor, NULL, "HAACK REMOVE ME", NULL}, + + { C("message_type"), (getter)GetMessageType, (setter)SetMessageType, "Message type", NULL}, + { C("enum_type"), (getter)GetEnumType, (setter)SetEnumType, "Enum type", NULL}, + { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL}, + { C("extension_scope"), (getter)GetExtensionScope, (setter)NULL, "Extension scope", NULL}, + { C("containing_oneof"), (getter)GetContainingOneof, (setter)SetContainingOneof, "Containing oneof", NULL}, + { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, + { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + {NULL} +}; + +} // namespace field_descriptor + +PyTypeObject PyFieldDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_message.FieldDescriptor"), // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Field Descriptor"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + field_descriptor::Methods, // tp_methods + 0, // tp_members + field_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base +}; -#ifndef PyVarObject_HEAD_INIT -#define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, -#endif -#ifndef Py_TYPE -#define Py_TYPE(ob) (((PyObject*)(ob))->ob_type) -#endif +PyObject* PyFieldDescriptor_New( + const FieldDescriptor* field_descriptor) { + return descriptor::NewInternedDescriptor( + &PyFieldDescriptor_Type, field_descriptor); +} + +const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj) { + if (!PyObject_TypeCheck(obj, &PyFieldDescriptor_Type)) { + PyErr_SetString(PyExc_TypeError, "Not a FieldDescriptor"); + return NULL; + } + return reinterpret_cast( + reinterpret_cast(obj)->descriptor); +} +namespace enum_descriptor { -static google::protobuf::DescriptorPool* g_descriptor_pool = NULL; +// Unchecked accessor to the C++ pointer. +static const EnumDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast(self->descriptor); +} -namespace cmessage_descriptor { +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} -static void Dealloc(CMessageDescriptor* self) { - Py_TYPE(self)->tp_free(reinterpret_cast(self)); +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} + +static PyObject* GetFile(PyBaseDescriptor *self, void *closure) { + return PyFileDescriptor_New(_GetDescriptor(self)->file()); +} + +static PyObject* GetEnumvaluesByName(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesByName(_GetDescriptor(self)); +} + +static PyObject* GetEnumvaluesByNumber(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesByNumber(_GetDescriptor(self)); +} + +static PyObject* GetEnumvaluesSeq(PyBaseDescriptor* self, void *closure) { + return NewEnumValuesSeq(_GetDescriptor(self)); +} + +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_New(containing_type); + } else { + Py_RETURN_NONE; + } } -static PyObject* GetFullName(CMessageDescriptor* self, void *closure) { - return PyString_FromStringAndSize( - self->descriptor->full_name().c_str(), - self->descriptor->full_name().size()); +static int SetContainingType(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("containing_type"); } -static PyObject* GetName(CMessageDescriptor *self, void *closure) { - return PyString_FromStringAndSize( - self->descriptor->name().c_str(), - self->descriptor->name().size()); + +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const EnumOptions& options(_GetDescriptor(self)->options()); + if (&options != &EnumOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); } +static PyObject* CopyToProto(PyBaseDescriptor *self, PyObject *target) { + return CopyToPythonProto(_GetDescriptor(self), target); +} + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} +}; + static PyGetSetDef Getters[] = { { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("name"), (getter)GetName, NULL, "Unqualified name", NULL}, + { C("name"), (getter)GetName, NULL, "last name", NULL}, + { C("file"), (getter)GetFile, NULL, "File descriptor", NULL}, + { C("values"), (getter)GetEnumvaluesSeq, NULL, "values", NULL}, + { C("values_by_name"), (getter)GetEnumvaluesByName, NULL, "Enumvalues by name", NULL}, + { C("values_by_number"), (getter)GetEnumvaluesByNumber, NULL, "Enumvalues by number", NULL}, + + { C("containing_type"), (getter)GetContainingType, (setter)SetContainingType, "Containing type", NULL}, + { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, + { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, {NULL} }; -} // namespace cmessage_descriptor +} // namespace enum_descriptor -PyTypeObject CMessageDescriptor_Type = { +PyTypeObject PyEnumDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) C("google.protobuf.internal." - "_net_proto2___python." - "CMessageDescriptor"), // tp_name - sizeof(CMessageDescriptor), // tp_basicsize + "_message.EnumDescriptor"), // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize - (destructor)cmessage_descriptor::Dealloc, // tp_dealloc + 0, // tp_dealloc 0, // tp_print 0, // tp_getattr 0, // tp_setattr @@ -115,77 +984,97 @@ PyTypeObject CMessageDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Message Descriptor"), // tp_doc + C("A Enum Descriptor"), // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare 0, // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext - 0, // tp_methods + enum_descriptor::Methods, // tp_getset 0, // tp_members - cmessage_descriptor::Getters, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - PyType_GenericAlloc, // tp_alloc - PyType_GenericNew, // tp_new - PyObject_Del, // tp_free + enum_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base }; +PyObject* PyEnumDescriptor_New( + const EnumDescriptor* enum_descriptor) { + return descriptor::NewInternedDescriptor( + &PyEnumDescriptor_Type, enum_descriptor); +} -namespace cfield_descriptor { +namespace enumvalue_descriptor { -static void Dealloc(CFieldDescriptor* self) { - Py_TYPE(self)->tp_free(reinterpret_cast(self)); +// Unchecked accessor to the C++ pointer. +static const EnumValueDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast(self->descriptor); } -static PyObject* GetFullName(CFieldDescriptor* self, void *closure) { - return PyString_FromStringAndSize( - self->descriptor->full_name().c_str(), - self->descriptor->full_name().size()); +static PyObject* GetName(PyBaseDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); } -static PyObject* GetName(CFieldDescriptor *self, void *closure) { - return PyString_FromStringAndSize( - self->descriptor->name().c_str(), - self->descriptor->name().size()); +static PyObject* GetNumber(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->number()); } -static PyObject* GetCppType(CFieldDescriptor *self, void *closure) { - return PyInt_FromLong(self->descriptor->cpp_type()); +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); } -static PyObject* GetLabel(CFieldDescriptor *self, void *closure) { - return PyInt_FromLong(self->descriptor->label()); +static PyObject* GetType(PyBaseDescriptor *self, void *closure) { + return PyEnumDescriptor_New(_GetDescriptor(self)->type()); } -static PyObject* GetID(CFieldDescriptor *self, void *closure) { - return PyLong_FromVoidPtr(self); +static PyObject* GetHasOptions(PyBaseDescriptor *self, void *closure) { + const EnumValueOptions& options(_GetDescriptor(self)->options()); + if (&options != &EnumValueOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} +static int SetHasOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} + +static PyObject* GetOptions(PyBaseDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); +} + +static int SetOptions(PyBaseDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); } + static PyGetSetDef Getters[] = { - { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, - { C("name"), (getter)GetName, NULL, "Unqualified name", NULL}, - { C("cpp_type"), (getter)GetCppType, NULL, "C++ Type", NULL}, - { C("label"), (getter)GetLabel, NULL, "Label", NULL}, - { C("id"), (getter)GetID, NULL, "ID", NULL}, + { C("name"), (getter)GetName, NULL, "name", NULL}, + { C("number"), (getter)GetNumber, NULL, "number", NULL}, + { C("index"), (getter)GetIndex, NULL, "index", NULL}, + { C("type"), (getter)GetType, NULL, "index", NULL}, + + { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, + { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, {NULL} }; -} // namespace cfield_descriptor +} // namespace enumvalue_descriptor -PyTypeObject CFieldDescriptor_Type = { +PyTypeObject PyEnumValueDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) C("google.protobuf.internal." - "_net_proto2___python." - "CFieldDescriptor"), // tp_name - sizeof(CFieldDescriptor), // tp_basicsize + "_message.EnumValueDescriptor"), // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize 0, // tp_itemsize - (destructor)cfield_descriptor::Dealloc, // tp_dealloc + 0, // tp_dealloc 0, // tp_print 0, // tp_getattr 0, // tp_setattr @@ -201,304 +1090,359 @@ PyTypeObject CFieldDescriptor_Type = { 0, // tp_setattro 0, // tp_as_buffer Py_TPFLAGS_DEFAULT, // tp_flags - C("A Field Descriptor"), // tp_doc + C("A EnumValue Descriptor"), // tp_doc 0, // tp_traverse 0, // tp_clear 0, // tp_richcompare 0, // tp_weaklistoffset 0, // tp_iter 0, // tp_iternext - 0, // tp_methods + enumvalue_descriptor::Methods, // tp_methods 0, // tp_members - cfield_descriptor::Getters, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - PyType_GenericAlloc, // tp_alloc - PyType_GenericNew, // tp_new - PyObject_Del, // tp_free + enumvalue_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base }; +PyObject* PyEnumValueDescriptor_New( + const EnumValueDescriptor* enumvalue_descriptor) { + return descriptor::NewInternedDescriptor( + &PyEnumValueDescriptor_Type, enumvalue_descriptor); +} + +namespace file_descriptor { + +// Unchecked accessor to the C++ pointer. +static const FileDescriptor* _GetDescriptor(PyFileDescriptor *self) { + return reinterpret_cast(self->base.descriptor); +} + +static void Dealloc(PyFileDescriptor* self) { + Py_XDECREF(self->serialized_pb); + descriptor::Dealloc(&self->base); +} + +static PyObject* GetName(PyFileDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} -namespace cdescriptor_pool { +static PyObject* GetPackage(PyFileDescriptor *self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->package()); +} -PyDescriptorPool* NewDescriptorPool() { - PyDescriptorPool* cdescriptor_pool = PyObject_New( - PyDescriptorPool, &PyDescriptorPool_Type); - if (cdescriptor_pool == NULL) { +static PyObject* GetSerializedPb(PyFileDescriptor *self, void *closure) { + PyObject *serialized_pb = self->serialized_pb; + if (serialized_pb != NULL) { + Py_INCREF(serialized_pb); + return serialized_pb; + } + FileDescriptorProto file_proto; + _GetDescriptor(self)->CopyTo(&file_proto); + string contents; + file_proto.SerializePartialToString(&contents); + self->serialized_pb = PyBytes_FromStringAndSize( + contents.c_str(), contents.size()); + if (self->serialized_pb == NULL) { return NULL; } + Py_INCREF(self->serialized_pb); + return self->serialized_pb; +} - // Build a DescriptorPool for messages only declared in Python libraries. - // generated_pool() contains all messages linked in C++ libraries, and is used - // as underlay. - cdescriptor_pool->pool = new google::protobuf::DescriptorPool( - google::protobuf::DescriptorPool::generated_pool()); +static PyObject* GetMessageTypesByName(PyFileDescriptor* self, void *closure) { + return NewFileMessageTypesByName(_GetDescriptor(self)); +} - // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same - // storage. - cdescriptor_pool->classes_by_descriptor = - new PyDescriptorPool::ClassesByMessageMap(); +static PyObject* GetEnumTypesByName(PyFileDescriptor* self, void *closure) { + return NewFileEnumTypesByName(_GetDescriptor(self)); +} - return cdescriptor_pool; +static PyObject* GetExtensionsByName(PyFileDescriptor* self, void *closure) { + return NewFileExtensionsByName(_GetDescriptor(self)); } -static void Dealloc(PyDescriptorPool* self) { - typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; - for (iterator it = self->classes_by_descriptor->begin(); - it != self->classes_by_descriptor->end(); ++it) { - Py_DECREF(it->second); - } - delete self->classes_by_descriptor; - Py_TYPE(self)->tp_free(reinterpret_cast(self)); +static PyObject* GetDependencies(PyFileDescriptor* self, void *closure) { + return NewFileDependencies(_GetDescriptor(self)); } -const google::protobuf::Descriptor* FindMessageTypeByName(PyDescriptorPool* self, - const string& name) { - return self->pool->FindMessageTypeByName(name); +static PyObject* GetPublicDependencies(PyFileDescriptor* self, void *closure) { + return NewFilePublicDependencies(_GetDescriptor(self)); } -static PyObject* NewCMessageDescriptor( - const google::protobuf::Descriptor* message_descriptor) { - CMessageDescriptor* cmessage_descriptor = PyObject_New( - CMessageDescriptor, &CMessageDescriptor_Type); - if (cmessage_descriptor == NULL) { - return NULL; +static PyObject* GetHasOptions(PyFileDescriptor *self, void *closure) { + const FileOptions& options(_GetDescriptor(self)->options()); + if (&options != &FileOptions::default_instance()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; } - cmessage_descriptor->descriptor = message_descriptor; +} +static int SetHasOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("has_options"); +} - return reinterpret_cast(cmessage_descriptor); +static PyObject* GetOptions(PyFileDescriptor *self) { + return GetOrBuildOptions(_GetDescriptor(self)); } -static PyObject* NewCFieldDescriptor( - const google::protobuf::FieldDescriptor* field_descriptor) { - CFieldDescriptor* cfield_descriptor = PyObject_New( - CFieldDescriptor, &CFieldDescriptor_Type); - if (cfield_descriptor == NULL) { - return NULL; - } - cfield_descriptor->descriptor = field_descriptor; +static int SetOptions(PyFileDescriptor *self, PyObject *value, + void *closure) { + return CheckCalledFromGeneratedFile("_options"); +} - return reinterpret_cast(cfield_descriptor); +static PyObject* GetSyntax(PyFileDescriptor *self, void *closure) { + return PyString_InternFromString( + FileDescriptor::SyntaxName(_GetDescriptor(self)->syntax())); } -// Add a message class to our database. -const google::protobuf::Descriptor* RegisterMessageClass( - PyDescriptorPool* self, PyObject *message_class, PyObject* descriptor) { - ScopedPyObjectPtr full_message_name( - PyObject_GetAttrString(descriptor, "full_name")); - const char* full_name = PyString_AsString(full_message_name); - if (full_name == NULL) { - return NULL; - } - const Descriptor *message_descriptor = - self->pool->FindMessageTypeByName(full_name); - if (!message_descriptor) { - PyErr_Format(PyExc_TypeError, "Could not find C++ descriptor for '%s'", - full_name); - return NULL; - } - Py_INCREF(message_class); - typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; - std::pair ret = self->classes_by_descriptor->insert( - make_pair(message_descriptor, message_class)); - if (!ret.second) { - // Update case: DECREF the previous value. - Py_DECREF(ret.first->second); - ret.first->second = message_class; - } +static PyObject* CopyToProto(PyFileDescriptor *self, PyObject *target) { + return CopyToPythonProto(_GetDescriptor(self), target); +} - // Also add the C++ descriptor to the Python descriptor class. - ScopedPyObjectPtr cdescriptor(NewCMessageDescriptor(message_descriptor)); - if (cdescriptor == NULL) { - return NULL; - } - if (PyObject_SetAttrString( - descriptor, "_cdescriptor", cdescriptor) < 0) { - return NULL; - } - return message_descriptor; +static PyGetSetDef Getters[] = { + { C("name"), (getter)GetName, NULL, "name", NULL}, + { C("package"), (getter)GetPackage, NULL, "package", NULL}, + { C("serialized_pb"), (getter)GetSerializedPb, NULL, NULL, NULL}, + { C("message_types_by_name"), (getter)GetMessageTypesByName, NULL, "Messages by name", NULL}, + { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enums by name", NULL}, + { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL}, + { C("dependencies"), (getter)GetDependencies, NULL, "Dependencies", NULL}, + { C("public_dependencies"), (getter)GetPublicDependencies, NULL, "Dependencies", NULL}, + + { C("has_options"), (getter)GetHasOptions, (setter)SetHasOptions, "Has Options", NULL}, + { C("_options"), (getter)NULL, (setter)SetOptions, "Options", NULL}, + { C("syntax"), (getter)GetSyntax, (setter)NULL, "Syntax", NULL}, + {NULL} +}; + +static PyMethodDef Methods[] = { + { "GetOptions", (PyCFunction)GetOptions, METH_NOARGS, }, + { "CopyToProto", (PyCFunction)CopyToProto, METH_O, }, + {NULL} +}; + +} // namespace file_descriptor + +PyTypeObject PyFileDescriptor_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_message.FileDescriptor"), // tp_name + sizeof(PyFileDescriptor), // tp_basicsize + 0, // tp_itemsize + (destructor)file_descriptor::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A File Descriptor"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + file_descriptor::Methods, // tp_methods + 0, // tp_members + file_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + PyType_GenericAlloc, // tp_alloc + PyType_GenericNew, // tp_new + PyObject_Del, // tp_free +}; + +PyObject* PyFileDescriptor_New(const FileDescriptor* file_descriptor) { + return descriptor::NewInternedDescriptor( + &PyFileDescriptor_Type, file_descriptor); } -// Retrieve the message class added to our database. -PyObject *GetMessageClass(PyDescriptorPool* self, - const Descriptor *message_descriptor) { - typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; - iterator ret = self->classes_by_descriptor->find(message_descriptor); - if (ret == self->classes_by_descriptor->end()) { - PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", - message_descriptor->full_name().c_str()); +PyObject* PyFileDescriptor_NewWithPb( + const FileDescriptor* file_descriptor, PyObject *serialized_pb) { + PyObject* py_descriptor = PyFileDescriptor_New(file_descriptor); + if (py_descriptor == NULL) { return NULL; - } else { - return ret->second; } + if (serialized_pb != NULL) { + PyFileDescriptor* cfile_descriptor = + reinterpret_cast(py_descriptor); + Py_XINCREF(serialized_pb); + cfile_descriptor->serialized_pb = serialized_pb; + } + + return py_descriptor; } -PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name) { - const char* full_field_name = PyString_AsString(name); - if (full_field_name == NULL) { - return NULL; - } +namespace oneof_descriptor { - const google::protobuf::FieldDescriptor* field_descriptor = NULL; +// Unchecked accessor to the C++ pointer. +static const OneofDescriptor* _GetDescriptor( + PyBaseDescriptor *self) { + return reinterpret_cast(self->descriptor); +} - field_descriptor = self->pool->FindFieldByName(full_field_name); +static PyObject* GetName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->name()); +} - if (field_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", - full_field_name); - return NULL; - } +static PyObject* GetFullName(PyBaseDescriptor* self, void *closure) { + return PyString_FromCppString(_GetDescriptor(self)->full_name()); +} - return NewCFieldDescriptor(field_descriptor); +static PyObject* GetIndex(PyBaseDescriptor *self, void *closure) { + return PyInt_FromLong(_GetDescriptor(self)->index()); } -PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { - const char* full_field_name = PyString_AsString(arg); - if (full_field_name == NULL) { - return NULL; - } +static PyObject* GetFields(PyBaseDescriptor* self, void *closure) { + return NewOneofFieldsSeq(_GetDescriptor(self)); +} - const google::protobuf::FieldDescriptor* field_descriptor = - self->pool->FindExtensionByName(full_field_name); - if (field_descriptor == NULL) { - PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", - full_field_name); - return NULL; +static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { + const Descriptor* containing_type = + _GetDescriptor(self)->containing_type(); + if (containing_type) { + return PyMessageDescriptor_New(containing_type); + } else { + Py_RETURN_NONE; } - - return NewCFieldDescriptor(field_descriptor); } -static PyMethodDef Methods[] = { - { C("FindFieldByName"), - (PyCFunction)FindFieldByName, - METH_O, - C("Searches for a field descriptor by full name.") }, - { C("FindExtensionByName"), - (PyCFunction)FindExtensionByName, - METH_O, - C("Searches for extension descriptor by full name.") }, +static PyGetSetDef Getters[] = { + { C("name"), (getter)GetName, NULL, "Name", NULL}, + { C("full_name"), (getter)GetFullName, NULL, "Full name", NULL}, + { C("index"), (getter)GetIndex, NULL, "Index", NULL}, + + { C("containing_type"), (getter)GetContainingType, NULL, "Containing type", NULL}, + { C("fields"), (getter)GetFields, NULL, "Fields", NULL}, {NULL} }; -} // namespace cdescriptor_pool +} // namespace oneof_descriptor -PyTypeObject PyDescriptorPool_Type = { +PyTypeObject PyOneofDescriptor_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) C("google.protobuf.internal." - "_net_proto2___python." - "CFieldDescriptor"), // tp_name - sizeof(PyDescriptorPool), // tp_basicsize - 0, // tp_itemsize - (destructor)cdescriptor_pool::Dealloc, // tp_dealloc - 0, // tp_print - 0, // tp_getattr - 0, // tp_setattr - 0, // tp_compare - 0, // tp_repr - 0, // tp_as_number - 0, // tp_as_sequence - 0, // tp_as_mapping - 0, // tp_hash - 0, // tp_call - 0, // tp_str - 0, // tp_getattro - 0, // tp_setattro - 0, // tp_as_buffer - Py_TPFLAGS_DEFAULT, // tp_flags - C("A Descriptor Pool"), // tp_doc - 0, // tp_traverse - 0, // tp_clear - 0, // tp_richcompare - 0, // tp_weaklistoffset - 0, // tp_iter - 0, // tp_iternext - cdescriptor_pool::Methods, // tp_methods - 0, // tp_members - 0, // tp_getset - 0, // tp_base - 0, // tp_dict - 0, // tp_descr_get - 0, // tp_descr_set - 0, // tp_dictoffset - 0, // tp_init - 0, // tp_alloc - 0, // tp_new - PyObject_Del, // tp_free + "_message.OneofDescriptor"), // tp_name + sizeof(PyBaseDescriptor), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Oneof Descriptor"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + 0, // tp_methods + 0, // tp_members + oneof_descriptor::Getters, // tp_getset + &descriptor::PyBaseDescriptor_Type, // tp_base }; +PyObject* PyOneofDescriptor_New( + const OneofDescriptor* oneof_descriptor) { + return descriptor::NewInternedDescriptor( + &PyOneofDescriptor_Type, oneof_descriptor); +} -// Collects errors that occur during proto file building to allow them to be -// propagated in the python exception instead of only living in ERROR logs. -class BuildFileErrorCollector : public google::protobuf::DescriptorPool::ErrorCollector { - public: - BuildFileErrorCollector() : error_message(""), had_errors(false) {} - - void AddError(const string& filename, const string& element_name, - const Message* descriptor, ErrorLocation location, - const string& message) { - // Replicates the logging behavior that happens in the C++ implementation - // when an error collector is not passed in. - if (!had_errors) { - error_message += - ("Invalid proto descriptor for file \"" + filename + "\":\n"); +// Add a enum values to a type dictionary. +static bool AddEnumValues(PyTypeObject *type, + const EnumDescriptor* enum_descriptor) { + for (int i = 0; i < enum_descriptor->value_count(); ++i) { + const EnumValueDescriptor* value = enum_descriptor->value(i); + ScopedPyObjectPtr obj(PyInt_FromLong(value->number())); + if (obj == NULL) { + return false; + } + if (PyDict_SetItemString(type->tp_dict, value->name().c_str(), obj) < 0) { + return false; } - // As this only happens on failure and will result in the program not - // running at all, no effort is made to optimize this string manipulation. - error_message += (" " + element_name + ": " + message + "\n"); } + return true; +} - string error_message; - bool had_errors; -}; +static bool AddIntConstant(PyTypeObject *type, const char* name, int value) { + ScopedPyObjectPtr obj(PyInt_FromLong(value)); + if (PyDict_SetItemString(type->tp_dict, name, obj) < 0) { + return false; + } + return true; +} -PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) { - char* message_type; - Py_ssize_t message_len; - if (PyBytes_AsStringAndSize(arg, &message_type, &message_len) < 0) { - return NULL; - } +bool InitDescriptor() { + if (PyType_Ready(&PyMessageDescriptor_Type) < 0) + return false; - google::protobuf::FileDescriptorProto file_proto; - if (!file_proto.ParseFromArray(message_type, message_len)) { - PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); - return NULL; - } + if (PyType_Ready(&PyFieldDescriptor_Type) < 0) + return false; - // If the file was already part of a C++ library, all its descriptors are in - // the underlying pool. No need to do anything else. - if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName( - file_proto.name()) != NULL) { - Py_RETURN_NONE; + if (!AddEnumValues(&PyFieldDescriptor_Type, + FieldDescriptorProto::Label_descriptor())) { + return false; } - - BuildFileErrorCollector error_collector; - const google::protobuf::FileDescriptor* descriptor = - GetDescriptorPool()->pool->BuildFileCollectingErrors(file_proto, - &error_collector); - if (descriptor == NULL) { - PyErr_Format(PyExc_TypeError, - "Couldn't build proto file into descriptor pool!\n%s", - error_collector.error_message.c_str()); - return NULL; + if (!AddEnumValues(&PyFieldDescriptor_Type, + FieldDescriptorProto::Type_descriptor())) { + return false; + } +#define ADD_FIELDDESC_CONSTANT(NAME) AddIntConstant( \ + &PyFieldDescriptor_Type, #NAME, FieldDescriptor::NAME) + if (!ADD_FIELDDESC_CONSTANT(CPPTYPE_INT32) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_INT64) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT32) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_UINT64) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_DOUBLE) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_FLOAT) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_BOOL) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_ENUM) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_STRING) || + !ADD_FIELDDESC_CONSTANT(CPPTYPE_MESSAGE)) { + return false; } +#undef ADD_FIELDDESC_CONSTANT - Py_RETURN_NONE; -} + if (PyType_Ready(&PyEnumDescriptor_Type) < 0) + return false; -bool InitDescriptor() { - if (PyType_Ready(&CMessageDescriptor_Type) < 0) + if (PyType_Ready(&PyEnumValueDescriptor_Type) < 0) + return false; + + if (PyType_Ready(&PyFileDescriptor_Type) < 0) return false; - if (PyType_Ready(&CFieldDescriptor_Type) < 0) + + if (PyType_Ready(&PyOneofDescriptor_Type) < 0) return false; - PyDescriptorPool_Type.tp_new = PyType_GenericNew; - if (PyType_Ready(&PyDescriptorPool_Type) < 0) + if (!InitDescriptorMappingTypes()) return false; return true; diff --git a/python/google/protobuf/pyext/descriptor.h b/python/google/protobuf/pyext/descriptor.h index 9e5957b5..ba6e7298 100644 --- a/python/google/protobuf/pyext/descriptor.h +++ b/python/google/protobuf/pyext/descriptor.h @@ -34,105 +34,55 @@ #define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_H__ #include -#include - -#include #include -#if PY_VERSION_HEX < 0x02050000 && !defined(PY_SSIZE_T_MIN) -typedef int Py_ssize_t; -#define PY_SSIZE_T_MAX INT_MAX -#define PY_SSIZE_T_MIN INT_MIN -#endif - namespace google { namespace protobuf { namespace python { -typedef struct CMessageDescriptor { - PyObject_HEAD - - // The proto2 descriptor that this object represents. - const google::protobuf::Descriptor* descriptor; -} CMessageDescriptor; - - -typedef struct CFieldDescriptor { - PyObject_HEAD - - // The proto2 descriptor that this object represents. - const google::protobuf::FieldDescriptor* descriptor; -} CFieldDescriptor; - - -// Wraps operations to the global DescriptorPool which contains information -// about all messages and fields. -// -// There is normally one pool per process. We make it a Python object only -// because it contains many Python references. -// TODO(amauryfa): See whether such objects can appear in reference cycles, and -// consider adding support for the cyclic GC. -// -// "Methods" that interacts with this DescriptorPool are in the cdescriptor_pool -// namespace. -typedef struct PyDescriptorPool { - PyObject_HEAD - - google::protobuf::DescriptorPool* pool; - - // Make our own mapping to retrieve Python classes from C++ descriptors. - // - // Descriptor pointers stored here are owned by the DescriptorPool above. - // Python references to classes are owned by this PyDescriptorPool. - typedef hash_map ClassesByMessageMap; - ClassesByMessageMap *classes_by_descriptor; -} PyDescriptorPool; - - -extern PyTypeObject CMessageDescriptor_Type; -extern PyTypeObject CFieldDescriptor_Type; - -extern PyTypeObject PyDescriptorPool_Type; - - -namespace cdescriptor_pool { - -// Builds a new DescriptorPool. Normally called only once per process. -PyDescriptorPool* NewDescriptorPool(); - -// Looks up a message by name. -// Returns a message Descriptor, or NULL if not found. -const google::protobuf::Descriptor* FindMessageTypeByName(PyDescriptorPool* self, - const string& name); - -// Registers a new Python class for the given message descriptor. -// Returns the message Descriptor. -// On error, returns NULL with a Python exception set. -const google::protobuf::Descriptor* RegisterMessageClass( - PyDescriptorPool* self, PyObject *message_class, PyObject *descriptor); +extern PyTypeObject PyMessageDescriptor_Type; +extern PyTypeObject PyFieldDescriptor_Type; +extern PyTypeObject PyEnumDescriptor_Type; +extern PyTypeObject PyEnumValueDescriptor_Type; +extern PyTypeObject PyFileDescriptor_Type; +extern PyTypeObject PyOneofDescriptor_Type; + +// Return a new reference to a Descriptor object. +// The C++ pointer is usually borrowed from the global DescriptorPool. +// In any case, it must stay alive as long as the Python object. +PyObject* PyMessageDescriptor_New(const Descriptor* descriptor); +PyObject* PyFieldDescriptor_New(const FieldDescriptor* descriptor); +PyObject* PyEnumDescriptor_New(const EnumDescriptor* descriptor); +PyObject* PyEnumValueDescriptor_New(const EnumValueDescriptor* descriptor); +PyObject* PyOneofDescriptor_New(const OneofDescriptor* descriptor); +PyObject* PyFileDescriptor_New(const FileDescriptor* file_descriptor); + +// Alternate constructor of PyFileDescriptor, used when we already have a +// serialized FileDescriptorProto that can be cached. +// Returns a new reference. +PyObject* PyFileDescriptor_NewWithPb(const FileDescriptor* file_descriptor, + PyObject* serialized_pb); -// Retrieves the Python class registered with the given message descriptor. -// -// Returns a *borrowed* reference if found, otherwise returns NULL with an +// Return the C++ descriptor pointer. +// This function checks the parameter type; on error, return NULL with a Python // exception set. -PyObject *GetMessageClass(PyDescriptorPool* self, - const Descriptor *message_descriptor); +const Descriptor* PyMessageDescriptor_AsDescriptor(PyObject* obj); +const FieldDescriptor* PyFieldDescriptor_AsDescriptor(PyObject* obj); -// Looks up a field by name. Returns a CDescriptor corresponding to -// the field on success, or NULL on failure. -// -// Returns a new reference. -PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name); +// Returns the raw C++ pointer. +const void* PyDescriptor_AsVoidPtr(PyObject* obj); -// Looks up an extension by name. Returns a CDescriptor corresponding -// to the field on success, or NULL on failure. +// Check that the calling Python code is the global scope of a _pb2.py module. +// This function is used to support the current code generated by the proto +// compiler, which insists on modifying descriptors after they have been +// created. // -// Returns a new reference. -PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg); -} // namespace cdescriptor_pool +// stacklevel indicates which Python frame should be the _pb2.py module. +// +// Don't use this function outside descriptor classes. +bool _CalledFromGeneratedFile(int stacklevel); -PyObject* Python_BuildFile(PyObject* ignored, PyObject* args); bool InitDescriptor(); } // namespace python diff --git a/python/google/protobuf/pyext/descriptor_containers.cc b/python/google/protobuf/pyext/descriptor_containers.cc new file mode 100644 index 00000000..06edebf8 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_containers.cc @@ -0,0 +1,1564 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Mappings and Sequences of descriptors. +// Used by Descriptor.fields_by_name, EnumDescriptor.values... +// +// They avoid the allocation of a full dictionary or a full list: they simply +// store a pointer to the parent descriptor, use the C++ Descriptor methods (see +// google/protobuf/descriptor.h) to retrieve other descriptors, and create +// Python objects on the fly. +// +// The containers fully conform to abc.Mapping and abc.Sequence, and behave just +// like read-only dictionaries and lists. +// +// Because the interface of C++ Descriptors is quite regular, this file actually +// defines only three types, the exact behavior of a container is controlled by +// a DescriptorContainerDef structure, which contains functions that uses the +// public Descriptor API. +// +// Note: This DescriptorContainerDef is similar to the "virtual methods table" +// that a C++ compiler generates for a class. We have to make it explicit +// because the Python API is based on C, and does not play well with C++ +// inheritance. + +#include + +#include +#include +#include +#include +#include + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #define PyString_FromFormat PyUnicode_FromFormat + #define PyInt_FromLong PyLong_FromLong + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +struct PyContainer; + +typedef int (*CountMethod)(PyContainer* self); +typedef const void* (*GetByIndexMethod)(PyContainer* self, int index); +typedef const void* (*GetByNameMethod)(PyContainer* self, const string& name); +typedef const void* (*GetByNumberMethod)(PyContainer* self, int index); +typedef PyObject* (*NewObjectFromItemMethod)(const void* descriptor); +typedef const string& (*GetItemNameMethod)(const void* descriptor); +typedef int (*GetItemNumberMethod)(const void* descriptor); +typedef int (*GetItemIndexMethod)(const void* descriptor); + +struct DescriptorContainerDef { + const char* mapping_name; + // Returns the number of items in the container. + CountMethod count_fn; + // Retrieve item by index (usually the order of declaration in the proto file) + // Used by sequences, but also iterators. 0 <= index < Count(). + GetByIndexMethod get_by_index_fn; + // Retrieve item by name (usually a call to some 'FindByName' method). + // Used by "by_name" mappings. + GetByNameMethod get_by_name_fn; + // Retrieve item by declared number (field tag, or enum value). + // Used by "by_number" mappings. + GetByNumberMethod get_by_number_fn; + // Converts a item C++ descriptor to a Python object. Returns a new reference. + NewObjectFromItemMethod new_object_from_item_fn; + // Retrieve the name of an item. Used by iterators on "by_name" mappings. + GetItemNameMethod get_item_name_fn; + // Retrieve the number of an item. Used by iterators on "by_number" mappings. + GetItemNumberMethod get_item_number_fn; + // Retrieve the index of an item for the container type. + // Used by "__contains__". + // If not set, "x in sequence" will do a linear search. + GetItemIndexMethod get_item_index_fn; +}; + +struct PyContainer { + PyObject_HEAD + + // The proto2 descriptor this container belongs to the global DescriptorPool. + const void* descriptor; + + // A pointer to a static structure with function pointers that control the + // behavior of the container. Very similar to the table of virtual functions + // of a C++ class. + const DescriptorContainerDef* container_def; + + // The kind of container: list, or dict by name or value. + enum ContainerKind { + KIND_SEQUENCE, + KIND_BYNAME, + KIND_BYNUMBER, + } kind; +}; + +struct PyContainerIterator { + PyObject_HEAD + + // The container we are iterating over. Own a reference. + PyContainer* container; + + // The current index in the iterator. + int index; + + // The kind of container: list, or dict by name or value. + enum IterKind { + KIND_ITERKEY, + KIND_ITERVALUE, + KIND_ITERITEM, + KIND_ITERVALUE_REVERSED, // For sequences + } kind; +}; + +namespace descriptor { + +// Returns the C++ item descriptor for a given Python key. +// When the descriptor is found, return true and set *item. +// When the descriptor is not found, return true, but set *item to NULL. +// On error, returns false with an exception set. +static bool _GetItemByKey(PyContainer* self, PyObject* key, const void** item) { + switch (self->kind) { + case PyContainer::KIND_BYNAME: + { + char* name; + Py_ssize_t name_size; + if (PyString_AsStringAndSize(key, &name, &name_size) < 0) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a string, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_name_fn( + self, string(name, name_size)); + return true; + } + case PyContainer::KIND_BYNUMBER: + { + Py_ssize_t number = PyNumber_AsSsize_t(key, NULL); + if (number == -1 && PyErr_Occurred()) { + if (PyErr_ExceptionMatches(PyExc_TypeError)) { + // Not a number, cannot be in the container. + PyErr_Clear(); + *item = NULL; + return true; + } + return false; + } + *item = self->container_def->get_by_number_fn(self, number); + return true; + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return false; + } +} + +// Returns the key of the object at the given index. +// Used when iterating over mappings. +static PyObject* _NewKey_ByIndex(PyContainer* self, Py_ssize_t index) { + const void* item = self->container_def->get_by_index_fn(self, index); + switch (self->kind) { + case PyContainer::KIND_BYNAME: + { + const string& name(self->container_def->get_item_name_fn(item)); + return PyString_FromStringAndSize(name.c_str(), name.size()); + } + case PyContainer::KIND_BYNUMBER: + { + int value = self->container_def->get_item_number_fn(item); + return PyInt_FromLong(value); + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } +} + +// Returns the object at the given index. +// Also used when iterating over mappings. +static PyObject* _NewObj_ByIndex(PyContainer* self, Py_ssize_t index) { + return self->container_def->new_object_from_item_fn( + self->container_def->get_by_index_fn(self, index)); +} + +static Py_ssize_t Length(PyContainer* self) { + return self->container_def->count_fn(self); +} + +// The DescriptorMapping type. + +static PyObject* Subscript(PyContainer* self, PyObject* key) { + const void* item = NULL; + if (!_GetItemByKey(self, key, &item)) { + return NULL; + } + if (!item) { + PyErr_SetObject(PyExc_KeyError, key); + return NULL; + } + return self->container_def->new_object_from_item_fn(item); +} + +static int AssSubscript(PyContainer* self, PyObject* key, PyObject* value) { + if (_CalledFromGeneratedFile(0)) { + return 0; + } + PyErr_Format(PyExc_TypeError, + "'%.200s' object does not support item assignment", + Py_TYPE(self)->tp_name); + return -1; +} + +static PyMappingMethods MappingMappingMethods = { + (lenfunc)Length, // mp_length + (binaryfunc)Subscript, // mp_subscript + (objobjargproc)AssSubscript, // mp_ass_subscript +}; + +static int Contains(PyContainer* self, PyObject* key) { + const void* item = NULL; + if (!_GetItemByKey(self, key, &item)) { + return -1; + } + if (item) { + return 1; + } else { + return 0; + } +} + +static PyObject* ContainerRepr(PyContainer* self) { + const char* kind = ""; + switch (self->kind) { + case PyContainer::KIND_SEQUENCE: + kind = "sequence"; + break; + case PyContainer::KIND_BYNAME: + kind = "mapping by name"; + break; + case PyContainer::KIND_BYNUMBER: + kind = "mapping by number"; + break; + } + return PyString_FromFormat( + "<%s %s>", self->container_def->mapping_name, kind); +} + +extern PyTypeObject DescriptorMapping_Type; +extern PyTypeObject DescriptorSequence_Type; + +// A sequence container can only be equal to another sequence container, or (for +// backward compatibility) to a list containing the same items. +// Returns 1 if equal, 0 if unequal, -1 on error. +static int DescriptorSequence_Equal(PyContainer* self, PyObject* other) { + // Check the identity of C++ pointers. + if (PyObject_TypeCheck(other, &DescriptorSequence_Type)) { + PyContainer* other_container = reinterpret_cast(other); + if (self->descriptor == other_container->descriptor && + self->container_def == other_container->container_def && + self->kind == other_container->kind) { + return 1; + } else { + return 0; + } + } + + // If other is a list + if (PyList_Check(other)) { + // return list(self) == other + int size = Length(self); + if (size != PyList_Size(other)) { + return false; + } + for (int index = 0; index < size; index++) { + ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index)); + if (value1 == NULL) { + return -1; + } + PyObject* value2 = PyList_GetItem(other, index); + if (value2 == NULL) { + return -1; + } + int cmp = PyObject_RichCompareBool(value1, value2, Py_EQ); + if (cmp != 1) // error or not equal + return cmp; + } + // All items were found and equal + return 1; + } + + // Any other object is different. + return 0; +} + +// A mapping container can only be equal to another mapping container, or (for +// backward compatibility) to a dict containing the same items. +// Returns 1 if equal, 0 if unequal, -1 on error. +static int DescriptorMapping_Equal(PyContainer* self, PyObject* other) { + // Check the identity of C++ pointers. + if (PyObject_TypeCheck(other, &DescriptorMapping_Type)) { + PyContainer* other_container = reinterpret_cast(other); + if (self->descriptor == other_container->descriptor && + self->container_def == other_container->container_def && + self->kind == other_container->kind) { + return 1; + } else { + return 0; + } + } + + // If other is a dict + if (PyDict_Check(other)) { + // equivalent to dict(self.items()) == other + int size = Length(self); + if (size != PyDict_Size(other)) { + return false; + } + for (int index = 0; index < size; index++) { + ScopedPyObjectPtr key(_NewKey_ByIndex(self, index)); + if (key == NULL) { + return -1; + } + ScopedPyObjectPtr value1(_NewObj_ByIndex(self, index)); + if (value1 == NULL) { + return -1; + } + PyObject* value2 = PyDict_GetItem(other, key); + if (value2 == NULL) { + // Not found in the other dictionary + return 0; + } + int cmp = PyObject_RichCompareBool(value1, value2, Py_EQ); + if (cmp != 1) // error or not equal + return cmp; + } + // All items were found and equal + return 1; + } + + // Any other object is different. + return 0; +} + +static PyObject* RichCompare(PyContainer* self, PyObject* other, int opid) { + if (opid != Py_EQ && opid != Py_NE) { + Py_INCREF(Py_NotImplemented); + return Py_NotImplemented; + } + + int result; + + if (self->kind == PyContainer::KIND_SEQUENCE) { + result = DescriptorSequence_Equal(self, other); + } else { + result = DescriptorMapping_Equal(self, other); + } + if (result < 0) { + return NULL; + } + if (result ^ (opid == Py_NE)) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +static PySequenceMethods MappingSequenceMethods = { + 0, // sq_length + 0, // sq_concat + 0, // sq_repeat + 0, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)Contains, // sq_contains +}; + +static PyObject* Get(PyContainer* self, PyObject* args) { + PyObject* key; + PyObject* default_value = Py_None; + if (!PyArg_UnpackTuple(args, "get", 1, 2, &key, &default_value)) { + return NULL; + } + + const void* item; + if (!_GetItemByKey(self, key, &item)) { + return NULL; + } + if (item == NULL) { + Py_INCREF(default_value); + return default_value; + } + return self->container_def->new_object_from_item_fn(item); +} + +static PyObject* Keys(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + PyObject* key = _NewKey_ByIndex(self, index); + if (key == NULL) { + return NULL; + } + PyList_SET_ITEM(list.get(), index, key); + } + return list.release(); +} + +static PyObject* Values(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + PyObject* value = _NewObj_ByIndex(self, index); + if (value == NULL) { + return NULL; + } + PyList_SET_ITEM(list.get(), index, value); + } + return list.release(); +} + +static PyObject* Items(PyContainer* self, PyObject* args) { + Py_ssize_t count = Length(self); + ScopedPyObjectPtr list(PyList_New(count)); + if (list == NULL) { + return NULL; + } + for (Py_ssize_t index = 0; index < count; ++index) { + ScopedPyObjectPtr obj(PyTuple_New(2)); + if (obj == NULL) { + return NULL; + } + PyObject* key = _NewKey_ByIndex(self, index); + if (key == NULL) { + return NULL; + } + PyTuple_SET_ITEM(obj.get(), 0, key); + PyObject* value = _NewObj_ByIndex(self, index); + if (value == NULL) { + return NULL; + } + PyTuple_SET_ITEM(obj.get(), 1, value); + PyList_SET_ITEM(list.get(), index, obj.release()); + } + return list.release(); +} + +static PyObject* NewContainerIterator(PyContainer* mapping, + PyContainerIterator::IterKind kind); + +static PyObject* Iter(PyContainer* self) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY); +} +static PyObject* IterKeys(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERKEY); +} +static PyObject* IterValues(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERVALUE); +} +static PyObject* IterItems(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, PyContainerIterator::KIND_ITERITEM); +} + +static PyMethodDef MappingMethods[] = { + { "get", (PyCFunction)Get, METH_VARARGS, }, + { "keys", (PyCFunction)Keys, METH_NOARGS, }, + { "values", (PyCFunction)Values, METH_NOARGS, }, + { "items", (PyCFunction)Items, METH_NOARGS, }, + { "iterkeys", (PyCFunction)IterKeys, METH_NOARGS, }, + { "itervalues", (PyCFunction)IterValues, METH_NOARGS, }, + { "iteritems", (PyCFunction)IterItems, METH_NOARGS, }, + {NULL} +}; + +PyTypeObject DescriptorMapping_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorMapping", // tp_name + sizeof(PyContainer), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)ContainerRepr, // tp_repr + 0, // tp_as_number + &MappingSequenceMethods, // tp_as_sequence + &MappingMappingMethods, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + (getiterfunc)Iter, // tp_iter + 0, // tp_iternext + MappingMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +// The DescriptorSequence type. + +static PyObject* GetItem(PyContainer* self, Py_ssize_t index) { + if (index < 0) { + index += Length(self); + } + if (index < 0 || index >= Length(self)) { + PyErr_SetString(PyExc_IndexError, "index out of range"); + return NULL; + } + return _NewObj_ByIndex(self, index); +} + +// Returns the position of the item in the sequence, of -1 if not found. +// This function never fails. +int Find(PyContainer* self, PyObject* item) { + // The item can only be in one position: item.index. + // Check that self[item.index] == item, it's faster than a linear search. + // + // This assumes that sequences are only defined by syntax of the .proto file: + // a specific item belongs to only one sequence, depending on its position in + // the .proto file definition. + const void* descriptor_ptr = PyDescriptor_AsVoidPtr(item); + if (descriptor_ptr == NULL) { + // Not a descriptor, it cannot be in the list. + return -1; + } + if (self->container_def->get_item_index_fn) { + int index = self->container_def->get_item_index_fn(descriptor_ptr); + if (index < 0 || index >= Length(self)) { + // This index is not from this collection. + return -1; + } + if (self->container_def->get_by_index_fn(self, index) != descriptor_ptr) { + // The descriptor at this index is not the same. + return -1; + } + // self[item.index] == item, so return the index. + return index; + } else { + // Fall back to linear search. + int length = Length(self); + for (int index=0; index < length; index++) { + if (self->container_def->get_by_index_fn(self, index) == descriptor_ptr) { + return index; + } + } + // Not found + return -1; + } +} + +// Implements list.index(): the position of the item is in the sequence. +static PyObject* Index(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + // Not found + PyErr_SetNone(PyExc_ValueError); + return NULL; + } else { + return PyInt_FromLong(position); + } +} +// Implements "list.__contains__()": is the object in the sequence. +static int SeqContains(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + return 0; + } else { + return 1; + } +} + +// Implements list.count(): number of occurrences of the item in the sequence. +// An item can only appear once in a sequence. If it exists, return 1. +static PyObject* Count(PyContainer* self, PyObject* item) { + int position = Find(self, item); + if (position < 0) { + return PyInt_FromLong(0); + } else { + return PyInt_FromLong(1); + } +} + +static PyObject* Append(PyContainer* self, PyObject* args) { + if (_CalledFromGeneratedFile(0)) { + Py_RETURN_NONE; + } + PyErr_Format(PyExc_TypeError, + "'%.200s' object is not a mutable sequence", + Py_TYPE(self)->tp_name); + return NULL; +} + +static PyObject* Reversed(PyContainer* self, PyObject* args) { + return NewContainerIterator(self, + PyContainerIterator::KIND_ITERVALUE_REVERSED); +} + +static PyMethodDef SeqMethods[] = { + { "index", (PyCFunction)Index, METH_O, }, + { "count", (PyCFunction)Count, METH_O, }, + { "append", (PyCFunction)Append, METH_O, }, + { "__reversed__", (PyCFunction)Reversed, METH_NOARGS, }, + {NULL} +}; + +static PySequenceMethods SeqSequenceMethods = { + (lenfunc)Length, // sq_length + 0, // sq_concat + 0, // sq_repeat + (ssizeargfunc)GetItem, // sq_item + 0, // sq_slice + 0, // sq_ass_item + 0, // sq_ass_slice + (objobjproc)SeqContains, // sq_contains +}; + +PyTypeObject DescriptorSequence_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorSequence", // tp_name + sizeof(PyContainer), // tp_basicsize + 0, // tp_itemsize + 0, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + (reprfunc)ContainerRepr, // tp_repr + 0, // tp_as_number + &SeqSequenceMethods, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + (richcmpfunc)RichCompare, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + SeqMethods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +static PyObject* NewMappingByName( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYNAME; + return reinterpret_cast(self); +} + +static PyObject* NewMappingByNumber( + DescriptorContainerDef* container_def, const void* descriptor) { + if (container_def->get_by_number_fn == NULL || + container_def->get_item_number_fn == NULL) { + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } + PyContainer* self = PyObject_New(PyContainer, &DescriptorMapping_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_BYNUMBER; + return reinterpret_cast(self); +} + +static PyObject* NewSequence( + DescriptorContainerDef* container_def, const void* descriptor) { + PyContainer* self = PyObject_New(PyContainer, &DescriptorSequence_Type); + if (self == NULL) { + return NULL; + } + self->descriptor = descriptor; + self->container_def = container_def; + self->kind = PyContainer::KIND_SEQUENCE; + return reinterpret_cast(self); +} + +// Implement iterators over PyContainers. + +static void Iterator_Dealloc(PyContainerIterator* self) { + Py_CLEAR(self->container); + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +static PyObject* Iterator_Next(PyContainerIterator* self) { + int count = self->container->container_def->count_fn(self->container); + if (self->index >= count) { + // Return NULL with no exception to indicate the end. + return NULL; + } + int index = self->index; + self->index += 1; + switch (self->kind) { + case PyContainerIterator::KIND_ITERKEY: + return _NewKey_ByIndex(self->container, index); + case PyContainerIterator::KIND_ITERVALUE: + return _NewObj_ByIndex(self->container, index); + case PyContainerIterator::KIND_ITERVALUE_REVERSED: + return _NewObj_ByIndex(self->container, count - index - 1); + case PyContainerIterator::KIND_ITERITEM: + { + PyObject* obj = PyTuple_New(2); + if (obj == NULL) { + return NULL; + } + PyObject* key = _NewKey_ByIndex(self->container, index); + if (key == NULL) { + Py_DECREF(obj); + return NULL; + } + PyTuple_SET_ITEM(obj, 0, key); + PyObject* value = _NewObj_ByIndex(self->container, index); + if (value == NULL) { + Py_DECREF(obj); + return NULL; + } + PyTuple_SET_ITEM(obj, 1, value); + return obj; + } + default: + PyErr_SetNone(PyExc_NotImplementedError); + return NULL; + } +} + +static PyTypeObject ContainerIterator_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + "DescriptorContainerIterator", // tp_name + sizeof(PyContainerIterator), // tp_basicsize + 0, // tp_itemsize + (destructor)Iterator_Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + 0, // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + PyObject_SelfIter, // tp_iter + (iternextfunc)Iterator_Next, // tp_iternext + 0, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + 0, // tp_free +}; + +static PyObject* NewContainerIterator(PyContainer* container, + PyContainerIterator::IterKind kind) { + PyContainerIterator* self = PyObject_New(PyContainerIterator, + &ContainerIterator_Type); + if (self == NULL) { + return NULL; + } + Py_INCREF(container); + self->container = container; + self->kind = kind; + self->index = 0; + + return reinterpret_cast(self); +} + +} // namespace descriptor + +// Now define the real collections! + +namespace message_descriptor { + +typedef const Descriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast(self->descriptor); +} + +namespace fields { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->field_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindFieldByName(name); +} + +static ItemDescriptor GetByNumber(PyContainer* self, int number) { + return GetDescriptor(self)->FindFieldByNumber(number); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->field(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemNumber(ItemDescriptor item) { + return item->number(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageFields", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)GetByNumber, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)GetItemNumber, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace fields + +PyObject* NewMessageFieldsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&fields::ContainerDef, descriptor); +} + +PyObject* NewMessageFieldsByNumber(ParentDescriptor descriptor) { + return descriptor::NewMappingByNumber(&fields::ContainerDef, descriptor); +} + +PyObject* NewMessageFieldsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&fields::ContainerDef, descriptor); +} + +namespace nested_types { + +typedef const Descriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->nested_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindNestedTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->nested_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyMessageDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageNestedTypes", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace nested_types + +PyObject* NewMessageNestedTypesSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&nested_types::ContainerDef, descriptor); +} + +PyObject* NewMessageNestedTypesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&nested_types::ContainerDef, descriptor); +} + +namespace enums { + +typedef const EnumDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->enum_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindEnumTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->enum_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageNestedEnums", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace enums + +PyObject* NewMessageEnumsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); +} + +PyObject* NewMessageEnumsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&enums::ContainerDef, descriptor); +} + +namespace enumvalues { + +// This is the "enum_values_by_name" mapping, which collects values from all +// enum types in a message. +// +// Note that the behavior of the C++ descriptor is different: it will search and +// return the first value that matches the name, whereas the Python +// implementation retrieves the last one. + +typedef const EnumValueDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + int count = 0; + for (int i = 0; i < GetDescriptor(self)->enum_type_count(); ++i) { + count += GetDescriptor(self)->enum_type(i)->value_count(); + } + return count; +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindEnumValueByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + // This is not optimal, but the number of enums *types* in a given message + // is small. This function is only used when iterating over the mapping. + const EnumDescriptor* enum_type = NULL; + int enum_type_count = GetDescriptor(self)->enum_type_count(); + for (int i = 0; i < enum_type_count; ++i) { + enum_type = GetDescriptor(self)->enum_type(i); + int enum_value_count = enum_type->value_count(); + if (index < enum_value_count) { + // Found it! + break; + } + index -= enum_value_count; + } + // The next statement cannot overflow, because this function is only called by + // internal iterators which ensure that 0 <= index < Count(). + return enum_type->value(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumValueDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageEnumValues", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)NULL, +}; + +} // namespace enumvalues + +PyObject* NewMessageEnumValuesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor); +} + +namespace extensions { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->extension_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindExtensionByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->extension(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageExtensions", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace extensions + +PyObject* NewMessageExtensionsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); +} + +PyObject* NewMessageExtensionsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&extensions::ContainerDef, descriptor); +} + +namespace oneofs { + +typedef const OneofDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->oneof_decl_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindOneofByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->oneof_decl(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyOneofDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "MessageOneofs", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace oneofs + +PyObject* NewMessageOneofsByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&oneofs::ContainerDef, descriptor); +} + +PyObject* NewMessageOneofsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&oneofs::ContainerDef, descriptor); +} + +} // namespace message_descriptor + +namespace enum_descriptor { + +typedef const EnumDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast(self->descriptor); +} + +namespace enumvalues { + +typedef const EnumValueDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->value_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->value(index); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindValueByName(name); +} + +static ItemDescriptor GetByNumber(PyContainer* self, int number) { + return GetDescriptor(self)->FindValueByNumber(number); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumValueDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemNumber(ItemDescriptor item) { + return item->number(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "EnumValues", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)GetByNumber, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)GetItemNumber, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace enumvalues + +PyObject* NewEnumValuesByName(ParentDescriptor descriptor) { + return descriptor::NewMappingByName(&enumvalues::ContainerDef, descriptor); +} + +PyObject* NewEnumValuesByNumber(ParentDescriptor descriptor) { + return descriptor::NewMappingByNumber(&enumvalues::ContainerDef, descriptor); +} + +PyObject* NewEnumValuesSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&enumvalues::ContainerDef, descriptor); +} + +} // namespace enum_descriptor + +namespace oneof_descriptor { + +typedef const OneofDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast(self->descriptor); +} + +namespace fields { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->field_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->field(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_New(item); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index_in_oneof(); +} + +static DescriptorContainerDef ContainerDef = { + "OneofFields", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace fields + +PyObject* NewOneofFieldsSeq(ParentDescriptor descriptor) { + return descriptor::NewSequence(&fields::ContainerDef, descriptor); +} + +} // namespace oneof_descriptor + +namespace file_descriptor { + +typedef const FileDescriptor* ParentDescriptor; + +static ParentDescriptor GetDescriptor(PyContainer* self) { + return reinterpret_cast(self->descriptor); +} + +namespace messages { + +typedef const Descriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->message_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindMessageTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->message_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyMessageDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileMessages", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace messages + +PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&messages::ContainerDef, descriptor); +} + +namespace enums { + +typedef const EnumDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->enum_type_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindEnumTypeByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->enum_type(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyEnumDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileEnums", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace enums + +PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&enums::ContainerDef, descriptor); +} + +namespace extensions { + +typedef const FieldDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->extension_count(); +} + +static ItemDescriptor GetByName(PyContainer* self, const string& name) { + return GetDescriptor(self)->FindExtensionByName(name); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->extension(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFieldDescriptor_New(item); +} + +static const string& GetItemName(ItemDescriptor item) { + return item->name(); +} + +static int GetItemIndex(ItemDescriptor item) { + return item->index(); +} + +static DescriptorContainerDef ContainerDef = { + "FileExtensions", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)GetByName, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)GetItemName, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)GetItemIndex, +}; + +} // namespace extensions + +PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor) { + return descriptor::NewMappingByName(&extensions::ContainerDef, descriptor); +} + +namespace dependencies { + +typedef const FileDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->dependency_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->dependency(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFileDescriptor_New(item); +} + +static DescriptorContainerDef ContainerDef = { + "FileDependencies", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)NULL, +}; + +} // namespace dependencies + +PyObject* NewFileDependencies(const FileDescriptor* descriptor) { + return descriptor::NewSequence(&dependencies::ContainerDef, descriptor); +} + +namespace public_dependencies { + +typedef const FileDescriptor* ItemDescriptor; + +static int Count(PyContainer* self) { + return GetDescriptor(self)->public_dependency_count(); +} + +static ItemDescriptor GetByIndex(PyContainer* self, int index) { + return GetDescriptor(self)->public_dependency(index); +} + +static PyObject* NewObjectFromItem(ItemDescriptor item) { + return PyFileDescriptor_New(item); +} + +static DescriptorContainerDef ContainerDef = { + "FilePublicDependencies", + (CountMethod)Count, + (GetByIndexMethod)GetByIndex, + (GetByNameMethod)NULL, + (GetByNumberMethod)NULL, + (NewObjectFromItemMethod)NewObjectFromItem, + (GetItemNameMethod)NULL, + (GetItemNumberMethod)NULL, + (GetItemIndexMethod)NULL, +}; + +} // namespace public_dependencies + +PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor) { + return descriptor::NewSequence(&public_dependencies::ContainerDef, + descriptor); +} + +} // namespace file_descriptor + + +// Register all implementations + +bool InitDescriptorMappingTypes() { + if (PyType_Ready(&descriptor::DescriptorMapping_Type) < 0) + return false; + if (PyType_Ready(&descriptor::DescriptorSequence_Type) < 0) + return false; + if (PyType_Ready(&descriptor::ContainerIterator_Type) < 0) + return false; + return true; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_containers.h b/python/google/protobuf/pyext/descriptor_containers.h new file mode 100644 index 00000000..d81537de --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_containers.h @@ -0,0 +1,95 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Mappings and Sequences of descriptors. +// They implement containers like fields_by_name, EnumDescriptor.values... +// See descriptor_containers.cc for more description. +#include + +namespace google { +namespace protobuf { + +class Descriptor; +class FileDescriptor; +class EnumDescriptor; +class OneofDescriptor; + +namespace python { + +// Initialize the various types and objects. +bool InitDescriptorMappingTypes(); + +// Each function below returns a Mapping, or a Sequence of descriptors. +// They all return a new reference. + +namespace message_descriptor { +PyObject* NewMessageFieldsByName(const Descriptor* descriptor); +PyObject* NewMessageFieldsByNumber(const Descriptor* descriptor); +PyObject* NewMessageFieldsSeq(const Descriptor* descriptor); + +PyObject* NewMessageNestedTypesSeq(const Descriptor* descriptor); +PyObject* NewMessageNestedTypesByName(const Descriptor* descriptor); + +PyObject* NewMessageEnumsByName(const Descriptor* descriptor); +PyObject* NewMessageEnumsSeq(const Descriptor* descriptor); +PyObject* NewMessageEnumValuesByName(const Descriptor* descriptor); + +PyObject* NewMessageExtensionsByName(const Descriptor* descriptor); +PyObject* NewMessageExtensionsSeq(const Descriptor* descriptor); + +PyObject* NewMessageOneofsByName(const Descriptor* descriptor); +PyObject* NewMessageOneofsSeq(const Descriptor* descriptor); +} // namespace message_descriptor + +namespace enum_descriptor { +PyObject* NewEnumValuesByName(const EnumDescriptor* descriptor); +PyObject* NewEnumValuesByNumber(const EnumDescriptor* descriptor); +PyObject* NewEnumValuesSeq(const EnumDescriptor* descriptor); +} // namespace enum_descriptor + +namespace oneof_descriptor { +PyObject* NewOneofFieldsSeq(const OneofDescriptor* descriptor); +} // namespace oneof_descriptor + +namespace file_descriptor { +PyObject* NewFileMessageTypesByName(const FileDescriptor* descriptor); + +PyObject* NewFileEnumTypesByName(const FileDescriptor* descriptor); + +PyObject* NewFileExtensionsByName(const FileDescriptor* descriptor); + +PyObject* NewFileDependencies(const FileDescriptor* descriptor); +PyObject* NewFilePublicDependencies(const FileDescriptor* descriptor); +} // namespace file_descriptor + + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_cpp2_test.py b/python/google/protobuf/pyext/descriptor_cpp2_test.py deleted file mode 100644 index 3cf45a22..00000000 --- a/python/google/protobuf/pyext/descriptor_cpp2_test.py +++ /dev/null @@ -1,58 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Tests for google.protobuf.pyext behavior.""" - -__author__ = 'anuraag@google.com (Anuraag Agrawal)' - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.descriptor_test import * - - -class ConfirmCppApi2Test(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/pyext/descriptor_pool.cc b/python/google/protobuf/pyext/descriptor_pool.cc new file mode 100644 index 00000000..bc3077bc --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_pool.cc @@ -0,0 +1,370 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Implements the DescriptorPool, which collects all descriptors. + +#include + +#include +#include +#include +#include + +#define C(str) const_cast(str) + +#if PY_MAJOR_VERSION >= 3 + #define PyString_FromStringAndSize PyUnicode_FromStringAndSize + #if PY_VERSION_HEX < 0x03030000 + #error "Python 3.0 - 3.2 are not supported." + #endif + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) +#endif + +namespace google { +namespace protobuf { +namespace python { + +namespace cdescriptor_pool { + +PyDescriptorPool* NewDescriptorPool() { + PyDescriptorPool* cdescriptor_pool = PyObject_New( + PyDescriptorPool, &PyDescriptorPool_Type); + if (cdescriptor_pool == NULL) { + return NULL; + } + + // Build a DescriptorPool for messages only declared in Python libraries. + // generated_pool() contains all messages linked in C++ libraries, and is used + // as underlay. + cdescriptor_pool->pool = new DescriptorPool(DescriptorPool::generated_pool()); + + // TODO(amauryfa): Rewrite the SymbolDatabase in C so that it uses the same + // storage. + cdescriptor_pool->classes_by_descriptor = + new PyDescriptorPool::ClassesByMessageMap(); + cdescriptor_pool->interned_descriptors = + new hash_map(); + cdescriptor_pool->descriptor_options = + new hash_map(); + + return cdescriptor_pool; +} + +static void Dealloc(PyDescriptorPool* self) { + typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; + for (iterator it = self->classes_by_descriptor->begin(); + it != self->classes_by_descriptor->end(); ++it) { + Py_DECREF(it->second); + } + delete self->classes_by_descriptor; + delete self->interned_descriptors; // its references were borrowed. + for (hash_map::iterator it = + self->descriptor_options->begin(); + it != self->descriptor_options->end(); ++it) { + Py_DECREF(it->second); + } + delete self->descriptor_options; + Py_TYPE(self)->tp_free(reinterpret_cast(self)); +} + +PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const Descriptor* message_descriptor = + self->pool->FindMessageTypeByName(string(name, name_size)); + + if (message_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find message %.200s", name); + return NULL; + } + + return PyMessageDescriptor_New(message_descriptor); +} + +// Add a message class to our database. +const Descriptor* RegisterMessageClass( + PyDescriptorPool* self, PyObject *message_class, PyObject* descriptor) { + ScopedPyObjectPtr full_message_name( + PyObject_GetAttrString(descriptor, "full_name")); + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(full_message_name, &name, &name_size) < 0) { + return NULL; + } + const Descriptor *message_descriptor = + self->pool->FindMessageTypeByName(string(name, name_size)); + if (!message_descriptor) { + PyErr_Format(PyExc_TypeError, "Could not find C++ descriptor for '%s'", + name); + return NULL; + } + Py_INCREF(message_class); + typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; + std::pair ret = self->classes_by_descriptor->insert( + std::make_pair(message_descriptor, message_class)); + if (!ret.second) { + // Update case: DECREF the previous value. + Py_DECREF(ret.first->second); + ret.first->second = message_class; + } + return message_descriptor; +} + +// Retrieve the message class added to our database. +PyObject *GetMessageClass(PyDescriptorPool* self, + const Descriptor *message_descriptor) { + typedef PyDescriptorPool::ClassesByMessageMap::iterator iterator; + iterator ret = self->classes_by_descriptor->find(message_descriptor); + if (ret == self->classes_by_descriptor->end()) { + PyErr_Format(PyExc_TypeError, "No message class registered for '%s'", + message_descriptor->full_name().c_str()); + return NULL; + } else { + return ret->second; + } +} + +PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FieldDescriptor* field_descriptor = + self->pool->FindFieldByName(string(name, name_size)); + if (field_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", + name); + return NULL; + } + + return PyFieldDescriptor_New(field_descriptor); +} + +PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const FieldDescriptor* field_descriptor = + self->pool->FindExtensionByName(string(name, name_size)); + if (field_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", name); + return NULL; + } + + return PyFieldDescriptor_New(field_descriptor); +} + +PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const EnumDescriptor* enum_descriptor = + self->pool->FindEnumTypeByName(string(name, name_size)); + if (enum_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find enum %.200s", name); + return NULL; + } + + return PyEnumDescriptor_New(enum_descriptor); +} + +PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg) { + Py_ssize_t name_size; + char* name; + if (PyString_AsStringAndSize(arg, &name, &name_size) < 0) { + return NULL; + } + + const OneofDescriptor* oneof_descriptor = + self->pool->FindOneofByName(string(name, name_size)); + if (oneof_descriptor == NULL) { + PyErr_Format(PyExc_TypeError, "Couldn't find oneof %.200s", name); + return NULL; + } + + return PyOneofDescriptor_New(oneof_descriptor); +} + +static PyMethodDef Methods[] = { + { C("FindFieldByName"), + (PyCFunction)FindFieldByName, + METH_O, + C("Searches for a field descriptor by full name.") }, + { C("FindExtensionByName"), + (PyCFunction)FindExtensionByName, + METH_O, + C("Searches for extension descriptor by full name.") }, + {NULL} +}; + +} // namespace cdescriptor_pool + +PyTypeObject PyDescriptorPool_Type = { + PyVarObject_HEAD_INIT(&PyType_Type, 0) + C("google.protobuf.internal." + "_message.DescriptorPool"), // tp_name + sizeof(PyDescriptorPool), // tp_basicsize + 0, // tp_itemsize + (destructor)cdescriptor_pool::Dealloc, // tp_dealloc + 0, // tp_print + 0, // tp_getattr + 0, // tp_setattr + 0, // tp_compare + 0, // tp_repr + 0, // tp_as_number + 0, // tp_as_sequence + 0, // tp_as_mapping + 0, // tp_hash + 0, // tp_call + 0, // tp_str + 0, // tp_getattro + 0, // tp_setattro + 0, // tp_as_buffer + Py_TPFLAGS_DEFAULT, // tp_flags + C("A Descriptor Pool"), // tp_doc + 0, // tp_traverse + 0, // tp_clear + 0, // tp_richcompare + 0, // tp_weaklistoffset + 0, // tp_iter + 0, // tp_iternext + cdescriptor_pool::Methods, // tp_methods + 0, // tp_members + 0, // tp_getset + 0, // tp_base + 0, // tp_dict + 0, // tp_descr_get + 0, // tp_descr_set + 0, // tp_dictoffset + 0, // tp_init + 0, // tp_alloc + 0, // tp_new + PyObject_Del, // tp_free +}; + +// The code below loads new Descriptors from a serialized FileDescriptorProto. + + +// Collects errors that occur during proto file building to allow them to be +// propagated in the python exception instead of only living in ERROR logs. +class BuildFileErrorCollector : public DescriptorPool::ErrorCollector { + public: + BuildFileErrorCollector() : error_message(""), had_errors(false) {} + + void AddError(const string& filename, const string& element_name, + const Message* descriptor, ErrorLocation location, + const string& message) { + // Replicates the logging behavior that happens in the C++ implementation + // when an error collector is not passed in. + if (!had_errors) { + error_message += + ("Invalid proto descriptor for file \"" + filename + "\":\n"); + } + // As this only happens on failure and will result in the program not + // running at all, no effort is made to optimize this string manipulation. + error_message += (" " + element_name + ": " + message + "\n"); + } + + string error_message; + bool had_errors; +}; + +PyObject* Python_BuildFile(PyObject* ignored, PyObject* serialized_pb) { + char* message_type; + Py_ssize_t message_len; + + if (PyBytes_AsStringAndSize(serialized_pb, &message_type, &message_len) < 0) { + return NULL; + } + + FileDescriptorProto file_proto; + if (!file_proto.ParseFromArray(message_type, message_len)) { + PyErr_SetString(PyExc_TypeError, "Couldn't parse file content!"); + return NULL; + } + + // If the file was already part of a C++ library, all its descriptors are in + // the underlying pool. No need to do anything else. + const FileDescriptor* generated_file = + DescriptorPool::generated_pool()->FindFileByName(file_proto.name()); + if (generated_file != NULL) { + return PyFileDescriptor_NewWithPb(generated_file, serialized_pb); + } + + BuildFileErrorCollector error_collector; + const FileDescriptor* descriptor = + GetDescriptorPool()->pool->BuildFileCollectingErrors(file_proto, + &error_collector); + if (descriptor == NULL) { + PyErr_Format(PyExc_TypeError, + "Couldn't build proto file into descriptor pool!\n%s", + error_collector.error_message.c_str()); + return NULL; + } + + return PyFileDescriptor_NewWithPb(descriptor, serialized_pb); +} + +static PyDescriptorPool* global_cdescriptor_pool = NULL; + +bool InitDescriptorPool() { + if (PyType_Ready(&PyDescriptorPool_Type) < 0) + return false; + + global_cdescriptor_pool = cdescriptor_pool::NewDescriptorPool(); + if (global_cdescriptor_pool == NULL) { + return false; + } + + return true; +} + +PyDescriptorPool* GetDescriptorPool() { + return global_cdescriptor_pool; +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/pyext/descriptor_pool.h b/python/google/protobuf/pyext/descriptor_pool.h new file mode 100644 index 00000000..4e494b89 --- /dev/null +++ b/python/google/protobuf/pyext/descriptor_pool.h @@ -0,0 +1,152 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#ifndef GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ +#define GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ + +#include + +#include +#include + +namespace google { +namespace protobuf { +namespace python { + +// Wraps operations to the global DescriptorPool which contains information +// about all messages and fields. +// +// There is normally one pool per process. We make it a Python object only +// because it contains many Python references. +// TODO(amauryfa): See whether such objects can appear in reference cycles, and +// consider adding support for the cyclic GC. +// +// "Methods" that interacts with this DescriptorPool are in the cdescriptor_pool +// namespace. +typedef struct PyDescriptorPool { + PyObject_HEAD + + DescriptorPool* pool; + + // Make our own mapping to retrieve Python classes from C++ descriptors. + // + // Descriptor pointers stored here are owned by the DescriptorPool above. + // Python references to classes are owned by this PyDescriptorPool. + typedef hash_map ClassesByMessageMap; + ClassesByMessageMap* classes_by_descriptor; + + // Store interned descriptors, so that the same C++ descriptor yields the same + // Python object. Objects are not immortal: this map does not own the + // references, and items are deleted when the last reference to the object is + // released. + // This is enough to support the "is" operator on live objects. + // All descriptors are stored here. + hash_map* interned_descriptors; + + // Cache the options for any kind of descriptor. + // Descriptor pointers are owned by the DescriptorPool above. + // Python objects are owned by the map. + hash_map* descriptor_options; +} PyDescriptorPool; + + +extern PyTypeObject PyDescriptorPool_Type; + +namespace cdescriptor_pool { + +// Builds a new DescriptorPool. Normally called only once per process. +PyDescriptorPool* NewDescriptorPool(); + +// Looks up a message by name. +// Returns a message Descriptor, or NULL if not found. +const Descriptor* FindMessageTypeByName(PyDescriptorPool* self, + const string& name); + +// Registers a new Python class for the given message descriptor. +// Returns the message Descriptor. +// On error, returns NULL with a Python exception set. +const Descriptor* RegisterMessageClass( + PyDescriptorPool* self, PyObject* message_class, PyObject* descriptor); + +// Retrieves the Python class registered with the given message descriptor. +// +// Returns a *borrowed* reference if found, otherwise returns NULL with an +// exception set. +PyObject* GetMessageClass(PyDescriptorPool* self, + const Descriptor* message_descriptor); + +// Looks up a message by name. Returns a PyMessageDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindMessageByName(PyDescriptorPool* self, PyObject* name); + +// Looks up a field by name. Returns a PyFieldDescriptor corresponding to +// the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindFieldByName(PyDescriptorPool* self, PyObject* name); + +// Looks up an extension by name. Returns a PyFieldDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindExtensionByName(PyDescriptorPool* self, PyObject* arg); + +// Looks up an enum type by name. Returns a PyEnumDescriptor corresponding +// to the field on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindEnumTypeByName(PyDescriptorPool* self, PyObject* arg); + +// Looks up a oneof by name. Returns a COneofDescriptor corresponding +// to the oneof on success, or NULL on failure. +// +// Returns a new reference. +PyObject* FindOneofByName(PyDescriptorPool* self, PyObject* arg); + +} // namespace cdescriptor_pool + +// Implement the Python "_BuildFile" method, it takes a serialized +// FileDescriptorProto, and adds it to the C++ DescriptorPool. +// It returns a new FileDescriptor object, or NULL when an exception is raised. +PyObject* Python_BuildFile(PyObject* ignored, PyObject* args); + +// Retrieve the global descriptor pool owned by the _message module. +PyDescriptorPool* GetDescriptorPool(); + +// Initialize objects used by this module. +bool InitDescriptorPool(); + +} // namespace python +} // namespace protobuf + +} // namespace google +#endif // GOOGLE_PROTOBUF_PYTHON_CPP_DESCRIPTOR_POOL_H__ diff --git a/python/google/protobuf/pyext/extension_dict.cc b/python/google/protobuf/pyext/extension_dict.cc index d83b57d5..8e38fc42 100644 --- a/python/google/protobuf/pyext/extension_dict.cc +++ b/python/google/protobuf/pyext/extension_dict.cc @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include @@ -48,13 +49,11 @@ namespace google { namespace protobuf { namespace python { -extern google::protobuf::DynamicMessageFactory* global_message_factory; - namespace extension_dict { // TODO(tibell): Always use self->message for clarity, just like in // RepeatedCompositeContainer. -static google::protobuf::Message* GetMessage(ExtensionDict* self) { +static Message* GetMessage(ExtensionDict* self) { if (self->parent != NULL) { return self->parent->message; } else { @@ -73,10 +72,9 @@ PyObject* len(ExtensionDict* self) { // TODO(tibell): Use VisitCompositeField. int ReleaseExtension(ExtensionDict* self, PyObject* extension, - const google::protobuf::FieldDescriptor* descriptor) { - if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + const FieldDescriptor* descriptor) { + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (repeated_composite_container::Release( reinterpret_cast( extension)) < 0) { @@ -89,8 +87,7 @@ int ReleaseExtension(ExtensionDict* self, return -1; } } - } else if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (cmessage::ReleaseSubMessage( GetMessage(self), descriptor, reinterpret_cast(extension)) < 0) { @@ -102,8 +99,7 @@ int ReleaseExtension(ExtensionDict* self, } PyObject* subscript(ExtensionDict* self, PyObject* key) { - const google::protobuf::FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(key); + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); if (descriptor == NULL) { return NULL; } @@ -162,8 +158,7 @@ PyObject* subscript(ExtensionDict* self, PyObject* key) { } int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { - const google::protobuf::FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(key); + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); if (descriptor == NULL) { return -1; } @@ -187,7 +182,7 @@ int ass_subscript(ExtensionDict* self, PyObject* key, PyObject* value) { } PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { - const google::protobuf::FieldDescriptor* descriptor = + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(extension); if (descriptor == NULL) { return NULL; @@ -208,7 +203,7 @@ PyObject* ClearExtension(ExtensionDict* self, PyObject* extension) { } PyObject* HasExtension(ExtensionDict* self, PyObject* extension) { - const google::protobuf::FieldDescriptor* descriptor = + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(extension); if (descriptor == NULL) { return NULL; diff --git a/python/google/protobuf/pyext/extension_dict.h b/python/google/protobuf/pyext/extension_dict.h index 47625e23..7e1049f1 100644 --- a/python/google/protobuf/pyext/extension_dict.h +++ b/python/google/protobuf/pyext/extension_dict.h @@ -41,7 +41,6 @@ #include #endif - namespace google { namespace protobuf { @@ -94,7 +93,7 @@ PyObject* len(ExtensionDict* self); // Returns 0 on success, -1 on failure. int ReleaseExtension(ExtensionDict* self, PyObject* extension, - const google::protobuf::FieldDescriptor* descriptor); + const FieldDescriptor* descriptor); // Gets an extension from the dict for the given extension descriptor. // diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index cd956e0e..137f5d5f 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -39,6 +39,7 @@ #endif #include #include +#include // A Python header file. #ifndef PyVarObject_HEAD_INIT #define PyVarObject_HEAD_INIT(type, size) PyObject_HEAD_INIT(type) size, @@ -54,10 +55,12 @@ #include #include #include +#include #include #include #include #include +#include #if PY_MAJOR_VERSION >= 3 #define PyInt_Check PyLong_Check @@ -72,6 +75,10 @@ #else #define PyString_AsString(ob) \ (PyUnicode_Check(ob)? PyUnicode_AsUTF8(ob): PyBytes_AsString(ob)) + #define PyString_AsStringAndSize(ob, charpp, sizep) \ + (PyUnicode_Check(ob)? \ + ((*(charpp) = PyUnicode_AsUTF8AndSize(ob, (sizep))) == NULL? -1: 0): \ + PyBytes_AsStringAndSize(ob, (charpp), (sizep))) #endif #endif @@ -81,14 +88,14 @@ namespace python { // Forward declarations namespace cmessage { -static const google::protobuf::FieldDescriptor* GetFieldDescriptor( +static const FieldDescriptor* GetFieldDescriptor( CMessage* self, PyObject* name); -static const google::protobuf::Descriptor* GetMessageDescriptor(PyTypeObject* cls); +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls); static string GetMessageName(CMessage* self); int InternalReleaseFieldByDescriptor( - const google::protobuf::FieldDescriptor* field_descriptor, + const FieldDescriptor* field_descriptor, PyObject* composite_field, - google::protobuf::Message* parent_message); + Message* parent_message); } // namespace cmessage // --------------------------------------------------------------------- @@ -107,7 +114,7 @@ struct ChildVisitor { // Returns 0 on success, -1 on failure. int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { return 0; } }; @@ -118,8 +125,8 @@ template static int VisitCompositeField(const FieldDescriptor* descriptor, PyObject* child, Visitor visitor) { - if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { RepeatedCompositeContainer* container = reinterpret_cast(child); if (visitor.VisitRepeatedCompositeContainer(container) == -1) @@ -130,8 +137,7 @@ static int VisitCompositeField(const FieldDescriptor* descriptor, if (visitor.VisitRepeatedScalarContainer(container) == -1) return -1; } - } else if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { CMessage* cmsg = reinterpret_cast(child); if (visitor.VisitCMessage(cmsg, descriptor) == -1) return -1; @@ -149,25 +155,31 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { PyObject* key; PyObject* field; - // Never use self->message in this function, it may be already freed. - const google::protobuf::Descriptor* message_descriptor = - cmessage::GetMessageDescriptor(Py_TYPE(self)); - // Visit normal fields. - while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { - const google::protobuf::FieldDescriptor* descriptor = - message_descriptor->FindFieldByName(PyString_AsString(key)); - if (descriptor != NULL) { - if (VisitCompositeField(descriptor, field, visitor) == -1) + if (self->composite_fields) { + // Never use self->message in this function, it may be already freed. + const Descriptor* message_descriptor = + cmessage::GetMessageDescriptor(Py_TYPE(self)); + while (PyDict_Next(self->composite_fields, &pos, &key, &field)) { + Py_ssize_t key_str_size; + char *key_str_data; + if (PyString_AsStringAndSize(key, &key_str_data, &key_str_size) != 0) return -1; + const string key_str(key_str_data, key_str_size); + const FieldDescriptor* descriptor = + message_descriptor->FindFieldByName(key_str); + if (descriptor != NULL) { + if (VisitCompositeField(descriptor, field, visitor) == -1) + return -1; + } } } // Visit extension fields. if (self->extensions != NULL) { + pos = 0; while (PyDict_Next(self->extensions->values, &pos, &key, &field)) { - const google::protobuf::FieldDescriptor* descriptor = - cmessage::GetExtensionDescriptor(key); + const FieldDescriptor* descriptor = cmessage::GetExtensionDescriptor(key); if (descriptor == NULL) return -1; if (VisitCompositeField(descriptor, field, visitor) == -1) @@ -180,6 +192,8 @@ int ForEachCompositeField(CMessage* self, Visitor visitor) { // --------------------------------------------------------------------- +static DynamicMessageFactory* message_factory; + // Constants used for integer type range checking. PyObject* kPythonZero; PyObject* kint32min_py; @@ -198,17 +212,8 @@ PyObject* PickleError_class; static PyObject* kDESCRIPTOR; static PyObject* k_cdescriptor; static PyObject* kfull_name; -static PyObject* kname; -static PyObject* kextensions_by_name; static PyObject* k_extensions_by_name; static PyObject* k_extensions_by_number; -static PyObject* kfields_by_name; - -static PyDescriptorPool* descriptor_pool; - -PyDescriptorPool* GetDescriptorPool() { - return descriptor_pool; -} /* Is 64bit */ void FormatTypeError(PyObject* arg, char* expected_types) { @@ -305,14 +310,14 @@ bool CheckAndGetBool(PyObject* arg, bool* value) { } bool CheckAndSetString( - PyObject* arg, google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, - const google::protobuf::Reflection* reflection, + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, bool append, int index) { - GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING || - descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES); - if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + GOOGLE_DCHECK(descriptor->type() == FieldDescriptor::TYPE_STRING || + descriptor->type() == FieldDescriptor::TYPE_BYTES); + if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (!PyBytes_Check(arg) && !PyUnicode_Check(arg)) { FormatTypeError(arg, "bytes, unicode"); return false; @@ -339,7 +344,7 @@ bool CheckAndSetString( } PyObject* encoded_string = NULL; - if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { + if (descriptor->type() == FieldDescriptor::TYPE_STRING) { if (PyBytes_Check(arg)) { // The bytes were already validated as correctly encoded UTF-8 above. encoded_string = arg; // Already encoded. @@ -376,9 +381,8 @@ bool CheckAndSetString( return true; } -PyObject* ToStringObject( - const google::protobuf::FieldDescriptor* descriptor, string value) { - if (descriptor->type() != google::protobuf::FieldDescriptor::TYPE_STRING) { +PyObject* ToStringObject(const FieldDescriptor* descriptor, string value) { + if (descriptor->type() != FieldDescriptor::TYPE_STRING) { return PyBytes_FromStringAndSize(value.c_str(), value.length()); } @@ -394,8 +398,8 @@ PyObject* ToStringObject( return result; } -bool CheckFieldBelongsToMessage(const google::protobuf::FieldDescriptor* field_descriptor, - const google::protobuf::Message* message) { +bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, + const Message* message) { if (message->GetDescriptor() == field_descriptor->containing_type()) { return true; } @@ -405,16 +409,18 @@ bool CheckFieldBelongsToMessage(const google::protobuf::FieldDescriptor* field_d return false; } -google::protobuf::DynamicMessageFactory* global_message_factory; - namespace cmessage { +DynamicMessageFactory* GetMessageFactory() { + return message_factory; +} + static int MaybeReleaseOverlappingOneofField( CMessage* cmessage, - const google::protobuf::FieldDescriptor* field) { + const FieldDescriptor* field) { #ifdef GOOGLE_PROTOBUF_HAS_ONEOF - google::protobuf::Message* message = cmessage->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = cmessage->message; + const Reflection* reflection = message->GetReflection(); if (!field->containing_oneof() || !reflection->HasOneof(*message, field->containing_oneof()) || reflection->HasField(*message, field)) { @@ -425,13 +431,13 @@ static int MaybeReleaseOverlappingOneofField( const OneofDescriptor* oneof = field->containing_oneof(); const FieldDescriptor* existing_field = reflection->GetOneofFieldDescriptor(*message, oneof); - if (existing_field->cpp_type() != google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (existing_field->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { // Non-message fields don't need to be released. return 0; } const char* field_name = existing_field->name().c_str(); - PyObject* child_message = PyDict_GetItemString( - cmessage->composite_fields, field_name); + PyObject* child_message = cmessage->composite_fields ? + PyDict_GetItemString(cmessage->composite_fields, field_name) : NULL; if (child_message == NULL) { // No python reference to this field so no need to release. return 0; @@ -450,21 +456,21 @@ static int MaybeReleaseOverlappingOneofField( // --------------------------------------------------------------------- // Making a message writable -static google::protobuf::Message* GetMutableMessage( +static Message* GetMutableMessage( CMessage* parent, - const google::protobuf::FieldDescriptor* parent_field) { - google::protobuf::Message* parent_message = parent->message; - const google::protobuf::Reflection* reflection = parent_message->GetReflection(); + const FieldDescriptor* parent_field) { + Message* parent_message = parent->message; + const Reflection* reflection = parent_message->GetReflection(); if (MaybeReleaseOverlappingOneofField(parent, parent_field) < 0) { return NULL; } return reflection->MutableMessage( - parent_message, parent_field, global_message_factory); + parent_message, parent_field, message_factory); } struct FixupMessageReference : public ChildVisitor { // message must outlive this object. - explicit FixupMessageReference(google::protobuf::Message* message) : + explicit FixupMessageReference(Message* message) : message_(message) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { @@ -478,7 +484,7 @@ struct FixupMessageReference : public ChildVisitor { } private: - google::protobuf::Message* message_; + Message* message_; }; int AssureWritable(CMessage* self) { @@ -490,7 +496,7 @@ int AssureWritable(CMessage* self) { // If parent is NULL but we are trying to modify a read-only message, this // is a reference to a constant default instance that needs to be replaced // with a mutable top-level message. - const Message* prototype = global_message_factory->GetPrototype( + const Message* prototype = message_factory->GetPrototype( self->message->GetDescriptor()); self->message = prototype->New(); self->owner.reset(self->message); @@ -500,8 +506,8 @@ int AssureWritable(CMessage* self) { return -1; // Make self->message writable. - google::protobuf::Message* parent_message = self->parent->message; - google::protobuf::Message* mutable_message = GetMutableMessage( + Message* parent_message = self->parent->message; + Message* mutable_message = GetMutableMessage( self->parent, self->parent_field_descriptor); if (mutable_message == NULL) { @@ -528,38 +534,35 @@ int AssureWritable(CMessage* self) { // Retrieve the C++ Descriptor of a message class. // On error, returns NULL with an exception set. -static const google::protobuf::Descriptor* GetMessageDescriptor(PyTypeObject* cls) { +static const Descriptor* GetMessageDescriptor(PyTypeObject* cls) { ScopedPyObjectPtr descriptor(PyObject_GetAttr( reinterpret_cast(cls), kDESCRIPTOR)); if (descriptor == NULL) { PyErr_SetString(PyExc_TypeError, "Message class has no DESCRIPTOR"); return NULL; } - ScopedPyObjectPtr cdescriptor(PyObject_GetAttr(descriptor, k_cdescriptor)); - if (cdescriptor == NULL) { - PyErr_SetString(PyExc_TypeError, "Unregistered message."); + if (!PyObject_TypeCheck(descriptor, &PyMessageDescriptor_Type)) { + PyErr_Format(PyExc_TypeError, "Expected a message Descriptor, got %s", + descriptor->ob_type->tp_name); return NULL; } - if (!PyObject_TypeCheck(cdescriptor, &CMessageDescriptor_Type)) { - PyErr_SetString(PyExc_TypeError, "Not a CMessageDescriptor"); - return NULL; - } - return reinterpret_cast(cdescriptor.get())->descriptor; + return PyMessageDescriptor_AsDescriptor(descriptor); } // Retrieve a C++ FieldDescriptor for a message attribute. // The C++ message must be valid. // TODO(amauryfa): This function should stay internal, because exception // handling is not consistent. -static const google::protobuf::FieldDescriptor* GetFieldDescriptor( +static const FieldDescriptor* GetFieldDescriptor( CMessage* self, PyObject* name) { - const google::protobuf::Descriptor *message_descriptor = self->message->GetDescriptor(); - const char* field_name = PyString_AsString(name); - if (field_name == NULL) { + const Descriptor *message_descriptor = self->message->GetDescriptor(); + char* field_name; + Py_ssize_t size; + if (PyString_AsStringAndSize(name, &field_name, &size) < 0) { return NULL; } - const google::protobuf::FieldDescriptor *field_descriptor = - message_descriptor->FindFieldByName(field_name); + const FieldDescriptor *field_descriptor = + message_descriptor->FindFieldByName(string(field_name, size)); if (field_descriptor == NULL) { // Note: No exception is set! return NULL; @@ -568,19 +571,16 @@ static const google::protobuf::FieldDescriptor* GetFieldDescriptor( } // Retrieve a C++ FieldDescriptor for an extension handle. -const google::protobuf::FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { - ScopedPyObjectPtr cdescriptor( - PyObject_GetAttrString(extension, "_cdescriptor")); - if (cdescriptor == NULL) { - PyErr_SetString(PyExc_KeyError, "Unregistered extension."); +const FieldDescriptor* GetExtensionDescriptor(PyObject* extension) { + ScopedPyObjectPtr cdescriptor; + if (!PyObject_TypeCheck(extension, &PyFieldDescriptor_Type)) { + // Most callers consider extensions as a plain dictionary. We should + // allow input which is not a field descriptor, and simply pretend it does + // not exist. + PyErr_SetObject(PyExc_KeyError, extension); return NULL; } - if (!PyObject_TypeCheck(cdescriptor, &CFieldDescriptor_Type)) { - PyErr_SetString(PyExc_TypeError, "Not a CFieldDescriptor"); - Py_DECREF(cdescriptor); - return NULL; - } - return reinterpret_cast(cdescriptor.get())->descriptor; + return PyFieldDescriptor_AsDescriptor(extension); } // If cmessage_list is not NULL, this function releases values into the @@ -588,12 +588,12 @@ const google::protobuf::FieldDescriptor* GetExtensionDescriptor(PyObject* extens // needs to do this to make sure CMessages stay alive if they're still // referenced after deletion. Repeated scalar container doesn't need to worry. int InternalDeleteRepeatedField( - google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, + Message* message, + const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list) { Py_ssize_t length, from, to, step, slice_length; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); int min, max; length = reflection->FieldSize(*message, field_descriptor); @@ -690,18 +690,18 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { PyErr_SetString(PyExc_ValueError, "Field name must be a string"); return -1; } - const google::protobuf::FieldDescriptor* descriptor = GetFieldDescriptor(self, name); + const FieldDescriptor* descriptor = GetFieldDescriptor(self, name); if (descriptor == NULL) { PyErr_Format(PyExc_ValueError, "Protocol message has no \"%s\" field.", PyString_AsString(name)); return -1; } - if (descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (descriptor->label() == FieldDescriptor::LABEL_REPEATED) { ScopedPyObjectPtr container(GetAttr(self, name)); if (container == NULL) { return -1; } - if (descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { if (repeated_composite_container::Extend( reinterpret_cast(container.get()), value) @@ -716,8 +716,7 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { return -1; } } - } else if (descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + } else if (descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { ScopedPyObjectPtr message(GetAttr(self, name)); if (message == NULL) { return -1; @@ -737,8 +736,7 @@ int InitAttributes(CMessage* self, PyObject* kwargs) { // Allocates an incomplete Python Message: the caller must fill self->message, // self->owner and eventually self->parent. -CMessage* NewEmptyMessage(PyObject* type, - const google::protobuf::Descriptor *descriptor) { +CMessage* NewEmptyMessage(PyObject* type, const Descriptor *descriptor) { CMessage* self = reinterpret_cast( PyType_GenericAlloc(reinterpret_cast(type), 0)); if (self == NULL) { @@ -751,10 +749,7 @@ CMessage* NewEmptyMessage(PyObject* type, self->read_only = false; self->extensions = NULL; - self->composite_fields = PyDict_New(); - if (self->composite_fields == NULL) { - return NULL; - } + self->composite_fields = NULL; // If there are extension_ranges, the message is "extendable". Allocate a // dictionary to store the extension fields. @@ -776,12 +771,12 @@ CMessage* NewEmptyMessage(PyObject* type, static PyObject* New(PyTypeObject* type, PyObject* unused_args, PyObject* unused_kwargs) { // Retrieve the message descriptor and the default instance (=prototype). - const google::protobuf::Descriptor* message_descriptor = GetMessageDescriptor(type); + const Descriptor* message_descriptor = GetMessageDescriptor(type); if (message_descriptor == NULL) { return NULL; } - const google::protobuf::Message* default_message = - global_message_factory->GetPrototype(message_descriptor); + const Message* default_message = + message_factory->GetPrototype(message_descriptor); if (default_message == NULL) { PyErr_SetString(PyExc_TypeError, message_descriptor->full_name().c_str()); return NULL; @@ -836,7 +831,7 @@ struct ClearWeakReferences : public ChildVisitor { } int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { cmessage->parent = NULL; return 0; } @@ -886,12 +881,12 @@ PyObject* IsInitialized(CMessage* self, PyObject* args) { } PyObject* HasFieldByDescriptor( - CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) { - google::protobuf::Message* message = self->message; + CMessage* self, const FieldDescriptor* field_descriptor) { + Message* message = self->message; if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return NULL; } - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyErr_SetString(PyExc_KeyError, "Field is repeated. A singular method is required."); return NULL; @@ -901,42 +896,78 @@ PyObject* HasFieldByDescriptor( return PyBool_FromLong(has_field ? 1 : 0); } -const google::protobuf::FieldDescriptor* FindFieldWithOneofs( - const google::protobuf::Message* message, const char* field_name, bool* in_oneof) { - const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); - const google::protobuf::FieldDescriptor* field_descriptor = +const FieldDescriptor* FindFieldWithOneofs( + const Message* message, const string& field_name, bool* in_oneof) { + *in_oneof = false; + const Descriptor* descriptor = message->GetDescriptor(); + const FieldDescriptor* field_descriptor = descriptor->FindFieldByName(field_name); - if (field_descriptor == NULL) { - const google::protobuf::OneofDescriptor* oneof_desc = - message->GetDescriptor()->FindOneofByName(field_name); - if (oneof_desc == NULL) { - *in_oneof = false; - return NULL; - } else { - *in_oneof = true; - return message->GetReflection()->GetOneofFieldDescriptor( - *message, oneof_desc); + if (field_descriptor != NULL) { + return field_descriptor; + } + const OneofDescriptor* oneof_desc = + descriptor->FindOneofByName(field_name); + if (oneof_desc != NULL) { + *in_oneof = true; + return message->GetReflection()->GetOneofFieldDescriptor(*message, + oneof_desc); + } + return NULL; +} + +bool CheckHasPresence(const FieldDescriptor* field_descriptor, bool in_oneof) { + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyErr_Format(PyExc_ValueError, + "Protocol message has no singular \"%s\" field.", + field_descriptor->name().c_str()); + return false; + } + + if (field_descriptor->file()->syntax() == FileDescriptor::SYNTAX_PROTO3) { + // HasField() for a oneof *itself* isn't supported. + if (in_oneof) { + PyErr_Format(PyExc_ValueError, + "Can't test oneof field \"%s\" for presence in proto3, use " + "WhichOneof instead.", + field_descriptor->containing_oneof()->name().c_str()); + return false; + } + + // ...but HasField() for fields *in* a oneof is supported. + if (field_descriptor->containing_oneof() != NULL) { + return true; + } + + if (field_descriptor->cpp_type() != FieldDescriptor::CPPTYPE_MESSAGE) { + PyErr_Format( + PyExc_ValueError, + "Can't test non-submessage field \"%s\" for presence in proto3.", + field_descriptor->name().c_str()); + return false; } } - return field_descriptor; + + return true; } PyObject* HasField(CMessage* self, PyObject* arg) { -#if PY_MAJOR_VERSION < 3 char* field_name; - if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { + Py_ssize_t size; +#if PY_MAJOR_VERSION < 3 + if (PyString_AsStringAndSize(arg, &field_name, &size) < 0) { + return NULL; + } #else - char* field_name = PyUnicode_AsUTF8(arg); + field_name = PyUnicode_AsUTF8AndSize(arg, &size); if (!field_name) { -#endif return NULL; } +#endif - google::protobuf::Message* message = self->message; - const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + Message* message = self->message; bool is_in_oneof; - const google::protobuf::FieldDescriptor* field_descriptor = - FindFieldWithOneofs(message, field_name, &is_in_oneof); + const FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof); if (field_descriptor == NULL) { if (!is_in_oneof) { PyErr_Format(PyExc_ValueError, "Unknown field %s.", field_name); @@ -946,28 +977,28 @@ PyObject* HasField(CMessage* self, PyObject* arg) { } } - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - PyErr_Format(PyExc_ValueError, - "Protocol message has no singular \"%s\" field.", field_name); + if (!CheckHasPresence(field_descriptor, is_in_oneof)) { return NULL; } - bool has_field = - message->GetReflection()->HasField(*message, field_descriptor); - if (!has_field && field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { - // We may have an invalid enum value stored in the UnknownFieldSet and need - // to check presence in there as well. - const google::protobuf::UnknownFieldSet& unknown_field_set = + if (message->GetReflection()->HasField(*message, field_descriptor)) { + Py_RETURN_TRUE; + } + if (!message->GetReflection()->SupportsUnknownEnumValues() && + field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + // Special case: Python HasField() differs in semantics from C++ + // slightly: we return HasField('enum_field') == true if there is + // an unknown enum value present. To implement this we have to + // look in the UnknownFieldSet. + const UnknownFieldSet& unknown_field_set = message->GetReflection()->GetUnknownFields(*message); for (int i = 0; i < unknown_field_set.field_count(); ++i) { if (unknown_field_set.field(i).number() == field_descriptor->number()) { Py_RETURN_TRUE; } } - Py_RETURN_FALSE; } - return PyBool_FromLong(has_field ? 1 : 0); + Py_RETURN_FALSE; } PyObject* ClearExtension(CMessage* self, PyObject* arg) { @@ -1034,7 +1065,7 @@ struct SetOwnerVisitor : public ChildVisitor { } int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { return SetOwner(cmessage, new_owner_); } @@ -1053,18 +1084,17 @@ int SetOwner(CMessage* self, const shared_ptr& new_owner) { // Releases the message specified by 'field' and returns the // pointer. If the field does not exist a new message is created using // 'descriptor'. The caller takes ownership of the returned pointer. -Message* ReleaseMessage(google::protobuf::Message* message, - const google::protobuf::Descriptor* descriptor, - const google::protobuf::FieldDescriptor* field_descriptor) { +Message* ReleaseMessage(Message* message, + const Descriptor* descriptor, + const FieldDescriptor* field_descriptor) { Message* released_message = message->GetReflection()->ReleaseMessage( - message, field_descriptor, global_message_factory); + message, field_descriptor, message_factory); // ReleaseMessage will return NULL which differs from // child_cmessage->message, if the field does not exist. In this case, // the latter points to the default instance via a const_cast<>, so we // have to reset it to a new mutable object since we are taking ownership. if (released_message == NULL) { - const Message* prototype = global_message_factory->GetPrototype( - descriptor); + const Message* prototype = message_factory->GetPrototype(descriptor); GOOGLE_DCHECK(prototype != NULL); released_message = prototype->New(); } @@ -1072,8 +1102,8 @@ Message* ReleaseMessage(google::protobuf::Message* message, return released_message; } -int ReleaseSubMessage(google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, +int ReleaseSubMessage(Message* message, + const FieldDescriptor* field_descriptor, CMessage* child_cmessage) { // Release the Message shared_ptr released_message(ReleaseMessage( @@ -1089,7 +1119,7 @@ int ReleaseSubMessage(google::protobuf::Message* message, struct ReleaseChild : public ChildVisitor { // message must outlive this object. - explicit ReleaseChild(google::protobuf::Message* parent_message) : + explicit ReleaseChild(Message* parent_message) : parent_message_(parent_message) {} int VisitRepeatedCompositeContainer(RepeatedCompositeContainer* container) { @@ -1103,38 +1133,27 @@ struct ReleaseChild : public ChildVisitor { } int VisitCMessage(CMessage* cmessage, - const google::protobuf::FieldDescriptor* field_descriptor) { + const FieldDescriptor* field_descriptor) { return ReleaseSubMessage(parent_message_, field_descriptor, reinterpret_cast(cmessage)); } - google::protobuf::Message* parent_message_; + Message* parent_message_; }; int InternalReleaseFieldByDescriptor( - const google::protobuf::FieldDescriptor* field_descriptor, + const FieldDescriptor* field_descriptor, PyObject* composite_field, - google::protobuf::Message* parent_message) { + Message* parent_message) { return VisitCompositeField( field_descriptor, composite_field, ReleaseChild(parent_message)); } -int InternalReleaseField(CMessage* self, PyObject* composite_field, - PyObject* name) { - const google::protobuf::FieldDescriptor* descriptor = GetFieldDescriptor(self, name); - if (descriptor != NULL) { - return InternalReleaseFieldByDescriptor( - descriptor, composite_field, self->message); - } - - return 0; -} - PyObject* ClearFieldByDescriptor( CMessage* self, - const google::protobuf::FieldDescriptor* descriptor) { + const FieldDescriptor* descriptor) { if (!CheckFieldBelongsToMessage(descriptor, self->message)) { return NULL; } @@ -1144,25 +1163,23 @@ PyObject* ClearFieldByDescriptor( } PyObject* ClearField(CMessage* self, PyObject* arg) { - char* field_name; if (!PyString_Check(arg)) { PyErr_SetString(PyExc_TypeError, "field name must be a string"); return NULL; } #if PY_MAJOR_VERSION < 3 - if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { - return NULL; - } + const char* field_name = PyString_AS_STRING(arg); + Py_ssize_t size = PyString_GET_SIZE(arg); #else - field_name = PyUnicode_AsUTF8(arg); + Py_ssize_t size; + const char* field_name = PyUnicode_AsUTF8AndSize(arg, &size); #endif AssureWritable(self); - google::protobuf::Message* message = self->message; - const google::protobuf::Descriptor* descriptor = message->GetDescriptor(); + Message* message = self->message; ScopedPyObjectPtr arg_in_oneof; bool is_in_oneof; - const google::protobuf::FieldDescriptor* field_descriptor = - FindFieldWithOneofs(message, field_name, &is_in_oneof); + const FieldDescriptor* field_descriptor = + FindFieldWithOneofs(message, string(field_name, size), &is_in_oneof); if (field_descriptor == NULL) { if (!is_in_oneof) { PyErr_Format(PyExc_ValueError, @@ -1172,24 +1189,27 @@ PyObject* ClearField(CMessage* self, PyObject* arg) { Py_RETURN_NONE; } } else if (is_in_oneof) { - arg_in_oneof.reset(PyString_FromString(field_descriptor->name().c_str())); + const string& name = field_descriptor->name(); + arg_in_oneof.reset(PyString_FromStringAndSize(name.c_str(), name.size())); arg = arg_in_oneof.get(); } - PyObject* composite_field = PyDict_GetItem(self->composite_fields, - arg); + PyObject* composite_field = self->composite_fields ? + PyDict_GetItem(self->composite_fields, arg) : NULL; // Only release the field if there's a possibility that there are // references to it. if (composite_field != NULL) { - if (InternalReleaseField(self, composite_field, arg) < 0) { + if (InternalReleaseFieldByDescriptor(field_descriptor, + composite_field, message) < 0) { return NULL; } PyDict_DelItem(self->composite_fields, arg); } message->GetReflection()->ClearField(message, field_descriptor); - if (field_descriptor->cpp_type() == google::protobuf::FieldDescriptor::CPPTYPE_ENUM) { - google::protobuf::UnknownFieldSet* unknown_field_set = + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_ENUM && + !message->GetReflection()->SupportsUnknownEnumValues()) { + UnknownFieldSet* unknown_field_set = message->GetReflection()->MutableUnknownFields(message); unknown_field_set->DeleteByNumber(field_descriptor->number()); } @@ -1212,7 +1232,9 @@ PyObject* Clear(CMessage* self) { } self->extensions = extension_dict; } - PyDict_Clear(self->composite_fields); + if (self->composite_fields) { + PyDict_Clear(self->composite_fields); + } self->message->Clear(); Py_RETURN_NONE; } @@ -1268,7 +1290,7 @@ static PyObject* SerializePartialToString(CMessage* self) { // Formats proto fields for ascii dumps using python formatting functions where // appropriate. -class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValuePrinter { +class PythonFieldValuePrinter : public TextFormat::FieldValuePrinter { public: PythonFieldValuePrinter() : float_holder_(PyFloat_FromDouble(0)) {} @@ -1301,7 +1323,7 @@ class PythonFieldValuePrinter : public google::protobuf::TextFormat::FieldValueP }; static PyObject* ToStr(CMessage* self) { - google::protobuf::TextFormat::Printer printer; + TextFormat::Printer printer; // Passes ownership printer.SetDefaultFieldValuePrinter(new PythonFieldValuePrinter()); printer.SetHideUnknownFields(true); @@ -1383,9 +1405,9 @@ static PyObject* MergeFromString(CMessage* self, PyObject* arg) { } AssureWritable(self); - google::protobuf::io::CodedInputStream input( + io::CodedInputStream input( reinterpret_cast(data), data_length); - input.SetExtensionRegistry(descriptor_pool->pool, global_message_factory); + input.SetExtensionRegistry(GetDescriptorPool()->pool, message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(input.CurrentPosition()); @@ -1412,10 +1434,22 @@ static PyObject* RegisterExtension(PyObject* cls, if (message_descriptor == NULL) { return NULL; } - if (PyObject_SetAttrString(extension_handle, "containing_type", - message_descriptor) < 0) { + + const FieldDescriptor* descriptor = + GetExtensionDescriptor(extension_handle); + if (descriptor == NULL) { return NULL; } + const Descriptor* cmessage_descriptor = GetMessageDescriptor( + reinterpret_cast(cls)); + + if (cmessage_descriptor != descriptor->containing_type()) { + if (PyObject_SetAttrString(extension_handle, "containing_type", + message_descriptor) < 0) { + return NULL; + } + } + ScopedPyObjectPtr extensions_by_name( PyObject_GetAttr(cls, k_extensions_by_name)); if (extensions_by_name == NULL) { @@ -1426,6 +1460,20 @@ static PyObject* RegisterExtension(PyObject* cls, if (full_name == NULL) { return NULL; } + + // If the extension was already registered, check that it is the same. + PyObject* existing_extension = PyDict_GetItem(extensions_by_name, full_name); + if (existing_extension != NULL) { + const FieldDescriptor* existing_extension_descriptor = + GetExtensionDescriptor(existing_extension); + if (existing_extension_descriptor != descriptor) { + PyErr_SetString(PyExc_ValueError, "Double registration of Extensions"); + return NULL; + } + // Nothing else to do. + Py_RETURN_NONE; + } + if (PyDict_SetItem(extensions_by_name, full_name, extension_handle) < 0) { return NULL; } @@ -1445,17 +1493,12 @@ static PyObject* RegisterExtension(PyObject* cls, return NULL; } - const google::protobuf::FieldDescriptor* descriptor = - GetExtensionDescriptor(extension_handle); - if (descriptor == NULL) { - return NULL; - } // Check if it's a message set if (descriptor->is_extension() && descriptor->containing_type()->options().message_set_wire_format() && - descriptor->type() == google::protobuf::FieldDescriptor::TYPE_MESSAGE && + descriptor->type() == FieldDescriptor::TYPE_MESSAGE && descriptor->message_type() == descriptor->extension_scope() && - descriptor->label() == google::protobuf::FieldDescriptor::LABEL_OPTIONAL) { + descriptor->label() == FieldDescriptor::LABEL_OPTIONAL) { ScopedPyObjectPtr message_name(PyString_FromStringAndSize( descriptor->message_type()->full_name().c_str(), descriptor->message_type()->full_name().size())); @@ -1474,53 +1517,36 @@ static PyObject* SetInParent(CMessage* self, PyObject* args) { } static PyObject* WhichOneof(CMessage* self, PyObject* arg) { - char* oneof_name; - if (!PyString_Check(arg)) { - PyErr_SetString(PyExc_TypeError, "field name must be a string"); + Py_ssize_t name_size; + char *name_data; + if (PyString_AsStringAndSize(arg, &name_data, &name_size) < 0) return NULL; - } - oneof_name = PyString_AsString(arg); - if (oneof_name == NULL) { - return NULL; - } - const google::protobuf::OneofDescriptor* oneof_desc = + string oneof_name = string(name_data, name_size); + const OneofDescriptor* oneof_desc = self->message->GetDescriptor()->FindOneofByName(oneof_name); if (oneof_desc == NULL) { PyErr_Format(PyExc_ValueError, - "Protocol message has no oneof \"%s\" field.", oneof_name); + "Protocol message has no oneof \"%s\" field.", + oneof_name.c_str()); return NULL; } - const google::protobuf::FieldDescriptor* field_in_oneof = + const FieldDescriptor* field_in_oneof = self->message->GetReflection()->GetOneofFieldDescriptor( *self->message, oneof_desc); if (field_in_oneof == NULL) { Py_RETURN_NONE; } else { - return PyString_FromString(field_in_oneof->name().c_str()); + const string& name = field_in_oneof->name(); + return PyString_FromStringAndSize(name.c_str(), name.size()); } } static PyObject* ListFields(CMessage* self) { - vector fields; + vector fields; self->message->GetReflection()->ListFields(*self->message, &fields); - PyObject* descriptor = PyDict_GetItem(Py_TYPE(self)->tp_dict, kDESCRIPTOR); - if (descriptor == NULL) { - return NULL; - } - ScopedPyObjectPtr fields_by_name( - PyObject_GetAttr(descriptor, kfields_by_name)); - if (fields_by_name == NULL) { - return NULL; - } - ScopedPyObjectPtr extensions_by_name(PyObject_GetAttr( - reinterpret_cast(Py_TYPE(self)), k_extensions_by_name)); - if (extensions_by_name == NULL) { - PyErr_SetString(PyExc_ValueError, "no extensionsbyname"); - return NULL; - } // Normally, the list will be exactly the size of the fields. - PyObject* all_fields = PyList_New(fields.size()); + ScopedPyObjectPtr all_fields(PyList_New(fields.size())); if (all_fields == NULL) { return NULL; } @@ -1532,35 +1558,34 @@ static PyObject* ListFields(CMessage* self) { for (Py_ssize_t i = 0; i < fields.size(); ++i) { ScopedPyObjectPtr t(PyTuple_New(2)); if (t == NULL) { - Py_DECREF(all_fields); return NULL; } if (fields[i]->is_extension()) { - const string& field_name = fields[i]->full_name(); - PyObject* extension_field = PyDict_GetItemString(extensions_by_name, - field_name.c_str()); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_New(fields[i])); if (extension_field == NULL) { - // If we couldn't fetch extension_field, it means the module that - // defines this extension has not been explicitly imported in Python - // code, and the extension hasn't been registered. There's nothing much - // we can do about this, so just skip it in the output to match the - // behavior of the python implementation. + return NULL; + } + // With C++ descriptors, the field can always be retrieved, but for + // unknown extensions which have not been imported in Python code, there + // is no message class and we cannot retrieve the value. + // TODO(amauryfa): consider building the class on the fly! + if (fields[i]->message_type() != NULL && + cdescriptor_pool::GetMessageClass( + GetDescriptorPool(), fields[i]->message_type()) == NULL) { + PyErr_Clear(); continue; } PyObject* extensions = reinterpret_cast(self->extensions); if (extensions == NULL) { - Py_DECREF(all_fields); return NULL; } // 'extension' reference later stolen by PyTuple_SET_ITEM. PyObject* extension = PyObject_GetItem(extensions, extension_field); if (extension == NULL) { - Py_DECREF(all_fields); return NULL; } - Py_INCREF(extension_field); - PyTuple_SET_ITEM(t.get(), 0, extension_field); + PyTuple_SET_ITEM(t.get(), 0, extension_field.release()); // Steals reference to 'extension' PyTuple_SET_ITEM(t.get(), 1, extension); } else { @@ -1569,35 +1594,30 @@ static PyObject* ListFields(CMessage* self) { field_name.c_str(), field_name.length())); if (py_field_name == NULL) { PyErr_SetString(PyExc_ValueError, "bad string"); - Py_DECREF(all_fields); return NULL; } - PyObject* field_descriptor = - PyDict_GetItem(fields_by_name, py_field_name); + ScopedPyObjectPtr field_descriptor(PyFieldDescriptor_New(fields[i])); if (field_descriptor == NULL) { - Py_DECREF(all_fields); return NULL; } PyObject* field_value = GetAttr(self, py_field_name); if (field_value == NULL) { PyErr_SetObject(PyExc_ValueError, py_field_name); - Py_DECREF(all_fields); return NULL; } - Py_INCREF(field_descriptor); - PyTuple_SET_ITEM(t.get(), 0, field_descriptor); + PyTuple_SET_ITEM(t.get(), 0, field_descriptor.release()); PyTuple_SET_ITEM(t.get(), 1, field_value); } - PyList_SET_ITEM(all_fields, actual_size, t.release()); + PyList_SET_ITEM(all_fields.get(), actual_size, t.release()); ++actual_size; } - Py_SIZE(all_fields) = actual_size; - return all_fields; + Py_SIZE(all_fields.get()) = actual_size; + return all_fields.release(); } PyObject* FindInitializationErrors(CMessage* self) { - google::protobuf::Message* message = self->message; + Message* message = self->message; vector errors; message->FindInitializationErrors(&errors); @@ -1645,9 +1665,9 @@ static PyObject* RichCompare(CMessage* self, PyObject* other, int opid) { PyObject* InternalGetScalar( CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor) { - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const FieldDescriptor* field_descriptor) { + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return NULL; @@ -1655,50 +1675,51 @@ PyObject* InternalGetScalar( PyObject* result = NULL; switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { int32 value = reflection->GetInt32(*message, field_descriptor); result = PyInt_FromLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { int64 value = reflection->GetInt64(*message, field_descriptor); result = PyLong_FromLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { uint32 value = reflection->GetUInt32(*message, field_descriptor); result = PyInt_FromSize_t(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { uint64 value = reflection->GetUInt64(*message, field_descriptor); result = PyLong_FromUnsignedLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { float value = reflection->GetFloat(*message, field_descriptor); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { double value = reflection->GetDouble(*message, field_descriptor); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { bool value = reflection->GetBool(*message, field_descriptor); result = PyBool_FromLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { string value = reflection->GetString(*message, field_descriptor); result = ToStringObject(field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - if (!message->GetReflection()->HasField(*message, field_descriptor)) { + case FieldDescriptor::CPPTYPE_ENUM: { + if (!message->GetReflection()->SupportsUnknownEnumValues() && + !message->GetReflection()->HasField(*message, field_descriptor)) { // Look for the value in the unknown fields. - google::protobuf::UnknownFieldSet* unknown_field_set = + UnknownFieldSet* unknown_field_set = message->GetReflection()->MutableUnknownFields(message); for (int i = 0; i < unknown_field_set->field_count(); ++i) { if (unknown_field_set->field(i).number() == @@ -1710,7 +1731,7 @@ PyObject* InternalGetScalar( } if (result == NULL) { - const google::protobuf::EnumValueDescriptor* enum_value = + const EnumValueDescriptor* enum_value = message->GetReflection()->GetEnum(*message, field_descriptor); result = PyInt_FromLong(enum_value->number()); } @@ -1726,13 +1747,13 @@ PyObject* InternalGetScalar( } PyObject* InternalGetSubMessage( - CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor) { - const google::protobuf::Reflection* reflection = self->message->GetReflection(); - const google::protobuf::Message& sub_message = reflection->GetMessage( - *self->message, field_descriptor, global_message_factory); + CMessage* self, const FieldDescriptor* field_descriptor) { + const Reflection* reflection = self->message->GetReflection(); + const Message& sub_message = reflection->GetMessage( + *self->message, field_descriptor, message_factory); PyObject *message_class = cdescriptor_pool::GetMessageClass( - descriptor_pool, field_descriptor->message_type()); + GetDescriptorPool(), field_descriptor->message_type()); if (message_class == NULL) { return NULL; } @@ -1747,17 +1768,17 @@ PyObject* InternalGetSubMessage( cmsg->parent = self; cmsg->parent_field_descriptor = field_descriptor; cmsg->read_only = !reflection->HasField(*self->message, field_descriptor); - cmsg->message = const_cast(&sub_message); + cmsg->message = const_cast(&sub_message); return reinterpret_cast(cmsg); } int InternalSetScalar( CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor, + const FieldDescriptor* field_descriptor, PyObject* arg) { - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); if (!CheckFieldBelongsToMessage(field_descriptor, message)) { return -1; @@ -1768,59 +1789,62 @@ int InternalSetScalar( } switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(arg, value, -1); reflection->SetInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { GOOGLE_CHECK_GET_INT64(arg, value, -1); reflection->SetInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { GOOGLE_CHECK_GET_UINT32(arg, value, -1); reflection->SetUInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { GOOGLE_CHECK_GET_UINT64(arg, value, -1); reflection->SetUInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { GOOGLE_CHECK_GET_FLOAT(arg, value, -1); reflection->SetFloat(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); reflection->SetDouble(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { GOOGLE_CHECK_GET_BOOL(arg, value, -1); reflection->SetBool(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { if (!CheckAndSetString( arg, message, field_descriptor, reflection, false, -1)) { return -1; } break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + case FieldDescriptor::CPPTYPE_ENUM: { GOOGLE_CHECK_GET_INT32(arg, value, -1); - const google::protobuf::EnumDescriptor* enum_descriptor = - field_descriptor->enum_type(); - const google::protobuf::EnumValueDescriptor* enum_value = - enum_descriptor->FindValueByNumber(value); - if (enum_value != NULL) { - reflection->SetEnum(message, field_descriptor, enum_value); + if (reflection->SupportsUnknownEnumValues()) { + reflection->SetEnumValue(message, field_descriptor, value); } else { - PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); - return -1; + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetEnum(message, field_descriptor, enum_value); + } else { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %d", value); + return -1; + } } break; } @@ -1851,14 +1875,35 @@ PyObject* FromString(PyTypeObject* cls, PyObject* serialized) { return py_cmsg; } +// Add the number of a field descriptor to the containing message class. +// Equivalent to: +// _cls._FIELD_NUMBER = +static bool AddFieldNumberToClass( + PyObject* cls, const FieldDescriptor* field_descriptor) { + string constant_name = field_descriptor->name() + "_FIELD_NUMBER"; + UpperString(&constant_name); + ScopedPyObjectPtr attr_name(PyString_FromStringAndSize( + constant_name.c_str(), constant_name.size())); + if (attr_name == NULL) { + return false; + } + ScopedPyObjectPtr number(PyInt_FromLong(field_descriptor->number())); + if (number == NULL) { + return false; + } + if (PyObject_SetAttr(cls, attr_name, number) == -1) { + return false; + } + return true; +} + // Finalize the creation of the Message class. // Called from its metaclass: GeneratedProtocolMessageType.__init__(). -static PyObject* AddDescriptors(PyTypeObject* cls, - PyObject* descriptor) { - const google::protobuf::Descriptor* message_descriptor = +static PyObject* AddDescriptors(PyObject* cls, PyObject* descriptor) { + const Descriptor* message_descriptor = cdescriptor_pool::RegisterMessageClass( - descriptor_pool, reinterpret_cast(cls), descriptor); + GetDescriptorPool(), cls, descriptor); if (message_descriptor == NULL) { return NULL; } @@ -1867,169 +1912,80 @@ static PyObject* AddDescriptors(PyTypeObject* cls, // classes will register themselves in this class. if (message_descriptor->extension_range_count() > 0) { ScopedPyObjectPtr by_name(PyDict_New()); - if (PyObject_SetAttr(reinterpret_cast(cls), - k_extensions_by_name, by_name) < 0) { + if (PyObject_SetAttr(cls, k_extensions_by_name, by_name) < 0) { return NULL; } ScopedPyObjectPtr by_number(PyDict_New()); - if (PyObject_SetAttr(reinterpret_cast(cls), - k_extensions_by_number, by_number) < 0) { + if (PyObject_SetAttr(cls, k_extensions_by_number, by_number) < 0) { return NULL; } } - ScopedPyObjectPtr fields(PyObject_GetAttrString(descriptor, "fields")); - if (fields == NULL) { - return NULL; - } - - ScopedPyObjectPtr _NUMBER_string(PyString_FromString("_FIELD_NUMBER")); - if (_NUMBER_string == NULL) { - return NULL; - } - - const Py_ssize_t fields_size = PyList_GET_SIZE(fields.get()); - for (int i = 0; i < fields_size; ++i) { - PyObject* field = PyList_GET_ITEM(fields.get(), i); - ScopedPyObjectPtr field_name(PyObject_GetAttr(field, kname)); - ScopedPyObjectPtr full_field_name(PyObject_GetAttr(field, kfull_name)); - if (field_name == NULL || full_field_name == NULL) { - PyErr_SetString(PyExc_TypeError, "Name is null"); - return NULL; - } - - ScopedPyObjectPtr field_descriptor( - cdescriptor_pool::FindFieldByName(descriptor_pool, full_field_name)); - if (field_descriptor == NULL) { - PyErr_SetString(PyExc_TypeError, "Couldn't find field"); - return NULL; - } - CFieldDescriptor* cfield_descriptor = reinterpret_cast( - field_descriptor.get()); - - // The FieldDescriptor's name field might either be of type bytes or - // of type unicode, depending on whether the FieldDescriptor was - // parsed from a serialized message or read from the - // _pb2.py module. - ScopedPyObjectPtr field_name_upcased( - PyObject_CallMethod(field_name, "upper", NULL)); - if (field_name_upcased == NULL) { - return NULL; - } - - ScopedPyObjectPtr field_number_name(PyObject_CallMethod( - field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); - if (field_number_name == NULL) { - return NULL; - } - - ScopedPyObjectPtr number(PyInt_FromLong( - cfield_descriptor->descriptor->number())); - if (number == NULL) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast(cls), - field_number_name, number) == -1) { + // For each field set: cls._FIELD_NUMBER = + for (int i = 0; i < message_descriptor->field_count(); ++i) { + if (!AddFieldNumberToClass(cls, message_descriptor->field(i))) { return NULL; } } - // Enum Values - ScopedPyObjectPtr enum_types(PyObject_GetAttrString(descriptor, - "enum_types")); - if (enum_types == NULL) { - return NULL; - } - ScopedPyObjectPtr type_iter(PyObject_GetIter(enum_types)); - if (type_iter == NULL) { - return NULL; - } - ScopedPyObjectPtr enum_type; - while ((enum_type.reset(PyIter_Next(type_iter))) != NULL) { + // For each enum set cls. = EnumTypeWrapper(). + // + // The enum descriptor we get from + // .enum_types_by_name[name] + // which was built previously. + for (int i = 0; i < message_descriptor->enum_type_count(); ++i) { + const EnumDescriptor* enum_descriptor = message_descriptor->enum_type(i); + ScopedPyObjectPtr enum_type(PyEnumDescriptor_New(enum_descriptor)); + if (enum_type == NULL) { + return NULL; + } + // Add wrapped enum type to message class. ScopedPyObjectPtr wrapped(PyObject_CallFunctionObjArgs( EnumTypeWrapper_class, enum_type.get(), NULL)); if (wrapped == NULL) { return NULL; } - ScopedPyObjectPtr enum_name(PyObject_GetAttr(enum_type, kname)); - if (enum_name == NULL) { - return NULL; - } - if (PyObject_SetAttr(reinterpret_cast(cls), - enum_name, wrapped) == -1) { + if (PyObject_SetAttrString( + cls, enum_descriptor->name().c_str(), wrapped) == -1) { return NULL; } - ScopedPyObjectPtr enum_values(PyObject_GetAttrString(enum_type, "values")); - if (enum_values == NULL) { - return NULL; - } - ScopedPyObjectPtr values_iter(PyObject_GetIter(enum_values)); - if (values_iter == NULL) { - return NULL; - } - ScopedPyObjectPtr enum_value; - while ((enum_value.reset(PyIter_Next(values_iter))) != NULL) { - ScopedPyObjectPtr value_name(PyObject_GetAttr(enum_value, kname)); - if (value_name == NULL) { - return NULL; - } - ScopedPyObjectPtr value_number(PyObject_GetAttrString(enum_value, - "number")); + // For each enum value add cls. = + for (int j = 0; j < enum_descriptor->value_count(); ++j) { + const EnumValueDescriptor* enum_value_descriptor = + enum_descriptor->value(j); + ScopedPyObjectPtr value_number(PyInt_FromLong( + enum_value_descriptor->number())); if (value_number == NULL) { return NULL; } - if (PyObject_SetAttr(reinterpret_cast(cls), - value_name, value_number) == -1) { + if (PyObject_SetAttrString( + cls, enum_value_descriptor->name().c_str(), value_number) == -1) { return NULL; } } - if (PyErr_Occurred()) { // If PyIter_Next failed - return NULL; - } - } - if (PyErr_Occurred()) { // If PyIter_Next failed - return NULL; } - ScopedPyObjectPtr extension_dict( - PyObject_GetAttr(descriptor, kextensions_by_name)); - if (extension_dict == NULL || !PyDict_Check(extension_dict)) { - PyErr_SetString(PyExc_TypeError, "extensions_by_name not a dict"); - return NULL; - } - Py_ssize_t pos = 0; - PyObject* extension_name; - PyObject* extension_field; - - while (PyDict_Next(extension_dict, &pos, &extension_name, &extension_field)) { - if (PyObject_SetAttr(reinterpret_cast(cls), - extension_name, extension_field) == -1) { - return NULL; - } - const google::protobuf::FieldDescriptor* field_descriptor = - GetExtensionDescriptor(extension_field); - if (field_descriptor == NULL) { + // For each extension set cls. = . + // + // Extension descriptors come from + // .extensions_by_name[name] + // which was defined previously. + for (int i = 0; i < message_descriptor->extension_count(); ++i) { + const google::protobuf::FieldDescriptor* field = message_descriptor->extension(i); + ScopedPyObjectPtr extension_field(PyFieldDescriptor_New(field)); + if (extension_field == NULL) { return NULL; } - ScopedPyObjectPtr field_name_upcased( - PyObject_CallMethod(extension_name, "upper", NULL)); - if (field_name_upcased == NULL) { - return NULL; - } - ScopedPyObjectPtr field_number_name(PyObject_CallMethod( - field_name_upcased, "__add__", "(O)", _NUMBER_string.get())); - if (field_number_name == NULL) { - return NULL; - } - ScopedPyObjectPtr number(PyInt_FromLong( - field_descriptor->number())); - if (number == NULL) { + // Add the extension field to the message class. + if (PyObject_SetAttrString( + cls, field->name().c_str(), extension_field) == -1) { return NULL; } - if (PyObject_SetAttr(reinterpret_cast(cls), - field_number_name, number) == -1) { + + // For each extension set cls._FIELD_NUMBER = . + if (!AddFieldNumberToClass(cls, field)) { return NULL; } } @@ -2121,12 +2077,35 @@ PyObject* SetState(CMessage* self, PyObject* state) { } // CMessage static methods: +PyObject* _GetMessageDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindMessageByName(GetDescriptorPool(), arg); +} + PyObject* _GetFieldDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindFieldByName(descriptor_pool, arg); + return cdescriptor_pool::FindFieldByName(GetDescriptorPool(), arg); } PyObject* _GetExtensionDescriptor(PyObject* unused, PyObject* arg) { - return cdescriptor_pool::FindExtensionByName(descriptor_pool, arg); + return cdescriptor_pool::FindExtensionByName(GetDescriptorPool(), arg); +} + +PyObject* _GetEnumDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindEnumTypeByName(GetDescriptorPool(), arg); +} + +PyObject* _GetOneofDescriptor(PyObject* unused, PyObject* arg) { + return cdescriptor_pool::FindOneofByName(GetDescriptorPool(), arg); +} + +PyObject* _CheckCalledFromGeneratedFile(PyObject* unused, + PyObject* unused_arg) { + if (!_CalledFromGeneratedFile(1)) { + PyErr_SetString(PyExc_TypeError, + "Descriptors should not be created directly, " + "but only retrieved from their parent."); + return NULL; + } + Py_RETURN_NONE; } static PyMemberDef Members[] = { @@ -2191,91 +2170,102 @@ static PyMethodDef Methods[] = { // Static Methods. { "_BuildFile", (PyCFunction)Python_BuildFile, METH_O | METH_STATIC, "Registers a new protocol buffer file in the global C++ descriptor pool." }, + { "_GetMessageDescriptor", (PyCFunction)_GetMessageDescriptor, + METH_O | METH_STATIC, "Finds a message descriptor in the message pool." }, { "_GetFieldDescriptor", (PyCFunction)_GetFieldDescriptor, METH_O | METH_STATIC, "Finds a field descriptor in the message pool." }, { "_GetExtensionDescriptor", (PyCFunction)_GetExtensionDescriptor, METH_O | METH_STATIC, "Finds a extension descriptor in the message pool." }, + { "_GetEnumDescriptor", (PyCFunction)_GetEnumDescriptor, + METH_O | METH_STATIC, + "Finds an enum descriptor in the message pool." }, + { "_GetOneofDescriptor", (PyCFunction)_GetOneofDescriptor, + METH_O | METH_STATIC, + "Finds an oneof descriptor in the message pool." }, + { "_CheckCalledFromGeneratedFile", (PyCFunction)_CheckCalledFromGeneratedFile, + METH_NOARGS | METH_STATIC, + "Raises TypeError if the caller is not in a _pb2.py file."}, { NULL, NULL} }; +static bool SetCompositeField( + CMessage* self, PyObject* name, PyObject* value) { + if (self->composite_fields == NULL) { + self->composite_fields = PyDict_New(); + if (self->composite_fields == NULL) { + return false; + } + } + return PyDict_SetItem(self->composite_fields, name, value) == 0; +} + PyObject* GetAttr(CMessage* self, PyObject* name) { - PyObject* value = PyDict_GetItem(self->composite_fields, name); + PyObject* value = self->composite_fields ? + PyDict_GetItem(self->composite_fields, name) : NULL; if (value != NULL) { Py_INCREF(value); return value; } - const google::protobuf::FieldDescriptor* field_descriptor = GetFieldDescriptor( - self, name); - if (field_descriptor != NULL) { - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { - if (field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject *message_class = cdescriptor_pool::GetMessageClass( - descriptor_pool, field_descriptor->message_type()); - if (message_class == NULL) { - return NULL; - } - PyObject* py_container = repeated_composite_container::NewContainer( - self, field_descriptor, message_class); - if (py_container == NULL) { - return NULL; - } - if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { - Py_DECREF(py_container); - return NULL; - } - return py_container; - } else { - PyObject* py_container = repeated_scalar_container::NewContainer( - self, field_descriptor); - if (py_container == NULL) { - return NULL; - } - if (PyDict_SetItem(self->composite_fields, name, py_container) < 0) { - Py_DECREF(py_container); - return NULL; - } - return py_container; + const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); + if (field_descriptor == NULL) { + return CMessage_Type.tp_base->tp_getattro( + reinterpret_cast(self), name); + } + + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { + PyObject* py_container = NULL; + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject *message_class = cdescriptor_pool::GetMessageClass( + GetDescriptorPool(), field_descriptor->message_type()); + if (message_class == NULL) { + return NULL; } + py_container = repeated_composite_container::NewContainer( + self, field_descriptor, message_class); } else { - if (field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { - PyObject* sub_message = InternalGetSubMessage(self, field_descriptor); - if (PyDict_SetItem(self->composite_fields, name, sub_message) < 0) { - Py_DECREF(sub_message); - return NULL; - } - return sub_message; - } else { - return InternalGetScalar(self, field_descriptor); - } + py_container = repeated_scalar_container::NewContainer( + self, field_descriptor); + } + if (py_container == NULL) { + return NULL; + } + if (!SetCompositeField(self, name, py_container)) { + Py_DECREF(py_container); + return NULL; + } + return py_container; + } + + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + PyObject* sub_message = InternalGetSubMessage(self, field_descriptor); + if (!SetCompositeField(self, name, sub_message)) { + Py_DECREF(sub_message); + return NULL; } + return sub_message; } - return CMessage_Type.tp_base->tp_getattro(reinterpret_cast(self), - name); + return InternalGetScalar(self, field_descriptor); } int SetAttr(CMessage* self, PyObject* name, PyObject* value) { - if (PyDict_Contains(self->composite_fields, name)) { + if (self->composite_fields && PyDict_Contains(self->composite_fields, name)) { PyErr_SetString(PyExc_TypeError, "Can't set composite field"); return -1; } - const google::protobuf::FieldDescriptor* field_descriptor = - GetFieldDescriptor(self, name); + const FieldDescriptor* field_descriptor = GetFieldDescriptor(self, name); if (field_descriptor != NULL) { AssureWritable(self); - if (field_descriptor->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) { + if (field_descriptor->label() == FieldDescriptor::LABEL_REPEATED) { PyErr_Format(PyExc_AttributeError, "Assignment not allowed to repeated " "field \"%s\" in protocol message object.", field_descriptor->name().c_str()); return -1; } else { - if (field_descriptor->cpp_type() == - google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE) { + if (field_descriptor->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { PyErr_Format(PyExc_AttributeError, "Assignment not allowed to " "field \"%s\" in protocol message object.", field_descriptor->name().c_str()); @@ -2339,7 +2329,7 @@ PyTypeObject CMessage_Type = { const Message* (*GetCProtoInsidePyProtoPtr)(PyObject* msg); Message* (*MutableCProtoInsidePyProtoPtr)(PyObject* msg); -static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { +static const Message* GetCProtoInsidePyProtoImpl(PyObject* msg) { if (!PyObject_TypeCheck(msg, &CMessage_Type)) { return NULL; } @@ -2347,12 +2337,12 @@ static const google::protobuf::Message* GetCProtoInsidePyProtoImpl(PyObject* msg return cmsg->message; } -static google::protobuf::Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { +static Message* MutableCProtoInsidePyProtoImpl(PyObject* msg) { if (!PyObject_TypeCheck(msg, &CMessage_Type)) { return NULL; } CMessage* cmsg = reinterpret_cast(msg); - if (PyDict_Size(cmsg->composite_fields) != 0 || + if ((cmsg->composite_fields && PyDict_Size(cmsg->composite_fields) != 0) || (cmsg->extensions != NULL && PyDict_Size(cmsg->extensions->values) != 0)) { // There is currently no way of accurately syncing arbitrary changes to @@ -2387,29 +2377,35 @@ void InitGlobals() { kDESCRIPTOR = PyString_FromString("DESCRIPTOR"); k_cdescriptor = PyString_FromString("_cdescriptor"); kfull_name = PyString_FromString("full_name"); - kextensions_by_name = PyString_FromString("extensions_by_name"); k_extensions_by_name = PyString_FromString("_extensions_by_name"); k_extensions_by_number = PyString_FromString("_extensions_by_number"); - kname = PyString_FromString("name"); - kfields_by_name = PyString_FromString("fields_by_name"); - descriptor_pool = cdescriptor_pool::NewDescriptorPool(); - - global_message_factory = new DynamicMessageFactory(descriptor_pool->pool); - global_message_factory->SetDelegateToGeneratedFactory(true); + message_factory = new DynamicMessageFactory(GetDescriptorPool()->pool); + message_factory->SetDelegateToGeneratedFactory(true); } bool InitProto2MessageModule(PyObject *m) { + // Initialize types and globals in descriptor.cc + if (!InitDescriptor()) { + return false; + } + + // Initialize types and globals in descriptor_pool.cc + if (!InitDescriptorPool()) { + return false; + } + + // Initialize constants defined in this file. InitGlobals(); - google::protobuf::python::CMessage_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::CMessage_Type) < 0) { + CMessage_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&CMessage_Type) < 0) { return false; } // DESCRIPTOR is set on each protocol buffer message class elsewhere, but set // it here as well to document that subclasses need to set it. - PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); + PyDict_SetItem(CMessage_Type.tp_dict, kDESCRIPTOR, Py_None); // Subclasses with message extensions will override _extensions_by_name and // _extensions_by_number with fresh mutable dictionaries in AddDescriptors. // All other classes can share this same immutable mapping. @@ -2421,58 +2417,69 @@ bool InitProto2MessageModule(PyObject *m) { if (immutable_dict == NULL) { return false; } - if (PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, + if (PyDict_SetItem(CMessage_Type.tp_dict, k_extensions_by_name, immutable_dict) < 0) { return false; } - if (PyDict_SetItem(google::protobuf::python::CMessage_Type.tp_dict, + if (PyDict_SetItem(CMessage_Type.tp_dict, k_extensions_by_number, immutable_dict) < 0) { return false; } - PyModule_AddObject(m, "Message", reinterpret_cast( - &google::protobuf::python::CMessage_Type)); + PyModule_AddObject(m, "Message", reinterpret_cast(&CMessage_Type)); - google::protobuf::python::RepeatedScalarContainer_Type.tp_hash = + RepeatedScalarContainer_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::RepeatedScalarContainer_Type) < 0) { + if (PyType_Ready(&RepeatedScalarContainer_Type) < 0) { return false; } PyModule_AddObject(m, "RepeatedScalarContainer", reinterpret_cast( - &google::protobuf::python::RepeatedScalarContainer_Type)); + &RepeatedScalarContainer_Type)); - google::protobuf::python::RepeatedCompositeContainer_Type.tp_hash = - PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::RepeatedCompositeContainer_Type) < 0) { + RepeatedCompositeContainer_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&RepeatedCompositeContainer_Type) < 0) { return false; } PyModule_AddObject( m, "RepeatedCompositeContainer", reinterpret_cast( - &google::protobuf::python::RepeatedCompositeContainer_Type)); + &RepeatedCompositeContainer_Type)); - google::protobuf::python::ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; - if (PyType_Ready(&google::protobuf::python::ExtensionDict_Type) < 0) { + ExtensionDict_Type.tp_hash = PyObject_HashNotImplemented; + if (PyType_Ready(&ExtensionDict_Type) < 0) { return false; } PyModule_AddObject( m, "ExtensionDict", - reinterpret_cast(&google::protobuf::python::ExtensionDict_Type)); - - if (!google::protobuf::python::InitDescriptor()) { - return false; - } + reinterpret_cast(&ExtensionDict_Type)); + + // This implementation provides full Descriptor types, we advertise it so that + // descriptor.py can use them in replacement of the Python classes. + PyModule_AddIntConstant(m, "_USE_C_DESCRIPTORS", 1); + + PyModule_AddObject(m, "Descriptor", reinterpret_cast( + &PyMessageDescriptor_Type)); + PyModule_AddObject(m, "FieldDescriptor", reinterpret_cast( + &PyFieldDescriptor_Type)); + PyModule_AddObject(m, "EnumDescriptor", reinterpret_cast( + &PyEnumDescriptor_Type)); + PyModule_AddObject(m, "EnumValueDescriptor", reinterpret_cast( + &PyEnumValueDescriptor_Type)); + PyModule_AddObject(m, "FileDescriptor", reinterpret_cast( + &PyFileDescriptor_Type)); + PyModule_AddObject(m, "OneofDescriptor", reinterpret_cast( + &PyOneofDescriptor_Type)); PyObject* enum_type_wrapper = PyImport_ImportModule( "google.protobuf.internal.enum_type_wrapper"); if (enum_type_wrapper == NULL) { return false; } - google::protobuf::python::EnumTypeWrapper_class = + EnumTypeWrapper_class = PyObject_GetAttrString(enum_type_wrapper, "EnumTypeWrapper"); Py_DECREF(enum_type_wrapper); @@ -2481,25 +2488,20 @@ bool InitProto2MessageModule(PyObject *m) { if (message_module == NULL) { return false; } - google::protobuf::python::EncodeError_class = PyObject_GetAttrString(message_module, - "EncodeError"); - google::protobuf::python::DecodeError_class = PyObject_GetAttrString(message_module, - "DecodeError"); + EncodeError_class = PyObject_GetAttrString(message_module, "EncodeError"); + DecodeError_class = PyObject_GetAttrString(message_module, "DecodeError"); Py_DECREF(message_module); PyObject* pickle_module = PyImport_ImportModule("pickle"); if (pickle_module == NULL) { return false; } - google::protobuf::python::PickleError_class = PyObject_GetAttrString(pickle_module, - "PickleError"); + PickleError_class = PyObject_GetAttrString(pickle_module, "PickleError"); Py_DECREF(pickle_module); // Override {Get,Mutable}CProtoInsidePyProto. - google::protobuf::python::GetCProtoInsidePyProtoPtr = - google::protobuf::python::GetCProtoInsidePyProtoImpl; - google::protobuf::python::MutableCProtoInsidePyProtoPtr = - google::protobuf::python::MutableCProtoInsidePyProtoImpl; + GetCProtoInsidePyProtoPtr = GetCProtoInsidePyProtoImpl; + MutableCProtoInsidePyProtoPtr = MutableCProtoInsidePyProtoImpl; return true; } diff --git a/python/google/protobuf/pyext/message.h b/python/google/protobuf/pyext/message.h index 0fef92a0..2f2da795 100644 --- a/python/google/protobuf/pyext/message.h +++ b/python/google/protobuf/pyext/message.h @@ -42,7 +42,6 @@ #endif #include - namespace google { namespace protobuf { @@ -50,12 +49,12 @@ class Message; class Reflection; class FieldDescriptor; class Descriptor; +class DynamicMessageFactory; using internal::shared_ptr; namespace python { -struct PyDescriptorPool; struct ExtensionDict; typedef struct CMessage { @@ -84,7 +83,7 @@ typedef struct CMessage { // Used together with the parent's message when making a default message // instance mutable. // The pointer is owned by the global DescriptorPool. - const google::protobuf::FieldDescriptor* parent_field_descriptor; + const FieldDescriptor* parent_field_descriptor; // Pointer to the C++ Message object for this CMessage. The // CMessage does not own this pointer. @@ -115,27 +114,26 @@ namespace cmessage { // Internal function to create a new empty Message Python object, but with empty // pointers to the C++ objects. // The caller must fill self->message, self->owner and eventually self->parent. -CMessage* NewEmptyMessage(PyObject* type, - const google::protobuf::Descriptor* descriptor); +CMessage* NewEmptyMessage(PyObject* type, const Descriptor* descriptor); // Release a submessage from its proto tree, making it a new top-level messgae. // A new message will be created if this is a read-only default instance. // // Corresponds to reflection api method ReleaseMessage. -int ReleaseSubMessage(google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, +int ReleaseSubMessage(Message* message, + const FieldDescriptor* field_descriptor, CMessage* child_cmessage); // Retrieves the C++ descriptor of a Python Extension descriptor. // On error, return NULL with an exception set. -const google::protobuf::FieldDescriptor* GetExtensionDescriptor(PyObject* extension); +const FieldDescriptor* GetExtensionDescriptor(PyObject* extension); // Initializes a new CMessage instance for a submessage. Only called once per // submessage as the result is cached in composite_fields. // // Corresponds to reflection api method GetMessage. PyObject* InternalGetSubMessage( - CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor); + CMessage* self, const FieldDescriptor* field_descriptor); // Deletes a range of C++ submessages in a repeated field (following a // removal in a RepeatedCompositeContainer). @@ -146,20 +144,20 @@ PyObject* InternalGetSubMessage( // by slice will be removed from cmessage_list by this function. // // Corresponds to reflection api method RemoveLast. -int InternalDeleteRepeatedField(google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* field_descriptor, +int InternalDeleteRepeatedField(Message* message, + const FieldDescriptor* field_descriptor, PyObject* slice, PyObject* cmessage_list); // Sets the specified scalar value to the message. int InternalSetScalar(CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor, + const FieldDescriptor* field_descriptor, PyObject* value); // Retrieves the specified scalar value from the message. // // Returns a new python reference. PyObject* InternalGetScalar(CMessage* self, - const google::protobuf::FieldDescriptor* field_descriptor); + const FieldDescriptor* field_descriptor); // Clears the message, removing all contained data. Extension dictionary and // submessages are released first if there are remaining external references. @@ -175,8 +173,7 @@ PyObject* Clear(CMessage* self); // // Corresponds to reflection api method ClearField. PyObject* ClearFieldByDescriptor( - CMessage* self, - const google::protobuf::FieldDescriptor* descriptor); + CMessage* self, const FieldDescriptor* descriptor); // Clears the data for the given field name. The message is released if there // are any external references. @@ -189,7 +186,7 @@ PyObject* ClearField(CMessage* self, PyObject* arg); // // Corresponds to reflection api method HasField PyObject* HasFieldByDescriptor( - CMessage* self, const google::protobuf::FieldDescriptor* field_descriptor); + CMessage* self, const FieldDescriptor* field_descriptor); // Checks if the message has the named field. // @@ -220,18 +217,16 @@ int SetOwner(CMessage* self, const shared_ptr& new_owner); int AssureWritable(CMessage* self); -} // namespace cmessage - +DynamicMessageFactory* GetMessageFactory(); -// Retrieve the global descriptor pool owned by the _message module. -PyDescriptorPool* GetDescriptorPool(); +} // namespace cmessage /* Is 64bit */ #define IS_64BIT (SIZEOF_LONG == 8) #define FIELD_IS_REPEATED(field_descriptor) \ - ((field_descriptor)->label() == google::protobuf::FieldDescriptor::LABEL_REPEATED) + ((field_descriptor)->label() == FieldDescriptor::LABEL_REPEATED) #define GOOGLE_CHECK_GET_INT32(arg, value, err) \ int32 value; \ @@ -294,18 +289,17 @@ bool CheckAndGetDouble(PyObject* arg, double* value); bool CheckAndGetFloat(PyObject* arg, float* value); bool CheckAndGetBool(PyObject* arg, bool* value); bool CheckAndSetString( - PyObject* arg, google::protobuf::Message* message, - const google::protobuf::FieldDescriptor* descriptor, - const google::protobuf::Reflection* reflection, + PyObject* arg, Message* message, + const FieldDescriptor* descriptor, + const Reflection* reflection, bool append, int index); -PyObject* ToStringObject( - const google::protobuf::FieldDescriptor* descriptor, string value); +PyObject* ToStringObject(const FieldDescriptor* descriptor, string value); // Check if the passed field descriptor belongs to the given message. // If not, return false and set a Python exception (a KeyError) -bool CheckFieldBelongsToMessage(const google::protobuf::FieldDescriptor* field_descriptor, - const google::protobuf::Message* message); +bool CheckFieldBelongsToMessage(const FieldDescriptor* field_descriptor, + const Message* message); extern PyObject* PickleError_class; diff --git a/python/google/protobuf/pyext/message_factory_cpp2_test.py b/python/google/protobuf/pyext/message_factory_cpp2_test.py deleted file mode 100644 index 32ab4f85..00000000 --- a/python/google/protobuf/pyext/message_factory_cpp2_test.py +++ /dev/null @@ -1,56 +0,0 @@ -#! /usr/bin/python -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Tests for google.protobuf.message_factory.""" - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' - -# We must set the implementation version above before the google3 imports. -# pylint: disable=g-import-not-at-top -from google.apputils import basetest -from google.protobuf.internal import api_implementation -# Run all tests from the original module by putting them in our namespace. -# pylint: disable=wildcard-import -from google.protobuf.internal.message_factory_test import * - - -class ConfirmCppApi2Test(basetest.TestCase): - - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py b/python/google/protobuf/pyext/reflection_cpp2_generated_test.py deleted file mode 100755 index 552efd48..00000000 --- a/python/google/protobuf/pyext/reflection_cpp2_generated_test.py +++ /dev/null @@ -1,94 +0,0 @@ -#! /usr/bin/python -# -*- coding: utf-8 -*- -# -# Protocol Buffers - Google's data interchange format -# Copyright 2008 Google Inc. All rights reserved. -# https://developers.google.com/protocol-buffers/ -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are -# met: -# -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above -# copyright notice, this list of conditions and the following disclaimer -# in the documentation and/or other materials provided with the -# distribution. -# * Neither the name of Google Inc. nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS -# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT -# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR -# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT -# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, -# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT -# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, -# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY -# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -"""Unittest for reflection.py, which tests the generated C++ implementation.""" - -__author__ = 'jasonh@google.com (Jason Hsueh)' - -import os -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' -os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION'] = '2' - -from google.apputils import basetest -from google.protobuf.internal import api_implementation -from google.protobuf.internal import more_extensions_dynamic_pb2 -from google.protobuf.internal import more_extensions_pb2 -from google.protobuf.internal.reflection_test import * - - -class ReflectionCppTest(basetest.TestCase): - def testImplementationSetting(self): - self.assertEqual('cpp', api_implementation.Type()) - self.assertEqual(2, api_implementation.Version()) - - def testExtensionOfGeneratedTypeInDynamicFile(self): - """Tests that a file built dynamically can extend a generated C++ type. - - The C++ implementation uses a DescriptorPool that has the generated - DescriptorPool as an underlay. Typically, a type can only find - extensions in its own pool. With the python C-extension, the generated C++ - extendee may be available, but not the extension. This tests that the - C-extension implements the correct special handling to make such extensions - available. - """ - pb1 = more_extensions_pb2.ExtendedMessage() - # Test that basic accessors work. - self.assertFalse( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertFalse( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 - pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 - self.assertTrue( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertTrue( - pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - - # Now serialize the data and parse to a new message. - pb2 = more_extensions_pb2.ExtendedMessage() - pb2.MergeFromString(pb1.SerializeToString()) - - self.assertTrue( - pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) - self.assertTrue( - pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) - self.assertEqual( - 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) - self.assertEqual( - 24, - pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) - - - -if __name__ == '__main__': - basetest.main() diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 36fe86ae..49ef7991 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -56,8 +56,6 @@ namespace google { namespace protobuf { namespace python { -extern google::protobuf::DynamicMessageFactory* global_message_factory; - namespace repeated_composite_container { // TODO(tibell): We might also want to check: @@ -120,9 +118,9 @@ static int InternalQuickSort(RepeatedCompositeContainer* self, GOOGLE_CHECK_ATTACHED(self); - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); - const google::protobuf::FieldDescriptor* descriptor = self->parent_field_descriptor; + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* descriptor = self->parent_field_descriptor; Py_ssize_t left; Py_ssize_t right; @@ -199,7 +197,7 @@ static int InternalQuickSort(RepeatedCompositeContainer* self, // len() static Py_ssize_t Length(RepeatedCompositeContainer* self) { - google::protobuf::Message* message = self->message; + Message* message = self->message; if (message != NULL) { return message->GetReflection()->FieldSize(*message, self->parent_field_descriptor); @@ -221,8 +219,8 @@ static int UpdateChildMessages(RepeatedCompositeContainer* self) { // be removed in such a way so there's no need to worry about that. Py_ssize_t message_length = Length(self); Py_ssize_t child_length = PyList_GET_SIZE(self->child_messages); - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); for (Py_ssize_t i = child_length; i < message_length; ++i) { const Message& sub_message = reflection->GetRepeatedMessage( *(self->message), self->parent_field_descriptor, i); @@ -233,7 +231,7 @@ static int UpdateChildMessages(RepeatedCompositeContainer* self) { return -1; } cmsg->owner = self->owner; - cmsg->message = const_cast(&sub_message); + cmsg->message = const_cast(&sub_message); cmsg->parent = self->parent; if (PyList_Append(self->child_messages, py_cmsg) < 0) { return -1; @@ -255,8 +253,8 @@ static PyObject* AddToAttached(RepeatedCompositeContainer* self, } if (cmessage::AssureWritable(self->parent) == -1) return NULL; - google::protobuf::Message* message = self->message; - google::protobuf::Message* sub_message = + Message* message = self->message; + Message* sub_message = message->GetReflection()->AddMessage(message, self->parent_field_descriptor); CMessage* cmsg = cmessage::NewEmptyMessage(self->subclass_init, @@ -482,9 +480,9 @@ static PyObject* SortAttached(RepeatedCompositeContainer* self, // Finally reverse the result if requested. if (reverse) { - google::protobuf::Message* message = self->message; - const google::protobuf::Reflection* reflection = message->GetReflection(); - const google::protobuf::FieldDescriptor* descriptor = self->parent_field_descriptor; + Message* message = self->message; + const Reflection* reflection = message->GetReflection(); + const FieldDescriptor* descriptor = self->parent_field_descriptor; // Reverse the Message array. for (int i = 0; i < length / 2; ++i) @@ -554,6 +552,26 @@ static PyObject* Item(RepeatedCompositeContainer* self, Py_ssize_t index) { return item; } +static PyObject* Pop(RepeatedCompositeContainer* self, + PyObject* args) { + Py_ssize_t index = -1; + if (!PyArg_ParseTuple(args, "|n", &index)) { + return NULL; + } + PyObject* item = Item(self, index); + if (item == NULL) { + PyErr_Format(PyExc_IndexError, + "list index (%zd) out of range", + index); + return NULL; + } + ScopedPyObjectPtr py_index(PyLong_FromSsize_t(index)); + if (AssignSubscript(self, py_index, NULL) < 0) { + return NULL; + } + return item; +} + // The caller takes ownership of the returned Message. Message* ReleaseLast(const FieldDescriptor* field, const Descriptor* type, @@ -571,7 +589,8 @@ Message* ReleaseLast(const FieldDescriptor* field, // the latter points to the default instance via a const_cast<>, so we // have to reset it to a new mutable object since we are taking ownership. if (released_message == NULL) { - const Message* prototype = global_message_factory->GetPrototype(type); + const Message* prototype = + cmessage::GetMessageFactory()->GetPrototype(type); GOOGLE_CHECK_NOTNULL(prototype); return prototype->New(); } else { @@ -646,7 +665,7 @@ int SetOwner(RepeatedCompositeContainer* self, // The private constructor of RepeatedCompositeContainer objects. PyObject *NewContainer( CMessage* parent, - const google::protobuf::FieldDescriptor* parent_field_descriptor, + const FieldDescriptor* parent_field_descriptor, PyObject *concrete_class) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; @@ -698,6 +717,8 @@ static PyMethodDef Methods[] = { "Adds an object to the repeated container." }, { "extend", (PyCFunction) Extend, METH_O, "Adds objects to the repeated container." }, + { "pop", (PyCFunction)Pop, METH_VARARGS, + "Removes an object from the repeated container and returns it." }, { "remove", (PyCFunction) Remove, METH_O, "Removes an object from the repeated container." }, { "sort", (PyCFunction) Sort, METH_VARARGS | METH_KEYWORDS, diff --git a/python/google/protobuf/pyext/repeated_composite_container.h b/python/google/protobuf/pyext/repeated_composite_container.h index 0969af08..ce7cee0f 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.h +++ b/python/google/protobuf/pyext/repeated_composite_container.h @@ -43,7 +43,6 @@ #include #include - namespace google { namespace protobuf { @@ -82,7 +81,7 @@ typedef struct RepeatedCompositeContainer { // A descriptor used to modify the underlying 'message'. // The pointer is owned by the global DescriptorPool. - const google::protobuf::FieldDescriptor* parent_field_descriptor; + const FieldDescriptor* parent_field_descriptor; // Pointer to the C++ Message that contains this container. The // RepeatedCompositeContainer does not own this pointer. @@ -106,7 +105,7 @@ namespace repeated_composite_container { // field descriptor. PyObject *NewContainer( CMessage* parent, - const google::protobuf::FieldDescriptor* parent_field_descriptor, + const FieldDescriptor* parent_field_descriptor, PyObject *concrete_class); // Returns the number of items in this repeated composite container. @@ -150,8 +149,7 @@ int AssignSubscript(RepeatedCompositeContainer* self, // Releases the messages in the container to the given message. // // Returns 0 on success, -1 on failure. -int ReleaseToMessage(RepeatedCompositeContainer* self, - google::protobuf::Message* new_message); +int ReleaseToMessage(RepeatedCompositeContainer* self, Message* new_message); // Releases the messages in the container to a new message. // diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index 49d23fd6..9f8075b2 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -60,8 +60,6 @@ namespace google { namespace protobuf { namespace python { -extern google::protobuf::DynamicMessageFactory* global_message_factory; - namespace repeated_scalar_container { static int InternalAssignRepeatedField( @@ -78,7 +76,7 @@ static int InternalAssignRepeatedField( } static Py_ssize_t Len(RepeatedScalarContainer* self) { - google::protobuf::Message* message = self->message; + Message* message = self->message; return message->GetReflection()->FieldSize(*message, self->parent_field_descriptor); } @@ -87,11 +85,10 @@ static int AssignItem(RepeatedScalarContainer* self, Py_ssize_t index, PyObject* arg) { cmessage::AssureWritable(self->parent); - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field_descriptor; + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); int field_size = reflection->FieldSize(*message, field_descriptor); if (index < 0) { index = field_size + index; @@ -115,64 +112,68 @@ static int AssignItem(RepeatedScalarContainer* self, } switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(arg, value, -1); reflection->SetRepeatedInt32(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { GOOGLE_CHECK_GET_INT64(arg, value, -1); reflection->SetRepeatedInt64(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { GOOGLE_CHECK_GET_UINT32(arg, value, -1); reflection->SetRepeatedUInt32(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { GOOGLE_CHECK_GET_UINT64(arg, value, -1); reflection->SetRepeatedUInt64(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { GOOGLE_CHECK_GET_FLOAT(arg, value, -1); reflection->SetRepeatedFloat(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { GOOGLE_CHECK_GET_DOUBLE(arg, value, -1); reflection->SetRepeatedDouble(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { GOOGLE_CHECK_GET_BOOL(arg, value, -1); reflection->SetRepeatedBool(message, field_descriptor, index, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { if (!CheckAndSetString( arg, message, field_descriptor, reflection, false, index)) { return -1; } break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + case FieldDescriptor::CPPTYPE_ENUM: { GOOGLE_CHECK_GET_INT32(arg, value, -1); - const google::protobuf::EnumDescriptor* enum_descriptor = - field_descriptor->enum_type(); - const google::protobuf::EnumValueDescriptor* enum_value = - enum_descriptor->FindValueByNumber(value); - if (enum_value != NULL) { - reflection->SetRepeatedEnum(message, field_descriptor, index, - enum_value); + if (reflection->SupportsUnknownEnumValues()) { + reflection->SetRepeatedEnumValue(message, field_descriptor, index, + value); } else { - ScopedPyObjectPtr s(PyObject_Str(arg)); - if (s != NULL) { - PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", - PyString_AsString(s)); + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->SetRepeatedEnum(message, field_descriptor, index, + enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(arg)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s)); + } + return -1; } - return -1; } break; } @@ -186,10 +187,9 @@ static int AssignItem(RepeatedScalarContainer* self, } static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field_descriptor; - const google::protobuf::Reflection* reflection = message->GetReflection(); + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; + const Reflection* reflection = message->GetReflection(); int field_size = reflection->FieldSize(*message, field_descriptor); if (index < 0) { @@ -197,80 +197,80 @@ static PyObject* Item(RepeatedScalarContainer* self, Py_ssize_t index) { } if (index < 0 || index >= field_size) { PyErr_Format(PyExc_IndexError, - "list assignment index (%d) out of range", - static_cast(index)); + "list index (%zd) out of range", + index); return NULL; } PyObject* result = NULL; switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { int32 value = reflection->GetRepeatedInt32( *message, field_descriptor, index); result = PyInt_FromLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { int64 value = reflection->GetRepeatedInt64( *message, field_descriptor, index); result = PyLong_FromLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { uint32 value = reflection->GetRepeatedUInt32( *message, field_descriptor, index); result = PyLong_FromLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { uint64 value = reflection->GetRepeatedUInt64( *message, field_descriptor, index); result = PyLong_FromUnsignedLongLong(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { float value = reflection->GetRepeatedFloat( *message, field_descriptor, index); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { double value = reflection->GetRepeatedDouble( *message, field_descriptor, index); result = PyFloat_FromDouble(value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { bool value = reflection->GetRepeatedBool( *message, field_descriptor, index); result = PyBool_FromLong(value ? 1 : 0); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { - const google::protobuf::EnumValueDescriptor* enum_value = + case FieldDescriptor::CPPTYPE_ENUM: { + const EnumValueDescriptor* enum_value = message->GetReflection()->GetRepeatedEnum( *message, field_descriptor, index); result = PyInt_FromLong(enum_value->number()); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { string value = reflection->GetRepeatedString( *message, field_descriptor, index); result = ToStringObject(field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_MESSAGE: { + case FieldDescriptor::CPPTYPE_MESSAGE: { PyObject* py_cmsg = PyObject_CallObject(reinterpret_cast( &CMessage_Type), NULL); if (py_cmsg == NULL) { return NULL; } CMessage* cmsg = reinterpret_cast(py_cmsg); - const google::protobuf::Message& msg = reflection->GetRepeatedMessage( + const Message& msg = reflection->GetRepeatedMessage( *message, field_descriptor, index); cmsg->owner = self->owner; cmsg->parent = self->parent; - cmsg->message = const_cast(&msg); + cmsg->message = const_cast(&msg); cmsg->read_only = false; result = reinterpret_cast(py_cmsg); break; @@ -351,69 +351,71 @@ static PyObject* Subscript(RepeatedScalarContainer* self, PyObject* slice) { PyObject* Append(RepeatedScalarContainer* self, PyObject* item) { cmessage::AssureWritable(self->parent); - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = - self->parent_field_descriptor; + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); switch (field_descriptor->cpp_type()) { - case google::protobuf::FieldDescriptor::CPPTYPE_INT32: { + case FieldDescriptor::CPPTYPE_INT32: { GOOGLE_CHECK_GET_INT32(item, value, NULL); reflection->AddInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_INT64: { + case FieldDescriptor::CPPTYPE_INT64: { GOOGLE_CHECK_GET_INT64(item, value, NULL); reflection->AddInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT32: { + case FieldDescriptor::CPPTYPE_UINT32: { GOOGLE_CHECK_GET_UINT32(item, value, NULL); reflection->AddUInt32(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_UINT64: { + case FieldDescriptor::CPPTYPE_UINT64: { GOOGLE_CHECK_GET_UINT64(item, value, NULL); reflection->AddUInt64(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_FLOAT: { + case FieldDescriptor::CPPTYPE_FLOAT: { GOOGLE_CHECK_GET_FLOAT(item, value, NULL); reflection->AddFloat(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_DOUBLE: { + case FieldDescriptor::CPPTYPE_DOUBLE: { GOOGLE_CHECK_GET_DOUBLE(item, value, NULL); reflection->AddDouble(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_BOOL: { + case FieldDescriptor::CPPTYPE_BOOL: { GOOGLE_CHECK_GET_BOOL(item, value, NULL); reflection->AddBool(message, field_descriptor, value); break; } - case google::protobuf::FieldDescriptor::CPPTYPE_STRING: { + case FieldDescriptor::CPPTYPE_STRING: { if (!CheckAndSetString( item, message, field_descriptor, reflection, true, -1)) { return NULL; } break; } - case google::protobuf::FieldDescriptor::CPPTYPE_ENUM: { + case FieldDescriptor::CPPTYPE_ENUM: { GOOGLE_CHECK_GET_INT32(item, value, NULL); - const google::protobuf::EnumDescriptor* enum_descriptor = - field_descriptor->enum_type(); - const google::protobuf::EnumValueDescriptor* enum_value = - enum_descriptor->FindValueByNumber(value); - if (enum_value != NULL) { - reflection->AddEnum(message, field_descriptor, enum_value); + if (reflection->SupportsUnknownEnumValues()) { + reflection->AddEnumValue(message, field_descriptor, value); } else { - ScopedPyObjectPtr s(PyObject_Str(item)); - if (s != NULL) { - PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", - PyString_AsString(s)); + const EnumDescriptor* enum_descriptor = field_descriptor->enum_type(); + const EnumValueDescriptor* enum_value = + enum_descriptor->FindValueByNumber(value); + if (enum_value != NULL) { + reflection->AddEnum(message, field_descriptor, enum_value); + } else { + ScopedPyObjectPtr s(PyObject_Str(item)); + if (s != NULL) { + PyErr_Format(PyExc_ValueError, "Unknown enum value: %s", + PyString_AsString(s)); + } + return NULL; } - return NULL; } break; } @@ -438,8 +440,8 @@ static int AssSubscript(RepeatedScalarContainer* self, bool create_list = false; cmessage::AssureWritable(self->parent); - google::protobuf::Message* message = self->message; - const google::protobuf::FieldDescriptor* field_descriptor = + Message* message = self->message; + const FieldDescriptor* field_descriptor = self->parent_field_descriptor; #if PY_MAJOR_VERSION < 3 @@ -450,7 +452,7 @@ static int AssSubscript(RepeatedScalarContainer* self, if (PyLong_Check(slice)) { from = to = PyLong_AsLong(slice); } else if (PySlice_Check(slice)) { - const google::protobuf::Reflection* reflection = message->GetReflection(); + const Reflection* reflection = message->GetReflection(); length = reflection->FieldSize(*message, field_descriptor); #if PY_MAJOR_VERSION >= 3 if (PySlice_GetIndicesEx(slice, @@ -492,9 +494,15 @@ static int AssSubscript(RepeatedScalarContainer* self, PyObject* Extend(RepeatedScalarContainer* self, PyObject* value) { cmessage::AssureWritable(self->parent); - if (PyObject_Not(value)) { + + // TODO(ptucker): Deprecate this behavior. b/18413862 + if (value == Py_None) { + Py_RETURN_NONE; + } + if ((Py_TYPE(value)->tp_as_sequence == NULL) && PyObject_Not(value)) { Py_RETURN_NONE; } + ScopedPyObjectPtr iter(PyObject_GetIter(value)); if (iter == NULL) { PyErr_SetString(PyExc_TypeError, "Value must be iterable"); @@ -627,9 +635,28 @@ static PyObject* Sort(RepeatedScalarContainer* self, Py_RETURN_NONE; } +static PyObject* Pop(RepeatedScalarContainer* self, + PyObject* args) { + Py_ssize_t index = -1; + if (!PyArg_ParseTuple(args, "|n", &index)) { + return NULL; + } + PyObject* item = Item(self, index); + if (item == NULL) { + PyErr_Format(PyExc_IndexError, + "list index (%zd) out of range", + index); + return NULL; + } + if (AssignItem(self, index, NULL) < 0) { + return NULL; + } + return item; +} + // The private constructor of RepeatedScalarContainer objects. PyObject *NewContainer( - CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor) { + CMessage* parent, const FieldDescriptor* parent_field_descriptor) { if (!CheckFieldBelongsToMessage(parent_field_descriptor, parent->message)) { return NULL; } @@ -663,7 +690,7 @@ static int InitializeAndCopyToParentContainer( if (values == NULL) { return -1; } - google::protobuf::Message* new_message = global_message_factory->GetPrototype( + Message* new_message = cmessage::GetMessageFactory()->GetPrototype( from->message->GetDescriptor())->New(); to->parent = NULL; to->parent_field_descriptor = from->parent_field_descriptor; @@ -729,6 +756,8 @@ static PyMethodDef Methods[] = { "Appends objects to the repeated container." }, { "insert", (PyCFunction)Insert, METH_VARARGS, "Appends objects to the repeated container." }, + { "pop", (PyCFunction)Pop, METH_VARARGS, + "Removes an object from the repeated container and returns it." }, { "remove", (PyCFunction)Remove, METH_O, "Removes an object from the repeated container." }, { "sort", (PyCFunction)Sort, METH_VARARGS | METH_KEYWORDS, diff --git a/python/google/protobuf/pyext/repeated_scalar_container.h b/python/google/protobuf/pyext/repeated_scalar_container.h index 513bfe48..5dfa21e0 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.h +++ b/python/google/protobuf/pyext/repeated_scalar_container.h @@ -77,7 +77,7 @@ typedef struct RepeatedScalarContainer { // field. Used together with the parent's message when making a // default message instance mutable. // The pointer is owned by the global DescriptorPool. - const google::protobuf::FieldDescriptor* parent_field_descriptor; + const FieldDescriptor* parent_field_descriptor; } RepeatedScalarContainer; extern PyTypeObject RepeatedScalarContainer_Type; @@ -87,7 +87,7 @@ namespace repeated_scalar_container { // Builds a RepeatedScalarContainer object, from a parent message and a // field descriptor. extern PyObject *NewContainer( - CMessage* parent, const google::protobuf::FieldDescriptor* parent_field_descriptor); + CMessage* parent, const FieldDescriptor* parent_field_descriptor); // Appends the scalar 'item' to the end of the container 'self'. // diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 1fc704a2..55e653a0 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -56,18 +56,12 @@ _FieldDescriptor = descriptor_mod.FieldDescriptor if api_implementation.Type() == 'cpp': - if api_implementation.Version() == 2: - from google.protobuf.pyext import cpp_message - _NewMessage = cpp_message.NewMessage - _InitMessage = cpp_message.InitMessage - else: - from google.protobuf.internal import cpp_message - _NewMessage = cpp_message.NewMessage - _InitMessage = cpp_message.InitMessage + from google.protobuf.pyext import cpp_message as message_impl else: - from google.protobuf.internal import python_message - _NewMessage = python_message.NewMessage - _InitMessage = python_message.InitMessage + from google.protobuf.internal import python_message as message_impl + +_NewMessage = message_impl.NewMessage +_InitMessage = message_impl.InitMessage class GeneratedProtocolMessageType(type): @@ -127,7 +121,6 @@ class GeneratedProtocolMessageType(type): superclass = super(GeneratedProtocolMessageType, cls) new_class = superclass.__new__(cls, name, bases, dictionary) - setattr(descriptor, '_concrete_class', new_class) return new_class def __init__(cls, name, bases, dictionary): @@ -151,6 +144,7 @@ class GeneratedProtocolMessageType(type): _InitMessage(descriptor, cls) superclass = super(GeneratedProtocolMessageType, cls) superclass.__init__(name, bases, dictionary) + setattr(descriptor, '_concrete_class', cls) def ParseMessage(descriptor, byte_str): diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index fb54c50c..a47ce3e3 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -319,6 +319,11 @@ def _MergeField(tokenizer, message, allow_multiple_scalars): ParseError: In case of ASCII parsing problems. """ message_descriptor = message.DESCRIPTOR + if (hasattr(message_descriptor, 'syntax') and + message_descriptor.syntax == 'proto3'): + # Proto3 doesn't represent presence so we can't test if multiple + # scalars have occurred. We have to allow them. + allow_multiple_scalars = True if tokenizer.TryConsume('['): name = [tokenizer.ConsumeIdentifier()] while tokenizer.TryConsume('.'): diff --git a/python/setup.py b/python/setup.py index 45dd0761..c6ff7454 100755 --- a/python/setup.py +++ b/python/setup.py @@ -77,6 +77,7 @@ def GenerateUnittestProtos(): generate_proto("../src/google/protobuf/unittest_import_public.proto") generate_proto("../src/google/protobuf/unittest_mset.proto") generate_proto("../src/google/protobuf/unittest_no_generic_services.proto") + generate_proto("../src/google/protobuf/unittest_proto3_arena.proto") generate_proto("google/protobuf/internal/descriptor_pool_test1.proto") generate_proto("google/protobuf/internal/descriptor_pool_test2.proto") generate_proto("google/protobuf/internal/test_bad_identifiers.proto") @@ -90,23 +91,6 @@ def GenerateUnittestProtos(): generate_proto("google/protobuf/internal/import_test_package/outer.proto") generate_proto("google/protobuf/pyext/python.proto") -def MakeTestSuite(): - # Test C++ implementation - import unittest - import google.protobuf.pyext.descriptor_cpp2_test as descriptor_cpp2_test - import google.protobuf.pyext.message_factory_cpp2_test \ - as message_factory_cpp2_test - import google.protobuf.pyext.reflection_cpp2_generated_test \ - as reflection_cpp2_generated_test - - loader = unittest.defaultTestLoader - suite = unittest.TestSuite() - for test in [ descriptor_cpp2_test, - message_factory_cpp2_test, - reflection_cpp2_generated_test]: - suite.addTest(loader.loadTestsFromModule(test)) - return suite - class clean(_clean): def run(self): # Delete generated files in the code tree. @@ -152,6 +136,8 @@ if __name__ == '__main__': ext_module_list.append(Extension( "google.protobuf.pyext._message", [ "google/protobuf/pyext/descriptor.cc", + "google/protobuf/pyext/descriptor_containers.cc", + "google/protobuf/pyext/descriptor_pool.cc", "google/protobuf/pyext/message.cc", "google/protobuf/pyext/extension_dict.cc", "google/protobuf/pyext/repeated_scalar_container.cc", @@ -161,12 +147,12 @@ if __name__ == '__main__': libraries = [ "protobuf" ], library_dirs = [ '../src/.libs' ], )) + os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' setup(name = 'protobuf', version = '3.0.0-alpha-2', packages = [ 'google' ], namespace_packages = [ 'google' ], - test_suite = 'setup.MakeTestSuite', google_test_dir = "google/protobuf/internal", # Must list modules explicitly so that we don't install tests. py_modules = [ diff --git a/src/google/protobuf/compiler/python/python_generator.cc b/src/google/protobuf/compiler/python/python_generator.cc index 0c65b993..7b3b5fa3 100644 --- a/src/google/protobuf/compiler/python/python_generator.cc +++ b/src/google/protobuf/compiler/python/python_generator.cc @@ -273,6 +273,19 @@ string StringifyDefaultValue(const FieldDescriptor& field) { return ""; } +string StringifySyntax(FileDescriptor::Syntax syntax) { + switch (syntax) { + case FileDescriptor::SYNTAX_PROTO2: + return "proto2"; + case FileDescriptor::SYNTAX_PROTO3: + return "proto3"; + case FileDescriptor::SYNTAX_UNKNOWN: + default: + GOOGLE_LOG(FATAL) << "Unsupported syntax; this generator only supports proto2 " + "and proto3 syntax."; + return ""; + } +} } // namespace @@ -367,10 +380,12 @@ void Generator::PrintFileDescriptor() const { m["descriptor_name"] = kDescriptorKey; m["name"] = file_->name(); m["package"] = file_->package(); + m["syntax"] = StringifySyntax(file_->syntax()); const char file_descriptor_template[] = "$descriptor_name$ = _descriptor.FileDescriptor(\n" " name='$name$',\n" - " package='$package$',\n"; + " package='$package$',\n" + " syntax='$syntax$',\n"; printer_->Print(m, file_descriptor_template); printer_->Indent(); printer_->Print( @@ -414,7 +429,7 @@ void Generator::PrintTopLevelEnums() const { for (int j = 0; j < enum_descriptor.value_count(); ++j) { const EnumValueDescriptor& value_descriptor = *enum_descriptor.value(j); top_level_enum_values.push_back( - make_pair(value_descriptor.name(), value_descriptor.number())); + std::make_pair(value_descriptor.name(), value_descriptor.number())); } } @@ -665,10 +680,12 @@ void Generator::PrintDescriptor(const Descriptor& message_descriptor) const { message_descriptor.options().SerializeToString(&options_string); printer_->Print( "options=$options_value$,\n" - "is_extendable=$extendable$", + "is_extendable=$extendable$,\n" + "syntax='$syntax$'", "options_value", OptionsValue("MessageOptions", options_string), "extendable", message_descriptor.extension_range_count() > 0 ? - "True" : "False"); + "True" : "False", + "syntax", StringifySyntax(message_descriptor.file()->syntax())); printer_->Print(",\n"); // Extension ranges diff --git a/src/google/protobuf/testdata/golden_message_proto3 b/src/google/protobuf/testdata/golden_message_proto3 new file mode 100644 index 00000000..934f36fa Binary files /dev/null and b/src/google/protobuf/testdata/golden_message_proto3 differ diff --git a/src/google/protobuf/unittest_proto3_arena.proto b/src/google/protobuf/unittest_proto3_arena.proto index 53f23d87..6d7bada8 100644 --- a/src/google/protobuf/unittest_proto3_arena.proto +++ b/src/google/protobuf/unittest_proto3_arena.proto @@ -47,9 +47,10 @@ message TestAllTypes { } enum NestedEnum { - FOO = 0; - BAR = 1; - BAZ = 2; + ZERO = 0; + FOO = 1; + BAR = 2; + BAZ = 3; NEG = -1; // Intentionally negative. } @@ -81,6 +82,11 @@ message TestAllTypes { optional NestedEnum optional_nested_enum = 21; optional ForeignEnum optional_foreign_enum = 22; + // Omitted (compared to unittest.proto) because proto2 enums are not allowed + // inside proto2 messages. + // + // optional protobuf_unittest_import.ImportEnum optional_import_enum = 23; + optional string optional_string_piece = 24 [ctype=STRING_PIECE]; optional string optional_cord = 25 [ctype=CORD]; @@ -118,6 +124,11 @@ message TestAllTypes { repeated NestedEnum repeated_nested_enum = 51; repeated ForeignEnum repeated_foreign_enum = 52; + // Omitted (compared to unittest.proto) because proto2 enums are not allowed + // inside proto2 messages. + // + // repeated protobuf_unittest_import.ImportEnum repeated_import_enum = 53; + repeated string repeated_string_piece = 54 [ctype=STRING_PIECE]; repeated string repeated_cord = 55 [ctype=CORD]; @@ -131,6 +142,31 @@ message TestAllTypes { } } +// Test messages for packed fields + +message TestPackedTypes { + repeated int32 packed_int32 = 90 [packed = true]; + repeated int64 packed_int64 = 91 [packed = true]; + repeated uint32 packed_uint32 = 92 [packed = true]; + repeated uint64 packed_uint64 = 93 [packed = true]; + repeated sint32 packed_sint32 = 94 [packed = true]; + repeated sint64 packed_sint64 = 95 [packed = true]; + repeated fixed32 packed_fixed32 = 96 [packed = true]; + repeated fixed64 packed_fixed64 = 97 [packed = true]; + repeated sfixed32 packed_sfixed32 = 98 [packed = true]; + repeated sfixed64 packed_sfixed64 = 99 [packed = true]; + repeated float packed_float = 100 [packed = true]; + repeated double packed_double = 101 [packed = true]; + repeated bool packed_bool = 102 [packed = true]; + repeated ForeignEnum packed_enum = 103 [packed = true]; +} + +// This proto includes a recusively nested message. +message NestedTestAllTypes { + NestedTestAllTypes child = 1; + TestAllTypes payload = 2; +} + // Define these after TestAllTypes to make sure the compiler can handle // that. message ForeignMessage { @@ -138,7 +174,13 @@ message ForeignMessage { } enum ForeignEnum { - FOREIGN_FOO = 0; - FOREIGN_BAR = 1; - FOREIGN_BAZ = 2; + FOREIGN_ZERO = 0; + FOREIGN_FOO = 4; + FOREIGN_BAR = 5; + FOREIGN_BAZ = 6; +} + +// TestEmptyMessage is used to test behavior of unknown fields. +message TestEmptyMessage { } + -- cgit v1.2.3 From 0b70a43736fe070bee49141d493c74079ea68f97 Mon Sep 17 00:00:00 2001 From: Josh Haberman Date: Wed, 25 Feb 2015 20:17:32 -0800 Subject: Fixes for Python/C++ implementation in open-source: * Rosy hack doesn't apply (that test should be removed for the open-source release). * Added our own copy of parameterized.py (the open-source version of Google Apputils doesn't contain it). * The C++ Descriptor object didn't implement extension_ranges. * Had to implement a hack around returning EncodeError, to work around the module-loading behavior of the test runner. --- Makefile.am | 9 +- python/google/protobuf/internal/_parameterized.py | 438 +++++++++++++++++++++ python/google/protobuf/internal/message_test.py | 5 +- python/google/protobuf/internal/reflection_test.py | 10 - .../google/protobuf/internal/text_format_test.py | 4 +- python/google/protobuf/message.py | 1 - python/google/protobuf/pyext/descriptor.cc | 15 + python/google/protobuf/pyext/message.cc | 26 +- .../protobuf/pyext/repeated_composite_container.cc | 4 +- .../protobuf/pyext/repeated_scalar_container.cc | 4 +- src/Makefile.am | 1 + 11 files changed, 490 insertions(+), 27 deletions(-) create mode 100755 python/google/protobuf/internal/_parameterized.py diff --git a/Makefile.am b/Makefile.am index 1c64b710..de8ad96d 100644 --- a/Makefile.am +++ b/Makefile.am @@ -197,7 +197,6 @@ javanano_EXTRA_DIST= python_EXTRA_DIST= \ python/google/protobuf/internal/api_implementation.cc \ python/google/protobuf/internal/api_implementation.py \ - python/google/protobuf/internal/api_implementation_default_test.py \ python/google/protobuf/internal/containers.py \ python/google/protobuf/internal/cpp_message.py \ python/google/protobuf/internal/decoder.py \ @@ -221,6 +220,7 @@ python_EXTRA_DIST= \ python/google/protobuf/internal/more_extensions.proto \ python/google/protobuf/internal/more_extensions_dynamic.proto \ python/google/protobuf/internal/more_messages.proto \ + python/google/protobuf/internal/_parameterized.py \ python/google/protobuf/internal/proto_builder_test.py \ python/google/protobuf/internal/python_message.py \ python/google/protobuf/internal/reflection_test.py \ @@ -242,16 +242,17 @@ python_EXTRA_DIST= \ python/google/protobuf/pyext/cpp_message.py \ python/google/protobuf/pyext/descriptor.h \ python/google/protobuf/pyext/descriptor.cc \ - python/google/protobuf/pyext/descriptor_cpp2_test.py \ + python/google/protobuf/pyext/descriptor_pool.h \ + python/google/protobuf/pyext/descriptor_pool.cc \ + python/google/protobuf/pyext/descriptor_containers.h \ + python/google/protobuf/pyext/descriptor_containers.cc \ python/google/protobuf/pyext/extension_dict.h \ python/google/protobuf/pyext/extension_dict.cc \ python/google/protobuf/pyext/message.h \ python/google/protobuf/pyext/message.cc \ - python/google/protobuf/pyext/message_factory_cpp2_test.py \ python/google/protobuf/pyext/proto2_api_test.proto \ python/google/protobuf/pyext/python.proto \ python/google/protobuf/pyext/python_protobuf.h \ - python/google/protobuf/pyext/reflection_cpp2_generated_test.py \ python/google/protobuf/pyext/repeated_composite_container.h \ python/google/protobuf/pyext/repeated_composite_container.cc \ python/google/protobuf/pyext/repeated_scalar_container.h \ diff --git a/python/google/protobuf/internal/_parameterized.py b/python/google/protobuf/internal/_parameterized.py new file mode 100755 index 00000000..6ed23308 --- /dev/null +++ b/python/google/protobuf/internal/_parameterized.py @@ -0,0 +1,438 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Adds support for parameterized tests to Python's unittest TestCase class. + +A parameterized test is a method in a test case that is invoked with different +argument tuples. + +A simple example: + + class AdditionExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + (1, 2, 3), + (4, 5, 9), + (1, 1, 3)) + def testAddition(self, op1, op2, result): + self.assertEquals(result, op1 + op2) + + +Each invocation is a separate test case and properly isolated just +like a normal test method, with its own setUp/tearDown cycle. In the +example above, there are three separate testcases, one of which will +fail due to an assertion error (1 + 1 != 3). + +Parameters for invididual test cases can be tuples (with positional parameters) +or dictionaries (with named parameters): + + class AdditionExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + {'op1': 1, 'op2': 2, 'result': 3}, + {'op1': 4, 'op2': 5, 'result': 9}, + ) + def testAddition(self, op1, op2, result): + self.assertEquals(result, op1 + op2) + +If a parameterized test fails, the error message will show the +original test name (which is modified internally) and the arguments +for the specific invocation, which are part of the string returned by +the shortDescription() method on test cases. + +The id method of the test, used internally by the unittest framework, +is also modified to show the arguments. To make sure that test names +stay the same across several invocations, object representations like + + >>> class Foo(object): + ... pass + >>> repr(Foo()) + '<__main__.Foo object at 0x23d8610>' + +are turned into '<__main__.Foo>'. For even more descriptive names, +especially in test logs, you can use the NamedParameters decorator. In +this case, only tuples are supported, and the first parameters has to +be a string (or an object that returns an apt name when converted via +str()): + + class NamedExample(parameterized.ParameterizedTestCase): + @parameterized.NamedParameters( + ('Normal', 'aa', 'aaa', True), + ('EmptyPrefix', '', 'abc', True), + ('BothEmpty', '', '', True)) + def testStartsWith(self, prefix, string, result): + self.assertEquals(result, strings.startswith(prefix)) + +Named tests also have the benefit that they can be run individually +from the command line: + + $ testmodule.py NamedExample.testStartsWithNormal + . + -------------------------------------------------------------------- + Ran 1 test in 0.000s + + OK + +Parameterized Classes +===================== +If invocation arguments are shared across test methods in a single +ParameterizedTestCase class, instead of decorating all test methods +individually, the class itself can be decorated: + + @parameterized.Parameters( + (1, 2, 3) + (4, 5, 9)) + class ArithmeticTest(parameterized.ParameterizedTestCase): + def testAdd(self, arg1, arg2, result): + self.assertEqual(arg1 + arg2, result) + + def testSubtract(self, arg2, arg2, result): + self.assertEqual(result - arg1, arg2) + +Inputs from Iterables +===================== +If parameters should be shared across several test cases, or are dynamically +created from other sources, a single non-tuple iterable can be passed into +the decorator. This iterable will be used to obtain the test cases: + + class AdditionExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + c.op1, c.op2, c.result for c in testcases + ) + def testAddition(self, op1, op2, result): + self.assertEquals(result, op1 + op2) + + +Single-Argument Test Methods +============================ +If a test method takes only one argument, the single argument does not need to +be wrapped into a tuple: + + class NegativeNumberExample(parameterized.ParameterizedTestCase): + @parameterized.Parameters( + -1, -3, -4, -5 + ) + def testIsNegative(self, arg): + self.assertTrue(IsNegative(arg)) +""" + +__author__ = 'tmarek@google.com (Torsten Marek)' + +import collections +import functools +import re +import types +import unittest +import uuid + +from google.apputils import basetest + +ADDR_RE = re.compile(r'\<([a-zA-Z0-9_\-\.]+) object at 0x[a-fA-F0-9]+\>') +_SEPARATOR = uuid.uuid1().hex +_FIRST_ARG = object() +_ARGUMENT_REPR = object() + + +def _CleanRepr(obj): + return ADDR_RE.sub(r'<\1>', repr(obj)) + + +# Helper function formerly from the unittest module, removed from it in +# Python 2.7. +def _StrClass(cls): + return '%s.%s' % (cls.__module__, cls.__name__) + + +def _NonStringIterable(obj): + return (isinstance(obj, collections.Iterable) and not + isinstance(obj, basestring)) + + +def _FormatParameterList(testcase_params): + if isinstance(testcase_params, collections.Mapping): + return ', '.join('%s=%s' % (argname, _CleanRepr(value)) + for argname, value in testcase_params.iteritems()) + elif _NonStringIterable(testcase_params): + return ', '.join(map(_CleanRepr, testcase_params)) + else: + return _FormatParameterList((testcase_params,)) + + +class _ParameterizedTestIter(object): + """Callable and iterable class for producing new test cases.""" + + def __init__(self, test_method, testcases, naming_type): + """Returns concrete test functions for a test and a list of parameters. + + The naming_type is used to determine the name of the concrete + functions as reported by the unittest framework. If naming_type is + _FIRST_ARG, the testcases must be tuples, and the first element must + have a string representation that is a valid Python identifier. + + Args: + test_method: The decorated test method. + testcases: (list of tuple/dict) A list of parameter + tuples/dicts for individual test invocations. + naming_type: The test naming type, either _NAMED or _ARGUMENT_REPR. + """ + self._test_method = test_method + self.testcases = testcases + self._naming_type = naming_type + + def __call__(self, *args, **kwargs): + raise RuntimeError('You appear to be running a parameterized test case ' + 'without having inherited from parameterized.' + 'ParameterizedTestCase. This is bad because none of ' + 'your test cases are actually being run.') + + def __iter__(self): + test_method = self._test_method + naming_type = self._naming_type + + def MakeBoundParamTest(testcase_params): + @functools.wraps(test_method) + def BoundParamTest(self): + if isinstance(testcase_params, collections.Mapping): + test_method(self, **testcase_params) + elif _NonStringIterable(testcase_params): + test_method(self, *testcase_params) + else: + test_method(self, testcase_params) + + if naming_type is _FIRST_ARG: + # Signal the metaclass that the name of the test function is unique + # and descriptive. + BoundParamTest.__x_use_name__ = True + BoundParamTest.__name__ += str(testcase_params[0]) + testcase_params = testcase_params[1:] + elif naming_type is _ARGUMENT_REPR: + # __x_extra_id__ is used to pass naming information to the __new__ + # method of TestGeneratorMetaclass. + # The metaclass will make sure to create a unique, but nondescriptive + # name for this test. + BoundParamTest.__x_extra_id__ = '(%s)' % ( + _FormatParameterList(testcase_params),) + else: + raise RuntimeError('%s is not a valid naming type.' % (naming_type,)) + + BoundParamTest.__doc__ = '%s(%s)' % ( + BoundParamTest.__name__, _FormatParameterList(testcase_params)) + if test_method.__doc__: + BoundParamTest.__doc__ += '\n%s' % (test_method.__doc__,) + return BoundParamTest + return (MakeBoundParamTest(c) for c in self.testcases) + + +def _IsSingletonList(testcases): + """True iff testcases contains only a single non-tuple element.""" + return len(testcases) == 1 and not isinstance(testcases[0], tuple) + + +def _ModifyClass(class_object, testcases, naming_type): + assert not getattr(class_object, '_id_suffix', None), ( + 'Cannot add parameters to %s,' + ' which already has parameterized methods.' % (class_object,)) + class_object._id_suffix = id_suffix = {} + for name, obj in class_object.__dict__.items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) + and isinstance(obj, types.FunctionType)): + delattr(class_object, name) + methods = {} + _UpdateClassDictForParamTestCase( + methods, id_suffix, name, + _ParameterizedTestIter(obj, testcases, naming_type)) + for name, meth in methods.iteritems(): + setattr(class_object, name, meth) + + +def _ParameterDecorator(naming_type, testcases): + """Implementation of the parameterization decorators. + + Args: + naming_type: The naming type. + testcases: Testcase parameters. + + Returns: + A function for modifying the decorated object. + """ + def _Apply(obj): + if isinstance(obj, type): + _ModifyClass( + obj, + list(testcases) if not isinstance(testcases, collections.Sequence) + else testcases, + naming_type) + return obj + else: + return _ParameterizedTestIter(obj, testcases, naming_type) + + if _IsSingletonList(testcases): + assert _NonStringIterable(testcases[0]), ( + 'Single parameter argument must be a non-string iterable') + testcases = testcases[0] + + return _Apply + + +def Parameters(*testcases): + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples/dicts/objects (for tests + with only one argument). + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _ParameterDecorator(_ARGUMENT_REPR, testcases) + + +def NamedParameters(*testcases): + """A decorator for creating parameterized tests. + + See the module docstring for a usage example. The first element of + each parameter tuple should be a string and will be appended to the + name of the test method. + + Args: + *testcases: Parameters for the decorated method, either a single + iterable, or a list of tuples. + + Returns: + A test generator to be handled by TestGeneratorMetaclass. + """ + return _ParameterDecorator(_FIRST_ARG, testcases) + + +class TestGeneratorMetaclass(type): + """Metaclass for test cases with test generators. + + A test generator is an iterable in a testcase that produces callables. These + callables must be single-argument methods. These methods are injected into + the class namespace and the original iterable is removed. If the name of the + iterable conforms to the test pattern, the injected methods will be picked + up as tests by the unittest framework. + + In general, it is supposed to be used in conjuction with the + Parameters decorator. + """ + + def __new__(mcs, class_name, bases, dct): + dct['_id_suffix'] = id_suffix = {} + for name, obj in dct.items(): + if (name.startswith(unittest.TestLoader.testMethodPrefix) and + _NonStringIterable(obj)): + iterator = iter(obj) + dct.pop(name) + _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator) + + return type.__new__(mcs, class_name, bases, dct) + + +def _UpdateClassDictForParamTestCase(dct, id_suffix, name, iterator): + """Adds individual test cases to a dictionary. + + Args: + dct: The target dictionary. + id_suffix: The dictionary for mapping names to test IDs. + name: The original name of the test case. + iterator: The iterator generating the individual test cases. + """ + for idx, func in enumerate(iterator): + assert callable(func), 'Test generators must yield callables, got %r' % ( + func,) + if getattr(func, '__x_use_name__', False): + new_name = func.__name__ + else: + new_name = '%s%s%d' % (name, _SEPARATOR, idx) + assert new_name not in dct, ( + 'Name of parameterized test case "%s" not unique' % (new_name,)) + dct[new_name] = func + id_suffix[new_name] = getattr(func, '__x_extra_id__', '') + + +class ParameterizedTestCase(basetest.TestCase): + """Base class for test cases using the Parameters decorator.""" + __metaclass__ = TestGeneratorMetaclass + + def _OriginalName(self): + return self._testMethodName.split(_SEPARATOR)[0] + + def __str__(self): + return '%s (%s)' % (self._OriginalName(), _StrClass(self.__class__)) + + def id(self): # pylint: disable=invalid-name + """Returns the descriptive ID of the test. + + This is used internally by the unittesting framework to get a name + for the test to be used in reports. + + Returns: + The test id. + """ + return '%s.%s%s' % (_StrClass(self.__class__), + self._OriginalName(), + self._id_suffix.get(self._testMethodName, '')) + + +def CoopParameterizedTestCase(other_base_class): + """Returns a new base class with a cooperative metaclass base. + + This enables the ParameterizedTestCase to be used in combination + with other base classes that have custom metaclasses, such as + mox.MoxTestBase. + + Only works with metaclasses that do not override type.__new__. + + Example: + + import google3 + import mox + + from google3.testing.pybase import parameterized + + class ExampleTest(parameterized.CoopParameterizedTestCase(mox.MoxTestBase)): + ... + + Args: + other_base_class: (class) A test case base class. + + Returns: + A new class object. + """ + metaclass = type( + 'CoopMetaclass', + (other_base_class.__metaclass__, + TestGeneratorMetaclass), {}) + return metaclass( + 'CoopParameterizedTestCase', + (other_base_class, ParameterizedTestCase), {}) diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 7ab814cf..4d4de510 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -51,10 +51,10 @@ import sys import unittest from google.apputils import basetest -from google.apputils.pybase import parameterized from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation +from google.protobuf.internal import _parameterized from google.protobuf.internal import test_util from google.protobuf import message @@ -72,7 +72,7 @@ def IsNegInf(val): return isinf(val) and (val < 0) -@parameterized.Parameters( +@_parameterized.Parameters( (unittest_pb2), (unittest_proto3_arena_pb2)) class MessageTest(basetest.TestCase): @@ -982,7 +982,6 @@ class Proto2Test(basetest.TestCase): # This is still an incomplete proto - so serializing should fail self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) - # TODO(haberman): this isn't really a proto2-specific test except that this # message has a required field in it. Should probably be factored out so # that we can test the other parts with proto3. diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 8ac76c63..a62d9845 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -1803,16 +1803,6 @@ class ReflectionTest(basetest.TestCase): self.assertRaises(TypeError, unittest_pb2.TestAllTypes().__getattribute__, 42) - @basetest.unittest.skipIf( - api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, - 'CPPv2-specific test') - def testRosyHack(self): - from google.protobuf.pyext import _message - from google3.gdata.rosy.proto import core_api2_pb2 - from google3.gdata.rosy.proto import core_pb2 - self.assertEqual(_message.Message, core_pb2.PageSelection.__base__) - self.assertEqual(_message.Message, core_api2_pb2.PageSelection.__base__) - # Since we had so many tests for protocol buffer equality, we broke these out # into separate TestCase classes. diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 68ab9659..b307620d 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -37,12 +37,12 @@ __author__ = 'kenton@google.com (Kenton Varda)' import re from google.apputils import basetest -from google.apputils.pybase import parameterized from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 from google.protobuf.internal import api_implementation +from google.protobuf.internal import _parameterized from google.protobuf.internal import test_util from google.protobuf import text_format @@ -72,7 +72,7 @@ class TextFormatBase(basetest.TestCase): return text -@parameterized.Parameters( +@_parameterized.Parameters( (unittest_pb2), (unittest_proto3_arena_pb2)) class TextFormatTest(TextFormatBase): diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index 88ed9f4c..de2f5697 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -36,7 +36,6 @@ __author__ = 'robinson@google.com (Will Robinson)' - class Error(Exception): pass class DecodeError(Error): pass class EncodeError(Error): pass diff --git a/python/google/protobuf/pyext/descriptor.cc b/python/google/protobuf/pyext/descriptor.cc index d9b90ddb..6890cd04 100644 --- a/python/google/protobuf/pyext/descriptor.cc +++ b/python/google/protobuf/pyext/descriptor.cc @@ -433,6 +433,20 @@ static PyObject* IsExtendable(PyBaseDescriptor *self, void *closure) { } } +static PyObject* GetExtensionRanges(PyBaseDescriptor *self, void *closure) { + const Descriptor* descriptor = _GetDescriptor(self); + PyObject* range_list = PyList_New(descriptor->extension_range_count()); + + for (int i = 0; i < descriptor->extension_range_count(); i++) { + const Descriptor::ExtensionRange* range = descriptor->extension_range(i); + PyObject* start = PyInt_FromLong(range->start); + PyObject* end = PyInt_FromLong(range->end); + PyList_SetItem(range_list, i, PyTuple_Pack(2, start, end)); + } + + return range_list; +} + static PyObject* GetContainingType(PyBaseDescriptor *self, void *closure) { const Descriptor* containing_type = _GetDescriptor(self)->containing_type(); @@ -512,6 +526,7 @@ static PyGetSetDef Getters[] = { { C("nested_types_by_name"), (getter)GetNestedTypesByName, NULL, "Nested types by name", NULL}, { C("extensions"), (getter)GetExtensions, NULL, "Extensions Sequence", NULL}, { C("extensions_by_name"), (getter)GetExtensionsByName, NULL, "Extensions by name", NULL}, + { C("extension_ranges"), (getter)GetExtensionRanges, NULL, "Extension ranges", NULL}, { C("enum_types"), (getter)GetEnumsSeq, NULL, "Enum sequence", NULL}, { C("enum_types_by_name"), (getter)GetEnumTypesByName, NULL, "Enum types by name", NULL}, { C("enum_values_by_name"), (getter)GetEnumValuesByName, NULL, "Enum values by name", NULL}, diff --git a/python/google/protobuf/pyext/message.cc b/python/google/protobuf/pyext/message.cc index 137f5d5f..c48f94c3 100644 --- a/python/google/protobuf/pyext/message.cc +++ b/python/google/protobuf/pyext/message.cc @@ -1264,7 +1264,27 @@ static PyObject* SerializeToString(CMessage* self, PyObject* args) { if (joined == NULL) { return NULL; } - PyErr_Format(EncodeError_class, "Message %s is missing required fields: %s", + + // TODO(haberman): this is a (hopefully temporary) hack. The unit testing + // infrastructure reloads all pure-Python modules for every test, but not + // C++ modules (because that's generally impossible: + // http://bugs.python.org/issue1144263). But if we cache EncodeError, we'll + // return the EncodeError from a previous load of the module, which won't + // match a user's attempt to catch EncodeError. So we have to look it up + // again every time. + ScopedPyObjectPtr message_module(PyImport_ImportModule( + "google.protobuf.message")); + if (message_module.get() == NULL) { + return NULL; + } + + ScopedPyObjectPtr encode_error( + PyObject_GetAttrString(message_module, "EncodeError")); + if (encode_error.get() == NULL) { + return NULL; + } + PyErr_Format(encode_error.get(), + "Message %s is missing required fields: %s", GetMessageName(self).c_str(), PyString_AsString(joined)); return NULL; } @@ -2284,8 +2304,8 @@ int SetAttr(CMessage* self, PyObject* name, PyObject* value) { PyTypeObject CMessage_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.CMessage", // tp_name + "google.protobuf." + "pyext._message.CMessage", // tp_name sizeof(CMessage), // tp_basicsize 0, // tp_itemsize (destructor)cmessage::Dealloc, // tp_dealloc diff --git a/python/google/protobuf/pyext/repeated_composite_container.cc b/python/google/protobuf/pyext/repeated_composite_container.cc index 49ef7991..abc859d7 100644 --- a/python/google/protobuf/pyext/repeated_composite_container.cc +++ b/python/google/protobuf/pyext/repeated_composite_container.cc @@ -732,8 +732,8 @@ static PyMethodDef Methods[] = { PyTypeObject RepeatedCompositeContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.RepeatedCompositeContainer", // tp_name + "google.protobuf.pyext." + "_message.RepeatedCompositeContainer", // tp_name sizeof(RepeatedCompositeContainer), // tp_basicsize 0, // tp_itemsize (destructor)repeated_composite_container::Dealloc, // tp_dealloc diff --git a/python/google/protobuf/pyext/repeated_scalar_container.cc b/python/google/protobuf/pyext/repeated_scalar_container.cc index 9f8075b2..17474598 100644 --- a/python/google/protobuf/pyext/repeated_scalar_container.cc +++ b/python/google/protobuf/pyext/repeated_scalar_container.cc @@ -769,8 +769,8 @@ static PyMethodDef Methods[] = { PyTypeObject RepeatedScalarContainer_Type = { PyVarObject_HEAD_INIT(&PyType_Type, 0) - "google.protobuf.internal." - "cpp._message.RepeatedScalarContainer", // tp_name + "google.protobuf." + "pyext._message.RepeatedScalarContainer", // tp_name sizeof(RepeatedScalarContainer), // tp_basicsize 0, // tp_itemsize (destructor)repeated_scalar_container::Dealloc, // tp_dealloc diff --git a/src/Makefile.am b/src/Makefile.am index 0f0bd089..deb0f8a9 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -307,6 +307,7 @@ EXTRA_DIST = \ google/protobuf/io/gzip_stream_unittest.sh \ google/protobuf/testdata/golden_message \ google/protobuf/testdata/golden_message_oneof_implemented \ + google/protobuf/testdata/golden_message_proto3 \ google/protobuf/testdata/golden_packed_fields_message \ google/protobuf/testdata/bad_utf8_string \ google/protobuf/testdata/text_format_unittest_data.txt \ -- cgit v1.2.3