From b55a20fa2c669b181f47ea9219b8e74d1263da19 Mon Sep 17 00:00:00 2001 From: "xiaofeng@google.com" Date: Sat, 22 Sep 2012 02:40:50 +0000 Subject: Down-integrate from internal branch --- python/google/protobuf/descriptor.py | 125 ++++- python/google/protobuf/descriptor_database.py | 120 +++++ python/google/protobuf/descriptor_pool.py | 527 +++++++++++++++++++++ .../google/protobuf/internal/api_implementation.py | 23 + python/google/protobuf/internal/containers.py | 14 +- python/google/protobuf/internal/cpp_message.py | 117 +++-- python/google/protobuf/internal/decoder.py | 6 + .../protobuf/internal/descriptor_database_test.py | 63 +++ .../protobuf/internal/descriptor_pool_test.py | 220 +++++++++ python/google/protobuf/internal/descriptor_test.py | 279 +++++++++++ .../google/protobuf/internal/enum_type_wrapper.py | 89 ++++ .../google/protobuf/internal/factory_test1.proto | 55 +++ .../google/protobuf/internal/factory_test2.proto | 77 +++ python/google/protobuf/internal/generator_test.py | 25 + .../google/protobuf/internal/message_cpp_test.py | 45 ++ .../protobuf/internal/message_factory_test.py | 113 +++++ python/google/protobuf/internal/message_test.py | 181 ++++++- .../internal/more_extensions_dynamic.proto | 49 ++ python/google/protobuf/internal/python_message.py | 66 ++- .../internal/reflection_cpp_generated_test.py | 91 ++++ python/google/protobuf/internal/reflection_test.py | 288 ++++++++++- .../protobuf/internal/test_bad_identifiers.proto | 52 ++ python/google/protobuf/internal/test_util.py | 20 +- .../google/protobuf/internal/text_format_test.py | 54 ++- .../protobuf/internal/unknown_fields_test.py | 170 +++++++ python/google/protobuf/message.py | 12 + python/google/protobuf/message_factory.py | 113 +++++ python/google/protobuf/pyext/python-proto2.cc | 91 +++- python/google/protobuf/pyext/python_descriptor.cc | 9 +- python/google/protobuf/reflection.py | 35 +- python/google/protobuf/text_format.py | 257 +++++----- python/setup.py | 32 +- 32 files changed, 3185 insertions(+), 233 deletions(-) create mode 100644 python/google/protobuf/descriptor_database.py create mode 100644 python/google/protobuf/descriptor_pool.py create mode 100644 python/google/protobuf/internal/descriptor_database_test.py create mode 100644 python/google/protobuf/internal/descriptor_pool_test.py create mode 100644 python/google/protobuf/internal/enum_type_wrapper.py create mode 100644 python/google/protobuf/internal/factory_test1.proto create mode 100644 python/google/protobuf/internal/factory_test2.proto create mode 100644 python/google/protobuf/internal/message_cpp_test.py create mode 100644 python/google/protobuf/internal/message_factory_test.py create mode 100644 python/google/protobuf/internal/more_extensions_dynamic.proto create mode 100755 python/google/protobuf/internal/reflection_cpp_generated_test.py create mode 100644 python/google/protobuf/internal/test_bad_identifiers.proto create mode 100755 python/google/protobuf/internal/unknown_fields_test.py create mode 100644 python/google/protobuf/message_factory.py (limited to 'python') diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py index cf609bee..eb13eda5 100755 --- a/python/google/protobuf/descriptor.py +++ b/python/google/protobuf/descriptor.py @@ -39,13 +39,20 @@ from google.protobuf.internal import api_implementation if api_implementation.Type() == 'cpp': - from google.protobuf.internal import cpp_message + if api_implementation.Version() == 2: + from google.protobuf.internal.cpp import _message + else: + from google.protobuf.internal import cpp_message class Error(Exception): """Base error for this module.""" +class TypeTransformationError(Error): + """Error transforming between python proto type and corresponding C++ type.""" + + class DescriptorBase(object): """Descriptors base class. @@ -72,6 +79,18 @@ class DescriptorBase(object): # Does this descriptor have non-default options? self.has_options = options is not None + def _SetOptions(self, options, options_class_name): + """Sets the descriptor's options + + This function is used in generated proto2 files to update descriptor + options. It must not be used outside proto2. + """ + self._options = options + self._options_class_name = options_class_name + + # Does this descriptor have non-default options? + self.has_options = options is not None + def GetOptions(self): """Retrieves descriptor options. @@ -250,6 +269,24 @@ class Descriptor(_NestedDescriptorBase): self._serialized_start = serialized_start self._serialized_end = serialized_end + def EnumValueName(self, enum, value): + """Returns the string name of an enum value. + + This is just a small helper method to simplify a common operation. + + Args: + enum: string name of the Enum. + value: int, value of the enum. + + Returns: + string name of the enum value. + + Raises: + KeyError if either the Enum doesn't exist or the value is not a valid + value for the enum. + """ + return self.enum_types_by_name[enum].values_by_number[value].name + def CopyToProto(self, proto): """Copies this to a descriptor_pb2.DescriptorProto. @@ -275,7 +312,7 @@ class FieldDescriptor(DescriptorBase): """Descriptor for a single field in a .proto file. - A FieldDescriptor instance has the following attriubtes: + A FieldDescriptor instance has the following attributes: name: (str) Name of this field, exactly as it appears in .proto. full_name: (str) Name of this field, including containing scope. This is @@ -358,6 +395,27 @@ class FieldDescriptor(DescriptorBase): CPPTYPE_MESSAGE = 10 MAX_CPPTYPE = 10 + _PYTHON_TO_CPP_PROTO_TYPE_MAP = { + TYPE_DOUBLE: CPPTYPE_DOUBLE, + TYPE_FLOAT: CPPTYPE_FLOAT, + TYPE_ENUM: CPPTYPE_ENUM, + TYPE_INT64: CPPTYPE_INT64, + TYPE_SINT64: CPPTYPE_INT64, + TYPE_SFIXED64: CPPTYPE_INT64, + TYPE_UINT64: CPPTYPE_UINT64, + TYPE_FIXED64: CPPTYPE_UINT64, + TYPE_INT32: CPPTYPE_INT32, + TYPE_SFIXED32: CPPTYPE_INT32, + TYPE_SINT32: CPPTYPE_INT32, + TYPE_UINT32: CPPTYPE_UINT32, + TYPE_FIXED32: CPPTYPE_UINT32, + TYPE_BYTES: CPPTYPE_STRING, + TYPE_STRING: CPPTYPE_STRING, + TYPE_BOOL: CPPTYPE_BOOL, + TYPE_MESSAGE: CPPTYPE_MESSAGE, + TYPE_GROUP: CPPTYPE_MESSAGE + } + # Must be consistent with C++ FieldDescriptor::Label enum in # descriptor.h. # @@ -395,12 +453,38 @@ class FieldDescriptor(DescriptorBase): self.extension_scope = extension_scope if api_implementation.Type() == 'cpp': if is_extension: - self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name) + if api_implementation.Version() == 2: + self._cdescriptor = _message.GetExtensionDescriptor(full_name) + else: + self._cdescriptor = cpp_message.GetExtensionDescriptor(full_name) else: - self._cdescriptor = cpp_message.GetFieldDescriptor(full_name) + if api_implementation.Version() == 2: + self._cdescriptor = _message.GetFieldDescriptor(full_name) + else: + self._cdescriptor = cpp_message.GetFieldDescriptor(full_name) else: self._cdescriptor = None + @staticmethod + def ProtoTypeToCppProtoType(proto_type): + """Converts from a Python proto type to a C++ Proto Type. + + The Python ProtocolBuffer classes specify both the 'Python' datatype and the + 'C++' datatype - and they're not the same. This helper method should + translate from one to another. + + Args: + proto_type: the Python proto type (descriptor.FieldDescriptor.TYPE_*) + Returns: + descriptor.FieldDescriptor.CPPTYPE_*, the C++ type. + Raises: + TypeTransformationError: when the Python proto type isn't known. + """ + try: + return FieldDescriptor._PYTHON_TO_CPP_PROTO_TYPE_MAP[proto_type] + except KeyError: + raise TypeTransformationError('Unknown proto_type: %s' % proto_type) + class EnumDescriptor(_NestedDescriptorBase): @@ -577,7 +661,10 @@ class FileDescriptor(DescriptorBase): self.serialized_pb = serialized_pb if (api_implementation.Type() == 'cpp' and self.serialized_pb is not None): - cpp_message.BuildFile(self.serialized_pb) + if api_implementation.Version() == 2: + _message.BuildFile(self.serialized_pb) + else: + cpp_message.BuildFile(self.serialized_pb) def CopyToProto(self, proto): """Copies this to a descriptor_pb2.FileDescriptorProto. @@ -596,3 +683,31 @@ def _ParseOptions(message, string): """ message.ParseFromString(string) return message + + +def MakeDescriptor(desc_proto, package=''): + """Make a protobuf Descriptor given a DescriptorProto protobuf. + + Args: + desc_proto: The descriptor_pb2.DescriptorProto protobuf message. + package: Optional package name for the new message Descriptor (string). + + Returns: + A Descriptor for protobuf messages. + """ + full_message_name = [desc_proto.name] + if package: full_message_name.insert(0, package) + fields = [] + for field_proto in desc_proto.field: + full_name = '.'.join(full_message_name + [field_proto.name]) + field = FieldDescriptor( + field_proto.name, full_name, field_proto.number - 1, + field_proto.number, field_proto.type, + FieldDescriptor.ProtoTypeToCppProtoType(field_proto.type), + field_proto.label, None, None, None, None, False, None, + has_default_value=False) + fields.append(field) + + desc_name = '.'.join(full_message_name) + return Descriptor(desc_proto.name, desc_name, None, None, fields, + [], [], []) diff --git a/python/google/protobuf/descriptor_database.py b/python/google/protobuf/descriptor_database.py new file mode 100644 index 00000000..8665d3c5 --- /dev/null +++ b/python/google/protobuf/descriptor_database.py @@ -0,0 +1,120 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides a container for DescriptorProtos.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + + +class DescriptorDatabase(object): + """A container accepting FileDescriptorProtos and maps DescriptorProtos.""" + + def __init__(self): + self._file_desc_protos_by_file = {} + self._file_desc_protos_by_symbol = {} + + def Add(self, file_desc_proto): + """Adds the FileDescriptorProto and its types to this database. + + Args: + file_desc_proto: The FileDescriptorProto to add. + """ + + self._file_desc_protos_by_file[file_desc_proto.name] = file_desc_proto + package = file_desc_proto.package + for message in file_desc_proto.message_type: + self._file_desc_protos_by_symbol.update( + (name, file_desc_proto) for name in _ExtractSymbols(message, package)) + for enum in file_desc_proto.enum_type: + self._file_desc_protos_by_symbol[ + '.'.join((package, enum.name))] = file_desc_proto + + def FindFileByName(self, name): + """Finds the file descriptor proto by file name. + + Typically the file name is a relative path ending to a .proto file. The + proto with the given name will have to have been added to this database + using the Add method or else an error will be raised. + + Args: + name: The file name to find. + + Returns: + The file descriptor proto matching the name. + + Raises: + KeyError if no file by the given name was added. + """ + + return self._file_desc_protos_by_file[name] + + def FindFileContainingSymbol(self, symbol): + """Finds the file descriptor proto containing the specified symbol. + + The symbol should be a fully qualified name including the file descriptor's + package and any containing messages. Some examples: + + 'some.package.name.Message' + 'some.package.name.Message.NestedEnum' + + The file descriptor proto containing the specified symbol must be added to + this database using the Add method or else an error will be raised. + + Args: + symbol: The fully qualified symbol name. + + Returns: + The file descriptor proto containing the symbol. + + Raises: + KeyError if no file contains the specified symbol. + """ + + return self._file_desc_protos_by_symbol[symbol] + + +def _ExtractSymbols(desc_proto, package): + """Pulls out all the symbols from a descriptor proto. + + Args: + desc_proto: The proto to extract symbols from. + package: The package containing the descriptor type. + + Yields: + The fully qualified name found in the descriptor. + """ + + message_name = '.'.join((package, desc_proto.name)) + yield message_name + for nested_type in desc_proto.nested_type: + for symbol in _ExtractSymbols(nested_type, message_name): + yield symbol + for enum_type in desc_proto.enum_type: + yield '.'.join((message_name, enum_type.name)) diff --git a/python/google/protobuf/descriptor_pool.py b/python/google/protobuf/descriptor_pool.py new file mode 100644 index 00000000..8f1f4457 --- /dev/null +++ b/python/google/protobuf/descriptor_pool.py @@ -0,0 +1,527 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides DescriptorPool to use as a container for proto2 descriptors. + +The DescriptorPool is used in conjection with a DescriptorDatabase to maintain +a collection of protocol buffer descriptors for use when dynamically creating +message types at runtime. + +For most applications protocol buffers should be used via modules generated by +the protocol buffer compiler tool. This should only be used when the type of +protocol buffers used in an application or library cannot be predetermined. + +Below is a straightforward example on how to use this class: + + pool = DescriptorPool() + file_descriptor_protos = [ ... ] + for file_descriptor_proto in file_descriptor_protos: + pool.Add(file_descriptor_proto) + my_message_descriptor = pool.FindMessageTypeByName('some.package.MessageType') + +The message descriptor can be used in conjunction with the message_factory +module in order to create a protocol buffer class that can be encoded and +decoded. +""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor +from google.protobuf import descriptor_database + + +class DescriptorPool(object): + """A collection of protobufs dynamically constructed by descriptor protos.""" + + def __init__(self, descriptor_db=None): + """Initializes a Pool of proto buffs. + + The descriptor_db argument to the constructor is provided to allow + specialized file descriptor proto lookup code to be triggered on demand. An + example would be an implementation which will read and compile a file + specified in a call to FindFileByName() and not require the call to Add() + at all. Results from this database will be cached internally here as well. + + Args: + descriptor_db: A secondary source of file descriptors. + """ + + self._internal_db = descriptor_database.DescriptorDatabase() + self._descriptor_db = descriptor_db + self._descriptors = {} + self._enum_descriptors = {} + self._file_descriptors = {} + + def Add(self, file_desc_proto): + """Adds the FileDescriptorProto and its types to this pool. + + Args: + file_desc_proto: The FileDescriptorProto to add. + """ + + self._internal_db.Add(file_desc_proto) + + def FindFileByName(self, file_name): + """Gets a FileDescriptor by file name. + + Args: + file_name: The path to the file to get a descriptor for. + + Returns: + A FileDescriptor for the named file. + + Raises: + KeyError: if the file can not be found in the pool. + """ + + try: + file_proto = self._internal_db.FindFileByName(file_name) + except KeyError as error: + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileByName(file_name) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file named %s' % file_name) + return self._ConvertFileProtoToFileDescriptor(file_proto) + + def FindFileContainingSymbol(self, symbol): + """Gets the FileDescriptor for the file containing the specified symbol. + + Args: + symbol: The name of the symbol to search for. + + Returns: + A FileDescriptor that contains the specified symbol. + + Raises: + KeyError: if the file can not be found in the pool. + """ + + try: + file_proto = self._internal_db.FindFileContainingSymbol(symbol) + except KeyError as error: + if self._descriptor_db: + file_proto = self._descriptor_db.FindFileContainingSymbol(symbol) + else: + raise error + if not file_proto: + raise KeyError('Cannot find a file containing %s' % symbol) + return self._ConvertFileProtoToFileDescriptor(file_proto) + + def FindMessageTypeByName(self, full_name): + """Loads the named descriptor from the pool. + + Args: + full_name: The full name of the descriptor to load. + + Returns: + The descriptor for the named type. + """ + + full_name = full_name.lstrip('.') # fix inconsistent qualified name formats + if full_name not in self._descriptors: + self.FindFileContainingSymbol(full_name) + return self._descriptors[full_name] + + def FindEnumTypeByName(self, full_name): + """Loads the named enum descriptor from the pool. + + Args: + full_name: The full name of the enum descriptor to load. + + Returns: + The enum descriptor for the named type. + """ + + full_name = full_name.lstrip('.') # fix inconsistent qualified name formats + if full_name not in self._enum_descriptors: + self.FindFileContainingSymbol(full_name) + return self._enum_descriptors[full_name] + + def _ConvertFileProtoToFileDescriptor(self, file_proto): + """Creates a FileDescriptor from a proto or returns a cached copy. + + This method also has the side effect of loading all the symbols found in + the file into the appropriate dictionaries in the pool. + + Args: + file_proto: The proto to convert. + + Returns: + A FileDescriptor matching the passed in proto. + """ + + if file_proto.name not in self._file_descriptors: + file_descriptor = descriptor.FileDescriptor( + name=file_proto.name, + package=file_proto.package, + options=file_proto.options, + serialized_pb=file_proto.SerializeToString()) + scope = {} + dependencies = list(self._GetDeps(file_proto)) + + for dependency in dependencies: + dep_desc = self.FindFileByName(dependency.name) + dep_proto = descriptor_pb2.FileDescriptorProto.FromString( + dep_desc.serialized_pb) + package = '.' + dep_proto.package + package_prefix = package + '.' + + def _strip_package(symbol): + if symbol.startswith(package_prefix): + return symbol[len(package_prefix):] + return symbol + + symbols = list(self._ExtractSymbols(dep_proto.message_type, package)) + scope.update(symbols) + scope.update((_strip_package(k), v) for k, v in symbols) + + symbols = list(self._ExtractEnums(dep_proto.enum_type, package)) + scope.update(symbols) + scope.update((_strip_package(k), v) for k, v in symbols) + + for message_type in file_proto.message_type: + message_desc = self._ConvertMessageDescriptor( + message_type, file_proto.package, file_descriptor, scope) + file_descriptor.message_types_by_name[message_desc.name] = message_desc + for enum_type in file_proto.enum_type: + self._ConvertEnumDescriptor(enum_type, file_proto.package, + file_descriptor, None, scope) + for desc_proto in self._ExtractMessages(file_proto.message_type): + self._SetFieldTypes(desc_proto, scope) + + for desc_proto in file_proto.message_type: + desc = scope[desc_proto.name] + file_descriptor.message_types_by_name[desc_proto.name] = desc + self.Add(file_proto) + self._file_descriptors[file_proto.name] = file_descriptor + + return self._file_descriptors[file_proto.name] + + def _ConvertMessageDescriptor(self, desc_proto, package=None, file_desc=None, + scope=None): + """Adds the proto to the pool in the specified package. + + Args: + desc_proto: The descriptor_pb2.DescriptorProto protobuf message. + package: The package the proto should be located in. + file_desc: The file containing this message. + scope: Dict mapping short and full symbols to message and enum types. + + Returns: + The added descriptor. + """ + + if package: + desc_name = '.'.join((package, desc_proto.name)) + else: + desc_name = desc_proto.name + + if file_desc is None: + file_name = None + else: + file_name = file_desc.name + + if scope is None: + scope = {} + + nested = [ + self._ConvertMessageDescriptor(nested, desc_name, file_desc, scope) + for nested in desc_proto.nested_type] + enums = [ + self._ConvertEnumDescriptor(enum, desc_name, file_desc, None, scope) + for enum in desc_proto.enum_type] + fields = [self._MakeFieldDescriptor(field, desc_name, index) + for index, field in enumerate(desc_proto.field)] + extensions = [self._MakeFieldDescriptor(extension, desc_name, True) + for index, extension in enumerate(desc_proto.extension)] + extension_ranges = [(r.start, r.end) for r in desc_proto.extension_range] + if extension_ranges: + is_extendable = True + else: + is_extendable = False + desc = descriptor.Descriptor( + name=desc_proto.name, + full_name=desc_name, + filename=file_name, + containing_type=None, + fields=fields, + nested_types=nested, + enum_types=enums, + extensions=extensions, + options=desc_proto.options, + is_extendable=is_extendable, + extension_ranges=extension_ranges, + file=file_desc, + serialized_start=None, + serialized_end=None) + for nested in desc.nested_types: + nested.containing_type = desc + for enum in desc.enum_types: + enum.containing_type = desc + scope[desc_proto.name] = desc + scope['.' + desc_name] = desc + self._descriptors[desc_name] = desc + return desc + + def _ConvertEnumDescriptor(self, enum_proto, package=None, file_desc=None, + containing_type=None, scope=None): + """Make a protobuf EnumDescriptor given an EnumDescriptorProto protobuf. + + Args: + enum_proto: The descriptor_pb2.EnumDescriptorProto protobuf message. + package: Optional package name for the new message EnumDescriptor. + file_desc: The file containing the enum descriptor. + containing_type: The type containing this enum. + scope: Scope containing available types. + + Returns: + The added descriptor + """ + + if package: + enum_name = '.'.join((package, enum_proto.name)) + else: + enum_name = enum_proto.name + + if file_desc is None: + file_name = None + else: + file_name = file_desc.name + + values = [self._MakeEnumValueDescriptor(value, index) + for index, value in enumerate(enum_proto.value)] + desc = descriptor.EnumDescriptor(name=enum_proto.name, + full_name=enum_name, + filename=file_name, + file=file_desc, + values=values, + containing_type=containing_type, + options=enum_proto.options) + scope[enum_proto.name] = desc + scope['.%s' % enum_name] = desc + self._enum_descriptors[enum_name] = desc + return desc + + def _MakeFieldDescriptor(self, field_proto, message_name, index, + is_extension=False): + """Creates a field descriptor from a FieldDescriptorProto. + + For message and enum type fields, this method will do a look up + in the pool for the appropriate descriptor for that type. If it + is unavailable, it will fall back to the _source function to + create it. If this type is still unavailable, construction will + fail. + + Args: + field_proto: The proto describing the field. + message_name: The name of the containing message. + index: Index of the field + is_extension: Indication that this field is for an extension. + + Returns: + An initialized FieldDescriptor object + """ + + if message_name: + full_name = '.'.join((message_name, field_proto.name)) + else: + full_name = field_proto.name + + return descriptor.FieldDescriptor( + name=field_proto.name, + full_name=full_name, + index=index, + number=field_proto.number, + type=field_proto.type, + cpp_type=None, + message_type=None, + enum_type=None, + containing_type=None, + label=field_proto.label, + has_default_value=False, + default_value=None, + is_extension=is_extension, + extension_scope=None, + options=field_proto.options) + + def _SetFieldTypes(self, desc_proto, scope): + """Sets the field's type, cpp_type, message_type and enum_type. + + Args: + desc_proto: The message descriptor to update. + scope: Enclosing scope of available types. + """ + + desc = scope[desc_proto.name] + for field_proto, field_desc in zip(desc_proto.field, desc.fields): + if field_proto.type_name: + type_name = field_proto.type_name + if type_name not in scope: + type_name = '.' + type_name + desc = scope[type_name] + else: + desc = None + + if not field_proto.HasField('type'): + if isinstance(desc, descriptor.Descriptor): + field_proto.type = descriptor.FieldDescriptor.TYPE_MESSAGE + else: + field_proto.type = descriptor.FieldDescriptor.TYPE_ENUM + + field_desc.cpp_type = descriptor.FieldDescriptor.ProtoTypeToCppProtoType( + field_proto.type) + + if (field_proto.type == descriptor.FieldDescriptor.TYPE_MESSAGE + or field_proto.type == descriptor.FieldDescriptor.TYPE_GROUP): + field_desc.message_type = desc + + if field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.enum_type = desc + + if field_proto.label == descriptor.FieldDescriptor.LABEL_REPEATED: + field_desc.has_default = False + field_desc.default_value = [] + elif field_proto.HasField('default_value'): + field_desc.has_default = True + if (field_proto.type == descriptor.FieldDescriptor.TYPE_DOUBLE or + field_proto.type == descriptor.FieldDescriptor.TYPE_FLOAT): + field_desc.default_value = float(field_proto.default_value) + elif field_proto.type == descriptor.FieldDescriptor.TYPE_STRING: + field_desc.default_value = field_proto.default_value + elif field_proto.type == descriptor.FieldDescriptor.TYPE_BOOL: + field_desc.default_value = field_proto.default_value.lower() == 'true' + elif field_proto.type == descriptor.FieldDescriptor.TYPE_ENUM: + field_desc.default_value = field_desc.enum_type.values_by_name[ + field_proto.default_value].index + else: + field_desc.default_value = int(field_proto.default_value) + else: + field_desc.has_default = False + field_desc.default_value = None + + field_desc.type = field_proto.type + + for nested_type in desc_proto.nested_type: + self._SetFieldTypes(nested_type, scope) + + def _MakeEnumValueDescriptor(self, value_proto, index): + """Creates a enum value descriptor object from a enum value proto. + + Args: + value_proto: The proto describing the enum value. + index: The index of the enum value. + + Returns: + An initialized EnumValueDescriptor object. + """ + + return descriptor.EnumValueDescriptor( + name=value_proto.name, + index=index, + number=value_proto.number, + options=value_proto.options, + type=None) + + def _ExtractSymbols(self, desc_protos, package): + """Pulls out all the symbols from descriptor protos. + + Args: + desc_protos: The protos to extract symbols from. + package: The package containing the descriptor type. + Yields: + A two element tuple of the type name and descriptor object. + """ + + for desc_proto in desc_protos: + if package: + message_name = '.'.join((package, desc_proto.name)) + else: + message_name = desc_proto.name + message_desc = self.FindMessageTypeByName(message_name) + yield (message_name, message_desc) + for symbol in self._ExtractSymbols(desc_proto.nested_type, message_name): + yield symbol + for symbol in self._ExtractEnums(desc_proto.enum_type, message_name): + yield symbol + + def _ExtractEnums(self, enum_protos, package): + """Pulls out all the symbols from enum protos. + + Args: + enum_protos: The protos to extract symbols from. + package: The package containing the enum type. + + Yields: + A two element tuple of the type name and enum descriptor object. + """ + + for enum_proto in enum_protos: + if package: + enum_name = '.'.join((package, enum_proto.name)) + else: + enum_name = enum_proto.name + enum_desc = self.FindEnumTypeByName(enum_name) + yield (enum_name, enum_desc) + + def _ExtractMessages(self, desc_protos): + """Pulls out all the message protos from descriptos. + + Args: + desc_protos: The protos to extract symbols from. + + Yields: + Descriptor protos. + """ + + for desc_proto in desc_protos: + yield desc_proto + for message in self._ExtractMessages(desc_proto.nested_type): + yield message + + def _GetDeps(self, file_proto): + """Recursively finds dependencies for file protos. + + Args: + file_proto: The proto to get dependencies from. + + Yields: + Each direct and indirect dependency. + """ + + for dependency in file_proto.dependency: + dep_desc = self.FindFileByName(dependency) + dep_proto = descriptor_pb2.FileDescriptorProto.FromString( + dep_desc.serialized_pb) + yield dep_proto + for parent_dep in self._GetDeps(dep_proto): + yield parent_dep diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index b3e412e2..ce02a329 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -56,9 +56,32 @@ if _implementation_type != 'python': # _implementation_type = 'python' +# This environment variable can be used to switch between the two +# 'cpp' implementations. Right now only 1 and 2 are valid values. Any +# other value will be ignored. +_implementation_version_str = os.getenv( + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', + '1') + + +if _implementation_version_str not in ('1', '2'): + raise ValueError( + "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + + _implementation_version_str + "' (supported versions: 1, 2)" + ) + + +_implementation_version = int(_implementation_version_str) + + + # Usage of this function is discouraged. Clients shouldn't care which # implementation of the API is in use. Note that there is no guarantee # that differences between APIs will be maintained. # Please don't use this function if possible. def Type(): return _implementation_type + +# See comment on 'Type' above. +def Version(): + return _implementation_version diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 097a3c26..34b35f8a 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -78,8 +78,13 @@ class BaseContainer(object): def __repr__(self): return repr(self._values) - def sort(self, sort_function=cmp): - self._values.sort(sort_function) + def sort(self, *args, **kwargs): + # Continue to support the old sort_function keyword argument. + # This is expected to be a rare occurrence, so use LBYL to avoid + # the overhead of actually catching KeyError. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._values.sort(*args, **kwargs) class RepeatedScalarFieldContainer(BaseContainer): @@ -235,6 +240,11 @@ class RepeatedCompositeFieldContainer(BaseContainer): """ self.extend(other._values) + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py index 3f426502..23ab9ba4 100755 --- a/python/google/protobuf/internal/cpp_message.py +++ b/python/google/protobuf/internal/cpp_message.py @@ -34,8 +34,10 @@ Descriptor objects at runtime backed by the protocol buffer C++ API. __author__ = 'petar@google.com (Petar Petrov)' +import copy_reg import operator from google.protobuf.internal import _net_proto2___python +from google.protobuf.internal import enum_type_wrapper from google.protobuf import message @@ -156,10 +158,12 @@ class RepeatedScalarContainer(object): def __hash__(self): raise TypeError('unhashable object') - def sort(self, sort_function=cmp): - values = self[slice(None, None, None)] - values.sort(sort_function) - self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + def sort(self, *args, **kwargs): + # Maintain compatibility with the previous interface. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, + sorted(self, *args, **kwargs)) def RepeatedScalarProperty(cdescriptor): @@ -202,6 +206,12 @@ class RepeatedCompositeContainer(object): for message in elem_seq: self.add().MergeFrom(message) + def remove(self, value): + # TODO(protocol-devel): This is inefficient as it needs to generate a + # message pointer for each message only to do index(). Move this to a C++ + # extension function. + self.__delitem__(self[slice(None, None, None)].index(value)) + def MergeFrom(self, other): for message in other[:]: self.add().MergeFrom(message) @@ -236,27 +246,29 @@ class RepeatedCompositeContainer(object): def __hash__(self): raise TypeError('unhashable object') - def sort(self, sort_function=cmp): - messages = [] - for index in range(len(self)): - # messages[i][0] is where the i-th element of the new array has to come - # from. - # messages[i][1] is where the i-th element of the old array has to go. - messages.append([index, 0, self[index]]) - messages.sort(lambda x,y: sort_function(x[2], y[2])) + def sort(self, cmp=None, key=None, reverse=False, **kwargs): + # Maintain compatibility with the old interface. + if cmp is None and 'sort_function' in kwargs: + cmp = kwargs.pop('sort_function') - # Remember which position each elements has to move to. - for i in range(len(messages)): - messages[messages[i][0]][1] = i + # The cmp function, if provided, is passed the results of the key function, + # so we only need to wrap one of them. + if key is None: + index_key = self.__getitem__ + else: + index_key = lambda i: key(self[i]) + + # Sort the list of current indexes by the underlying object. + indexes = range(len(self)) + indexes.sort(cmp=cmp, key=index_key, reverse=reverse) # Apply the transposition. - for i in range(len(messages)): - from_position = messages[i][0] - if i == from_position: + for dest, src in enumerate(indexes): + if dest == src: continue - self._cmsg.SwapRepeatedFieldElements( - self._cfield_descriptor, i, from_position) - messages[messages[i][1]][0] = from_position + self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) + # Don't swap the same value twice. + indexes[src] = src def RepeatedCompositeProperty(cdescriptor, message_type): @@ -359,11 +371,12 @@ class ExtensionDict(object): return None -def NewMessage(message_descriptor, dictionary): +def NewMessage(bases, message_descriptor, dictionary): """Creates a new protocol message *class*.""" _AddClassAttributesForNestedExtensions(message_descriptor, dictionary) _AddEnumValues(message_descriptor, dictionary) _AddDescriptors(message_descriptor, dictionary) + return bases def InitMessage(message_descriptor, cls): @@ -372,6 +385,7 @@ def InitMessage(message_descriptor, cls): _AddInitMethod(message_descriptor, cls) _AddMessageMethods(message_descriptor, cls) _AddPropertiesForExtensions(message_descriptor, cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) def _AddDescriptors(message_descriptor, dictionary): @@ -387,7 +401,7 @@ def _AddDescriptors(message_descriptor, dictionary): field.full_name) dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ - '_cmsg', '_owner', '_composite_fields', 'Extensions'] + '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] def _AddEnumValues(message_descriptor, dictionary): @@ -398,6 +412,7 @@ def _AddEnumValues(message_descriptor, dictionary): dictionary: Class dictionary that should be populated. """ for enum_type in message_descriptor.enum_types: + dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) for enum_value in enum_type.values: dictionary[enum_value.name] = enum_value.number @@ -439,28 +454,35 @@ def _AddInitMethod(message_descriptor, cls): def Init(self, **kwargs): """Message constructor.""" cmessage = kwargs.pop('__cmessage', None) - if cmessage is None: - self._cmsg = NewCMessage(message_descriptor.full_name) - else: + if cmessage: self._cmsg = cmessage + else: + self._cmsg = NewCMessage(message_descriptor.full_name) # Keep a reference to the owner, as the owner keeps a reference to the # underlying protocol buffer message. owner = kwargs.pop('__owner', None) - if owner is not None: + if owner: self._owner = owner - self.Extensions = ExtensionDict(self) + if message_descriptor.is_extendable: + self.Extensions = ExtensionDict(self) + else: + # Reference counting in the C++ code is broken and depends on + # the Extensions reference to keep this object alive during unit + # tests (see b/4856052). Remove this once b/4945904 is fixed. + self._HACK_REFCOUNTS = self self._composite_fields = {} for field_name, field_value in kwargs.iteritems(): field_cdescriptor = self.__descriptors.get(field_name, None) - if field_cdescriptor is None: + if not field_cdescriptor: raise ValueError('Protocol message has no "%s" field.' % field_name) if field_cdescriptor.label == _LABEL_REPEATED: if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + field_name = getattr(self, field_name) for val in field_value: - getattr(self, field_name).add().MergeFrom(val) + field_name.add().MergeFrom(val) else: getattr(self, field_name).extend(field_value) elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: @@ -497,12 +519,34 @@ def _AddMessageMethods(message_descriptor, cls): return self._cmsg.HasField(field_name) def ClearField(self, field_name): + child_cmessage = None if field_name in self._composite_fields: + child_field = self._composite_fields[field_name] del self._composite_fields[field_name] - self._cmsg.ClearField(field_name) + + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + child_cmessage = child_field._cmsg + + if child_cmessage is not None: + self._cmsg.ClearField(field_name, child_cmessage) + else: + self._cmsg.ClearField(field_name) def Clear(self): - return self._cmsg.Clear() + cmessages_to_release = [] + for field_name, child_field in self._composite_fields.iteritems(): + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) + self._composite_fields.clear() + self._cmsg.Clear(cmessages_to_release) def IsInitialized(self, errors=None): if self._cmsg.IsInitialized(): @@ -514,8 +558,8 @@ def _AddMessageMethods(message_descriptor, cls): def SerializeToString(self): if not self.IsInitialized(): raise message.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) + 'Message %s is missing required fields: %s' % ( + self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) return self._cmsg.SerializeToString() def SerializePartialToString(self): @@ -534,7 +578,8 @@ def _AddMessageMethods(message_descriptor, cls): def MergeFrom(self, msg): if not isinstance(msg, cls): raise TypeError( - "Parameter to MergeFrom() must be instance of same class.") + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) self._cmsg.MergeFrom(msg._cmsg) def CopyFrom(self, msg): @@ -581,6 +626,8 @@ def _AddMessageMethods(message_descriptor, cls): raise TypeError('unhashable object') def __unicode__(self): + # Lazy import to prevent circular import when text_format imports this file. + from google.protobuf import text_format return text_format.MessageToString(self, as_utf8=True).decode('utf-8') # Attach the local methods to the message class. diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 55f746f5..cb6f5729 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -576,6 +576,7 @@ def MessageSetItemDecoder(extensions_by_number): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 @@ -614,6 +615,11 @@ def MessageSetItemDecoder(extensions_by_number): # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, + buffer[message_set_item_start:pos])) return pos diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py new file mode 100644 index 00000000..d0ca7892 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_database.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database + + +class DescriptorDatabaseTest(unittest.TestCase): + + def testAdd(self): + db = descriptor_database.DescriptorDatabase() + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + db.Add(file_desc_proto) + + self.assertEquals(file_desc_proto, db.FindFileByName( + 'net/proto2/python/internal/factory_test2.proto')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Enum')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum')) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py new file mode 100644 index 00000000..a615d787 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -0,0 +1,220 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_pool.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool + + +class DescriptorPoolTest(unittest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.DescriptorPool() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(self.factory_test1_fd) + self.pool.Add(self.factory_test2_fd) + + def testFindFileByName(self): + name1 = 'net/proto2/python/internal/factory_test1.proto' + file_desc1 = self.pool.FindFileByName(name1) + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals(name1, file_desc1.name) + self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + name2 = 'net/proto2/python/internal/factory_test2.proto' + file_desc2 = self.pool.FindFileByName(name2) + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals(name2, file_desc2.name) + self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileByNameFailure(self): + try: + self.pool.FindFileByName('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testFindFileContainingSymbol(self): + file_desc1 = self.pool.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory1Message') + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals('net/proto2/python/internal/factory_test1.proto', + file_desc1.name) + self.assertEquals('net.proto2.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + file_desc2 = self.pool.FindFileContainingSymbol( + 'net.proto2.python.internal.Factory2Message') + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals('net/proto2/python/internal/factory_test2.proto', + file_desc2.name) + self.assertEquals('net.proto2.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileContainingSymbolFailure(self): + try: + self.pool.FindFileContainingSymbol('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testFindMessageTypeByName(self): + msg1 = self.pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory1Message') + self.assertIsInstance(msg1, descriptor.Descriptor) + self.assertEquals('Factory1Message', msg1.name) + self.assertEquals('net.proto2.python.internal.Factory1Message', + msg1.full_name) + self.assertEquals(None, msg1.containing_type) + + nested_msg1 = msg1.nested_types[0] + self.assertEquals('NestedFactory1Message', nested_msg1.name) + self.assertEquals(msg1, nested_msg1.containing_type) + + nested_enum1 = msg1.enum_types[0] + self.assertEquals('NestedFactory1Enum', nested_enum1.name) + self.assertEquals(msg1, nested_enum1.containing_type) + + self.assertEquals(nested_msg1, msg1.fields_by_name[ + 'nested_factory_1_message'].message_type) + self.assertEquals(nested_enum1, msg1.fields_by_name[ + 'nested_factory_1_enum'].enum_type) + + msg2 = self.pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message') + self.assertIsInstance(msg2, descriptor.Descriptor) + self.assertEquals('Factory2Message', msg2.name) + self.assertEquals('net.proto2.python.internal.Factory2Message', + msg2.full_name) + self.assertIsNone(msg2.containing_type) + + nested_msg2 = msg2.nested_types[0] + self.assertEquals('NestedFactory2Message', nested_msg2.name) + self.assertEquals(msg2, nested_msg2.containing_type) + + nested_enum2 = msg2.enum_types[0] + self.assertEquals('NestedFactory2Enum', nested_enum2.name) + self.assertEquals(msg2, nested_enum2.containing_type) + + self.assertEquals(nested_msg2, msg2.fields_by_name[ + 'nested_factory_2_message'].message_type) + self.assertEquals(nested_enum2, msg2.fields_by_name[ + 'nested_factory_2_enum'].enum_type) + + self.assertTrue(msg2.fields_by_name['int_with_default'].has_default) + self.assertEquals( + 1776, msg2.fields_by_name['int_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['double_with_default'].has_default) + self.assertEquals( + 9.99, msg2.fields_by_name['double_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['string_with_default'].has_default) + self.assertEquals( + 'hello world', msg2.fields_by_name['string_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default) + self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default) + self.assertEquals( + 1, msg2.fields_by_name['enum_with_default'].default_value) + + msg3 = self.pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Message') + self.assertEquals(nested_msg2, msg3) + + def testFindMessageTypeByNameFailure(self): + try: + self.pool.FindMessageTypeByName('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testFindEnumTypeByName(self): + enum1 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory1Enum') + self.assertIsInstance(enum1, descriptor.EnumDescriptor) + self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) + self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + + nested_enum1 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory1Message.NestedFactory1Enum') + self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) + self.assertEquals( + 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) + + enum2 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory2Enum') + self.assertIsInstance(enum2, descriptor.EnumDescriptor) + self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) + self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) + + nested_enum2 = self.pool.FindEnumTypeByName( + 'net.proto2.python.internal.Factory2Message.NestedFactory2Enum') + self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) + self.assertEquals( + 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) + + def testFindEnumTypeByNameFailure(self): + try: + self.pool.FindEnumTypeByName('Does not exist') + self.fail('Expected KeyError') + except KeyError: + pass + + def testUserDefinedDB(self): + db = descriptor_database.DescriptorDatabase() + self.pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + self.testFindMessageTypeByName() + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 05c27452..c74f882e 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -35,6 +35,7 @@ __author__ = 'robinson@google.com (Will Robinson)' import unittest +from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 @@ -101,6 +102,15 @@ class DescriptorTest(unittest.TestCase): self.my_method ]) + def testEnumValueName(self): + self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4), + 'FOREIGN_FOO') + + self.assertEqual( + self.my_message.enum_types_by_name[ + 'ForeignEnum'].values_by_number[4].name, + self.my_message.EnumValueName('ForeignEnum', 4)) + def testEnumFixups(self): self.assertEqual(self.my_enum, self.my_enum.values[0].type) @@ -125,6 +135,257 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_service.GetOptions(), descriptor_pb2.ServiceOptions()) + def testSimpleCustomOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["field1"] + enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"] + enum_value_descriptor =\ + message_descriptor.enum_values_by_name["ANENUM_VAL2"] + service_descriptor =\ + unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Foo") + + file_options = file_descriptor.GetOptions() + file_opt1 = unittest_custom_options_pb2.file_opt1 + self.assertEqual(9876543210, file_options.Extensions[file_opt1]) + message_options = message_descriptor.GetOptions() + message_opt1 = unittest_custom_options_pb2.message_opt1 + self.assertEqual(-56, message_options.Extensions[message_opt1]) + field_options = field_descriptor.GetOptions() + field_opt1 = unittest_custom_options_pb2.field_opt1 + self.assertEqual(8765432109, field_options.Extensions[field_opt1]) + field_opt2 = unittest_custom_options_pb2.field_opt2 + self.assertEqual(42, field_options.Extensions[field_opt2]) + enum_options = enum_descriptor.GetOptions() + enum_opt1 = unittest_custom_options_pb2.enum_opt1 + self.assertEqual(-789, enum_options.Extensions[enum_opt1]) + enum_value_options = enum_value_descriptor.GetOptions() + enum_value_opt1 = unittest_custom_options_pb2.enum_value_opt1 + self.assertEqual(123, enum_value_options.Extensions[enum_value_opt1]) + + service_options = service_descriptor.GetOptions() + service_opt1 = unittest_custom_options_pb2.service_opt1 + self.assertEqual(-9876543210, service_options.Extensions[service_opt1]) + method_options = method_descriptor.GetOptions() + method_opt1 = unittest_custom_options_pb2.method_opt1 + self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, + method_options.Extensions[method_opt1]) + + def testDifferentCustomOptionTypes(self): + kint32min = -2**31 + kint64min = -2**63 + kint32max = 2**31 - 1 + kint64max = 2**63 - 1 + kuint32max = 2**32 - 1 + kuint64max = 2**64 - 1 + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMinIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(False, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMaxIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(True, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionOtherValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(-100, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertAlmostEqual(12.3456789, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(1.234567890123456789, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + self.assertEqual("Hello, \"World\"", message_options.Extensions[ + unittest_custom_options_pb2.string_opt]) + self.assertEqual("Hello\0World", message_options.Extensions[ + unittest_custom_options_pb2.bytes_opt]) + dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum + self.assertEqual( + dummy_enum.TEST_OPTION_ENUM_TYPE2, + message_options.Extensions[unittest_custom_options_pb2.enum_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromPositiveInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromNegativeInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(-12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(-154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + def testComplexExtensionOptions(self): + descriptor =\ + unittest_custom_options_pb2.VariousComplexOptions.DESCRIPTOR + options = descriptor.GetOptions() + self.assertEqual(42, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].foo) + self.assertEqual(324, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(876, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(987, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].baz) + self.assertEqual(654, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.grault]) + self.assertEqual(743, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.foo) + self.assertEqual(1999, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2008, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(741, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].foo) + self.assertEqual(1998, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2121, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(1971, options.Extensions[ + unittest_custom_options_pb2.ComplexOptionType2 + .ComplexOptionType4.complex_opt4].waldo) + self.assertEqual(321, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].fred.waldo) + self.assertEqual(9, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].qux) + self.assertEqual(22, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].complexoptiontype5.plugh) + self.assertEqual(24, options.Extensions[ + unittest_custom_options_pb2.complexopt6].xyzzy) + + # Check that aggregate options were parsed and saved correctly in + # the appropriate descriptors. + def testAggregateOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.AggregateMessage.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["fieldname"] + enum_descriptor = unittest_custom_options_pb2.AggregateEnum.DESCRIPTOR + enum_value_descriptor = enum_descriptor.values_by_name["VALUE"] + service_descriptor =\ + unittest_custom_options_pb2.AggregateService.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Method") + + # Tests for the different types of data embedded in fileopt + file_options = file_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fileopt] + self.assertEqual(100, file_options.i) + self.assertEqual("FileAnnotation", file_options.s) + self.assertEqual("NestedFileAnnotation", file_options.sub.s) + self.assertEqual("FileExtensionAnnotation", file_options.file.Extensions[ + unittest_custom_options_pb2.fileopt].s) + self.assertEqual("EmbeddedMessageSetElement", file_options.mset.Extensions[ + unittest_custom_options_pb2.AggregateMessageSetElement + .message_set_extension].s) + + # Simple tests for all the other types of annotations + self.assertEqual( + "MessageAnnotation", + message_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.msgopt].s) + self.assertEqual( + "FieldAnnotation", + field_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fieldopt].s) + self.assertEqual( + "EnumAnnotation", + enum_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumopt].s) + self.assertEqual( + "EnumValueAnnotation", + enum_value_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumvalopt].s) + self.assertEqual( + "ServiceAnnotation", + service_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.serviceopt].s) + self.assertEqual( + "MethodAnnotation", + method_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.methodopt].s) + + def testNestedOptions(self): + nested_message =\ + unittest_custom_options_pb2.NestedOptionType.NestedMessage.DESCRIPTOR + self.assertEqual(1001, nested_message.GetOptions().Extensions[ + unittest_custom_options_pb2.message_opt1]) + nested_field = nested_message.fields_by_name["nested_field"] + self.assertEqual(1002, nested_field.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt1]) + outer_message =\ + unittest_custom_options_pb2.NestedOptionType.DESCRIPTOR + nested_enum = outer_message.enum_types_by_name["NestedEnum"] + self.assertEqual(1003, nested_enum.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_opt1]) + nested_enum_value = outer_message.enum_values_by_name["NESTED_ENUM_VALUE"] + self.assertEqual(1004, nested_enum_value.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_value_opt1]) + nested_extension = outer_message.extensions_by_name["nested_extension"] + self.assertEqual(1005, nested_extension.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt2]) + def testFileDescriptorReferences(self): self.assertEqual(self.my_enum.file, self.my_file) self.assertEqual(self.my_message.file, self.my_file) @@ -273,6 +534,7 @@ class DescriptorCopyToProtoTest(unittest.TestCase): UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" name: 'google/protobuf/unittest_import.proto' package: 'protobuf_unittest_import' + dependency: 'google/protobuf/unittest_import_public.proto' message_type: < name: 'ImportMessage' field: < @@ -302,6 +564,7 @@ class DescriptorCopyToProtoTest(unittest.TestCase): java_package: 'com.google.protobuf.test' optimize_for: 1 # SPEED > + public_dependency: 0 """) self._InternalTestCopyToProto( @@ -330,5 +593,21 @@ class DescriptorCopyToProtoTest(unittest.TestCase): TEST_SERVICE_ASCII) +class MakeDescriptorTest(unittest.TestCase): + def testMakeDescriptorWithUnsignedIntField(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/enum_type_wrapper.py b/python/google/protobuf/internal/enum_type_wrapper.py new file mode 100644 index 00000000..7b28645a --- /dev/null +++ b/python/google/protobuf/internal/enum_type_wrapper.py @@ -0,0 +1,89 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A simple wrapper around enum types to expose utility functions. + +Instances are created as properties with the same name as the enum they wrap +on proto classes. For usage, see: + reflection_test.py +""" + +__author__ = 'rabsatt@google.com (Kevin Rabsatt)' + + +class EnumTypeWrapper(object): + """A utility for finding the names of enum values.""" + + DESCRIPTOR = None + + def __init__(self, enum_type): + """Inits EnumTypeWrapper with an EnumDescriptor.""" + self._enum_type = enum_type + self.DESCRIPTOR = enum_type; + + def Name(self, number): + """Returns a string containing the name of an enum value.""" + if number in self._enum_type.values_by_number: + return self._enum_type.values_by_number[number].name + raise ValueError('Enum %s has no name defined for value %d' % ( + self._enum_type.name, number)) + + def Value(self, name): + """Returns the value coresponding to the given enum name.""" + if name in self._enum_type.values_by_name: + return self._enum_type.values_by_name[name].number + raise ValueError('Enum %s has no value defined for name %s' % ( + self._enum_type.name, name)) + + def keys(self): + """Return a list of the string names in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.name + for value_descriptor in self._enum_type.values] + + def values(self): + """Return a list of the integer values in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.number + for value_descriptor in self._enum_type.values] + + def items(self): + """Return a list of the (name, value) pairs of the enum. + + These are returned in the order they were defined in the .proto file. + """ + return [(value_descriptor.name, value_descriptor.number) + for value_descriptor in self._enum_type.values] diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto new file mode 100644 index 00000000..9f55e037 --- /dev/null +++ b/python/google/protobuf/internal/factory_test1.proto @@ -0,0 +1,55 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + + +enum Factory1Enum { + FACTORY_1_VALUE_0 = 0; + FACTORY_1_VALUE_1 = 1; +} + +message Factory1Message { + optional Factory1Enum factory_1_enum = 1; + enum NestedFactory1Enum { + NESTED_FACTORY_1_VALUE_0 = 0; + NESTED_FACTORY_1_VALUE_1 = 1; + } + optional NestedFactory1Enum nested_factory_1_enum = 2; + message NestedFactory1Message { + optional string value = 1; + } + optional NestedFactory1Message nested_factory_1_message = 3; + optional int32 scalar_value = 4; + repeated string list_value = 5; +} diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto new file mode 100644 index 00000000..d3ce4d7f --- /dev/null +++ b/python/google/protobuf/internal/factory_test2.proto @@ -0,0 +1,77 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + +import "google/protobuf/internal/factory_test1.proto"; + + +enum Factory2Enum { + FACTORY_2_VALUE_0 = 0; + FACTORY_2_VALUE_1 = 1; +} + +message Factory2Message { + required int32 mandatory = 1; + optional Factory2Enum factory_2_enum = 2; + enum NestedFactory2Enum { + NESTED_FACTORY_2_VALUE_0 = 0; + NESTED_FACTORY_2_VALUE_1 = 1; + } + optional NestedFactory2Enum nested_factory_2_enum = 3; + message NestedFactory2Message { + optional string value = 1; + } + optional NestedFactory2Message nested_factory_2_message = 4; + optional Factory1Message factory_1_message = 5; + optional Factory1Enum factory_1_enum = 6; + optional Factory1Message.NestedFactory1Enum nested_factory_1_enum = 7; + optional Factory1Message.NestedFactory1Message nested_factory_1_message = 8; + optional Factory2Message circular_message = 9; + optional string scalar_value = 10; + repeated string list_value = 11; + repeated group Grouped = 12 { + optional string part_1 = 13; + optional string part_2 = 14; + } + optional LoopMessage loop = 15; + optional int32 int_with_default = 16 [default = 1776]; + optional double double_with_default = 17 [default = 9.99]; + optional string string_with_default = 18 [default = "hello world"]; + optional bool bool_with_default = 19 [default = false]; + optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1]; +} + +message LoopMessage { + optional Factory2Message loop = 1; +} diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index b3f7d9b1..8343aba1 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -42,8 +42,10 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' import unittest +from google.protobuf.internal import test_bad_identifiers_pb2 from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_no_generic_services_pb2 @@ -239,6 +241,29 @@ class GeneratorTest(unittest.TestCase): unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in file_type.message_types_by_name) + def testPublicImports(self): + # Test public imports as embedded message. + all_type_proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, all_type_proto.optional_public_import_message.e) + + # PublicImportMessage is actually defined in unittest_import_public_pb2 + # module, and is public imported by unittest_import_pb2 module. + public_import_proto = unittest_import_pb2.PublicImportMessage() + self.assertEqual(0, public_import_proto.e) + self.assertTrue(unittest_import_public_pb2.PublicImportMessage is + unittest_import_pb2.PublicImportMessage) + + def testBadIdentifiers(self): + # We're just testing that the code was imported without problems. + message = test_bad_identifiers_pb2.TestBadIdentifiers() + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.message], + "foo") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.descriptor], + "bar") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.reflection], + "baz") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service], + "qux") if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/message_cpp_test.py b/python/google/protobuf/internal/message_cpp_test.py new file mode 100644 index 00000000..0d84b320 --- /dev/null +++ b/python/google/protobuf/internal/message_cpp_test.py @@ -0,0 +1,45 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.internal.message_cpp.""" + +__author__ = 'shahms@google.com (Shahms King)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' + +import unittest +from google.protobuf.internal.message_test import * + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py new file mode 100644 index 00000000..0bc9be99 --- /dev/null +++ b/python/google/protobuf/internal/message_factory_test.py @@ -0,0 +1,113 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.message_factory.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool +from google.protobuf import message_factory + + +class MessageFactoryTest(unittest.TestCase): + + def setUp(self): + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + + def _ExerciseDynamicClass(self, cls): + msg = cls() + msg.mandatory = 42 + msg.nested_factory_2_enum = 0 + msg.nested_factory_2_message.value = 'nested message value' + msg.factory_1_message.factory_1_enum = 1 + msg.factory_1_message.nested_factory_1_enum = 0 + msg.factory_1_message.nested_factory_1_message.value = ( + 'nested message value') + msg.factory_1_message.scalar_value = 22 + msg.factory_1_message.list_value.extend(['one', 'two', 'three']) + msg.factory_1_message.list_value.append('four') + msg.factory_1_enum = 1 + msg.nested_factory_1_enum = 0 + msg.nested_factory_1_message.value = 'nested message value' + msg.circular_message.mandatory = 1 + msg.circular_message.circular_message.mandatory = 2 + msg.circular_message.scalar_value = 'one deep' + msg.scalar_value = 'zero deep' + msg.list_value.extend(['four', 'three', 'two']) + msg.list_value.append('one') + msg.grouped.add() + msg.grouped[0].part_1 = 'hello' + msg.grouped[0].part_2 = 'world' + msg.grouped.add(part_1='testing', part_2='123') + msg.loop.loop.mandatory = 2 + msg.loop.loop.loop.loop.mandatory = 4 + serialized = msg.SerializeToString() + converted = factory_test2_pb2.Factory2Message.FromString(serialized) + reserialized = converted.SerializeToString() + self.assertEquals(serialized, reserialized) + result = cls.FromString(reserialized) + self.assertEquals(msg, result) + + def testGetPrototype(self): + db = descriptor_database.DescriptorDatabase() + pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + factory = message_factory.MessageFactory() + cls = factory.GetPrototype(pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message')) + self.assertIsNot(cls, factory_test2_pb2.Factory2Message) + self._ExerciseDynamicClass(cls) + cls2 = factory.GetPrototype(pool.FindMessageTypeByName( + 'net.proto2.python.internal.Factory2Message')) + self.assertIs(cls, cls2) + + def testGetMessages(self): + messages = message_factory.GetMessages([self.factory_test2_fd, + self.factory_test1_fd]) + self.assertContainsSubset( + ['net.proto2.python.internal.Factory2Message', + 'net.proto2.python.internal.Factory1Message'], + messages.keys()) + self._ExerciseDynamicClass( + messages['net.proto2.python.internal.Factory2Message']) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 65174373..53e9d507 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -45,10 +45,15 @@ __author__ = 'gps@google.com (Gregory P. Smith)' import copy import math +import operator +import pickle + import unittest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util +from google.protobuf import message # Python pre-2.6 does not have isinf() or isnan() functions, so we have # to provide our own. @@ -70,9 +75,9 @@ class MessageTest(unittest.TestCase): golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): golden_data = test_util.GoldenFile('golden_message').read() @@ -81,9 +86,9 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): golden_data = test_util.GoldenFile('golden_packed_fields_message').read() @@ -92,9 +97,9 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): golden_data = test_util.GoldenFile('golden_packed_fields_message').read() @@ -103,9 +108,28 @@ class MessageTest(unittest.TestCase): all_set = unittest_pb2.TestPackedExtensions() test_util.SetAllPackedExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) golden_copy = copy.deepcopy(golden_message) - self.assertTrue(golden_copy.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleSupport(self): + golden_data = test_util.GoldenFile('golden_message').read() + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + self.assertEquals(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) def testPositiveInfinity(self): golden_data = ('\x5D\x00\x00\x80\x7F' @@ -118,7 +142,7 @@ class MessageTest(unittest.TestCase): self.assertTrue(IsPosInf(golden_message.optional_double)) self.assertTrue(IsPosInf(golden_message.repeated_float[0])) self.assertTrue(IsPosInf(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinity(self): golden_data = ('\x5D\x00\x00\x80\xFF' @@ -131,7 +155,7 @@ class MessageTest(unittest.TestCase): self.assertTrue(IsNegInf(golden_message.optional_double)) self.assertTrue(IsNegInf(golden_message.repeated_float[0])) self.assertTrue(IsNegInf(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumber(self): golden_data = ('\x5D\x00\x00\xC0\x7F' @@ -144,7 +168,18 @@ class MessageTest(unittest.TestCase): self.assertTrue(isnan(golden_message.optional_double)) self.assertTrue(isnan(golden_message.repeated_float[0])) self.assertTrue(isnan(golden_message.repeated_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + + # The protocol buffer may serialize to any one of multiple different + # representations of a NaN. Rather than verify a specific representation, + # verify the serialized string can be converted into a correctly + # behaving protocol buffer. + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestAllTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.optional_float)) + self.assertTrue(isnan(message.optional_double)) + self.assertTrue(isnan(message.repeated_float[0])) + self.assertTrue(isnan(message.repeated_double[0])) def testPositiveInfinityPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\x80\x7F' @@ -153,7 +188,7 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(IsPosInf(golden_message.packed_float[0])) self.assertTrue(IsPosInf(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNegativeInfinityPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\x80\xFF' @@ -162,7 +197,7 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(IsNegInf(golden_message.packed_float[0])) self.assertTrue(IsNegInf(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) def testNotANumberPacked(self): golden_data = ('\xA2\x06\x04\x00\x00\xC0\x7F' @@ -171,7 +206,12 @@ class MessageTest(unittest.TestCase): golden_message.ParseFromString(golden_data) self.assertTrue(isnan(golden_message.packed_float[0])) self.assertTrue(isnan(golden_message.packed_double[0])) - self.assertTrue(golden_message.SerializeToString() == golden_data) + + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestPackedTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.packed_float[0])) + self.assertTrue(isnan(message.packed_double[0])) def testExtremeFloatValues(self): message = unittest_pb2.TestAllTypes() @@ -218,7 +258,7 @@ class MessageTest(unittest.TestCase): message.ParseFromString(message.SerializeToString()) self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) - def testExtremeFloatValues(self): + def testExtremeDoubleValues(self): message = unittest_pb2.TestAllTypes() # Most positive exponent, no significand bits set. @@ -338,6 +378,117 @@ class MessageTest(unittest.TestCase): self.assertEqual(message.repeated_nested_message[4].bb, 5) self.assertEqual(message.repeated_nested_message[5].bb, 6) + def testRepeatedCompositeFieldSortArguments(self): + """Check sorting a repeated composite field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + get_bb = operator.attrgetter('bb') + cmp_bb = lambda a, b: cmp(a.bb, b.bb) + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=get_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(key=get_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + message.repeated_nested_message.sort(sort_function=cmp_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + + def testRepeatedScalarFieldSortArguments(self): + """Check sorting a scalar field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(key=abs, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(key=len, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + + def testSortEmptyRepeatedCompositeContainer(self): + """Exercise a scenario that has led to segfaults in the past. + """ + m = unittest_pb2.TestAllTypes() + m.repeated_nested_message.sort() + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto new file mode 100644 index 00000000..df98ac4b --- /dev/null +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -0,0 +1,49 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jasonh@google.com (Jason Hsueh) +// +// This file is used to test a corner case in the CPP implementation where the +// generated C++ type is available for the extendee, but the extension is +// defined in a file whose C++ type is not in the binary. + + +import "google/protobuf/internal/more_extensions.proto"; + +package google.protobuf.internal; + +message DynamicMessageType { + optional int32 a = 1; +} + +extend ExtendedMessage { + optional int32 dynamic_int32_extension = 100; + optional DynamicMessageType dynamic_message_extension = 101; +} diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 66fca918..4bea57ac 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -54,6 +54,7 @@ try: from cStringIO import StringIO except ImportError: from StringIO import StringIO +import copy_reg import struct import weakref @@ -61,6 +62,7 @@ import weakref from google.protobuf.internal import containers from google.protobuf.internal import decoder from google.protobuf.internal import encoder +from google.protobuf.internal import enum_type_wrapper from google.protobuf.internal import message_listener as message_listener_mod from google.protobuf.internal import type_checkers from google.protobuf.internal import wire_format @@ -71,9 +73,10 @@ from google.protobuf import text_format _FieldDescriptor = descriptor_mod.FieldDescriptor -def NewMessage(descriptor, dictionary): +def NewMessage(bases, descriptor, dictionary): _AddClassAttributesForNestedExtensions(descriptor, dictionary) _AddSlots(descriptor, dictionary) + return bases def InitMessage(descriptor, cls): @@ -96,6 +99,7 @@ def InitMessage(descriptor, cls): _AddStaticMethods(cls) _AddMessageMethods(descriptor, cls) _AddPrivateHelperMethods(cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) # Stateless helpers for GeneratedProtocolMessageType below. @@ -145,6 +149,10 @@ def _VerifyExtensionHandle(message, extension_handle): if not extension_handle.is_extension: raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + if not extension_handle.containing_type: + raise KeyError('"%s" is missing a containing_type.' + % extension_handle.full_name) + if extension_handle.containing_type is not message.DESCRIPTOR: raise KeyError('Extension "%s" extends message type "%s", but this ' 'message is of type "%s".' % @@ -164,6 +172,7 @@ def _AddSlots(message_descriptor, dictionary): dictionary['__slots__'] = ['_cached_byte_size', '_cached_byte_size_dirty', '_fields', + '_unknown_fields', '_is_present_in_parent', '_listener', '_listener_for_children', @@ -224,11 +233,14 @@ def _AddClassAttributesForNestedExtensions(descriptor, dictionary): def _AddEnumValues(descriptor, cls): """Sets class-level attributes for all enum fields defined in this message. + Also exporting a class-level object that can name enum values. + Args: descriptor: Descriptor object for this message type. cls: Class we're constructing for this message type. """ for enum_type in descriptor.enum_types: + setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) for enum_value in enum_type.values: setattr(cls, enum_value.name, enum_value.number) @@ -248,7 +260,7 @@ def _DefaultValueConstructorForField(field): """ if field.label == _FieldDescriptor.LABEL_REPEATED: - if field.default_value != []: + if field.has_default_value and field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( field.default_value)) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: @@ -276,6 +288,8 @@ def _DefaultValueConstructorForField(field): return MakeSubMessageDefault def MakeScalarDefault(message): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. return field.default_value return MakeScalarDefault @@ -287,6 +301,9 @@ def _AddInitMethod(message_descriptor, cls): self._cached_byte_size = 0 self._cached_byte_size_dirty = len(kwargs) > 0 self._fields = {} + # _unknown_fields is () when empty for efficiency, and will be turned into + # a list if fields are added. + self._unknown_fields = () self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) @@ -428,6 +445,8 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): valid_values = set() def getter(self): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -462,13 +481,18 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # for non-repeated scalars. proto_field_name = field.name property_name = _PropertyName(proto_field_name) + + # TODO(komarek): Can anyone explain to me why we cache the message_type this + # way, instead of referring to field.message_type inside of getter(self)? + # What if someone sets message_type later on (which makes for simpler + # dyanmic proto descriptor and class creation code). message_type = field.message_type def getter(self): field_value = self._fields.get(field) if field_value is None: # Construct a new object to represent this field. - field_value = message_type._concrete_class() + field_value = message_type._concrete_class() # use field.message_type? field_value._SetListener(self._listener_for_children) # Atomically check if another thread has preempted us and, if not, swap @@ -620,6 +644,7 @@ def _AddClearMethod(message_descriptor, cls): def Clear(self): # Clear fields. self._fields = {} + self._unknown_fields = () self._Modified() cls.Clear = Clear @@ -649,7 +674,16 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - return self.ListFields() == other.ListFields() + if not self.ListFields() == other.ListFields(): + return False + + # Sort unknown fields because their order shouldn't affect equality test. + unknown_fields = list(self._unknown_fields) + unknown_fields.sort() + other_unknown_fields = list(other._unknown_fields) + other_unknown_fields.sort() + + return unknown_fields == other_unknown_fields cls.__eq__ = __eq__ @@ -710,6 +744,9 @@ def _AddByteSizeMethod(message_descriptor, cls): for field_descriptor, field_value in self.ListFields(): size += field_descriptor._sizer(field_value) + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) + self._cached_byte_size = size self._cached_byte_size_dirty = False self._listener_for_children.dirty = False @@ -726,8 +763,8 @@ def _AddSerializeToStringMethod(message_descriptor, cls): errors = [] if not self.IsInitialized(): raise message_mod.EncodeError( - 'Message is missing required fields: ' + - ','.join(self.FindInitializationErrors())) + 'Message %s is missing required fields: %s' % ( + self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) return self.SerializePartialToString() cls.SerializeToString = SerializeToString @@ -744,6 +781,9 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): def InternalSerialize(self, write_bytes): for field_descriptor, field_value in self.ListFields(): field_descriptor._encoder(write_bytes, field_value) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) cls._InternalSerialize = InternalSerialize @@ -770,13 +810,18 @@ def _AddMergeFromStringMethod(message_descriptor, cls): def InternalParse(self, buffer, pos, end): self._Modified() field_dict = self._fields + unknown_field_list = self._unknown_fields while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) field_decoder = decoders_by_tag.get(tag_bytes) if field_decoder is None: + value_start_pos = new_pos new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) if new_pos == -1: return pos + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -873,7 +918,8 @@ def _AddMergeFromMethod(cls): def MergeFrom(self, msg): if not isinstance(msg, cls): raise TypeError( - "Parameter to MergeFrom() must be instance of same class.") + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) assert msg is not self self._Modified() @@ -898,6 +944,12 @@ def _AddMergeFromMethod(cls): field_value.MergeFrom(value) else: self._fields[field] = value + + if msg._unknown_fields: + if not self._unknown_fields: + self._unknown_fields = [] + self._unknown_fields.extend(msg._unknown_fields) + cls.MergeFrom = MergeFrom diff --git a/python/google/protobuf/internal/reflection_cpp_generated_test.py b/python/google/protobuf/internal/reflection_cpp_generated_test.py new file mode 100755 index 00000000..2a0a5124 --- /dev/null +++ b/python/google/protobuf/internal/reflection_cpp_generated_test.py @@ -0,0 +1,91 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for reflection.py, which tests the generated C++ implementation.""" + +__author__ = 'jasonh@google.com (Jason Hsueh)' + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'cpp' + +import unittest +from google.protobuf.internal import api_implementation +from google.protobuf.internal import more_extensions_dynamic_pb2 +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal.reflection_test import * + + +class ReflectionCppTest(unittest.TestCase): + def testImplementationSetting(self): + self.assertEqual('cpp', api_implementation.Type()) + + def testExtensionOfGeneratedTypeInDynamicFile(self): + """Tests that a file built dynamically can extend a generated C++ type. + + The C++ implementation uses a DescriptorPool that has the generated + DescriptorPool as an underlay. Typically, a type can only find + extensions in its own pool. With the python C-extension, the generated C++ + extendee may be available, but not the extension. This tests that the + C-extension implements the correct special handling to make such extensions + available. + """ + pb1 = more_extensions_pb2.ExtendedMessage() + # Test that basic accessors work. + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertFalse( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension] = 17 + pb1.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a = 24 + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb1.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + + # Now serialize the data and parse to a new message. + pb2 = more_extensions_pb2.ExtendedMessage() + pb2.MergeFromString(pb1.SerializeToString()) + + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_int32_extension)) + self.assertTrue( + pb2.HasExtension(more_extensions_dynamic_pb2.dynamic_message_extension)) + self.assertEqual( + 17, pb2.Extensions[more_extensions_dynamic_pb2.dynamic_int32_extension]) + self.assertEqual( + 24, + pb2.Extensions[more_extensions_dynamic_pb2.dynamic_message_extension].a) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 7b9d3398..ed286461 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -37,6 +37,7 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import gc import operator import struct @@ -318,15 +319,6 @@ class ReflectionTest(unittest.TestCase): # ...and ensure that the scalar field has returned to its default. self.assertEqual(0, getattr(composite_field, scalar_field_name)) - # Finally, ensure that modifications to the old composite field object - # don't have any effect on the parent. Possible only with the pure-python - # implementation of the API. - # - # (NOTE that when we clear the composite field in the parent, we actually - # don't recursively clear down the tree. Instead, we just disconnect the - # cleared composite from the tree.) - if api_implementation.Type() != 'python': - return self.assertTrue(old_composite_field is not composite_field) setattr(old_composite_field, scalar_field_name, new_val) self.assertTrue(not composite_field.HasField(scalar_field_name)) @@ -348,8 +340,6 @@ class ReflectionTest(unittest.TestCase): nested.bb = 23 def testDisconnectingNestedMessageBeforeSettingField(self): - if api_implementation.Type() != 'python': - return proto = unittest_pb2.TestAllTypes() nested = proto.optional_nested_message proto.ClearField('optional_nested_message') # Should disconnect from parent @@ -358,6 +348,64 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) + def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') + del proto + del nested + # Force a garbage collect so that the underlying CMessages are freed along + # with the Messages they point to. This is to make sure we're not deleting + # default message instances. + gc.collect() + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + + def testDisconnectingNestedMessageAfterSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + self.assertTrue(proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertEqual(5, nested.bb) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testDisconnectingNestedMessageBeforeGettingField(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + + def testDisconnectingNestedMessageAfterMerge(self): + # This test exercises the code path that does not use ReleaseMessage(). + # The underlying fear is that if we use ReleaseMessage() incorrectly, + # we will have memory leaks. It's hard to check that that doesn't happen, + # but at least we can exercise that code path to make sure it works. + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_nested_message.bb = 5 + proto1.MergeFrom(proto2) + self.assertTrue(proto1.HasField('optional_nested_message')) + proto1.ClearField('optional_nested_message') + self.assertTrue(not proto1.HasField('optional_nested_message')) + + def testDisconnectingLazyNestedMessage(self): + # This test exercises releasing a nested message that is lazy. This test + # only exercises real code in the C++ implementation as Python does not + # support lazy parsing, but the current C++ implementation results in + # memory corruption and a crash. + if api_implementation.Type() != 'python': + return + proto = unittest_pb2.TestAllTypes() + proto.optional_lazy_message.bb = 5 + proto.ClearField('optional_lazy_message') + del proto + gc.collect() + def testHasBitsWhenModifyingRepeatedFields(self): # Test nesting when we add an element to a repeated field in a submessage. proto = unittest_pb2.TestNestedMessageHasBits() @@ -635,6 +683,77 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + def testEnum_Name(self): + self.assertEqual('FOREIGN_FOO', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) + self.assertEqual('FOREIGN_BAR', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) + self.assertEqual('FOREIGN_BAZ', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Name, 11312) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual('FOO', + proto.NestedEnum.Name(proto.FOO)) + self.assertEqual('FOO', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) + self.assertEqual('BAR', + proto.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAR', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAZ', + proto.NestedEnum.Name(proto.BAZ)) + self.assertEqual('BAZ', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) + self.assertRaises(ValueError, + proto.NestedEnum.Name, 11312) + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) + + def testEnum_Value(self): + self.assertEqual(unittest_pb2.FOREIGN_FOO, + unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) + self.assertEqual(unittest_pb2.FOREIGN_BAR, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Value, 'FO') + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(proto.FOO, + proto.NestedEnum.Value('FOO')) + self.assertEqual(proto.FOO, + unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) + self.assertEqual(proto.BAR, + proto.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAR, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAZ, + proto.NestedEnum.Value('BAZ')) + self.assertEqual(proto.BAZ, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) + self.assertRaises(ValueError, + proto.NestedEnum.Value, 'Foo') + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') + + def testEnum_KeysAndValues(self): + self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], + unittest_pb2.ForeignEnum.keys()) + self.assertEqual([4, 5, 6], + unittest_pb2.ForeignEnum.values()) + self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), + ('FOREIGN_BAZ', 6)], + unittest_pb2.ForeignEnum.items()) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(['FOO', 'BAR', 'BAZ'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3)], + proto.NestedEnum.items()) + def testRepeatedScalars(self): proto = unittest_pb2.TestAllTypes() @@ -826,6 +945,35 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(1, len(proto.repeated_nested_message)) self.assertEqual(23, proto.repeated_nested_message[0].bb) + def testRepeatedCompositeRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + # Need to set some differentiating variable so m0 != m1 != m2: + m0.bb = len(proto.repeated_nested_message) + m1 = proto.repeated_nested_message.add() + m1.bb = len(proto.repeated_nested_message) + self.assertTrue(m0 != m1) + m2 = proto.repeated_nested_message.add() + m2.bb = len(proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) + + self.assertEqual(3, len(proto.repeated_nested_message)) + proto.repeated_nested_message.remove(m0) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + self.assertEqual(m2, proto.repeated_nested_message[1]) + + # Removing m0 again or removing None should raise error + self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) + self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) + self.assertEqual(2, len(proto.repeated_nested_message)) + + proto.repeated_nested_message.remove(m2) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + def testHandWrittenReflection(self): # Hand written extensions are only supported by the pure-Python # implementation of the API. @@ -856,6 +1004,68 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(23, myproto_instance.foo_field) self.assertTrue(myproto_instance.HasField('foo_field')) + def testDescriptorProtoSupport(self): + # Hand written descriptors/reflection are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + + def AddDescriptorField(proto, field_name, field_type): + AddDescriptorField.field_index += 1 + new_field = proto.field.add() + new_field.name = field_name + new_field.type = field_type + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + AddDescriptorField.field_index = 0 + + desc_proto = descriptor_pb2.DescriptorProto() + desc_proto.name = 'Car' + fdp = descriptor_pb2.FieldDescriptorProto + AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) + AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) + AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) + AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) + # Add a repeated field + AddDescriptorField.field_index += 1 + new_field = desc_proto.field.add() + new_field.name = 'owners' + new_field.type = fdp.TYPE_STRING + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + + desc = descriptor.MakeDescriptor(desc_proto) + self.assertTrue(desc.fields_by_name.has_key('name')) + self.assertTrue(desc.fields_by_name.has_key('year')) + self.assertTrue(desc.fields_by_name.has_key('automatic')) + self.assertTrue(desc.fields_by_name.has_key('price')) + self.assertTrue(desc.fields_by_name.has_key('owners')) + + class CarMessage(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = desc + + prius = CarMessage() + prius.name = 'prius' + prius.year = 2010 + prius.automatic = True + prius.price = 25134.75 + prius.owners.extend(['bob', 'susan']) + + serialized_prius = prius.SerializeToString() + new_prius = reflection.ParseMessage(desc, serialized_prius) + self.assertTrue(new_prius is not prius) + self.assertEqual(prius, new_prius) + + # these are unnecessary assuming message equality works as advertised but + # explicitly check to be safe since we're mucking about in metaclass foo + self.assertEqual(prius.name, new_prius.name) + self.assertEqual(prius.year, new_prius.year) + self.assertEqual(prius.automatic, new_prius.automatic) + self.assertEqual(prius.price, new_prius.price) + self.assertEqual(prius.owners, new_prius.owners) + def testTopLevelExtensionsForOptionalScalar(self): extendee_proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.optional_int32_extension @@ -1243,7 +1453,12 @@ class ReflectionTest(unittest.TestCase): def testClear(self): proto = unittest_pb2.TestAllTypes() - test_util.SetAllFields(proto) + # C++ implementation does not support lazy fields right now so leave it + # out for now. + if api_implementation.Type() == 'python': + test_util.SetAllFields(proto) + else: + test_util.SetAllNonLazyFields(proto) # Clear the message. proto.Clear() self.assertEquals(proto.ByteSize(), 0) @@ -1259,6 +1474,33 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def testDisconnectingBeforeClear(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + foreign = proto.optional_foreign_message + foreign.c = 6 + + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + self.assertTrue(foreign is not proto.optional_foreign_message) + self.assertEqual(5, nested.bb) + self.assertEqual(6, foreign.c) + nested.bb = 15 + foreign.c = 16 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(not proto.HasField('optional_foreign_message')) + self.assertEqual(0, proto.optional_foreign_message.c) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1408,7 +1650,7 @@ class ReflectionTest(unittest.TestCase): unicode_decode_failed = False try: message2.MergeFromString(bytes) - except UnicodeDecodeError, e: + except UnicodeDecodeError as e: unicode_decode_failed = True string_field = message2.str self.assertTrue(unicode_decode_failed or type(string_field) == str) @@ -2119,7 +2361,7 @@ class SerializationTest(unittest.TestCase): """This method checks if the excpetion type and message are as expected.""" try: callable_obj() - except exc_class, ex: + except exc_class as ex: # Check if the exception message is the right one. self.assertEqual(exception, str(ex)) return @@ -2131,15 +2373,22 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: a,b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: ' + 'a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() + proto2 = unittest_pb2.TestRequired() + self.assertFalse(proto2.HasField('a')) + # proto2 ParseFromString does not check that required fields are set. + proto2.ParseFromString(partial) + self.assertFalse(proto2.HasField('a')) + proto.a = 1 self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -2147,7 +2396,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: c') + 'Message protobuf_unittest.TestRequired is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -2176,7 +2425,8 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign ' + 'is missing required fields: ' 'optional_message.b,optional_message.c') proto.optional_message.b = 2 @@ -2188,7 +2438,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 'repeated_message[0].b,repeated_message[0].c,' 'repeated_message[1].a,repeated_message[1].c') diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto new file mode 100644 index 00000000..6a82299a --- /dev/null +++ b/python/google/protobuf/internal/test_bad_identifiers.proto @@ -0,0 +1,52 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// http://code.google.com/p/protobuf/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) + + +package protobuf_unittest; + +option py_generic_services = true; + +message TestBadIdentifiers { + extensions 100 to max; +} + +// Make sure these reasonable extension names don't conflict with internal +// variables. +extend TestBadIdentifiers { + optional string message = 100 [default="foo"]; + optional string descriptor = 101 [default="bar"]; + optional string reflection = 102 [default="baz"]; + optional string service = 103 [default="qux"]; +} + +message AnotherMessage {} +service AnotherService {} diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 1df16194..be8ae7be 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -42,8 +42,8 @@ from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 -def SetAllFields(message): - """Sets every field in the message to a unique value. +def SetAllNonLazyFields(message): + """Sets every non-lazy field in the message to a unique value. Args: message: A unittest_pb2.TestAllTypes instance. @@ -79,6 +79,7 @@ def SetAllFields(message): message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 + message.optional_public_import_message.e = 126 message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ @@ -111,6 +112,7 @@ def SetAllFields(message): message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 + message.repeated_lazy_message.add().bb = 227 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) @@ -140,6 +142,7 @@ def SetAllFields(message): message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 + message.repeated_lazy_message.add().bb = 327 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) @@ -176,6 +179,11 @@ def SetAllFields(message): message.default_cord = '425' +def SetAllFields(message): + SetAllNonLazyFields(message) + message.optional_lazy_message.bb = 127 + + def SetAllExtensions(message): """Sets every extension in the message to a unique value. @@ -211,6 +219,8 @@ def SetAllExtensions(message): extensions[pb2.optional_nested_message_extension].bb = 118 extensions[pb2.optional_foreign_message_extension].c = 119 extensions[pb2.optional_import_message_extension].d = 120 + extensions[pb2.optional_public_import_message_extension].e = 126 + extensions[pb2.optional_lazy_message_extension].bb = 127 extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ @@ -244,6 +254,7 @@ def SetAllExtensions(message): extensions[pb2.repeated_nested_message_extension].add().bb = 218 extensions[pb2.repeated_foreign_message_extension].add().c = 219 extensions[pb2.repeated_import_message_extension].add().d = 220 + extensions[pb2.repeated_lazy_message_extension].add().bb = 227 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) @@ -273,6 +284,7 @@ def SetAllExtensions(message): extensions[pb2.repeated_nested_message_extension].add().bb = 318 extensions[pb2.repeated_foreign_message_extension].add().c = 319 extensions[pb2.repeated_import_message_extension].add().d = 320 + extensions[pb2.repeated_lazy_message_extension].add().bb = 327 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) @@ -407,6 +419,8 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) + test_case.assertEqual(126, message.optional_public_import_message.e) + test_case.assertEqual(127, message.optional_lazy_message.bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.optional_nested_enum) @@ -464,6 +478,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) + test_case.assertEqual(227, message.repeated_lazy_message[0].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAR, message.repeated_nested_enum[0]) @@ -492,6 +507,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) + test_case.assertEqual(327, message.repeated_lazy_message[1].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.repeated_nested_enum[1]) diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 73d97d18..23b50eb5 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -94,6 +94,28 @@ class TextFormatTest(unittest.TestCase): ' }\n' '}\n') + def testPrintBadEnumValue(self): + message = unittest_pb2.TestAllTypes() + message.optional_nested_enum = 100 + message.optional_foreign_enum = 101 + message.optional_import_enum = 102 + self.CompareToGoldenText( + text_format.MessageToString(message), + 'optional_nested_enum: 100\n' + 'optional_foreign_enum: 101\n' + 'optional_import_enum: 102\n') + + def testPrintBadEnumValueExtensions(self): + message = unittest_pb2.TestAllExtensions() + message.Extensions[unittest_pb2.optional_nested_enum_extension] = 100 + message.Extensions[unittest_pb2.optional_foreign_enum_extension] = 101 + message.Extensions[unittest_pb2.optional_import_enum_extension] = 102 + self.CompareToGoldenText( + text_format.MessageToString(message), + '[protobuf_unittest.optional_nested_enum_extension]: 100\n' + '[protobuf_unittest.optional_foreign_enum_extension]: 101\n' + '[protobuf_unittest.optional_import_enum_extension]: 102\n') + def testPrintExotic(self): message = unittest_pb2.TestAllTypes() message.repeated_int64.append(-9223372036854775808) @@ -399,6 +421,14 @@ class TextFormatTest(unittest.TestCase): 'has no value with number 100.'), text_format.Merge, text, message) + def testMergeBadIntValue(self): + message = unittest_pb2.TestAllTypes() + text = 'optional_int32: bork' + self.assertRaisesWithMessage( + text_format.ParseError, + ('1:17 : Couldn\'t parse integer: bork'), + text_format.Merge, text, message) + def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): """Same as assertRaises, but also compares the exception message.""" if hasattr(e_class, '__name__'): @@ -408,7 +438,7 @@ class TextFormatTest(unittest.TestCase): try: func(*args, **kwargs) - except e_class, expr: + except e_class as expr: if str(expr) != e: msg = '%s raised, but with wrong message: "%s" instead of "%s"' raise self.failureException(msg % (exc_name, @@ -427,7 +457,7 @@ class TokenizerTest(unittest.TestCase): 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n' 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' - 'ID12: 2222222222222222222 ' + 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ' ) tokenizer = text_format._Tokenizer(text) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), @@ -456,10 +486,10 @@ class TokenizerTest(unittest.TestCase): '{', (tokenizer.ConsumeIdentifier, 'A'), ':', - (tokenizer.ConsumeFloat, text_format._INFINITY), + (tokenizer.ConsumeFloat, float('inf')), (tokenizer.ConsumeIdentifier, 'B'), ':', - (tokenizer.ConsumeFloat, -text_format._INFINITY), + (tokenizer.ConsumeFloat, -float('inf')), (tokenizer.ConsumeIdentifier, 'C'), ':', (tokenizer.ConsumeBool, True), @@ -479,6 +509,12 @@ class TokenizerTest(unittest.TestCase): (tokenizer.ConsumeIdentifier, 'ID12'), ':', (tokenizer.ConsumeUint64, 2222222222222222222), + (tokenizer.ConsumeIdentifier, 'ID13'), + ':', + (tokenizer.ConsumeFloat, 1.23456), + (tokenizer.ConsumeIdentifier, 'ID14'), + ':', + (tokenizer.ConsumeFloat, 1.2e+2), (tokenizer.ConsumeIdentifier, 'false_bool'), ':', (tokenizer.ConsumeBool, False), @@ -556,16 +592,6 @@ class TokenizerTest(unittest.TestCase): tokenizer = text_format._Tokenizer(text) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) - def testInfNan(self): - # Make sure our infinity and NaN definitions are sound. - self.assertEquals(float, type(text_format._INFINITY)) - self.assertEquals(float, type(text_format._NAN)) - self.assertTrue(text_format._NAN != text_format._NAN) - - inf_times_zero = text_format._INFINITY * 0 - self.assertTrue(inf_times_zero != inf_times_zero) - self.assertTrue(text_format._INFINITY > 0) - if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py new file mode 100755 index 00000000..84984b40 --- /dev/null +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -0,0 +1,170 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for preservation of unknown fields in the pure Python implementation.""" + +__author__ = 'bohdank@google.com (Bohdan Koval)' + +import unittest +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import encoder +from google.protobuf.internal import test_util +from google.protobuf.internal import type_checkers + + +class UnknownFieldsTest(unittest.TestCase): + + def setUp(self): + self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.all_fields = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.all_fields) + self.all_fields_data = self.all_fields.SerializeToString() + self.empty_message = unittest_pb2.TestEmptyMessage() + self.empty_message.ParseFromString(self.all_fields_data) + self.unknown_fields = self.empty_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes] + result_dict = {} + decoder(value, 0, len(value), self.all_fields, result_dict) + return result_dict[field_descriptor] + + def testVarint(self): + value = self.GetField('optional_int32') + self.assertEqual(self.all_fields.optional_int32, value) + + def testFixed32(self): + value = self.GetField('optional_fixed32') + self.assertEqual(self.all_fields.optional_fixed32, value) + + def testFixed64(self): + value = self.GetField('optional_fixed64') + self.assertEqual(self.all_fields.optional_fixed64, value) + + def testLengthDelimited(self): + value = self.GetField('optional_string') + self.assertEqual(self.all_fields.optional_string, value) + + def testGroup(self): + value = self.GetField('optionalgroup') + self.assertEqual(self.all_fields.optionalgroup, value) + + def testSerialize(self): + data = self.empty_message.SerializeToString() + + # Don't use assertEqual because we don't want to dump raw binary data to + # stdout. + self.assertTrue(data == self.all_fields_data) + + def testCopyFrom(self): + message = unittest_pb2.TestEmptyMessage() + message.CopyFrom(self.empty_message) + self.assertEqual(self.unknown_fields, message._unknown_fields) + + def testMergeFrom(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 1 + message.optional_uint32 = 2 + source = unittest_pb2.TestEmptyMessage() + source.ParseFromString(message.SerializeToString()) + + message.ClearField('optional_int32') + message.optional_int64 = 3 + message.optional_uint32 = 4 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_fields = destination._unknown_fields[:] + + destination.MergeFrom(source) + self.assertEqual(unknown_fields + source._unknown_fields, + destination._unknown_fields) + + def testClear(self): + self.empty_message.Clear() + self.assertEqual(0, len(self.empty_message._unknown_fields)) + + def testByteSize(self): + self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + + def testUnknownExtensions(self): + message = unittest_pb2.TestEmptyMessageWithExtensions() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message._unknown_fields, + message._unknown_fields) + + def testListFields(self): + # Make sure ListFields doesn't return unknown fields. + self.assertEqual(0, len(self.empty_message.ListFields())) + + def testSerializeMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an unknown extension. + item = raw.item.add() + item.type_id = 1545009 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Verify that the unknown extension is serialized unchanged + reserialized = proto.SerializeToString() + new_raw = unittest_mset_pb2.RawMessageSet() + new_raw.MergeFromString(reserialized) + self.assertEqual(raw, new_raw) + + def testEquals(self): + message = unittest_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message, message) + + self.all_fields.ClearField('optional_string') + message.ParseFromString(self.all_fields.SerializeToString()) + self.assertNotEqual(self.empty_message, message) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py index 6f19f85f..6ec2f8be 100755 --- a/python/google/protobuf/message.py +++ b/python/google/protobuf/message.py @@ -73,6 +73,7 @@ class Message(object): return clone def __eq__(self, other_msg): + """Recursively compares two messages by value and structure.""" raise NotImplementedError def __ne__(self, other_msg): @@ -83,9 +84,11 @@ class Message(object): raise TypeError('unhashable object') def __str__(self): + """Outputs a human-readable representation of the message.""" raise NotImplementedError def __unicode__(self): + """Outputs a human-readable representation of the message.""" raise NotImplementedError def MergeFrom(self, other_msg): @@ -266,3 +269,12 @@ class Message(object): via a previous _SetListener() call. """ raise NotImplementedError + + def __getstate__(self): + """Support the pickle protocol.""" + return dict(serialized=self.SerializePartialToString()) + + def __setstate__(self, state): + """Support the pickle protocol.""" + self.__init__() + self.ParseFromString(state['serialized']) diff --git a/python/google/protobuf/message_factory.py b/python/google/protobuf/message_factory.py new file mode 100644 index 00000000..36e2fef0 --- /dev/null +++ b/python/google/protobuf/message_factory.py @@ -0,0 +1,113 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# http://code.google.com/p/protobuf/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Provides a factory class for generating dynamic messages.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool +from google.protobuf import message +from google.protobuf import reflection + + +class MessageFactory(object): + """Factory for creating Proto2 messages from descriptors in a pool.""" + + def __init__(self): + """Initializes a new factory.""" + self._classes = {} + + def GetPrototype(self, descriptor): + """Builds a proto2 message class based on the passed in descriptor. + + Passing a descriptor with a fully qualified name matching a previous + invocation will cause the same class to be returned. + + Args: + descriptor: The descriptor to build from. + + Returns: + A class describing the passed in descriptor. + """ + + if descriptor.full_name not in self._classes: + result_class = reflection.GeneratedProtocolMessageType( + descriptor.name.encode('ascii', 'ignore'), + (message.Message,), + {'DESCRIPTOR': descriptor}) + self._classes[descriptor.full_name] = result_class + for field in descriptor.fields: + if field.message_type: + self.GetPrototype(field.message_type) + return self._classes[descriptor.full_name] + + +_DB = descriptor_database.DescriptorDatabase() +_POOL = descriptor_pool.DescriptorPool(_DB) +_FACTORY = MessageFactory() + + +def GetMessages(file_protos): + """Builds a dictionary of all the messages available in a set of files. + + Args: + file_protos: A sequence of file protos to build messages out of. + + Returns: + A dictionary containing all the message types in the files mapping the + fully qualified name to a Message subclass for the descriptor. + """ + + result = {} + for file_proto in file_protos: + _DB.Add(file_proto) + for file_proto in file_protos: + for desc in _GetAllDescriptors(file_proto.message_type, file_proto.package): + result[desc.full_name] = _FACTORY.GetPrototype(desc) + return result + + +def _GetAllDescriptors(desc_protos, package): + """Gets all levels of nested message types as a flattened list of descriptors. + + Args: + desc_protos: The descriptor protos to process. + package: The package where the protos are defined. + + Yields: + Each message descriptor for each nested type. + """ + + for desc_proto in desc_protos: + name = '.'.join((package, desc_proto.name)) + yield _POOL.FindMessageTypeByName(name) + for nested_desc in _GetAllDescriptors(desc_proto.nested_type, name): + yield nested_desc diff --git a/python/google/protobuf/pyext/python-proto2.cc b/python/google/protobuf/pyext/python-proto2.cc index 181d7e97..62753413 100644 --- a/python/google/protobuf/pyext/python-proto2.cc +++ b/python/google/protobuf/pyext/python-proto2.cc @@ -207,9 +207,9 @@ static PyMethodDef CMessageMethods[] = { "Clears and sets the values of a repeated scalar field."), CMETHOD(ByteSize, METH_NOARGS, "Returns the size of the message in bytes."), - CMETHOD(Clear, METH_NOARGS, + CMETHOD(Clear, METH_O, "Clears a protocol message."), - CMETHOD(ClearField, METH_O, + CMETHOD(ClearField, METH_VARARGS, "Clears a protocol message field by name."), CMETHOD(ClearFieldByDescriptor, METH_O, "Clears a protocol message field by descriptor."), @@ -274,7 +274,7 @@ static PyMemberDef CMessageMembers[] = { PyTypeObject CMessage_Type = { PyObject_HEAD_INIT(&PyType_Type) 0, - C("google3.net.google.protobuf.python.internal." + C("google.protobuf.internal." "_net_proto2___python." "CMessage"), // tp_name sizeof(CMessage), // tp_basicsize @@ -319,14 +319,12 @@ PyTypeObject CMessage_Type = { // ------ Helper Functions: static void FormatTypeError(PyObject* arg, char* expected_types) { - PyObject* s = PyObject_Str(PyObject_Type(arg)); - PyObject* repr = PyObject_Repr(PyObject_Type(arg)); + PyObject* repr = PyObject_Repr(arg); PyErr_Format(PyExc_TypeError, "%.100s has type %.100s, but expected one of: %s", PyString_AS_STRING(repr), - PyString_AS_STRING(s), + arg->ob_type->tp_name, expected_types); - Py_DECREF(s); Py_DECREF(repr); } @@ -398,6 +396,28 @@ static const google::protobuf::Message* CreateMessage(const char* message_type) return global_message_factory->GetPrototype(descriptor); } +static void ReleaseSubMessage(google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field_descriptor, + CMessage* child_cmessage) { + Message* released_message = message->GetReflection()->ReleaseMessage( + message, field_descriptor, global_message_factory); + GOOGLE_DCHECK(child_cmessage->message != NULL); + // ReleaseMessage will return NULL which differs from + // child_cmessage->message, if the field does not exist. In this case, + // the latter points to the default instance via a const_cast<>, so we + // have to reset it to a new mutable object since we are taking ownership. + if (released_message == NULL) { + const Message* prototype = global_message_factory->GetPrototype( + child_cmessage->message->GetDescriptor()); + GOOGLE_DCHECK(prototype != NULL); + child_cmessage->message = prototype->New(); + } + child_cmessage->parent = NULL; + child_cmessage->parent_field = NULL; + child_cmessage->free_message = true; + child_cmessage->read_only = false; +} + static bool CheckAndSetString( PyObject* arg, google::protobuf::Message* message, const google::protobuf::FieldDescriptor* descriptor, @@ -407,6 +427,9 @@ static bool CheckAndSetString( GOOGLE_DCHECK(descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING || descriptor->type() == google::protobuf::FieldDescriptor::TYPE_BYTES); if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { +#else + if (descriptor->file()->options().cc_api_version() == 2 && + descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { if (!PyString_Check(arg) && !PyUnicode_Check(arg)) { FormatTypeError(arg, "str, unicode"); return false; @@ -434,6 +457,9 @@ static bool CheckAndSetString( PyObject* encoded_string = NULL; if (descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { +#else + if (descriptor->file()->options().cc_api_version() == 2 && + descriptor->type() == google::protobuf::FieldDescriptor::TYPE_STRING) { if (PyString_Check(arg)) { encoded_string = PyString_AsEncodedObject(arg, "utf-8", NULL); } else { @@ -504,8 +530,6 @@ static void AssureWritable(CMessage* self) { self->message = reflection->MutableMessage( message, self->parent_field->descriptor, global_message_factory); self->read_only = false; - self->parent = NULL; - self->parent_field = NULL; } static PyObject* InternalGetScalar( @@ -955,9 +979,41 @@ static void CMessageDealloc(CMessage* self) { // ------ Methods: -static PyObject* CMessage_Clear(CMessage* self, PyObject* args) { +static PyObject* CMessage_Clear(CMessage* self, PyObject* arg) { AssureWritable(self); - self->message->Clear(); + google::protobuf::Message* message = self->message; + + // This block of code is equivalent to the following: + // for cfield_descriptor, child_cmessage in arg: + // ReleaseSubMessage(cfield_descriptor, child_cmessage) + if (!PyList_Check(arg)) { + PyErr_SetString(PyExc_TypeError, "Must be a list"); + return NULL; + } + PyObject* messages_to_clear = arg; + Py_ssize_t num_messages_to_clear = PyList_GET_SIZE(messages_to_clear); + for(int i = 0; i < num_messages_to_clear; ++i) { + PyObject* message_tuple = PyList_GET_ITEM(messages_to_clear, i); + if (!PyTuple_Check(message_tuple) || PyTuple_GET_SIZE(message_tuple) != 2) { + PyErr_SetString(PyExc_TypeError, "Must be a tuple of size 2"); + return NULL; + } + + PyObject* py_cfield_descriptor = PyTuple_GET_ITEM(message_tuple, 0); + PyObject* py_child_cmessage = PyTuple_GET_ITEM(message_tuple, 1); + if (!PyObject_TypeCheck(py_cfield_descriptor, &CFieldDescriptor_Type) || + !PyObject_TypeCheck(py_child_cmessage, &CMessage_Type)) { + PyErr_SetString(PyExc_ValueError, "Invalid Tuple"); + return NULL; + } + + CFieldDescriptor* cfield_descriptor = reinterpret_cast( + py_cfield_descriptor); + CMessage* child_cmessage = reinterpret_cast(py_child_cmessage); + ReleaseSubMessage(message, cfield_descriptor->descriptor, child_cmessage); + } + + message->Clear(); Py_RETURN_NONE; } @@ -1039,9 +1095,11 @@ static PyObject* CMessage_ClearFieldByDescriptor( Py_RETURN_NONE; } -static PyObject* CMessage_ClearField(CMessage* self, PyObject* arg) { +static PyObject* CMessage_ClearField(CMessage* self, PyObject* args) { char* field_name; - if (PyString_AsStringAndSize(arg, &field_name, NULL) < 0) { + CMessage* child_cmessage = NULL; + if (!PyArg_ParseTuple(args, C("s|O!:ClearField"), &field_name, + &CMessage_Type, &child_cmessage)) { return NULL; } @@ -1054,7 +1112,11 @@ static PyObject* CMessage_ClearField(CMessage* self, PyObject* arg) { return NULL; } - message->GetReflection()->ClearField(message, field_descriptor); + if (child_cmessage != NULL && !FIELD_IS_REPEATED(field_descriptor)) { + ReleaseSubMessage(message, field_descriptor, child_cmessage); + } else { + message->GetReflection()->ClearField(message, field_descriptor); + } Py_RETURN_NONE; } @@ -1313,6 +1375,7 @@ static PyObject* CMessage_MergeFromString(CMessage* self, PyObject* arg) { AssureWritable(self); google::protobuf::io::CodedInputStream input( reinterpret_cast(data), data_length); + input.SetExtensionRegistry(GetDescriptorPool(), global_message_factory); bool success = self->message->MergePartialFromCodedStream(&input); if (success) { return PyInt_FromLong(self->message->ByteSize()); diff --git a/python/google/protobuf/pyext/python_descriptor.cc b/python/google/protobuf/pyext/python_descriptor.cc index fb87bad1..5e3e9ea4 100644 --- a/python/google/protobuf/pyext/python_descriptor.cc +++ b/python/google/protobuf/pyext/python_descriptor.cc @@ -31,6 +31,7 @@ // Author: petar@google.com (Petar Petrov) #include +#include #include #include @@ -41,6 +42,7 @@ namespace google { namespace protobuf { namespace python { + static void CFieldDescriptorDealloc(CFieldDescriptor* self); static google::protobuf::DescriptorPool* g_descriptor_pool = NULL; @@ -93,7 +95,7 @@ static PyGetSetDef CFieldDescriptorGetters[] = { PyTypeObject CFieldDescriptor_Type = { PyObject_HEAD_INIT(&PyType_Type) 0, - C("google3.net.google.protobuf.python.internal." + C("google.protobuf.internal." "_net_proto2___python." "CFieldDescriptor"), // tp_name sizeof(CFieldDescriptor), // tp_basicsize @@ -181,6 +183,8 @@ static PyObject* CDescriptorPool_FindFieldByName( const google::protobuf::FieldDescriptor* field_descriptor = NULL; field_descriptor = self->pool->FindFieldByName(full_field_name); + + if (field_descriptor == NULL) { PyErr_Format(PyExc_TypeError, "Couldn't find field %.200s", full_field_name); @@ -223,7 +227,7 @@ static PyMethodDef CDescriptorPoolMethods[] = { PyTypeObject CDescriptorPool_Type = { PyObject_HEAD_INIT(&PyType_Type) 0, - C("google3.net.google.protobuf.python.internal." + C("google.protobuf.internal." "_net_proto2___python." "CFieldDescriptor"), // tp_name sizeof(CDescriptorPool), // tp_basicsize @@ -301,7 +305,6 @@ PyObject* Python_BuildFile(PyObject* ignored, PyObject* arg) { return NULL; } - // If this file is already in the generated pool, don't add it again. if (google::protobuf::DescriptorPool::generated_pool()->FindFileByName( file_proto.name()) != NULL) { Py_RETURN_NONE; diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index 1373c882..9570fd50 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -50,13 +50,20 @@ __author__ = 'robinson@google.com (Will Robinson)' from google.protobuf.internal import api_implementation from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message + _FieldDescriptor = descriptor_mod.FieldDescriptor if api_implementation.Type() == 'cpp': - from google.protobuf.internal import cpp_message - _NewMessage = cpp_message.NewMessage - _InitMessage = cpp_message.InitMessage + if api_implementation.Version() == 2: + from google.protobuf.internal.cpp import cpp_message + _NewMessage = cpp_message.NewMessage + _InitMessage = cpp_message.InitMessage + else: + from google.protobuf.internal import cpp_message + _NewMessage = cpp_message.NewMessage + _InitMessage = cpp_message.InitMessage else: from google.protobuf.internal import python_message _NewMessage = python_message.NewMessage @@ -112,7 +119,7 @@ class GeneratedProtocolMessageType(type): Newly-allocated class. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] - _NewMessage(descriptor, dictionary) + bases = _NewMessage(bases, descriptor, dictionary) superclass = super(GeneratedProtocolMessageType, cls) new_class = superclass.__new__(cls, name, bases, dictionary) @@ -140,3 +147,23 @@ class GeneratedProtocolMessageType(type): _InitMessage(descriptor, cls) superclass = super(GeneratedProtocolMessageType, cls) superclass.__init__(name, bases, dictionary) + + +def ParseMessage(descriptor, byte_str): + """Generate a new Message instance from this Descriptor and a byte string. + + Args: + descriptor: Protobuf Descriptor object + byte_str: Serialized protocol buffer byte string + + Returns: + Newly created protobuf Message object. + """ + + class _ResultClass(message.Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = descriptor + + new_msg = _ResultClass() + new_msg.ParseFromString(byte_str) + return new_msg diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py index c3a1cf60..0714c39d 100755 --- a/python/google/protobuf/text_format.py +++ b/python/google/protobuf/text_format.py @@ -43,10 +43,12 @@ __all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue', 'Merge' ] -# Infinity and NaN are not explicitly supported by Python pre-2.6, and -# float('inf') does not work on Windows (pre-2.6). -_INFINITY = 1e10000 # overflows, thus will actually be infinity. -_NAN = _INFINITY * 0 +_INTEGER_CHECKERS = (type_checkers.Uint32ValueChecker(), + type_checkers.Int32ValueChecker(), + type_checkers.Uint64ValueChecker(), + type_checkers.Int64ValueChecker()) +_FLOAT_INFINITY = re.compile('-?inf(?:inity)?f?', re.IGNORECASE) +_FLOAT_NAN = re.compile('nanf?', re.IGNORECASE) class ParseError(Exception): @@ -120,7 +122,11 @@ def PrintFieldValue(field, value, out, indent=0, PrintMessage(value, out, indent + 2, as_utf8, as_one_line) out.write(' ' * indent + '}') elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: - out.write(field.enum_type.values_by_number[value].name) + enum_value = field.enum_type.values_by_number.get(value, None) + if enum_value is not None: + out.write(enum_value.name) + else: + out.write(str(value)) elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: out.write('\"') if type(value) is unicode: @@ -271,24 +277,7 @@ def _MergeScalarField(tokenizer, message, field): elif field.type == descriptor.FieldDescriptor.TYPE_BYTES: value = tokenizer.ConsumeByteString() elif field.type == descriptor.FieldDescriptor.TYPE_ENUM: - # Enum can be specified by a number (the enum value), or by - # a string literal (the enum name). - enum_descriptor = field.enum_type - if tokenizer.LookingAtInteger(): - number = tokenizer.ConsumeInt32() - enum_value = enum_descriptor.values_by_number.get(number, None) - if enum_value is None: - raise tokenizer.ParseErrorPreviousToken( - 'Enum type "%s" has no value with number %d.' % ( - enum_descriptor.full_name, number)) - else: - identifier = tokenizer.ConsumeIdentifier() - enum_value = enum_descriptor.values_by_name.get(identifier, None) - if enum_value is None: - raise tokenizer.ParseErrorPreviousToken( - 'Enum type "%s" has no value named %s.' % ( - enum_descriptor.full_name, identifier)) - value = enum_value.number + value = tokenizer.ConsumeEnum(field) else: raise RuntimeError('Unknown field type %d' % field.type) @@ -320,12 +309,6 @@ class _Tokenizer(object): '\"([^\"\n\\\\]|\\\\.)*(\"|\\\\?$)|' # a double-quoted string '\'([^\'\n\\\\]|\\\\.)*(\'|\\\\?$)') # a single-quoted string _IDENTIFIER = re.compile('\w+') - _INTEGER_CHECKERS = [type_checkers.Uint32ValueChecker(), - type_checkers.Int32ValueChecker(), - type_checkers.Uint64ValueChecker(), - type_checkers.Int64ValueChecker()] - _FLOAT_INFINITY = re.compile('-?inf(inity)?f?', re.IGNORECASE) - _FLOAT_NAN = re.compile("nanf?", re.IGNORECASE) def __init__(self, text_message): self._text_message = text_message @@ -394,17 +377,6 @@ class _Tokenizer(object): if not self.TryConsume(token): raise self._ParseError('Expected "%s".' % token) - def LookingAtInteger(self): - """Checks if the current token is an integer. - - Returns: - True iff the current token is an integer. - """ - if not self.token: - return False - c = self.token[0] - return (c >= '0' and c <= '9') or c == '-' or c == '+' - def ConsumeIdentifier(self): """Consumes protocol message field identifier. @@ -430,9 +402,9 @@ class _Tokenizer(object): ParseError: If a signed 32bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=True, is_long=False) + result = ParseInteger(self.token, is_signed=True, is_long=False) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -446,9 +418,9 @@ class _Tokenizer(object): ParseError: If an unsigned 32bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=False, is_long=False) + result = ParseInteger(self.token, is_signed=False, is_long=False) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -462,9 +434,9 @@ class _Tokenizer(object): ParseError: If a signed 64bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=True, is_long=True) + result = ParseInteger(self.token, is_signed=True, is_long=True) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -478,9 +450,9 @@ class _Tokenizer(object): ParseError: If an unsigned 64bit integer couldn't be consumed. """ try: - result = self._ParseInteger(self.token, is_signed=False, is_long=True) + result = ParseInteger(self.token, is_signed=False, is_long=True) except ValueError, e: - raise self._IntegerParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -493,21 +465,10 @@ class _Tokenizer(object): Raises: ParseError: If a floating point number couldn't be consumed. """ - text = self.token - if self._FLOAT_INFINITY.match(text): - self.NextToken() - if text.startswith('-'): - return -_INFINITY - return _INFINITY - - if self._FLOAT_NAN.match(text): - self.NextToken() - return _NAN - try: - result = float(text) + result = ParseFloat(self.token) except ValueError, e: - raise self._FloatParseError(e) + raise self._ParseError(str(e)) self.NextToken() return result @@ -520,14 +481,12 @@ class _Tokenizer(object): Raises: ParseError: If a boolean value couldn't be consumed. """ - if self.token in ('true', 't', '1'): - self.NextToken() - return True - elif self.token in ('false', 'f', '0'): - self.NextToken() - return False - else: - raise self._ParseError('Expected "true" or "false".') + try: + result = ParseBool(self.token) + except ValueError, e: + raise self._ParseError(str(e)) + self.NextToken() + return result def ConsumeString(self): """Consumes a string value. @@ -567,7 +526,7 @@ class _Tokenizer(object): """ text = self.token if len(text) < 1 or text[0] not in ('\'', '"'): - raise self._ParseError('Exptected string.') + raise self._ParseError('Expected string.') if len(text) < 2 or text[-1] != text[0]: raise self._ParseError('String missing ending quote.') @@ -579,36 +538,12 @@ class _Tokenizer(object): self.NextToken() return result - def _ParseInteger(self, text, is_signed=False, is_long=False): - """Parses an integer. - - Args: - text: The text to parse. - is_signed: True if a signed integer must be parsed. - is_long: True if a long integer must be parsed. - - Returns: - The integer value. - - Raises: - ValueError: Thrown Iff the text is not a valid integer. - """ - pos = 0 - if text.startswith('-'): - pos += 1 - - base = 10 - if text.startswith('0x', pos) or text.startswith('0X', pos): - base = 16 - elif text.startswith('0', pos): - base = 8 - - # Do the actual parsing. Exception handling is propagated to caller. - result = int(text, base) - - # Check if the integer is sane. Exceptions handled by callers. - checker = self._INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] - checker.CheckValue(result) + def ConsumeEnum(self, field): + try: + result = ParseEnum(field, self.token) + except ValueError, e: + raise self._ParseError(str(e)) + self.NextToken() return result def ParseErrorPreviousToken(self, message): @@ -626,13 +561,7 @@ class _Tokenizer(object): def _ParseError(self, message): """Creates and *returns* a ParseError for the current token.""" return ParseError('%d:%d : %s' % ( - self._line + 1, self._column - len(self.token) + 1, message)) - - def _IntegerParseError(self, e): - return self._ParseError('Couldn\'t parse integer: ' + str(e)) - - def _FloatParseError(self, e): - return self._ParseError('Couldn\'t parse number: ' + str(e)) + self._line + 1, self._column + 1, message)) def _StringParseError(self, e): return self._ParseError('Couldn\'t parse string: ' + str(e)) @@ -689,3 +618,117 @@ def _CUnescape(text): # allow single-digit hex escapes (like '\xf'). result = _CUNESCAPE_HEX.sub(ReplaceHex, text) return result.decode('string_escape') + + +def ParseInteger(text, is_signed=False, is_long=False): + """Parses an integer. + + Args: + text: The text to parse. + is_signed: True if a signed integer must be parsed. + is_long: True if a long integer must be parsed. + + Returns: + The integer value. + + Raises: + ValueError: Thrown Iff the text is not a valid integer. + """ + # Do the actual parsing. Exception handling is propagated to caller. + try: + result = int(text, 0) + except ValueError: + raise ValueError('Couldn\'t parse integer: %s' % text) + + # Check if the integer is sane. Exceptions handled by callers. + checker = _INTEGER_CHECKERS[2 * int(is_long) + int(is_signed)] + checker.CheckValue(result) + return result + + +def ParseFloat(text): + """Parse a floating point number. + + Args: + text: Text to parse. + + Returns: + The number parsed. + + Raises: + ValueError: If a floating point number couldn't be parsed. + """ + try: + # Assume Python compatible syntax. + return float(text) + except ValueError: + # Check alternative spellings. + if _FLOAT_INFINITY.match(text): + if text[0] == '-': + return float('-inf') + else: + return float('inf') + elif _FLOAT_NAN.match(text): + return float('nan') + else: + # assume '1.0f' format + try: + return float(text.rstrip('f')) + except ValueError: + raise ValueError('Couldn\'t parse float: %s' % text) + + +def ParseBool(text): + """Parse a boolean value. + + Args: + text: Text to parse. + + Returns: + Boolean values parsed + + Raises: + ValueError: If text is not a valid boolean. + """ + if text in ('true', 't', '1'): + return True + elif text in ('false', 'f', '0'): + return False + else: + raise ValueError('Expected "true" or "false".') + + +def ParseEnum(field, value): + """Parse an enum value. + + The value can be specified by a number (the enum value), or by + a string literal (the enum name). + + Args: + field: Enum field descriptor. + value: String value. + + Returns: + Enum value number. + + Raises: + ValueError: If the enum value could not be parsed. + """ + enum_descriptor = field.enum_type + try: + number = int(value, 0) + except ValueError: + # Identifier. + enum_value = enum_descriptor.values_by_name.get(value, None) + if enum_value is None: + raise ValueError( + 'Enum type "%s" has no value named %s.' % ( + enum_descriptor.full_name, value)) + else: + # Numeric value. + enum_value = enum_descriptor.values_by_number.get(number, None) + if enum_value is None: + raise ValueError( + 'Enum type "%s" has no value with number %d.' % ( + enum_descriptor.full_name, number)) + return enum_value.number diff --git a/python/setup.py b/python/setup.py index 5f0acbac..fbe27660 100755 --- a/python/setup.py +++ b/python/setup.py @@ -63,19 +63,26 @@ def generate_proto(source): if subprocess.call(protoc_command) != 0: sys.exit(-1) -def MakeTestSuite(): - # This is apparently needed on some systems to make sure that the tests - # work even if a previous version is already installed. - if 'google' in sys.modules: - del sys.modules['google'] - +def GenerateUnittestProtos(): generate_proto("../src/google/protobuf/unittest.proto") generate_proto("../src/google/protobuf/unittest_custom_options.proto") generate_proto("../src/google/protobuf/unittest_import.proto") + generate_proto("../src/google/protobuf/unittest_import_public.proto") generate_proto("../src/google/protobuf/unittest_mset.proto") generate_proto("../src/google/protobuf/unittest_no_generic_services.proto") + generate_proto("google/protobuf/internal/test_bad_identifiers.proto") generate_proto("google/protobuf/internal/more_extensions.proto") + generate_proto("google/protobuf/internal/more_extensions_dynamic.proto") generate_proto("google/protobuf/internal/more_messages.proto") + generate_proto("google/protobuf/internal/factory_test1.proto") + generate_proto("google/protobuf/internal/factory_test2.proto") + +def MakeTestSuite(): + # This is apparently needed on some systems to make sure that the tests + # work even if a previous version is already installed. + if 'google' in sys.modules: + del sys.modules['google'] + GenerateUnittestProtos() import unittest import google.protobuf.internal.generator_test as generator_test @@ -85,6 +92,14 @@ def MakeTestSuite(): as service_reflection_test import google.protobuf.internal.text_format_test as text_format_test import google.protobuf.internal.wire_format_test as wire_format_test + import google.protobuf.internal.unknown_fields_test as unknown_fields_test + import google.protobuf.internal.descriptor_database_test \ + as descriptor_database_test + import google.protobuf.internal.descriptor_pool_test as descriptor_pool_test + import google.protobuf.internal.message_factory_test as message_factory_test + import google.protobuf.internal.message_cpp_test as message_cpp_test + import google.protobuf.internal.reflection_cpp_generated_test \ + as reflection_cpp_generated_test loader = unittest.defaultTestLoader suite = unittest.TestSuite() @@ -117,6 +132,8 @@ class build_py(_build_py): # Generate necessary .proto file if it doesn't exist. generate_proto("../src/google/protobuf/descriptor.proto") generate_proto("../src/google/protobuf/compiler/plugin.proto") + + GenerateUnittestProtos() # Make sure google.protobuf.compiler is a valid package. open('google/protobuf/compiler/__init__.py', 'a').close() # _build_py is an old-style class, so super() doesn't work. @@ -156,6 +173,9 @@ if __name__ == '__main__': 'google.protobuf.descriptor_pb2', 'google.protobuf.compiler.plugin_pb2', 'google.protobuf.message', + 'google.protobuf.descriptor_database', + 'google.protobuf.descriptor_pool', + 'google.protobuf.message_factory', 'google.protobuf.reflection', 'google.protobuf.service', 'google.protobuf.service_reflection', -- cgit v1.2.3