From 40ee551715c3a784ea6132dbf604b0e665ca2def Mon Sep 17 00:00:00 2001 From: temporal Date: Thu, 10 Jul 2008 02:12:20 +0000 Subject: Initial checkin. --- python/google/__init__.py | 1 + python/google/protobuf/__init__.py | 0 python/google/protobuf/descriptor.py | 419 +++++ python/google/protobuf/internal/__init__.py | 0 python/google/protobuf/internal/decoder.py | 194 +++ python/google/protobuf/internal/decoder_test.py | 230 +++ python/google/protobuf/internal/descriptor_test.py | 97 ++ python/google/protobuf/internal/encoder.py | 192 +++ python/google/protobuf/internal/encoder_test.py | 211 +++ python/google/protobuf/internal/generator_test.py | 84 + python/google/protobuf/internal/input_stream.py | 211 +++ .../google/protobuf/internal/input_stream_test.py | 279 ++++ .../google/protobuf/internal/message_listener.py | 55 + .../google/protobuf/internal/more_extensions.proto | 44 + .../google/protobuf/internal/more_messages.proto | 37 + python/google/protobuf/internal/output_stream.py | 112 ++ .../google/protobuf/internal/output_stream_test.py | 162 ++ python/google/protobuf/internal/reflection_test.py | 1300 +++++++++++++++ .../protobuf/internal/service_reflection_test.py | 98 ++ python/google/protobuf/internal/test_util.py | 354 ++++ .../google/protobuf/internal/text_format_test.py | 97 ++ python/google/protobuf/internal/wire_format.py | 222 +++ .../google/protobuf/internal/wire_format_test.py | 232 +++ python/google/protobuf/message.py | 184 +++ python/google/protobuf/reflection.py | 1734 ++++++++++++++++++++ python/google/protobuf/service.py | 194 +++ python/google/protobuf/service_reflection.py | 275 ++++ python/google/protobuf/text_format.py | 111 ++ 28 files changed, 7129 insertions(+) create mode 100755 python/google/__init__.py create mode 100755 python/google/protobuf/__init__.py create mode 100755 python/google/protobuf/descriptor.py create mode 100755 python/google/protobuf/internal/__init__.py create mode 100755 python/google/protobuf/internal/decoder.py create mode 100755 python/google/protobuf/internal/decoder_test.py create mode 100755 python/google/protobuf/internal/descriptor_test.py create mode 100755 python/google/protobuf/internal/encoder.py create mode 100755 python/google/protobuf/internal/encoder_test.py create mode 100755 python/google/protobuf/internal/generator_test.py create mode 100755 python/google/protobuf/internal/input_stream.py create mode 100755 python/google/protobuf/internal/input_stream_test.py create mode 100755 python/google/protobuf/internal/message_listener.py create mode 100644 python/google/protobuf/internal/more_extensions.proto create mode 100644 python/google/protobuf/internal/more_messages.proto create mode 100755 python/google/protobuf/internal/output_stream.py create mode 100755 python/google/protobuf/internal/output_stream_test.py create mode 100755 python/google/protobuf/internal/reflection_test.py create mode 100755 python/google/protobuf/internal/service_reflection_test.py create mode 100755 python/google/protobuf/internal/test_util.py create mode 100755 python/google/protobuf/internal/text_format_test.py create mode 100755 python/google/protobuf/internal/wire_format.py create mode 100755 python/google/protobuf/internal/wire_format_test.py create mode 100755 python/google/protobuf/message.py create mode 100755 python/google/protobuf/reflection.py create mode 100755 python/google/protobuf/service.py create mode 100755 python/google/protobuf/service_reflection.py create mode 100755 python/google/protobuf/text_format.py (limited to 'python/google') diff --git a/python/google/__init__.py b/python/google/__init__.py new file mode 100755 index 00000000..de40ea7c --- /dev/null +++ b/python/google/__init__.py @@ -0,0 +1 @@ +__import__('pkg_resources').declare_namespace(__name__) diff --git a/python/google/protobuf/__init__.py b/python/google/protobuf/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/python/google/protobuf/descriptor.py b/python/google/protobuf/descriptor.py new file mode 100755 index 00000000..04748053 --- /dev/null +++ b/python/google/protobuf/descriptor.py @@ -0,0 +1,419 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(robinson): We probably need to provide deep-copy methods for +# descriptor types. When a FieldDescriptor is passed into +# Descriptor.__init__(), we should make a deep copy and then set +# containing_type on it. Alternatively, we could just get +# rid of containing_type (iit's not needed for reflection.py, at least). +# +# TODO(robinson): Print method? +# +# TODO(robinson): Useful __repr__? + +"""Descriptors essentially contain exactly the information found in a .proto +file, in types that make this information accessible in Python. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +class DescriptorBase(object): + + """Descriptors base class. + + This class is the base of all descriptor classes. It provides common options + related functionaility. + """ + + def __init__(self, options, options_class_name): + """Initialize the descriptor given its options message and the name of the + class of the options message. The name of the class is required in case + the options message is None and has to be created. + """ + self._options = options + self._options_class_name = options_class_name + + def GetOptions(self): + """Retrieves descriptor options. + + This method returns the options set or creates the default options for the + descriptor. + """ + if self._options: + return self._options + from google.protobuf import descriptor_pb2 + try: + options_class = getattr(descriptor_pb2, self._options_class_name) + except AttributeError: + raise RuntimeError('Unknown options class name %s!' % + (self._options_class_name)) + self._options = options_class() + return self._options + + +class Descriptor(DescriptorBase): + + """Descriptor for a protocol message type. + + A Descriptor instance has the following attributes: + + name: (str) Name of this protocol message type. + full_name: (str) Fully-qualified name of this protocol message type, + which will include protocol "package" name and the name of any + enclosing types. + + filename: (str) Name of the .proto file containing this message. + + containing_type: (Descriptor) Reference to the descriptor of the + type containing us, or None if we have no containing type. + + fields: (list of FieldDescriptors) Field descriptors for all + fields in this type. + fields_by_number: (dict int -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "number" attribute in each + FieldDescriptor. + fields_by_name: (dict str -> FieldDescriptor) Same FieldDescriptor + objects as in |fields|, but indexed by "name" attribute in each + FieldDescriptor. + + nested_types: (list of Descriptors) Descriptor references + for all protocol message types nested within this one. + nested_types_by_name: (dict str -> Descriptor) Same Descriptor + objects as in |nested_types|, but indexed by "name" attribute + in each Descriptor. + + enum_types: (list of EnumDescriptors) EnumDescriptor references + for all enums contained within this type. + enum_types_by_name: (dict str ->EnumDescriptor) Same EnumDescriptor + objects as in |enum_types|, but indexed by "name" attribute + in each EnumDescriptor. + enum_values_by_name: (dict str -> EnumValueDescriptor) Dict mapping + from enum value name to EnumValueDescriptor for that value. + + extensions: (list of FieldDescriptor) All extensions defined directly + within this message type (NOT within a nested type). + extensions_by_name: (dict, string -> FieldDescriptor) Same FieldDescriptor + objects as |extensions|, but indexed by "name" attribute of each + FieldDescriptor. + + options: (descriptor_pb2.MessageOptions) Protocol message options or None + to use default message options. + """ + + def __init__(self, name, full_name, filename, containing_type, + fields, nested_types, enum_types, extensions, options=None): + """Arguments to __init__() are as described in the description + of Descriptor fields above. + """ + super(Descriptor, self).__init__(options, 'MessageOptions') + self.name = name + self.full_name = full_name + self.filename = filename + self.containing_type = containing_type + + # We have fields in addition to fields_by_name and fields_by_number, + # so that: + # 1. Clients can index fields by "order in which they're listed." + # 2. Clients can easily iterate over all fields with the terse + # syntax: for f in descriptor.fields: ... + self.fields = fields + for field in self.fields: + field.containing_type = self + self.fields_by_number = dict((f.number, f) for f in fields) + self.fields_by_name = dict((f.name, f) for f in fields) + + self.nested_types = nested_types + self.nested_types_by_name = dict((t.name, t) for t in nested_types) + + self.enum_types = enum_types + for enum_type in self.enum_types: + enum_type.containing_type = self + self.enum_types_by_name = dict((t.name, t) for t in enum_types) + self.enum_values_by_name = dict( + (v.name, v) for t in enum_types for v in t.values) + + self.extensions = extensions + for extension in self.extensions: + extension.extension_scope = self + self.extensions_by_name = dict((f.name, f) for f in extensions) + + +# TODO(robinson): We should have aggressive checking here, +# for example: +# * If you specify a repeated field, you should not be allowed +# to specify a default value. +# * [Other examples here as needed]. +# +# TODO(robinson): for this and other *Descriptor classes, we +# might also want to lock things down aggressively (e.g., +# prevent clients from setting the attributes). Having +# stronger invariants here in general will reduce the number +# of runtime checks we must do in reflection.py... +class FieldDescriptor(DescriptorBase): + + """Descriptor for a single field in a .proto file. + + A FieldDescriptor instance has the following attriubtes: + + name: (str) Name of this field, exactly as it appears in .proto. + full_name: (str) Name of this field, including containing scope. This is + particularly relevant for extensions. + index: (int) Dense, 0-indexed index giving the order that this + field textually appears within its message in the .proto file. + number: (int) Tag number declared for this field in the .proto file. + + type: (One of the TYPE_* constants below) Declared type. + cpp_type: (One of the CPPTYPE_* constants below) C++ type used to + represent this field. + + label: (One of the LABEL_* constants below) Tells whether this + field is optional, required, or repeated. + default_value: (Varies) Default value of this field. Only + meaningful for non-repeated scalar fields. Repeated fields + should always set this to [], and non-repeated composite + fields should always set this to None. + + containing_type: (Descriptor) Descriptor of the protocol message + type that contains this field. Set by the Descriptor constructor + if we're passed into one. + Somewhat confusingly, for extension fields, this is the + descriptor of the EXTENDED message, not the descriptor + of the message containing this field. (See is_extension and + extension_scope below). + message_type: (Descriptor) If a composite field, a descriptor + of the message type contained in this field. Otherwise, this is None. + enum_type: (EnumDescriptor) If this field contains an enum, a + descriptor of that enum. Otherwise, this is None. + + is_extension: True iff this describes an extension field. + extension_scope: (Descriptor) Only meaningful if is_extension is True. + Gives the message that immediately contains this extension field. + Will be None iff we're a top-level (file-level) extension field. + + options: (descriptor_pb2.FieldOptions) Protocol message field options or + None to use default field options. + """ + + # Must be consistent with C++ FieldDescriptor::Type enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + TYPE_DOUBLE = 1 + TYPE_FLOAT = 2 + TYPE_INT64 = 3 + TYPE_UINT64 = 4 + TYPE_INT32 = 5 + TYPE_FIXED64 = 6 + TYPE_FIXED32 = 7 + TYPE_BOOL = 8 + TYPE_STRING = 9 + TYPE_GROUP = 10 + TYPE_MESSAGE = 11 + TYPE_BYTES = 12 + TYPE_UINT32 = 13 + TYPE_ENUM = 14 + TYPE_SFIXED32 = 15 + TYPE_SFIXED64 = 16 + TYPE_SINT32 = 17 + TYPE_SINT64 = 18 + MAX_TYPE = 18 + + # Must be consistent with C++ FieldDescriptor::CppType enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + CPPTYPE_INT32 = 1 + CPPTYPE_INT64 = 2 + CPPTYPE_UINT32 = 3 + CPPTYPE_UINT64 = 4 + CPPTYPE_DOUBLE = 5 + CPPTYPE_FLOAT = 6 + CPPTYPE_BOOL = 7 + CPPTYPE_ENUM = 8 + CPPTYPE_STRING = 9 + CPPTYPE_MESSAGE = 10 + MAX_CPPTYPE = 10 + + # Must be consistent with C++ FieldDescriptor::Label enum in + # descriptor.h. + # + # TODO(robinson): Find a way to eliminate this repetition. + LABEL_OPTIONAL = 1 + LABEL_REQUIRED = 2 + LABEL_REPEATED = 3 + MAX_LABEL = 3 + + def __init__(self, name, full_name, index, number, type, cpp_type, label, + default_value, message_type, enum_type, containing_type, + is_extension, extension_scope, options=None): + """The arguments are as described in the description of FieldDescriptor + attributes above. + + Note that containing_type may be None, and may be set later if necessary + (to deal with circular references between message types, for example). + Likewise for extension_scope. + """ + super(FieldDescriptor, self).__init__(options, 'FieldOptions') + self.name = name + self.full_name = full_name + self.index = index + self.number = number + self.type = type + self.cpp_type = cpp_type + self.label = label + self.default_value = default_value + self.containing_type = containing_type + self.message_type = message_type + self.enum_type = enum_type + self.is_extension = is_extension + self.extension_scope = extension_scope + + +class EnumDescriptor(DescriptorBase): + + """Descriptor for an enum defined in a .proto file. + + An EnumDescriptor instance has the following attributes: + + name: (str) Name of the enum type. + full_name: (str) Full name of the type, including package name + and any enclosing type(s). + filename: (str) Name of the .proto file in which this appears. + + values: (list of EnumValueDescriptors) List of the values + in this enum. + values_by_name: (dict str -> EnumValueDescriptor) Same as |values|, + but indexed by the "name" field of each EnumValueDescriptor. + values_by_number: (dict int -> EnumValueDescriptor) Same as |values|, + but indexed by the "number" field of each EnumValueDescriptor. + containing_type: (Descriptor) Descriptor of the immediate containing + type of this enum, or None if this is an enum defined at the + top level in a .proto file. Set by Descriptor's constructor + if we're passed into one. + options: (descriptor_pb2.EnumOptions) Enum options message or + None to use default enum options. + """ + + def __init__(self, name, full_name, filename, values, + containing_type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumDescriptor, self).__init__(options, 'EnumOptions') + self.name = name + self.full_name = full_name + self.filename = filename + self.values = values + for value in self.values: + value.type = self + self.values_by_name = dict((v.name, v) for v in values) + self.values_by_number = dict((v.number, v) for v in values) + self.containing_type = containing_type + + +class EnumValueDescriptor(DescriptorBase): + + """Descriptor for a single value within an enum. + + name: (str) Name of this value. + index: (int) Dense, 0-indexed index giving the order that this + value appears textually within its enum in the .proto file. + number: (int) Actual number assigned to this enum value. + type: (EnumDescriptor) EnumDescriptor to which this value + belongs. Set by EnumDescriptor's constructor if we're + passed into one. + options: (descriptor_pb2.EnumValueOptions) Enum value options message or + None to use default enum value options options. + """ + + def __init__(self, name, index, number, type=None, options=None): + """Arguments are as described in the attribute description above.""" + super(EnumValueDescriptor, self).__init__(options, 'EnumValueOptions') + self.name = name + self.index = index + self.number = number + self.type = type + + +class ServiceDescriptor(DescriptorBase): + + """Descriptor for a service. + + name: (str) Name of the service. + full_name: (str) Full name of the service, including package name. + index: (int) 0-indexed index giving the order that this services + definition appears withing the .proto file. + methods: (list of MethodDescriptor) List of methods provided by this + service. + options: (descriptor_pb2.ServiceOptions) Service options message or + None to use default service options. + """ + + def __init__(self, name, full_name, index, methods, options=None): + super(ServiceDescriptor, self).__init__(options, 'ServiceOptions') + self.name = name + self.full_name = full_name + self.index = index + self.methods = methods + # Set the containing service for each method in this service. + for method in self.methods: + method.containing_service = self + + def FindMethodByName(self, name): + """Searches for the specified method, and returns its descriptor.""" + for method in self.methods: + if name == method.name: + return method + return None + + +class MethodDescriptor(DescriptorBase): + + """Descriptor for a method in a service. + + name: (str) Name of the method within the service. + full_name: (str) Full name of method. + index: (int) 0-indexed index of the method inside the service. + containing_service: (ServiceDescriptor) The service that contains this + method. + input_type: The descriptor of the message that this method accepts. + output_type: The descriptor of the message that this method returns. + options: (descriptor_pb2.MethodOptions) Method options message or + None to use default method options. + """ + + def __init__(self, name, full_name, index, containing_service, + input_type, output_type, options=None): + """The arguments are as described in the description of MethodDescriptor + attributes above. + + Note that containing_service may be None, and may be set later if necessary. + """ + super(MethodDescriptor, self).__init__(options, 'MethodOptions') + self.name = name + self.full_name = full_name + self.index = index + self.containing_service = containing_service + self.input_type = input_type + self.output_type = output_type + + +def _ParseOptions(message, string): + """Parses serialized options. + + This helper function is used to parse serialized options in generated + proto2 files. It must not be used outside proto2. + """ + message.ParseFromString(string) + return message; diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py new file mode 100755 index 00000000..e69de29b diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py new file mode 100755 index 00000000..b81f04a5 --- /dev/null +++ b/python/google/protobuf/internal/decoder.py @@ -0,0 +1,194 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Class for decoding protocol buffer primitives. + +Contains the logic for decoding every logical protocol field type +from one of the 5 physical wire types. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message +from google.protobuf.internal import input_stream +from google.protobuf.internal import wire_format + + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by WireFormat from the C++ proto2 +# implementation. + + +class Decoder(object): + + """Decodes logical protocol buffer fields from the wire.""" + + def __init__(self, s): + """Initializes the decoder to read from s. + + Args: + s: An immutable sequence of bytes, which must be accessible + via the Python buffer() primitive (i.e., buffer(s)). + """ + self._stream = input_stream.InputStream(s) + + def EndOfStream(self): + """Returns true iff we've reached the end of the bytes we're reading.""" + return self._stream.EndOfStream() + + def Position(self): + """Returns the 0-indexed position in |s|.""" + return self._stream.Position() + + def ReadFieldNumberAndWireType(self): + """Reads a tag from the wire. Returns a (field_number, wire_type) pair.""" + tag_and_type = self.ReadUInt32() + return wire_format.UnpackTag(tag_and_type) + + def SkipBytes(self, bytes): + """Skips the specified number of bytes on the wire.""" + self._stream.SkipBytes(bytes) + + # Note that the Read*() methods below are not exactly symmetrical with the + # corresponding Encoder.Append*() methods. Those Encoder methods first + # encode a tag, but the Read*() methods below assume that the tag has already + # been read, and that the client wishes to read a field of the specified type + # starting at the current position. + + def ReadInt32(self): + """Reads and returns a signed, varint-encoded, 32-bit integer.""" + return self._stream.ReadVarint32() + + def ReadInt64(self): + """Reads and returns a signed, varint-encoded, 64-bit integer.""" + return self._stream.ReadVarint64() + + def ReadUInt32(self): + """Reads and returns an signed, varint-encoded, 32-bit integer.""" + return self._stream.ReadVarUInt32() + + def ReadUInt64(self): + """Reads and returns an signed, varint-encoded,64-bit integer.""" + return self._stream.ReadVarUInt64() + + def ReadSInt32(self): + """Reads and returns a signed, zigzag-encoded, varint-encoded, + 32-bit integer.""" + return wire_format.ZigZagDecode(self._stream.ReadVarUInt32()) + + def ReadSInt64(self): + """Reads and returns a signed, zigzag-encoded, varint-encoded, + 64-bit integer.""" + return wire_format.ZigZagDecode(self._stream.ReadVarUInt64()) + + def ReadFixed32(self): + """Reads and returns an unsigned, fixed-width, 32-bit integer.""" + return self._stream.ReadLittleEndian32() + + def ReadFixed64(self): + """Reads and returns an unsigned, fixed-width, 64-bit integer.""" + return self._stream.ReadLittleEndian64() + + def ReadSFixed32(self): + """Reads and returns a signed, fixed-width, 32-bit integer.""" + value = self._stream.ReadLittleEndian32() + if value >= (1 << 31): + value -= (1 << 32) + return value + + def ReadSFixed64(self): + """Reads and returns a signed, fixed-width, 64-bit integer.""" + value = self._stream.ReadLittleEndian64() + if value >= (1 << 63): + value -= (1 << 64) + return value + + def ReadFloat(self): + """Reads and returns a 4-byte floating-point number.""" + serialized = self._stream.ReadString(4) + return struct.unpack('f', serialized)[0] + + def ReadDouble(self): + """Reads and returns an 8-byte floating-point number.""" + serialized = self._stream.ReadString(8) + return struct.unpack('d', serialized)[0] + + def ReadBool(self): + """Reads and returns a bool.""" + i = self._stream.ReadVarUInt32() + return bool(i) + + def ReadEnum(self): + """Reads and returns an enum value.""" + return self._stream.ReadVarUInt32() + + def ReadString(self): + """Reads and returns a length-delimited string.""" + length = self._stream.ReadVarUInt32() + return self._stream.ReadString(length) + + def ReadBytes(self): + """Reads and returns a length-delimited byte sequence.""" + return self.ReadString() + + def ReadMessageInto(self, msg): + """Calls msg.MergeFromString() to merge + length-delimited serialized message data into |msg|. + + REQUIRES: The decoder must be positioned at the serialized "length" + prefix to a length-delmiited serialized message. + + POSTCONDITION: The decoder is positioned just after the + serialized message, and we have merged those serialized + contents into |msg|. + """ + length = self._stream.ReadVarUInt32() + sub_buffer = self._stream.GetSubBuffer(length) + num_bytes_used = msg.MergeFromString(sub_buffer) + if num_bytes_used != length: + raise message.DecodeError( + 'Submessage told to deserialize from %d-byte encoding, ' + 'but used only %d bytes' % (length, num_bytes_used)) + self._stream.SkipBytes(num_bytes_used) + + def ReadGroupInto(self, expected_field_number, group): + """Calls group.MergeFromString() to merge + END_GROUP-delimited serialized message data into |group|. + We'll raise an exception if we don't find an END_GROUP + tag immediately after the serialized message contents. + + REQUIRES: The decoder is positioned just after the START_GROUP + tag for this group. + + POSTCONDITION: The decoder is positioned just after the + END_GROUP tag for this group, and we have merged + the contents of the group into |group|. + """ + sub_buffer = self._stream.GetSubBuffer() # No a priori length limit. + num_bytes_used = group.MergeFromString(sub_buffer) + if num_bytes_used < 0: + raise message.DecodeError('Group message reported negative bytes read.') + self._stream.SkipBytes(num_bytes_used) + field_number, field_type = self.ReadFieldNumberAndWireType() + if field_type != wire_format.WIRETYPE_END_GROUP: + raise message.DecodeError('Group message did not end with an END_GROUP.') + if field_number != expected_field_number: + raise message.DecodeError('END_GROUP tag had field ' + 'number %d, was expecting field number %d' % ( + field_number, expected_field_number)) + # We're now positioned just after the END_GROUP tag. Perfect. diff --git a/python/google/protobuf/internal/decoder_test.py b/python/google/protobuf/internal/decoder_test.py new file mode 100755 index 00000000..e36a96fc --- /dev/null +++ b/python/google/protobuf/internal/decoder_test.py @@ -0,0 +1,230 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for google.protobuf.internal.decoder.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +import unittest +from google.protobuf.internal import wire_format +from google.protobuf.internal import encoder +from google.protobuf.internal import decoder +import logging +import mox +from google.protobuf.internal import input_stream +from google.protobuf import message + + +class DecoderTest(unittest.TestCase): + + def setUp(self): + self.mox = mox.Mox() + self.mock_stream = self.mox.CreateMock(input_stream.InputStream) + self.mock_message = self.mox.CreateMock(message.Message) + + def testReadFieldNumberAndWireType(self): + # Test field numbers that will require various varint sizes. + for expected_field_number in (1, 15, 16, 2047, 2048): + for expected_wire_type in range(6): # Highest-numbered wiretype is 5. + e = encoder.Encoder() + e._AppendTag(expected_field_number, expected_wire_type) + s = e.ToString() + d = decoder.Decoder(s) + field_number, wire_type = d.ReadFieldNumberAndWireType() + self.assertEqual(expected_field_number, field_number) + self.assertEqual(expected_wire_type, wire_type) + + def ReadScalarTestHelper(self, test_name, decoder_method, expected_result, + expected_stream_method_name, + stream_method_return, *args): + """Helper for testReadScalars below. + + Calls one of the Decoder.Read*() methods and ensures that the results are + as expected. + + Args: + test_name: Name of this test, used for logging only. + decoder_method: Unbound decoder.Decoder method to call. + expected_result: Value we expect returned from decoder_method(). + expected_stream_method_name: (string) Name of the InputStream + method we expect Decoder to call to actually read the value + on the wire. + stream_method_return: Value our mocked-out stream method should + return to the decoder. + args: Additional arguments that we expect to be passed to the + stream method. + """ + logging.info('Testing %s scalar input.\n' + 'Calling %r(), and expecting that to call the ' + 'stream method %s(%r), which will return %r. Finally, ' + 'expecting the Decoder method to return %r'% ( + test_name, decoder_method, + expected_stream_method_name, args, stream_method_return, + expected_result)) + + d = decoder.Decoder('') + d._stream = self.mock_stream + if decoder_method in (decoder.Decoder.ReadString, + decoder.Decoder.ReadBytes): + self.mock_stream.ReadVarUInt32().AndReturn(len(stream_method_return)) + # We have to use names instead of methods to work around some + # mox weirdness. (ResetAll() is overzealous). + expected_stream_method = getattr(self.mock_stream, + expected_stream_method_name) + expected_stream_method(*args).AndReturn(stream_method_return) + + self.mox.ReplayAll() + self.assertEqual(expected_result, decoder_method(d)) + self.mox.VerifyAll() + self.mox.ResetAll() + + def testReadScalars(self): + test_string = 'I can feel myself getting sutpider.' + scalar_tests = [ + ['int32', decoder.Decoder.ReadInt32, 0, 'ReadVarint32', 0], + ['int64', decoder.Decoder.ReadInt64, 0, 'ReadVarint64', 0], + ['uint32', decoder.Decoder.ReadUInt32, 0, 'ReadVarUInt32', 0], + ['uint64', decoder.Decoder.ReadUInt64, 0, 'ReadVarUInt64', 0], + ['fixed32', decoder.Decoder.ReadFixed32, 0xffffffff, + 'ReadLittleEndian32', 0xffffffff], + ['fixed64', decoder.Decoder.ReadFixed64, 0xffffffffffffffff, + 'ReadLittleEndian64', 0xffffffffffffffff], + ['sfixed32', decoder.Decoder.ReadSFixed32, -1, + 'ReadLittleEndian32', 0xffffffff], + ['sfixed64', decoder.Decoder.ReadSFixed64, -1, + 'ReadLittleEndian64', 0xffffffffffffffff], + ['float', decoder.Decoder.ReadFloat, 0.0, + 'ReadString', struct.pack('f', 0.0), 4], + ['double', decoder.Decoder.ReadDouble, 0.0, + 'ReadString', struct.pack('d', 0.0), 8], + ['bool', decoder.Decoder.ReadBool, True, 'ReadVarUInt32', 1], + ['enum', decoder.Decoder.ReadEnum, 23, 'ReadVarUInt32', 23], + ['string', decoder.Decoder.ReadString, + test_string, 'ReadString', test_string, len(test_string)], + ['bytes', decoder.Decoder.ReadBytes, + test_string, 'ReadString', test_string, len(test_string)], + # We test zigzag decoding routines more extensively below. + ['sint32', decoder.Decoder.ReadSInt32, -1, 'ReadVarUInt32', 1], + ['sint64', decoder.Decoder.ReadSInt64, -1, 'ReadVarUInt64', 1], + ] + # Ensure that we're testing different Decoder methods and using + # different test names in all test cases above. + self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests))) + self.assertEqual(len(scalar_tests), len(set(t[1] for t in scalar_tests))) + for args in scalar_tests: + self.ReadScalarTestHelper(*args) + + def testReadMessageInto(self): + length = 23 + def Test(simulate_error): + d = decoder.Decoder('') + d._stream = self.mock_stream + self.mock_stream.ReadVarUInt32().AndReturn(length) + sub_buffer = object() + self.mock_stream.GetSubBuffer(length).AndReturn(sub_buffer) + + if simulate_error: + self.mock_message.MergeFromString(sub_buffer).AndReturn(length - 1) + self.mox.ReplayAll() + self.assertRaises( + message.DecodeError, d.ReadMessageInto, self.mock_message) + else: + self.mock_message.MergeFromString(sub_buffer).AndReturn(length) + self.mock_stream.SkipBytes(length) + self.mox.ReplayAll() + d.ReadMessageInto(self.mock_message) + + self.mox.VerifyAll() + self.mox.ResetAll() + + Test(simulate_error=False) + Test(simulate_error=True) + + def testReadGroupInto_Success(self): + # Test both the empty and nonempty cases. + for num_bytes in (5, 0): + field_number = expected_field_number = 10 + d = decoder.Decoder('') + d._stream = self.mock_stream + sub_buffer = object() + self.mock_stream.GetSubBuffer().AndReturn(sub_buffer) + self.mock_message.MergeFromString(sub_buffer).AndReturn(num_bytes) + self.mock_stream.SkipBytes(num_bytes) + self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( + field_number, wire_format.WIRETYPE_END_GROUP)) + self.mox.ReplayAll() + d.ReadGroupInto(expected_field_number, self.mock_message) + self.mox.VerifyAll() + self.mox.ResetAll() + + def ReadGroupInto_FailureTestHelper(self, bytes_read): + d = decoder.Decoder('') + d._stream = self.mock_stream + sub_buffer = object() + self.mock_stream.GetSubBuffer().AndReturn(sub_buffer) + self.mock_message.MergeFromString(sub_buffer).AndReturn(bytes_read) + return d + + def testReadGroupInto_NegativeBytesReported(self): + expected_field_number = 10 + d = self.ReadGroupInto_FailureTestHelper(bytes_read=-1) + self.mox.ReplayAll() + self.assertRaises(message.DecodeError, + d.ReadGroupInto, expected_field_number, + self.mock_message) + self.mox.VerifyAll() + + def testReadGroupInto_NoEndGroupTag(self): + field_number = expected_field_number = 10 + num_bytes = 5 + d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes) + self.mock_stream.SkipBytes(num_bytes) + # Right field number, wrong wire type. + self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( + field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)) + self.mox.ReplayAll() + self.assertRaises(message.DecodeError, + d.ReadGroupInto, expected_field_number, + self.mock_message) + self.mox.VerifyAll() + + def testReadGroupInto_WrongFieldNumberInEndGroupTag(self): + expected_field_number = 10 + field_number = expected_field_number + 1 + num_bytes = 5 + d = self.ReadGroupInto_FailureTestHelper(bytes_read=num_bytes) + self.mock_stream.SkipBytes(num_bytes) + # Wrong field number, right wire type. + self.mock_stream.ReadVarUInt32().AndReturn(wire_format.PackTag( + field_number, wire_format.WIRETYPE_END_GROUP)) + self.mox.ReplayAll() + self.assertRaises(message.DecodeError, + d.ReadGroupInto, expected_field_number, + self.mock_message) + self.mox.VerifyAll() + + def testSkipBytes(self): + d = decoder.Decoder('') + num_bytes = 1024 + self.mock_stream.SkipBytes(num_bytes) + d._stream = self.mock_stream + self.mox.ReplayAll() + d.SkipBytes(num_bytes) + self.mox.VerifyAll() + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py new file mode 100755 index 00000000..625d0326 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_test.py @@ -0,0 +1,97 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittest for google.protobuf.internal.descriptor.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor + +class DescriptorTest(unittest.TestCase): + + def setUp(self): + self.my_enum = descriptor.EnumDescriptor( + name='ForeignEnum', + full_name='protobuf_unittest.ForeignEnum', + filename='ForeignEnum', + values=[ + descriptor.EnumValueDescriptor(name='FOREIGN_FOO', index=0, number=4), + descriptor.EnumValueDescriptor(name='FOREIGN_BAR', index=1, number=5), + descriptor.EnumValueDescriptor(name='FOREIGN_BAZ', index=2, number=6), + ]) + self.my_message = descriptor.Descriptor( + name='NestedMessage', + full_name='protobuf_unittest.TestAllTypes.NestedMessage', + filename='some/filename/some.proto', + containing_type=None, + fields=[ + descriptor.FieldDescriptor( + name='bb', + full_name='protobuf_unittest.TestAllTypes.NestedMessage.bb', + index=0, number=1, + type=5, cpp_type=1, label=1, + default_value=0, + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None), + ], + nested_types=[], + enum_types=[ + self.my_enum, + ], + extensions=[]) + self.my_method = descriptor.MethodDescriptor( + name='Bar', + full_name='protobuf_unittest.TestService.Bar', + index=0, + containing_service=None, + input_type=None, + output_type=None) + self.my_service = descriptor.ServiceDescriptor( + name='TestServiceWithOptions', + full_name='protobuf_unittest.TestServiceWithOptions', + index=0, + methods=[ + self.my_method + ]) + + def testEnumFixups(self): + self.assertEqual(self.my_enum, self.my_enum.values[0].type) + + def testContainingTypeFixups(self): + self.assertEqual(self.my_message, self.my_message.fields[0].containing_type) + self.assertEqual(self.my_message, self.my_enum.containing_type) + + def testContainingServiceFixups(self): + self.assertEqual(self.my_service, self.my_method.containing_service) + + def testGetOptions(self): + self.assertEqual(self.my_enum.GetOptions(), + descriptor_pb2.EnumOptions()) + self.assertEqual(self.my_enum.values[0].GetOptions(), + descriptor_pb2.EnumValueOptions()) + self.assertEqual(self.my_message.GetOptions(), + descriptor_pb2.MessageOptions()) + self.assertEqual(self.my_message.fields[0].GetOptions(), + descriptor_pb2.FieldOptions()) + self.assertEqual(self.my_method.GetOptions(), + descriptor_pb2.MethodOptions()) + self.assertEqual(self.my_service.GetOptions(), + descriptor_pb2.ServiceOptions()) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py new file mode 100755 index 00000000..29c78b23 --- /dev/null +++ b/python/google/protobuf/internal/encoder.py @@ -0,0 +1,192 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Class for encoding protocol message primitives. + +Contains the logic for encoding every logical protocol field type +into one of the 5 physical wire types. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message +from google.protobuf.internal import wire_format +from google.protobuf.internal import output_stream + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by WireFormat from the C++ proto2 +# implementation. + + +class Encoder(object): + + """Encodes logical protocol buffer fields to the wire format.""" + + def __init__(self): + self._stream = output_stream.OutputStream() + + def ToString(self): + """Returns all values encoded in this object as a string.""" + return self._stream.ToString() + + # All the Append*() methods below first append a tag+type pair to the buffer + # before appending the specified value. + + def AppendInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarint32(value) + + def AppendInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarint64(value) + + def AppendUInt32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarUInt32(unsigned_value) + + def AppendUInt64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, varint-encoded.""" + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + self._stream.AppendVarUInt64(unsigned_value) + + def AppendSInt32(self, field_number, value): + """Appends a 32-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + zigzag_value = wire_format.ZigZagEncode(value) + self._stream.AppendVarUInt32(zigzag_value) + + def AppendSInt64(self, field_number, value): + """Appends a 64-bit integer to our buffer, zigzag-encoded and then + varint-encoded. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_VARINT) + zigzag_value = wire_format.ZigZagEncode(value) + self._stream.AppendVarUInt64(zigzag_value) + + def AppendFixed32(self, field_number, unsigned_value): + """Appends an unsigned 32-bit integer to our buffer, in little-endian + byte-order. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self._stream.AppendLittleEndian32(unsigned_value) + + def AppendFixed64(self, field_number, unsigned_value): + """Appends an unsigned 64-bit integer to our buffer, in little-endian + byte-order. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self._stream.AppendLittleEndian64(unsigned_value) + + def AppendSFixed32(self, field_number, value): + """Appends a signed 32-bit integer to our buffer, in little-endian + byte-order. + """ + sign = (value & 0x80000000) and -1 or 0 + if value >> 32 != sign: + raise message.EncodeError('SFixed32 out of range: %d' % value) + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self._stream.AppendLittleEndian32(value & 0xffffffff) + + def AppendSFixed64(self, field_number, value): + """Appends a signed 64-bit integer to our buffer, in little-endian + byte-order. + """ + sign = (value & 0x8000000000000000) and -1 or 0 + if value >> 64 != sign: + raise message.EncodeError('SFixed64 out of range: %d' % value) + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self._stream.AppendLittleEndian64(value & 0xffffffffffffffff) + + def AppendFloat(self, field_number, value): + """Appends a floating-point number to our buffer.""" + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED32) + self._stream.AppendRawBytes(struct.pack('f', value)) + + def AppendDouble(self, field_number, value): + """Appends a double-precision floating-point number to our buffer.""" + self._AppendTag(field_number, wire_format.WIRETYPE_FIXED64) + self._stream.AppendRawBytes(struct.pack('d', value)) + + def AppendBool(self, field_number, value): + """Appends a boolean to our buffer.""" + self.AppendInt32(field_number, value) + + def AppendEnum(self, field_number, value): + """Appends an enum value to our buffer.""" + self.AppendInt32(field_number, value) + + def AppendString(self, field_number, value): + """Appends a length-prefixed string to our buffer, with the + length varint-encoded. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self._stream.AppendVarUInt32(len(value)) + self._stream.AppendRawBytes(value) + + def AppendBytes(self, field_number, value): + """Appends a length-prefixed sequence of bytes to our buffer, with the + length varint-encoded. + """ + self.AppendString(field_number, value) + + # TODO(robinson): For AppendGroup() and AppendMessage(), we'd really like to + # avoid the extra string copy here. We can do so if we widen the Message + # interface to be able to serialize to a stream in addition to a string. The + # challenge when thinking ahead to the Python/C API implementation of Message + # is finding a stream-like Python thing to which we can write raw bytes + # from C. I'm not sure such a thing exists(?). (array.array is pretty much + # what we want, but it's not directly exposed in the Python/C API). + + def AppendGroup(self, field_number, group): + """Appends a group to our buffer. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_START_GROUP) + self._stream.AppendRawBytes(group.SerializeToString()) + self._AppendTag(field_number, wire_format.WIRETYPE_END_GROUP) + + def AppendMessage(self, field_number, msg): + """Appends a nested message to our buffer. + """ + self._AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + self._stream.AppendVarUInt32(msg.ByteSize()) + self._stream.AppendRawBytes(msg.SerializeToString()) + + def AppendMessageSetItem(self, field_number, msg): + """Appends an item using the message set wire format. + + The message set message looks like this: + message MessageSet { + repeated group Item = 1 { + required int32 type_id = 2; + required string message = 3; + } + } + """ + self._AppendTag(1, wire_format.WIRETYPE_START_GROUP) + self.AppendInt32(2, field_number) + self.AppendMessage(3, msg) + self._AppendTag(1, wire_format.WIRETYPE_END_GROUP) + + def _AppendTag(self, field_number, wire_type): + """Appends a tag containing field number and wire type information.""" + self._stream.AppendVarUInt32(wire_format.PackTag(field_number, wire_type)) diff --git a/python/google/protobuf/internal/encoder_test.py b/python/google/protobuf/internal/encoder_test.py new file mode 100755 index 00000000..5d690da7 --- /dev/null +++ b/python/google/protobuf/internal/encoder_test.py @@ -0,0 +1,211 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for google.protobuf.internal.encoder.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +import logging +import unittest +import mox +from google.protobuf.internal import wire_format +from google.protobuf.internal import encoder +from google.protobuf.internal import output_stream +from google.protobuf import message + + +class EncoderTest(unittest.TestCase): + + def setUp(self): + self.mox = mox.Mox() + self.encoder = encoder.Encoder() + self.mock_stream = self.mox.CreateMock(output_stream.OutputStream) + self.mock_message = self.mox.CreateMock(message.Message) + self.encoder._stream = self.mock_stream + + def PackTag(self, field_number, wire_type): + return wire_format.PackTag(field_number, wire_type) + + def AppendScalarTestHelper(self, test_name, encoder_method, + expected_stream_method_name, + wire_type, field_value, expected_value=None): + """Helper for testAppendScalars. + + Calls one of the Encoder methods, and ensures that the Encoder + in turn makes the expected calls into its OutputStream. + + Args: + test_name: Name of this test, used only for logging. + encoder_method: Callable on self.encoder, which should + accept |field_value| as an argument. This is the Encoder + method we're testing. + expected_stream_method_name: (string) Name of the OutputStream + method we expect Encoder to call to actually put the value + on the wire. + wire_type: The WIRETYPE_* constant we expect encoder to + use in the specified encoder_method. + field_value: The value we're trying to encode. Passed + into encoder_method. + expected_value: The value we expect Encoder to pass into + the OutputStream method. If None, we expect field_value + to pass through unmodified. + """ + if expected_value is None: + expected_value = field_value + + logging.info('Testing %s scalar output.\n' + 'Calling %r(%r), and expecting that to call the ' + 'stream method %s(%r).' % ( + test_name, encoder_method, field_value, + expected_stream_method_name, expected_value)) + + field_number = 10 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32(self.PackTag(field_number, wire_type)) + # If we're length-delimited, we should then append the length. + if wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + self.mock_stream.AppendVarUInt32(len(field_value)) + # Should then append the value itself. + # We have to use names instead of methods to work around some + # mox weirdness. (ResetAll() is overzealous). + expected_stream_method = getattr(self.mock_stream, + expected_stream_method_name) + expected_stream_method(expected_value) + + self.mox.ReplayAll() + encoder_method(field_number, field_value) + self.mox.VerifyAll() + self.mox.ResetAll() + + def testAppendScalars(self): + scalar_tests = [ + ['int32', self.encoder.AppendInt32, 'AppendVarint32', + wire_format.WIRETYPE_VARINT, 0], + ['int64', self.encoder.AppendInt64, 'AppendVarint64', + wire_format.WIRETYPE_VARINT, 0], + ['uint32', self.encoder.AppendUInt32, 'AppendVarUInt32', + wire_format.WIRETYPE_VARINT, 0], + ['uint64', self.encoder.AppendUInt64, 'AppendVarUInt64', + wire_format.WIRETYPE_VARINT, 0], + ['fixed32', self.encoder.AppendFixed32, 'AppendLittleEndian32', + wire_format.WIRETYPE_FIXED32, 0], + ['fixed64', self.encoder.AppendFixed64, 'AppendLittleEndian64', + wire_format.WIRETYPE_FIXED64, 0], + ['sfixed32', self.encoder.AppendSFixed32, 'AppendLittleEndian32', + wire_format.WIRETYPE_FIXED32, -1, 0xffffffff], + ['sfixed64', self.encoder.AppendSFixed64, 'AppendLittleEndian64', + wire_format.WIRETYPE_FIXED64, -1, 0xffffffffffffffff], + ['float', self.encoder.AppendFloat, 'AppendRawBytes', + wire_format.WIRETYPE_FIXED32, 0.0, struct.pack('f', 0.0)], + ['double', self.encoder.AppendDouble, 'AppendRawBytes', + wire_format.WIRETYPE_FIXED64, 0.0, struct.pack('d', 0.0)], + ['bool', self.encoder.AppendBool, 'AppendVarint32', + wire_format.WIRETYPE_VARINT, False], + ['enum', self.encoder.AppendEnum, 'AppendVarint32', + wire_format.WIRETYPE_VARINT, 0], + ['string', self.encoder.AppendString, 'AppendRawBytes', + wire_format.WIRETYPE_LENGTH_DELIMITED, + "You're in a maze of twisty little passages, all alike."], + # We test zigzag encoding routines more extensively below. + ['sint32', self.encoder.AppendSInt32, 'AppendVarUInt32', + wire_format.WIRETYPE_VARINT, -1, 1], + ['sint64', self.encoder.AppendSInt64, 'AppendVarUInt64', + wire_format.WIRETYPE_VARINT, -1, 1], + ] + # Ensure that we're testing different Encoder methods and using + # different test names in all test cases above. + self.assertEqual(len(scalar_tests), len(set(t[0] for t in scalar_tests))) + self.assertEqual(len(scalar_tests), len(set(t[1] for t in scalar_tests))) + for args in scalar_tests: + self.AppendScalarTestHelper(*args) + + def testAppendGroup(self): + field_number = 23 + # Should first append the start-group marker. + self.mock_stream.AppendVarUInt32( + self.PackTag(field_number, wire_format.WIRETYPE_START_GROUP)) + # Should then serialize itself. + self.mock_message.SerializeToString().AndReturn('foo') + self.mock_stream.AppendRawBytes('foo') + # Should finally append the end-group marker. + self.mock_stream.AppendVarUInt32( + self.PackTag(field_number, wire_format.WIRETYPE_END_GROUP)) + + self.mox.ReplayAll() + self.encoder.AppendGroup(field_number, self.mock_message) + self.mox.VerifyAll() + + def testAppendMessage(self): + field_number = 23 + byte_size = 42 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32( + self.PackTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED)) + # Should then append its length. + self.mock_message.ByteSize().AndReturn(byte_size) + self.mock_stream.AppendVarUInt32(byte_size) + # Should then serialize itself to the encoder. + self.mock_message.SerializeToString().AndReturn('foo') + self.mock_stream.AppendRawBytes('foo') + + self.mox.ReplayAll() + self.encoder.AppendMessage(field_number, self.mock_message) + self.mox.VerifyAll() + + def testAppendMessageSetItem(self): + field_number = 23 + byte_size = 42 + # Should first append the field number and type information. + self.mock_stream.AppendVarUInt32( + self.PackTag(1, wire_format.WIRETYPE_START_GROUP)) + self.mock_stream.AppendVarUInt32( + self.PackTag(2, wire_format.WIRETYPE_VARINT)) + self.mock_stream.AppendVarint32(field_number) + self.mock_stream.AppendVarUInt32( + self.PackTag(3, wire_format.WIRETYPE_LENGTH_DELIMITED)) + # Should then append its length. + self.mock_message.ByteSize().AndReturn(byte_size) + self.mock_stream.AppendVarUInt32(byte_size) + # Should then serialize itself to the encoder. + self.mock_message.SerializeToString().AndReturn('foo') + self.mock_stream.AppendRawBytes('foo') + self.mock_stream.AppendVarUInt32( + self.PackTag(1, wire_format.WIRETYPE_END_GROUP)) + + self.mox.ReplayAll() + self.encoder.AppendMessageSetItem(field_number, self.mock_message) + self.mox.VerifyAll() + + def testAppendSFixed(self): + # Most of our bounds-checking is done in output_stream.py, + # but encoder.py is responsible for transforming signed + # fixed-width integers into unsigned ones, so we test here + # to ensure that we're not losing any entropy when we do + # that conversion. + field_number = 10 + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32, + 10, wire_format.UINT32_MAX + 1) + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed32, + 10, -(1 << 32)) + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64, + 10, wire_format.UINT64_MAX + 1) + self.assertRaises(message.EncodeError, self.encoder.AppendSFixed64, + 10, -(1 << 64)) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py new file mode 100755 index 00000000..02f993f7 --- /dev/null +++ b/python/google/protobuf/internal/generator_test.py @@ -0,0 +1,84 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(robinson): Flesh this out considerably. We focused on reflection_test.py +# first, since it's testing the subtler code, and since it provides decent +# indirect testing of the protocol compiler output. + +"""Unittest that directly tests the output of the pure-Python protocol +compiler. See //net/proto2/internal/reflection_test.py for a test which +further ensures that we can use Python protocol message objects as we expect. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 + + +class GeneratorTest(unittest.TestCase): + + def testNestedMessageDescriptor(self): + field_name = 'optional_nested_message' + proto_type = unittest_pb2.TestAllTypes + self.assertEqual( + proto_type.NestedMessage.DESCRIPTOR, + proto_type.DESCRIPTOR.fields_by_name[field_name].message_type) + + def testEnums(self): + # We test only module-level enums here. + # TODO(robinson): Examine descriptors directly to check + # enum descriptor output. + self.assertEqual(4, unittest_pb2.FOREIGN_FOO) + self.assertEqual(5, unittest_pb2.FOREIGN_BAR) + self.assertEqual(6, unittest_pb2.FOREIGN_BAZ) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(1, proto.FOO) + self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) + self.assertEqual(2, proto.BAR) + self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) + self.assertEqual(3, proto.BAZ) + self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + + def testContainingTypeBehaviorForExtensions(self): + self.assertEqual(unittest_pb2.optional_int32_extension.containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) + self.assertEqual(unittest_pb2.TestRequired.single.containing_type, + unittest_pb2.TestAllExtensions.DESCRIPTOR) + + def testExtensionScope(self): + self.assertEqual(unittest_pb2.optional_int32_extension.extension_scope, + None) + self.assertEqual(unittest_pb2.TestRequired.single.extension_scope, + unittest_pb2.TestRequired.DESCRIPTOR) + + def testIsExtension(self): + self.assertTrue(unittest_pb2.optional_int32_extension.is_extension) + self.assertTrue(unittest_pb2.TestRequired.single.is_extension) + + message_descriptor = unittest_pb2.TestRequired.DESCRIPTOR + non_extension_descriptor = message_descriptor.fields_by_name['a'] + self.assertTrue(not non_extension_descriptor.is_extension) + + def testOptions(self): + proto = unittest_mset_pb2.TestMessageSet() + self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/input_stream.py b/python/google/protobuf/internal/input_stream.py new file mode 100755 index 00000000..9f3b0f5a --- /dev/null +++ b/python/google/protobuf/internal/input_stream.py @@ -0,0 +1,211 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""InputStream is the primitive interface for reading bits from the wire. + +All protocol buffer deserialization can be expressed in terms of +the InputStream primitives provided here. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message +from google.protobuf.internal import wire_format + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by CodedInputStream from the C++ +# proto2 implementation. + + +class InputStream(object): + + """Contains all logic for reading bits, and dealing with stream position. + + If an InputStream method ever raises an exception, the stream is left + in an indeterminate state and is not safe for further use. + """ + + def __init__(self, s): + # What we really want is something like array('B', s), where elements we + # read from the array are already given to us as one-byte integers. BUT + # using array() instead of buffer() would force full string copies to result + # from each GetSubBuffer() call. + # + # So, if the N serialized bytes of a single protocol buffer object are + # split evenly between 2 child messages, and so on recursively, using + # array('B', s) instead of buffer() would incur an additional N*logN bytes + # copied during deserialization. + # + # The higher constant overhead of having to ord() for every byte we read + # from the buffer in _ReadVarintHelper() could definitely lead to worse + # performance in many real-world scenarios, even if the asymptotic + # complexity is better. However, our real answer is that the mythical + # Python/C extension module output mode for the protocol compiler will + # be blazing-fast and will eliminate most use of this class anyway. + self._buffer = buffer(s) + self._pos = 0 + + def EndOfStream(self): + """Returns true iff we're at the end of the stream. + If this returns true, then a call to any other InputStream method + will raise an exception. + """ + return self._pos >= len(self._buffer) + + def Position(self): + """Returns the current position in the stream, or equivalently, the + number of bytes read so far. + """ + return self._pos + + def GetSubBuffer(self, size=None): + """Returns a sequence-like object that represents a portion of our + underlying sequence. + + Position 0 in the returned object corresponds to self.Position() + in this stream. + + If size is specified, then the returned object ends after the + next "size" bytes in this stream. If size is not specified, + then the returned object ends at the end of this stream. + + We guarantee that the returned object R supports the Python buffer + interface (and thus that the call buffer(R) will work). + + Note that the returned buffer is read-only. + + The intended use for this method is for nested-message and nested-group + deserialization, where we want to make a recursive MergeFromString() + call on the portion of the original sequence that contains the serialized + nested message. (And we'd like to do so without making unnecessary string + copies). + + REQUIRES: size is nonnegative. + """ + # Note that buffer() doesn't perform any actual string copy. + if size is None: + return buffer(self._buffer, self._pos) + else: + if size < 0: + raise message.DecodeError('Negative size %d' % size) + return buffer(self._buffer, self._pos, size) + + def SkipBytes(self, num_bytes): + """Skip num_bytes bytes ahead, or go to the end of the stream, whichever + comes first. + + REQUIRES: num_bytes is nonnegative. + """ + if num_bytes < 0: + raise message.DecodeError('Negative num_bytes %d' % num_bytes) + self._pos += num_bytes + self._pos = min(self._pos, len(self._buffer)) + + def ReadString(self, size): + """Reads up to 'size' bytes from the stream, stopping early + only if we reach the end of the stream. Returns the bytes read + as a string. + """ + if size < 0: + raise message.DecodeError('Negative size %d' % size) + s = (self._buffer[self._pos : self._pos + size]) + self._pos += len(s) # Only advance by the number of bytes actually read. + return s + + def ReadLittleEndian32(self): + """Interprets the next 4 bytes of the stream as a little-endian + encoded, unsiged 32-bit integer, and returns that integer. + """ + try: + i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 4]) + self._pos += 4 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadLittleEndian64(self): + """Interprets the next 8 bytes of the stream as a little-endian + encoded, unsiged 64-bit integer, and returns that integer. + """ + try: + i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 8]) + self._pos += 8 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadVarint32(self): + """Reads a varint from the stream, interprets this varint + as a signed, 32-bit integer, and returns the integer. + """ + i = self.ReadVarint64() + if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: + raise message.DecodeError('Value out of range for int32: %d' % i) + return int(i) + + def ReadVarUInt32(self): + """Reads a varint from the stream, interprets this varint + as an unsigned, 32-bit integer, and returns the integer. + """ + i = self.ReadVarUInt64() + if i > wire_format.UINT32_MAX: + raise message.DecodeError('Value out of range for uint32: %d' % i) + return i + + def ReadVarint64(self): + """Reads a varint from the stream, interprets this varint + as a signed, 64-bit integer, and returns the integer. + """ + i = self.ReadVarUInt64() + if i > wire_format.INT64_MAX: + i -= (1 << 64) + return i + + def ReadVarUInt64(self): + """Reads a varint from the stream, interprets this varint + as an unsigned, 64-bit integer, and returns the integer. + """ + i = self._ReadVarintHelper() + if not 0 <= i <= wire_format.UINT64_MAX: + raise message.DecodeError('Value out of range for uint64: %d' % i) + return i + + def _ReadVarintHelper(self): + """Helper for the various varint-reading methods above. + Reads an unsigned, varint-encoded integer from the stream and + returns this integer. + + Does no bounds checking except to ensure that we read at most as many bytes + as could possibly be present in a varint-encoded 64-bit number. + """ + result = 0 + shift = 0 + while 1: + if shift >= 64: + raise message.DecodeError('Too many bytes when decoding varint.') + try: + b = ord(self._buffer[self._pos]) + except IndexError: + raise message.DecodeError('Truncated varint.') + self._pos += 1 + result |= ((b & 0x7f) << shift) + shift += 7 + if not (b & 0x80): + return result diff --git a/python/google/protobuf/internal/input_stream_test.py b/python/google/protobuf/internal/input_stream_test.py new file mode 100755 index 00000000..2d685545 --- /dev/null +++ b/python/google/protobuf/internal/input_stream_test.py @@ -0,0 +1,279 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for google.protobuf.internal.input_stream.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import wire_format +from google.protobuf.internal import input_stream + + +class InputStreamTest(unittest.TestCase): + + def testEndOfStream(self): + stream = input_stream.InputStream('abcd') + self.assertFalse(stream.EndOfStream()) + self.assertEqual('abcd', stream.ReadString(10)) + self.assertTrue(stream.EndOfStream()) + + def testPosition(self): + stream = input_stream.InputStream('abcd') + self.assertEqual(0, stream.Position()) + self.assertEqual(0, stream.Position()) # No side-effects. + stream.ReadString(1) + self.assertEqual(1, stream.Position()) + stream.ReadString(1) + self.assertEqual(2, stream.Position()) + stream.ReadString(10) + self.assertEqual(4, stream.Position()) # Can't go past end of stream. + + def testGetSubBuffer(self): + stream = input_stream.InputStream('abcd') + # Try leaving out the size. + self.assertEqual('abcd', str(stream.GetSubBuffer())) + stream.SkipBytes(1) + # GetSubBuffer() always starts at current size. + self.assertEqual('bcd', str(stream.GetSubBuffer())) + # Try 0-size. + self.assertEqual('', str(stream.GetSubBuffer(0))) + # Negative sizes should raise an error. + self.assertRaises(message.DecodeError, stream.GetSubBuffer, -1) + # Positive sizes should work as expected. + self.assertEqual('b', str(stream.GetSubBuffer(1))) + self.assertEqual('bc', str(stream.GetSubBuffer(2))) + # Sizes longer than remaining bytes in the buffer should + # return the whole remaining buffer. + self.assertEqual('bcd', str(stream.GetSubBuffer(1000))) + + def testSkipBytes(self): + stream = input_stream.InputStream('') + # Skipping bytes when at the end of stream + # should have no effect. + stream.SkipBytes(0) + stream.SkipBytes(1) + stream.SkipBytes(2) + self.assertTrue(stream.EndOfStream()) + self.assertEqual(0, stream.Position()) + + # Try skipping within a stream. + stream = input_stream.InputStream('abcd') + self.assertEqual(0, stream.Position()) + stream.SkipBytes(1) + self.assertEqual(1, stream.Position()) + stream.SkipBytes(10) # Can't skip past the end. + self.assertEqual(4, stream.Position()) + + # Ensure that a negative skip raises an exception. + stream = input_stream.InputStream('abcd') + stream.SkipBytes(1) + self.assertRaises(message.DecodeError, stream.SkipBytes, -1) + + def testReadString(self): + s = 'abcd' + # Also test going past the total stream length. + for i in range(len(s) + 10): + stream = input_stream.InputStream(s) + self.assertEqual(s[:i], stream.ReadString(i)) + self.assertEqual(min(i, len(s)), stream.Position()) + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadString, -1) + + def EnsureFailureOnEmptyStream(self, input_stream_method): + """Helper for integer-parsing tests below. + Ensures that the given InputStream method raises a DecodeError + if called on a stream with no bytes remaining. + """ + stream = input_stream.InputStream('') + self.assertRaises(message.DecodeError, input_stream_method, stream) + + def testReadLittleEndian32(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian32) + s = '' + # Read 0. + s += '\x00\x00\x00\x00' + # Read 1. + s += '\x01\x00\x00\x00' + # Read a bunch of different bytes. + s += '\x01\x02\x03\x04' + # Read max unsigned 32-bit int. + s += '\xff\xff\xff\xff' + # Try a read with fewer than 4 bytes left in the stream. + s += '\x00\x00\x00' + stream = input_stream.InputStream(s) + self.assertEqual(0, stream.ReadLittleEndian32()) + self.assertEqual(4, stream.Position()) + self.assertEqual(1, stream.ReadLittleEndian32()) + self.assertEqual(8, stream.Position()) + self.assertEqual(0x04030201, stream.ReadLittleEndian32()) + self.assertEqual(12, stream.Position()) + self.assertEqual(wire_format.UINT32_MAX, stream.ReadLittleEndian32()) + self.assertEqual(16, stream.Position()) + self.assertRaises(message.DecodeError, stream.ReadLittleEndian32) + + def testReadLittleEndian64(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadLittleEndian64) + s = '' + # Read 0. + s += '\x00\x00\x00\x00\x00\x00\x00\x00' + # Read 1. + s += '\x01\x00\x00\x00\x00\x00\x00\x00' + # Read a bunch of different bytes. + s += '\x01\x02\x03\x04\x05\x06\x07\x08' + # Read max unsigned 64-bit int. + s += '\xff\xff\xff\xff\xff\xff\xff\xff' + # Try a read with fewer than 8 bytes left in the stream. + s += '\x00\x00\x00' + stream = input_stream.InputStream(s) + self.assertEqual(0, stream.ReadLittleEndian64()) + self.assertEqual(8, stream.Position()) + self.assertEqual(1, stream.ReadLittleEndian64()) + self.assertEqual(16, stream.Position()) + self.assertEqual(0x0807060504030201, stream.ReadLittleEndian64()) + self.assertEqual(24, stream.Position()) + self.assertEqual(wire_format.UINT64_MAX, stream.ReadLittleEndian64()) + self.assertEqual(32, stream.Position()) + self.assertRaises(message.DecodeError, stream.ReadLittleEndian64) + + def ReadVarintSuccessTestHelper(self, varints_and_ints, read_method): + """Helper for tests below that test successful reads of various varints. + + Args: + varints_and_ints: Iterable of (str, integer) pairs, where the string + gives the wire encoding and the integer gives the value we expect + to be returned by the read_method upon encountering this string. + read_method: Unbound InputStream method that is capable of reading + the encoded strings provided in the first elements of varints_and_ints. + """ + s = ''.join(s for s, i in varints_and_ints) + stream = input_stream.InputStream(s) + expected_pos = 0 + self.assertEqual(expected_pos, stream.Position()) + for s, expected_int in varints_and_ints: + self.assertEqual(expected_int, read_method(stream)) + expected_pos += len(s) + self.assertEqual(expected_pos, stream.Position()) + + def testReadVarint32Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1), + ('\xff\xff\xff\xff\x07', wire_format.INT32_MAX), + ('\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01', wire_format.INT32_MIN), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarint32) + + def testReadVarint32Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint32) + + # Try and fail to read INT32_MAX + 1. + s = '\x80\x80\x80\x80\x08' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint32) + + # Try and fail to read INT32_MIN - 1. + s = '\xfe\xff\xff\xff\xf7\xff\xff\xff\xff\x01' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint32) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint32) + + def testReadVarUInt32Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\xff\xff\xff\xff\x0f', wire_format.UINT32_MAX), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarUInt32) + + def testReadVarUInt32Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt32) + # Try and fail to read UINT32_MAX + 1 + s = '\x80\x80\x80\x80\x10' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt32) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt32) + + def testReadVarint64Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01', -1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\xff\xff\xff\xff\xff\xff\xff\xff\x7f', wire_format.INT64_MAX), + ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', wire_format.INT64_MIN), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarint64) + + def testReadVarint64Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarint64) + # Try and fail to read something with the mythical 64th bit set. + s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint64) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarint64) + + def testReadVarUInt64Success(self): + varints_and_ints = [ + ('\x00', 0), + ('\x01', 1), + ('\x7f', 127), + ('\x80\x01', 128), + ('\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01', 1 << 63), + ] + self.ReadVarintSuccessTestHelper(varints_and_ints, + input_stream.InputStream.ReadVarUInt64) + + def testReadVarUInt64Failure(self): + self.EnsureFailureOnEmptyStream(input_stream.InputStream.ReadVarUInt64) + # Try and fail to read something with the mythical 64th bit set. + s = '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt64) + + # Try and fail to read something that looks like + # a varint with more than 10 bytes. + s = '\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\x00' + stream = input_stream.InputStream(s) + self.assertRaises(message.DecodeError, stream.ReadVarUInt64) + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/message_listener.py b/python/google/protobuf/internal/message_listener.py new file mode 100755 index 00000000..3747909e --- /dev/null +++ b/python/google/protobuf/internal/message_listener.py @@ -0,0 +1,55 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Defines a listener interface for observing certain +state transitions on Message objects. + +Also defines a null implementation of this interface. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + + +class MessageListener(object): + + """Listens for transitions to nonempty and for invalidations of cached + byte sizes. Meant to be registered via Message._SetListener(). + """ + + def TransitionToNonempty(self): + """Called the *first* time that this message becomes nonempty. + Implementations are free (but not required) to call this method multiple + times after the message has become nonempty. + """ + raise NotImplementedError + + def ByteSizeDirty(self): + """Called *every* time the cached byte size value + for this object is invalidated (transitions from being + "clean" to "dirty"). + """ + raise NotImplementedError + + +class NullMessageListener(object): + + """No-op MessageListener implementation.""" + + def TransitionToNonempty(self): + pass + + def ByteSizeDirty(self): + pass diff --git a/python/google/protobuf/internal/more_extensions.proto b/python/google/protobuf/internal/more_extensions.proto new file mode 100644 index 00000000..48df6f55 --- /dev/null +++ b/python/google/protobuf/internal/more_extensions.proto @@ -0,0 +1,44 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. +// http://code.google.com/p/protobuf/ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Author: robinson@google.com (Will Robinson) + + +package google.protobuf.internal; + + +message TopLevelMessage { + optional ExtendedMessage submessage = 1; +} + + +message ExtendedMessage { + extensions 1 to max; +} + + +message ForeignMessage { + optional int32 foreign_message_int = 1; +} + + +extend ExtendedMessage { + optional int32 optional_int_extension = 1; + optional ForeignMessage optional_message_extension = 2; + + repeated int32 repeated_int_extension = 3; + repeated ForeignMessage repeated_message_extension = 4; +} diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto new file mode 100644 index 00000000..bfa12273 --- /dev/null +++ b/python/google/protobuf/internal/more_messages.proto @@ -0,0 +1,37 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. +// http://code.google.com/p/protobuf/ +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Author: robinson@google.com (Will Robinson) + + +package google.protobuf.internal; + +// A message where tag numbers are listed out of order, to allow us to test our +// canonicalization of serialized output, which should always be in tag order. +// We also mix in some extensions for extra fun. +message OutOfOrderFields { + optional sint32 optional_sint32 = 5; + extensions 4 to 4; + optional uint32 optional_uint32 = 3; + extensions 2 to 2; + optional int32 optional_int32 = 1; +}; + + +extend OutOfOrderFields { + optional uint64 optional_uint64 = 4; + optional int64 optional_int64 = 2; +} diff --git a/python/google/protobuf/internal/output_stream.py b/python/google/protobuf/internal/output_stream.py new file mode 100755 index 00000000..767e9725 --- /dev/null +++ b/python/google/protobuf/internal/output_stream.py @@ -0,0 +1,112 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""OutputStream is the primitive interface for sticking bits on the wire. + +All protocol buffer serialization can be expressed in terms of +the OutputStream primitives provided here. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import array +import struct +from google.protobuf import message +from google.protobuf.internal import wire_format + + + +# Note that much of this code is ported from //net/proto/ProtocolBuffer, and +# that the interface is strongly inspired by CodedOutputStream from the C++ +# proto2 implementation. + + +class OutputStream(object): + + """Contains all logic for writing bits, and ToString() to get the result.""" + + def __init__(self): + self._buffer = array.array('B') + + def AppendRawBytes(self, raw_bytes): + """Appends raw_bytes to our internal buffer.""" + self._buffer.fromstring(raw_bytes) + + def AppendLittleEndian32(self, unsigned_value): + """Appends an unsigned 32-bit integer to the internal buffer, + in little-endian byte order. + """ + if not 0 <= unsigned_value <= wire_format.UINT32_MAX: + raise message.EncodeError( + 'Unsigned 32-bit out of range: %d' % unsigned_value) + self._buffer.fromstring(struct.pack( + wire_format.FORMAT_UINT32_LITTLE_ENDIAN, unsigned_value)) + + def AppendLittleEndian64(self, unsigned_value): + """Appends an unsigned 64-bit integer to the internal buffer, + in little-endian byte order. + """ + if not 0 <= unsigned_value <= wire_format.UINT64_MAX: + raise message.EncodeError( + 'Unsigned 64-bit out of range: %d' % unsigned_value) + self._buffer.fromstring(struct.pack( + wire_format.FORMAT_UINT64_LITTLE_ENDIAN, unsigned_value)) + + def AppendVarint32(self, value): + """Appends a signed 32-bit integer to the internal buffer, + encoded as a varint. (Note that a negative varint32 will + always require 10 bytes of space.) + """ + if not wire_format.INT32_MIN <= value <= wire_format.INT32_MAX: + raise message.EncodeError('Value out of range: %d' % value) + self.AppendVarint64(value) + + def AppendVarUInt32(self, value): + """Appends an unsigned 32-bit integer to the internal buffer, + encoded as a varint. + """ + if not 0 <= value <= wire_format.UINT32_MAX: + raise message.EncodeError('Value out of range: %d' % value) + self.AppendVarUInt64(value) + + def AppendVarint64(self, value): + """Appends a signed 64-bit integer to the internal buffer, + encoded as a varint. + """ + if not wire_format.INT64_MIN <= value <= wire_format.INT64_MAX: + raise message.EncodeError('Value out of range: %d' % value) + if value < 0: + value += (1 << 64) + self.AppendVarUInt64(value) + + def AppendVarUInt64(self, unsigned_value): + """Appends an unsigned 64-bit integer to the internal buffer, + encoded as a varint. + """ + if not 0 <= unsigned_value <= wire_format.UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % unsigned_value) + while True: + bits = unsigned_value & 0x7f + unsigned_value >>= 7 + if unsigned_value: + bits |= 0x80 + self._buffer.append(bits) + if not unsigned_value: + break + + def ToString(self): + """Returns a string containing the bytes in our internal buffer.""" + return self._buffer.tostring() diff --git a/python/google/protobuf/internal/output_stream_test.py b/python/google/protobuf/internal/output_stream_test.py new file mode 100755 index 00000000..026f6161 --- /dev/null +++ b/python/google/protobuf/internal/output_stream_test.py @@ -0,0 +1,162 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for google.protobuf.internal.output_stream.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import output_stream +from google.protobuf.internal import wire_format + + +class OutputStreamTest(unittest.TestCase): + + def setUp(self): + self.stream = output_stream.OutputStream() + + def testAppendRawBytes(self): + # Empty string. + self.stream.AppendRawBytes('') + self.assertEqual('', self.stream.ToString()) + + # Nonempty string. + self.stream.AppendRawBytes('abc') + self.assertEqual('abc', self.stream.ToString()) + + # Ensure that we're actually appending. + self.stream.AppendRawBytes('def') + self.assertEqual('abcdef', self.stream.ToString()) + + def AppendNumericTestHelper(self, append_fn, values_and_strings): + """For each (value, expected_string) pair in values_and_strings, + calls an OutputStream.Append*(value) method on an OutputStream and ensures + that the string written to that stream matches expected_string. + + Args: + append_fn: Unbound OutputStream method that takes an integer or + long value as input. + values_and_strings: Iterable of (value, expected_string) pairs. + """ + for conversion in (int, long): + for value, string in values_and_strings: + stream = output_stream.OutputStream() + expected_string = '' + append_fn(stream, conversion(value)) + expected_string += string + self.assertEqual(expected_string, stream.ToString()) + + def AppendOverflowTestHelper(self, append_fn, value): + """Calls an OutputStream.Append*(value) method and asserts + that the method raises message.EncodeError. + + Args: + append_fn: Unbound OutputStream method that takes an integer or + long value as input. + value: Value to pass to append_fn which should cause an + message.EncodeError. + """ + stream = output_stream.OutputStream() + self.assertRaises(message.EncodeError, append_fn, stream, value) + + def testAppendLittleEndian32(self): + append_fn = output_stream.OutputStream.AppendLittleEndian32 + values_and_expected_strings = [ + (0, '\x00\x00\x00\x00'), + (1, '\x01\x00\x00\x00'), + ((1 << 32) - 1, '\xff\xff\xff\xff'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, 1 << 32) + self.AppendOverflowTestHelper(append_fn, -1) + + def testAppendLittleEndian64(self): + append_fn = output_stream.OutputStream.AppendLittleEndian64 + values_and_expected_strings = [ + (0, '\x00\x00\x00\x00\x00\x00\x00\x00'), + (1, '\x01\x00\x00\x00\x00\x00\x00\x00'), + ((1 << 64) - 1, '\xff\xff\xff\xff\xff\xff\xff\xff'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, 1 << 64) + self.AppendOverflowTestHelper(append_fn, -1) + + def testAppendVarint32(self): + append_fn = output_stream.OutputStream.AppendVarint32 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), + (wire_format.INT32_MAX, '\xff\xff\xff\xff\x07'), + (wire_format.INT32_MIN, '\x80\x80\x80\x80\xf8\xff\xff\xff\xff\x01'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MAX + 1) + self.AppendOverflowTestHelper(append_fn, wire_format.INT32_MIN - 1) + + def testAppendVarUInt32(self): + append_fn = output_stream.OutputStream.AppendVarUInt32 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (wire_format.UINT32_MAX, '\xff\xff\xff\xff\x0f'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, -1) + self.AppendOverflowTestHelper(append_fn, wire_format.UINT32_MAX + 1) + + def testAppendVarint64(self): + append_fn = output_stream.OutputStream.AppendVarint64 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (-1, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), + (wire_format.INT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\x7f'), + (wire_format.INT64_MIN, '\x80\x80\x80\x80\x80\x80\x80\x80\x80\x01'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MAX + 1) + self.AppendOverflowTestHelper(append_fn, wire_format.INT64_MIN - 1) + + def testAppendVarUInt64(self): + append_fn = output_stream.OutputStream.AppendVarUInt64 + values_and_expected_strings = [ + (0, '\x00'), + (1, '\x01'), + (127, '\x7f'), + (128, '\x80\x01'), + (wire_format.UINT64_MAX, '\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01'), + ] + self.AppendNumericTestHelper(append_fn, values_and_expected_strings) + + self.AppendOverflowTestHelper(append_fn, -1) + self.AppendOverflowTestHelper(append_fn, wire_format.UINT64_MAX + 1) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py new file mode 100755 index 00000000..5947f97a --- /dev/null +++ b/python/google/protobuf/internal/reflection_test.py @@ -0,0 +1,1300 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Unittest for reflection.py, which also indirectly tests the output of the +pure-Python protocol compiler. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import operator + +import unittest +# TODO(robinson): When we split this test in two, only some of these imports +# will be necessary in each test. +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf import descriptor +from google.protobuf import message +from google.protobuf import reflection +from google.protobuf.internal import more_extensions_pb2 +from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import wire_format +from google.protobuf.internal import test_util +from google.protobuf.internal import decoder + + +class RefectionTest(unittest.TestCase): + + def testSimpleHasBits(self): + # Test a scalar. + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_int32')) + self.assertEqual(0, proto.optional_int32) + # HasField() shouldn't be true if all we've done is + # read the default value. + self.assertTrue(not proto.HasField('optional_int32')) + proto.optional_int32 = 1 + # Setting a value however *should* set the "has" bit. + self.assertTrue(proto.HasField('optional_int32')) + proto.ClearField('optional_int32') + # And clearing that value should unset the "has" bit. + self.assertTrue(not proto.HasField('optional_int32')) + + def testHasBitsWithSinglyNestedScalar(self): + # Helper used to test foreign messages and groups. + # + # composite_field_name should be the name of a non-repeated + # composite (i.e., foreign or group) field in TestAllTypes, + # and scalar_field_name should be the name of an integer-valued + # scalar field within that composite. + # + # I never thought I'd miss C++ macros and templates so much. :( + # This helper is semantically just: + # + # assert proto.composite_field.scalar_field == 0 + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + # + # proto.composite_field.scalar_field = 10 + # old_composite_field = proto.composite_field + # + # assert proto.composite_field.scalar_field == 10 + # assert proto.composite_field.HasField('scalar_field') + # assert proto.HasField('composite_field') + # + # proto.ClearField('composite_field') + # + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + # assert proto.composite_field.scalar_field == 0 + # + # # Now ensure that ClearField('composite_field') disconnected + # # the old field object from the object tree... + # assert old_composite_field is not proto.composite_field + # old_composite_field.scalar_field = 20 + # assert not proto.composite_field.HasField('scalar_field') + # assert not proto.HasField('composite_field') + def TestCompositeHasBits(composite_field_name, scalar_field_name): + proto = unittest_pb2.TestAllTypes() + # First, check that we can get the scalar value, and see that it's the + # default (0), but that proto.HasField('omposite') and + # proto.composite.HasField('scalar') will still return False. + composite_field = getattr(proto, composite_field_name) + original_scalar_value = getattr(composite_field, scalar_field_name) + self.assertEqual(0, original_scalar_value) + # Assert that the composite object does not "have" the scalar. + self.assertTrue(not composite_field.HasField(scalar_field_name)) + # Assert that proto does not "have" the composite field. + self.assertTrue(not proto.HasField(composite_field_name)) + + # Now set the scalar within the composite field. Ensure that the setting + # is reflected, and that proto.HasField('composite') and + # proto.composite.HasField('scalar') now both return True. + new_val = 20 + setattr(composite_field, scalar_field_name, new_val) + self.assertEqual(new_val, getattr(composite_field, scalar_field_name)) + # Hold on to a reference to the current composite_field object. + old_composite_field = composite_field + # Assert that the has methods now return true. + self.assertTrue(composite_field.HasField(scalar_field_name)) + self.assertTrue(proto.HasField(composite_field_name)) + + # Now call the clear method... + proto.ClearField(composite_field_name) + + # ...and ensure that the "has" bits are all back to False... + composite_field = getattr(proto, composite_field_name) + self.assertTrue(not composite_field.HasField(scalar_field_name)) + self.assertTrue(not proto.HasField(composite_field_name)) + # ...and ensure that the scalar field has returned to its default. + self.assertEqual(0, getattr(composite_field, scalar_field_name)) + + # Finally, ensure that modifications to the old composite field object + # don't have any effect on the parent. + # + # (NOTE that when we clear the composite field in the parent, we actually + # don't recursively clear down the tree. Instead, we just disconnect the + # cleared composite from the tree.) + self.assertTrue(old_composite_field is not composite_field) + setattr(old_composite_field, scalar_field_name, new_val) + self.assertTrue(not composite_field.HasField(scalar_field_name)) + self.assertTrue(not proto.HasField(composite_field_name)) + self.assertEqual(0, getattr(composite_field, scalar_field_name)) + + # Test simple, single-level nesting when we set a scalar. + TestCompositeHasBits('optionalgroup', 'a') + TestCompositeHasBits('optional_nested_message', 'bb') + TestCompositeHasBits('optional_foreign_message', 'c') + TestCompositeHasBits('optional_import_message', 'd') + + def testReferencesToNestedMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + del proto + # A previous version had a bug where this would raise an exception when + # hitting a now-dead weak reference. + nested.bb = 23 + + def testDisconnectingNestedMessageBeforeSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testHasBitsWhenModifyingRepeatedFields(self): + # Test nesting when we add an element to a repeated field in a submessage. + proto = unittest_pb2.TestNestedMessageHasBits() + proto.optional_nested_message.nestedmessage_repeated_int32.append(5) + self.assertEqual( + [5], proto.optional_nested_message.nestedmessage_repeated_int32) + self.assertTrue(proto.HasField('optional_nested_message')) + + # Do the same test, but with a repeated composite field within the + # submessage. + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.optional_nested_message.nestedmessage_repeated_foreignmessage.add() + self.assertTrue(proto.HasField('optional_nested_message')) + + def testHasBitsForManyLevelsOfNesting(self): + # Test nesting many levels deep. + recursive_proto = unittest_pb2.TestMutualRecursionA() + self.assertTrue(not recursive_proto.HasField('bb')) + self.assertEqual(0, recursive_proto.bb.a.bb.a.bb.optional_int32) + self.assertTrue(not recursive_proto.HasField('bb')) + recursive_proto.bb.a.bb.a.bb.optional_int32 = 5 + self.assertEqual(5, recursive_proto.bb.a.bb.a.bb.optional_int32) + self.assertTrue(recursive_proto.HasField('bb')) + self.assertTrue(recursive_proto.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.HasField('bb')) + self.assertTrue(recursive_proto.bb.a.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.bb.a.HasField('bb')) + self.assertTrue(not recursive_proto.bb.a.bb.a.bb.HasField('a')) + self.assertTrue(recursive_proto.bb.a.bb.a.bb.HasField('optional_int32')) + + def testSingularListFields(self): + proto = unittest_pb2.TestAllTypes() + proto.optional_fixed32 = 1 + proto.optional_int32 = 5 + proto.optional_string = 'foo' + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 5), + (proto.DESCRIPTOR.fields_by_name['optional_fixed32'], 1), + (proto.DESCRIPTOR.fields_by_name['optional_string' ], 'foo') ], + proto.ListFields()) + + def testRepeatedListFields(self): + proto = unittest_pb2.TestAllTypes() + proto.repeated_fixed32.append(1) + proto.repeated_int32.append(5) + proto.repeated_int32.append(11) + proto.repeated_string.append('foo') + proto.repeated_string.append('bar') + proto.repeated_string.append('baz') + proto.optional_int32 = 21 + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['optional_int32' ], 21), + (proto.DESCRIPTOR.fields_by_name['repeated_int32' ], [5, 11]), + (proto.DESCRIPTOR.fields_by_name['repeated_fixed32'], [1]), + (proto.DESCRIPTOR.fields_by_name['repeated_string' ], + ['foo', 'bar', 'baz']) ], + proto.ListFields()) + + def testSingularListExtensions(self): + proto = unittest_pb2.TestAllExtensions() + proto.Extensions[unittest_pb2.optional_fixed32_extension] = 1 + proto.Extensions[unittest_pb2.optional_int32_extension ] = 5 + proto.Extensions[unittest_pb2.optional_string_extension ] = 'foo' + self.assertEqual( + [ (unittest_pb2.optional_int32_extension , 5), + (unittest_pb2.optional_fixed32_extension, 1), + (unittest_pb2.optional_string_extension , 'foo') ], + proto.ListFields()) + + def testRepeatedListExtensions(self): + proto = unittest_pb2.TestAllExtensions() + proto.Extensions[unittest_pb2.repeated_fixed32_extension].append(1) + proto.Extensions[unittest_pb2.repeated_int32_extension ].append(5) + proto.Extensions[unittest_pb2.repeated_int32_extension ].append(11) + proto.Extensions[unittest_pb2.repeated_string_extension ].append('foo') + proto.Extensions[unittest_pb2.repeated_string_extension ].append('bar') + proto.Extensions[unittest_pb2.repeated_string_extension ].append('baz') + proto.Extensions[unittest_pb2.optional_int32_extension ] = 21 + self.assertEqual( + [ (unittest_pb2.optional_int32_extension , 21), + (unittest_pb2.repeated_int32_extension , [5, 11]), + (unittest_pb2.repeated_fixed32_extension, [1]), + (unittest_pb2.repeated_string_extension , ['foo', 'bar', 'baz']) ], + proto.ListFields()) + + def testListFieldsAndExtensions(self): + proto = unittest_pb2.TestFieldOrderings() + test_util.SetAllFieldsAndExtensions(proto) + unittest_pb2.my_extension_int + self.assertEqual( + [ (proto.DESCRIPTOR.fields_by_name['my_int' ], 1), + (unittest_pb2.my_extension_int , 23), + (proto.DESCRIPTOR.fields_by_name['my_string'], 'foo'), + (unittest_pb2.my_extension_string , 'bar'), + (proto.DESCRIPTOR.fields_by_name['my_float' ], 1.0) ], + proto.ListFields()) + + def testDefaultValues(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.optional_int32) + self.assertEqual(0, proto.optional_int64) + self.assertEqual(0, proto.optional_uint32) + self.assertEqual(0, proto.optional_uint64) + self.assertEqual(0, proto.optional_sint32) + self.assertEqual(0, proto.optional_sint64) + self.assertEqual(0, proto.optional_fixed32) + self.assertEqual(0, proto.optional_fixed64) + self.assertEqual(0, proto.optional_sfixed32) + self.assertEqual(0, proto.optional_sfixed64) + self.assertEqual(0.0, proto.optional_float) + self.assertEqual(0.0, proto.optional_double) + self.assertEqual(False, proto.optional_bool) + self.assertEqual('', proto.optional_string) + self.assertEqual('', proto.optional_bytes) + + self.assertEqual(41, proto.default_int32) + self.assertEqual(42, proto.default_int64) + self.assertEqual(43, proto.default_uint32) + self.assertEqual(44, proto.default_uint64) + self.assertEqual(-45, proto.default_sint32) + self.assertEqual(46, proto.default_sint64) + self.assertEqual(47, proto.default_fixed32) + self.assertEqual(48, proto.default_fixed64) + self.assertEqual(49, proto.default_sfixed32) + self.assertEqual(-50, proto.default_sfixed64) + self.assertEqual(51.5, proto.default_float) + self.assertEqual(52e3, proto.default_double) + self.assertEqual(True, proto.default_bool) + self.assertEqual('hello', proto.default_string) + self.assertEqual('world', proto.default_bytes) + self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) + self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) + self.assertEqual(unittest_import_pb2.IMPORT_BAR, + proto.default_import_enum) + + def testHasFieldWithUnknownFieldName(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, proto.HasField, 'nonexistent_field') + + def testClearFieldWithUnknownFieldName(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + + def testDisallowedAssignments(self): + # It's illegal to assign values directly to repeated fields + # or to nonrepeated composite fields. Ensure that this fails. + proto = unittest_pb2.TestAllTypes() + # Repeated fields. + self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', 10) + # Lists shouldn't work, either. + self.assertRaises(AttributeError, setattr, proto, 'repeated_int32', [10]) + # Composite fields. + self.assertRaises(AttributeError, setattr, proto, + 'optional_nested_message', 23) + # proto.nonexistent_field = 23 should fail as well. + self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) + + # TODO(robinson): Add type-safety check for enums. + def testSingleScalarTypeSafety(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) + self.assertRaises(TypeError, setattr, proto, 'optional_int32', 'foo') + self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) + self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + + def testSingleScalarBoundsChecking(self): + def TestMinAndMaxIntegers(field_name, expected_min, expected_max): + pb = unittest_pb2.TestAllTypes() + setattr(pb, field_name, expected_min) + setattr(pb, field_name, expected_max) + self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) + self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) + + TestMinAndMaxIntegers('optional_int32', -(1 << 31), (1 << 31) - 1) + TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) + TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) + TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) + TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) + + def testRepeatedScalarTypeSafety(self): + proto = unittest_pb2.TestAllTypes() + self.assertRaises(TypeError, proto.repeated_int32.append, 1.1) + self.assertRaises(TypeError, proto.repeated_int32.append, 'foo') + self.assertRaises(TypeError, proto.repeated_string, 10) + self.assertRaises(TypeError, proto.repeated_bytes, 10) + + proto.repeated_int32.append(10) + proto.repeated_int32[0] = 23 + self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) + self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + + def testSingleScalarGettersAndSetters(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.optional_int32) + proto.optional_int32 = 1 + self.assertEqual(1, proto.optional_int32) + # TODO(robinson): Test all other scalar field types. + + def testSingleScalarClearField(self): + proto = unittest_pb2.TestAllTypes() + # Should be allowed to clear something that's not there (a no-op). + proto.ClearField('optional_int32') + proto.optional_int32 = 1 + self.assertTrue(proto.HasField('optional_int32')) + proto.ClearField('optional_int32') + self.assertEqual(0, proto.optional_int32) + self.assertTrue(not proto.HasField('optional_int32')) + # TODO(robinson): Test all other scalar field types. + + def testEnums(self): + proto = unittest_pb2.TestAllTypes() + self.assertEqual(1, proto.FOO) + self.assertEqual(1, unittest_pb2.TestAllTypes.FOO) + self.assertEqual(2, proto.BAR) + self.assertEqual(2, unittest_pb2.TestAllTypes.BAR) + self.assertEqual(3, proto.BAZ) + self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + + def testRepeatedScalars(self): + proto = unittest_pb2.TestAllTypes() + + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(5); + proto.repeated_int32.append(10); + self.assertTrue(proto.repeated_int32) + self.assertEqual(2, len(proto.repeated_int32)) + + self.assertEqual([5, 10], proto.repeated_int32) + self.assertEqual(5, proto.repeated_int32[0]) + self.assertEqual(10, proto.repeated_int32[-1]) + # Test out-of-bounds indices. + self.assertRaises(IndexError, proto.repeated_int32.__getitem__, 1234) + self.assertRaises(IndexError, proto.repeated_int32.__getitem__, -1234) + # Test incorrect types passed to __getitem__. + self.assertRaises(TypeError, proto.repeated_int32.__getitem__, 'foo') + self.assertRaises(TypeError, proto.repeated_int32.__getitem__, None) + + # Test that we can use the field as an iterator. + result = [] + for i in proto.repeated_int32: + result.append(i) + self.assertEqual([5, 10], result) + + # Test clearing. + proto.ClearField('repeated_int32') + self.assertTrue(not proto.repeated_int32) + self.assertEqual(0, len(proto.repeated_int32)) + + def testRepeatedComposites(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.repeated_nested_message) + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + m1 = proto.repeated_nested_message.add() + self.assertTrue(proto.repeated_nested_message) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertTrue(m0 is proto.repeated_nested_message[0]) + self.assertTrue(m1 is proto.repeated_nested_message[1]) + self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) + + # Test out-of-bounds indices. + self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, + 1234) + self.assertRaises(IndexError, proto.repeated_nested_message.__getitem__, + -1234) + + # Test incorrect types passed to __getitem__. + self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, + 'foo') + self.assertRaises(TypeError, proto.repeated_nested_message.__getitem__, + None) + + # Test that we can use the field as an iterator. + result = [] + for i in proto.repeated_nested_message: + result.append(i) + self.assertEqual(2, len(result)) + self.assertTrue(m0 is result[0]) + self.assertTrue(m1 is result[1]) + + # Test clearing. + proto.ClearField('repeated_nested_message') + self.assertTrue(not proto.repeated_nested_message) + self.assertEqual(0, len(proto.repeated_nested_message)) + + def testHandWrittenReflection(self): + # TODO(robinson): We probably need a better way to specify + # protocol types by hand. But then again, this isn't something + # we expect many people to do. Hmm. + FieldDescriptor = descriptor.FieldDescriptor + foo_field_descriptor = FieldDescriptor( + name='foo_field', full_name='MyProto.foo_field', + index=0, number=1, type=FieldDescriptor.TYPE_INT64, + cpp_type=FieldDescriptor.CPPTYPE_INT64, + label=FieldDescriptor.LABEL_OPTIONAL, default_value=0, + containing_type=None, message_type=None, enum_type=None, + is_extension=False, extension_scope=None, + options=descriptor_pb2.FieldOptions()) + mydescriptor = descriptor.Descriptor( + name='MyProto', full_name='MyProto', filename='ignored', + containing_type=None, nested_types=[], enum_types=[], + fields=[foo_field_descriptor], extensions=[], + options=descriptor_pb2.MessageOptions()) + class MyProtoClass(message.Message): + DESCRIPTOR = mydescriptor + __metaclass__ = reflection.GeneratedProtocolMessageType + myproto_instance = MyProtoClass() + self.assertEqual(0, myproto_instance.foo_field) + self.assertTrue(not myproto_instance.HasField('foo_field')) + myproto_instance.foo_field = 23 + self.assertEqual(23, myproto_instance.foo_field) + self.assertTrue(myproto_instance.HasField('foo_field')) + + def testTopLevelExtensionsForOptionalScalar(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.optional_int32_extension + self.assertTrue(not extendee_proto.HasExtension(extension)) + self.assertEqual(0, extendee_proto.Extensions[extension]) + # As with normal scalar fields, just doing a read doesn't actually set the + # "has" bit. + self.assertTrue(not extendee_proto.HasExtension(extension)) + # Actually set the thing. + extendee_proto.Extensions[extension] = 23 + self.assertEqual(23, extendee_proto.Extensions[extension]) + self.assertTrue(extendee_proto.HasExtension(extension)) + # Ensure that clearing works as well. + extendee_proto.ClearExtension(extension) + self.assertEqual(0, extendee_proto.Extensions[extension]) + self.assertTrue(not extendee_proto.HasExtension(extension)) + + def testTopLevelExtensionsForRepeatedScalar(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.repeated_string_extension + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + extendee_proto.Extensions[extension].append('foo') + self.assertEqual(['foo'], extendee_proto.Extensions[extension]) + string_list = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + self.assertTrue(string_list is not extendee_proto.Extensions[extension]) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testTopLevelExtensionsForOptionalMessage(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.optional_foreign_message_extension + self.assertTrue(not extendee_proto.HasExtension(extension)) + self.assertEqual(0, extendee_proto.Extensions[extension].c) + # As with normal (non-extension) fields, merely reading from the + # thing shouldn't set the "has" bit. + self.assertTrue(not extendee_proto.HasExtension(extension)) + extendee_proto.Extensions[extension].c = 23 + self.assertEqual(23, extendee_proto.Extensions[extension].c) + self.assertTrue(extendee_proto.HasExtension(extension)) + # Save a reference here. + foreign_message = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertTrue(foreign_message is not extendee_proto.Extensions[extension]) + # Setting a field on foreign_message now shouldn't set + # any "has" bits on extendee_proto. + foreign_message.c = 42 + self.assertEqual(42, foreign_message.c) + self.assertTrue(foreign_message.HasField('c')) + self.assertTrue(not extendee_proto.HasExtension(extension)) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testTopLevelExtensionsForRepeatedMessage(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.repeatedgroup_extension + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + group = extendee_proto.Extensions[extension].add() + group.a = 23 + self.assertEqual(23, extendee_proto.Extensions[extension][0].a) + group.a = 42 + self.assertEqual(42, extendee_proto.Extensions[extension][0].a) + group_list = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + self.assertEqual(0, len(extendee_proto.Extensions[extension])) + self.assertTrue(group_list is not extendee_proto.Extensions[extension]) + # Shouldn't be allowed to do Extensions[extension] = 'a' + self.assertRaises(TypeError, operator.setitem, extendee_proto.Extensions, + extension, 'a') + + def testNestedExtensions(self): + extendee_proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.single + + # We just test the non-repeated case. + self.assertTrue(not extendee_proto.HasExtension(extension)) + required = extendee_proto.Extensions[extension] + self.assertEqual(0, required.a) + self.assertTrue(not extendee_proto.HasExtension(extension)) + required.a = 23 + self.assertEqual(23, extendee_proto.Extensions[extension].a) + self.assertTrue(extendee_proto.HasExtension(extension)) + extendee_proto.ClearExtension(extension) + self.assertTrue(required is not extendee_proto.Extensions[extension]) + self.assertTrue(not extendee_proto.HasExtension(extension)) + + # If message A directly contains message B, and + # a.HasField('b') is currently False, then mutating any + # extension in B should change a.HasField('b') to True + # (and so on up the object tree). + def testHasBitsForAncestorsOfExtendedMessage(self): + # Optional scalar extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension] = 23 + self.assertEqual(23, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertTrue(toplevel.HasField('submessage')) + + # Repeated scalar extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual([], toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension]) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension].append(23) + self.assertEqual([23], toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_int_extension]) + self.assertTrue(toplevel.HasField('submessage')) + + # Optional message extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int) + self.assertTrue(not toplevel.HasField('submessage')) + toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int = 23 + self.assertEqual(23, toplevel.submessage.Extensions[ + more_extensions_pb2.optional_message_extension].foreign_message_int) + self.assertTrue(toplevel.HasField('submessage')) + + # Repeated message extension. + toplevel = more_extensions_pb2.TopLevelMessage() + self.assertTrue(not toplevel.HasField('submessage')) + self.assertEqual(0, len(toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension])) + self.assertTrue(not toplevel.HasField('submessage')) + foreign = toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension].add() + self.assertTrue(foreign is toplevel.submessage.Extensions[ + more_extensions_pb2.repeated_message_extension][0]) + self.assertTrue(toplevel.HasField('submessage')) + + def testDisconnectionAfterClearingEmptyMessage(self): + toplevel = more_extensions_pb2.TopLevelMessage() + extendee_proto = toplevel.submessage + extension = more_extensions_pb2.optional_message_extension + extension_proto = extendee_proto.Extensions[extension] + extendee_proto.ClearExtension(extension) + extension_proto.foreign_message_int = 23 + + self.assertTrue(not toplevel.HasField('submessage')) + self.assertTrue(extension_proto is not extendee_proto.Extensions[extension]) + + def testExtensionFailureModes(self): + extendee_proto = unittest_pb2.TestAllExtensions() + + # Try non-extension-handle arguments to HasExtension, + # ClearExtension(), and Extensions[]... + self.assertRaises(KeyError, extendee_proto.HasExtension, 1234) + self.assertRaises(KeyError, extendee_proto.ClearExtension, 1234) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, 1234) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, 1234, 5) + + # Try something that *is* an extension handle, just not for + # this message... + unknown_handle = more_extensions_pb2.optional_int_extension + self.assertRaises(KeyError, extendee_proto.HasExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.ClearExtension, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__getitem__, + unknown_handle) + self.assertRaises(KeyError, extendee_proto.Extensions.__setitem__, + unknown_handle, 5) + + # Try call HasExtension() with a valid handle, but for a + # *repeated* field. (Just as with non-extension repeated + # fields, Has*() isn't supported for extension repeated fields). + self.assertRaises(KeyError, extendee_proto.HasExtension, + unittest_pb2.repeated_string_extension) + + def testCopyFrom(self): + # TODO(robinson): Implement. + pass + + def testClear(self): + proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto) + # Clear the message. + proto.Clear() + self.assertEquals(proto.ByteSize(), 0) + empty_proto = unittest_pb2.TestAllTypes() + self.assertEquals(proto, empty_proto) + + # Test if extensions which were set are cleared. + proto = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(proto) + # Clear the message. + proto.Clear() + self.assertEquals(proto.ByteSize(), 0) + empty_proto = unittest_pb2.TestAllExtensions() + self.assertEquals(proto, empty_proto) + + def testIsInitialized(self): + # Trivial cases - all optional fields and extensions. + proto = unittest_pb2.TestAllTypes() + self.assertTrue(proto.IsInitialized()) + proto = unittest_pb2.TestAllExtensions() + self.assertTrue(proto.IsInitialized()) + + # The case of uninitialized required fields. + proto = unittest_pb2.TestRequired() + self.assertFalse(proto.IsInitialized()) + proto.a = proto.b = proto.c = 2 + self.assertTrue(proto.IsInitialized()) + + # The case of uninitialized submessage. + proto = unittest_pb2.TestRequiredForeign() + self.assertTrue(proto.IsInitialized()) + proto.optional_message.a = 1 + self.assertFalse(proto.IsInitialized()) + proto.optional_message.b = 0 + proto.optional_message.c = 0 + self.assertTrue(proto.IsInitialized()) + + # Uninitialized repeated submessage. + message1 = proto.repeated_message.add() + self.assertFalse(proto.IsInitialized()) + message1.a = message1.b = message1.c = 0 + self.assertTrue(proto.IsInitialized()) + + # Uninitialized repeated group in an extension. + proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.multi + message1 = proto.Extensions[extension].add() + message2 = proto.Extensions[extension].add() + self.assertFalse(proto.IsInitialized()) + message1.a = 1 + message1.b = 1 + message1.c = 1 + self.assertFalse(proto.IsInitialized()) + message2.a = 2 + message2.b = 2 + message2.c = 2 + self.assertTrue(proto.IsInitialized()) + + # Uninitialized nonrepeated message in an extension. + proto = unittest_pb2.TestAllExtensions() + extension = unittest_pb2.TestRequired.single + proto.Extensions[extension].a = 1 + self.assertFalse(proto.IsInitialized()) + proto.Extensions[extension].b = 2 + proto.Extensions[extension].c = 3 + self.assertTrue(proto.IsInitialized()) + + +# Since we had so many tests for protocol buffer equality, we broke these out +# into separate TestCase classes. + + +class TestAllTypesEqualityTest(unittest.TestCase): + + def setUp(self): + self.first_proto = unittest_pb2.TestAllTypes() + self.second_proto = unittest_pb2.TestAllTypes() + + def testSelfEquality(self): + self.assertEqual(self.first_proto, self.first_proto) + + def testEmptyProtosEqual(self): + self.assertEqual(self.first_proto, self.second_proto) + + +class FullProtosEqualityTest(unittest.TestCase): + + """Equality tests using completely-full protos as a starting point.""" + + def setUp(self): + self.first_proto = unittest_pb2.TestAllTypes() + self.second_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.first_proto) + test_util.SetAllFields(self.second_proto) + + def testAllFieldsFilledEquality(self): + self.assertEqual(self.first_proto, self.second_proto) + + def testNonRepeatedScalar(self): + # Nonrepeated scalar field change should cause inequality. + self.first_proto.optional_int32 += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + # ...as should clearing a field. + self.first_proto.ClearField('optional_int32') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testNonRepeatedComposite(self): + # Change a nonrepeated composite field. + self.first_proto.optional_nested_message.bb += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.optional_nested_message.bb -= 1 + self.assertEqual(self.first_proto, self.second_proto) + # Clear a field in the nested message. + self.first_proto.optional_nested_message.ClearField('bb') + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.optional_nested_message.bb = ( + self.second_proto.optional_nested_message.bb) + self.assertEqual(self.first_proto, self.second_proto) + # Remove the nested message entirely. + self.first_proto.ClearField('optional_nested_message') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testRepeatedScalar(self): + # Change a repeated scalar field. + self.first_proto.repeated_int32.append(5) + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.ClearField('repeated_int32') + self.assertNotEqual(self.first_proto, self.second_proto) + + def testRepeatedComposite(self): + # Change value within a repeated composite field. + self.first_proto.repeated_nested_message[0].bb += 1 + self.assertNotEqual(self.first_proto, self.second_proto) + self.first_proto.repeated_nested_message[0].bb -= 1 + self.assertEqual(self.first_proto, self.second_proto) + # Add a value to a repeated composite field. + self.first_proto.repeated_nested_message.add() + self.assertNotEqual(self.first_proto, self.second_proto) + self.second_proto.repeated_nested_message.add() + self.assertEqual(self.first_proto, self.second_proto) + + def testNonRepeatedScalarHasBits(self): + # Ensure that we test "has" bits as well as value for + # nonrepeated scalar field. + self.first_proto.ClearField('optional_int32') + self.second_proto.optional_int32 = 0 + self.assertNotEqual(self.first_proto, self.second_proto) + + def testNonRepeatedCompositeHasBits(self): + # Ensure that we test "has" bits as well as value for + # nonrepeated composite field. + self.first_proto.ClearField('optional_nested_message') + self.second_proto.optional_nested_message.ClearField('bb') + self.assertNotEqual(self.first_proto, self.second_proto) + # TODO(robinson): Replace next two lines with method + # to set the "has" bit without changing the value, + # if/when such a method exists. + self.first_proto.optional_nested_message.bb = 0 + self.first_proto.optional_nested_message.ClearField('bb') + self.assertEqual(self.first_proto, self.second_proto) + + +class ExtensionEqualityTest(unittest.TestCase): + + def testExtensionEquality(self): + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + self.assertEqual(first_proto, second_proto) + test_util.SetAllExtensions(first_proto) + self.assertNotEqual(first_proto, second_proto) + test_util.SetAllExtensions(second_proto) + self.assertEqual(first_proto, second_proto) + + # Ensure that we check value equality. + first_proto.Extensions[unittest_pb2.optional_int32_extension] += 1 + self.assertNotEqual(first_proto, second_proto) + first_proto.Extensions[unittest_pb2.optional_int32_extension] -= 1 + self.assertEqual(first_proto, second_proto) + + # Ensure that we also look at "has" bits. + first_proto.ClearExtension(unittest_pb2.optional_int32_extension) + second_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 + self.assertNotEqual(first_proto, second_proto) + first_proto.Extensions[unittest_pb2.optional_int32_extension] = 0 + self.assertEqual(first_proto, second_proto) + + # Ensure that differences in cached values + # don't matter if "has" bits are both false. + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + self.assertEqual( + 0, first_proto.Extensions[unittest_pb2.optional_int32_extension]) + self.assertEqual(first_proto, second_proto) + + +class MutualRecursionEqualityTest(unittest.TestCase): + + def testEqualityWithMutualRecursion(self): + first_proto = unittest_pb2.TestMutualRecursionA() + second_proto = unittest_pb2.TestMutualRecursionA() + self.assertEqual(first_proto, second_proto) + first_proto.bb.a.bb.optional_int32 = 23 + self.assertNotEqual(first_proto, second_proto) + second_proto.bb.a.bb.optional_int32 = 23 + self.assertEqual(first_proto, second_proto) + + +class ByteSizeTest(unittest.TestCase): + + def setUp(self): + self.proto = unittest_pb2.TestAllTypes() + self.extended_proto = more_extensions_pb2.ExtendedMessage() + + def Size(self): + return self.proto.ByteSize() + + def testEmptyMessage(self): + self.assertEqual(0, self.proto.ByteSize()) + + def testVarints(self): + def Test(i, expected_varint_size): + self.proto.Clear() + self.proto.optional_int64 = i + # Add one to the varint size for the tag info + # for tag 1. + self.assertEqual(expected_varint_size + 1, self.Size()) + Test(0, 1) + Test(1, 1) + for i, num_bytes in zip(range(7, 63, 7), range(1, 10000)): + Test((1 << i) - 1, num_bytes) + Test(-1, 10) + Test(-2, 10) + Test(-(1 << 63), 10) + + def testStrings(self): + self.proto.optional_string = '' + # Need one byte for tag info (tag #14), and one byte for length. + self.assertEqual(2, self.Size()) + + self.proto.optional_string = 'abc' + # Need one byte for tag info (tag #14), and one byte for length. + self.assertEqual(2 + len(self.proto.optional_string), self.Size()) + + self.proto.optional_string = 'x' * 128 + # Need one byte for tag info (tag #14), and TWO bytes for length. + self.assertEqual(3 + len(self.proto.optional_string), self.Size()) + + def testOtherNumerics(self): + self.proto.optional_fixed32 = 1234 + # One byte for tag and 4 bytes for fixed32. + self.assertEqual(5, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_fixed64 = 1234 + # One byte for tag and 8 bytes for fixed64. + self.assertEqual(9, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_float = 1.234 + # One byte for tag and 4 bytes for float. + self.assertEqual(5, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_double = 1.234 + # One byte for tag and 8 bytes for float. + self.assertEqual(9, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + self.proto.optional_sint32 = 64 + # One byte for tag and 2 bytes for zig-zag-encoded 64. + self.assertEqual(3, self.Size()) + self.proto = unittest_pb2.TestAllTypes() + + def testComposites(self): + # 3 bytes. + self.proto.optional_nested_message.bb = (1 << 14) + # Plus one byte for bb tag. + # Plus 1 byte for optional_nested_message serialized size. + # Plus two bytes for optional_nested_message tag. + self.assertEqual(3 + 1 + 1 + 2, self.Size()) + + def testGroups(self): + # 4 bytes. + self.proto.optionalgroup.a = (1 << 21) + # Plus two bytes for |a| tag. + # Plus 2 * two bytes for START_GROUP and END_GROUP tags. + self.assertEqual(4 + 2 + 2*2, self.Size()) + + def testRepeatedScalars(self): + self.proto.repeated_int32.append(10) # 1 byte. + self.proto.repeated_int32.append(128) # 2 bytes. + # Also need 2 bytes for each entry for tag. + self.assertEqual(1 + 2 + 2*2, self.Size()) + + def testRepeatedComposites(self): + # Empty message. 2 bytes tag plus 1 byte length. + foreign_message_0 = self.proto.repeated_nested_message.add() + # 2 bytes tag plus 1 byte length plus 1 byte bb tag 1 byte int. + foreign_message_1 = self.proto.repeated_nested_message.add() + foreign_message_1.bb = 7 + self.assertEqual(2 + 1 + 2 + 1 + 1 + 1, self.Size()) + + def testRepeatedGroups(self): + # 2-byte START_GROUP plus 2-byte END_GROUP. + group_0 = self.proto.repeatedgroup.add() + # 2-byte START_GROUP plus 2-byte |a| tag + 1-byte |a| + # plus 2-byte END_GROUP. + group_1 = self.proto.repeatedgroup.add() + group_1.a = 7 + self.assertEqual(2 + 2 + 2 + 2 + 1 + 2, self.Size()) + + def testExtensions(self): + proto = unittest_pb2.TestAllExtensions() + self.assertEqual(0, proto.ByteSize()) + extension = unittest_pb2.optional_int32_extension # Field #1, 1 byte. + proto.Extensions[extension] = 23 + # 1 byte for tag, 1 byte for value. + self.assertEqual(2, proto.ByteSize()) + + def testCacheInvalidationForNonrepeatedScalar(self): + # Test non-extension. + self.proto.optional_int32 = 1 + self.assertEqual(2, self.proto.ByteSize()) + self.proto.optional_int32 = 128 + self.assertEqual(3, self.proto.ByteSize()) + self.proto.ClearField('optional_int32') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.optional_int_extension + self.extended_proto.Extensions[extension] = 1 + self.assertEqual(2, self.extended_proto.ByteSize()) + self.extended_proto.Extensions[extension] = 128 + self.assertEqual(3, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForRepeatedScalar(self): + # Test non-extension. + self.proto.repeated_int32.append(1) + self.assertEqual(3, self.proto.ByteSize()) + self.proto.repeated_int32.append(1) + self.assertEqual(6, self.proto.ByteSize()) + self.proto.repeated_int32[1] = 128 + self.assertEqual(7, self.proto.ByteSize()) + self.proto.ClearField('repeated_int32') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.repeated_int_extension + repeated = self.extended_proto.Extensions[extension] + repeated.append(1) + self.assertEqual(2, self.extended_proto.ByteSize()) + repeated.append(1) + self.assertEqual(4, self.extended_proto.ByteSize()) + repeated[1] = 128 + self.assertEqual(5, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForNonrepeatedMessage(self): + # Test non-extension. + self.proto.optional_foreign_message.c = 1 + self.assertEqual(5, self.proto.ByteSize()) + self.proto.optional_foreign_message.c = 128 + self.assertEqual(6, self.proto.ByteSize()) + self.proto.optional_foreign_message.ClearField('c') + self.assertEqual(3, self.proto.ByteSize()) + self.proto.ClearField('optional_foreign_message') + self.assertEqual(0, self.proto.ByteSize()) + child = self.proto.optional_foreign_message + self.proto.ClearField('optional_foreign_message') + child.c = 128 + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.optional_message_extension + child = self.extended_proto.Extensions[extension] + self.assertEqual(0, self.extended_proto.ByteSize()) + child.foreign_message_int = 1 + self.assertEqual(4, self.extended_proto.ByteSize()) + child.foreign_message_int = 128 + self.assertEqual(5, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + def testCacheInvalidationForRepeatedMessage(self): + # Test non-extension. + child0 = self.proto.repeated_foreign_message.add() + self.assertEqual(3, self.proto.ByteSize()) + self.proto.repeated_foreign_message.add() + self.assertEqual(6, self.proto.ByteSize()) + child0.c = 1 + self.assertEqual(8, self.proto.ByteSize()) + self.proto.ClearField('repeated_foreign_message') + self.assertEqual(0, self.proto.ByteSize()) + + # Test within extension. + extension = more_extensions_pb2.repeated_message_extension + child_list = self.extended_proto.Extensions[extension] + child0 = child_list.add() + self.assertEqual(2, self.extended_proto.ByteSize()) + child_list.add() + self.assertEqual(4, self.extended_proto.ByteSize()) + child0.foreign_message_int = 1 + self.assertEqual(6, self.extended_proto.ByteSize()) + child0.ClearField('foreign_message_int') + self.assertEqual(4, self.extended_proto.ByteSize()) + self.extended_proto.ClearExtension(extension) + self.assertEqual(0, self.extended_proto.ByteSize()) + + +# TODO(robinson): We need cross-language serialization consistency tests. +# Issues to be sure to cover include: +# * Handling of unrecognized tags ("uninterpreted_bytes"). +# * Handling of MessageSets. +# * Consistent ordering of tags in the wire format, +# including ordering between extensions and non-extension +# fields. +# * Consistent serialization of negative numbers, especially +# negative int32s. +# * Handling of empty submessages (with and without "has" +# bits set). + +class SerializationTest(unittest.TestCase): + + def testSerializeEmtpyMessage(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllFields(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(first_proto) + serialized = first_proto.SerializeToString() + self.assertEqual(first_proto.ByteSize(), len(serialized)) + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testSerializeAllExtensions(self): + first_proto = unittest_pb2.TestAllExtensions() + second_proto = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(first_proto) + serialized = first_proto.SerializeToString() + second_proto.MergeFromString(serialized) + self.assertEqual(first_proto, second_proto) + + def testCanonicalSerializationOrder(self): + proto = more_messages_pb2.OutOfOrderFields() + # These are also their tag numbers. Even though we're setting these in + # reverse-tag order AND they're listed in reverse tag-order in the .proto + # file, they should nonetheless be serialized in tag order. + proto.optional_sint32 = 5 + proto.Extensions[more_messages_pb2.optional_uint64] = 4 + proto.optional_uint32 = 3 + proto.Extensions[more_messages_pb2.optional_int64] = 2 + proto.optional_int32 = 1 + serialized = proto.SerializeToString() + self.assertEqual(proto.ByteSize(), len(serialized)) + d = decoder.Decoder(serialized) + ReadTag = d.ReadFieldNumberAndWireType + self.assertEqual((1, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(1, d.ReadInt32()) + self.assertEqual((2, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(2, d.ReadInt64()) + self.assertEqual((3, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(3, d.ReadUInt32()) + self.assertEqual((4, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(4, d.ReadUInt64()) + self.assertEqual((5, wire_format.WIRETYPE_VARINT), ReadTag()) + self.assertEqual(5, d.ReadSInt32()) + + def testCanonicalSerializationOrderSameAsCpp(self): + # Copy of the same test we use for C++. + proto = unittest_pb2.TestFieldOrderings() + test_util.SetAllFieldsAndExtensions(proto) + serialized = proto.SerializeToString() + test_util.ExpectAllFieldsAndExtensionsInOrder(serialized) + + def testMergeFromStringWhenFieldsAlreadySet(self): + first_proto = unittest_pb2.TestAllTypes() + first_proto.repeated_string.append('foobar') + first_proto.optional_int32 = 23 + first_proto.optional_nested_message.bb = 42 + serialized = first_proto.SerializeToString() + + second_proto = unittest_pb2.TestAllTypes() + second_proto.repeated_string.append('baz') + second_proto.optional_int32 = 100 + second_proto.optional_nested_message.bb = 999 + + second_proto.MergeFromString(serialized) + # Ensure that we append to repeated fields. + self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) + # Ensure that we overwrite nonrepeatd scalars. + self.assertEqual(23, second_proto.optional_int32) + # Ensure that we recursively call MergeFromString() on + # submessages. + self.assertEqual(42, second_proto.optional_nested_message.bb) + + def testMessageSetWireFormat(self): + proto = unittest_mset_pb2.TestMessageSet() + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension_message2 = unittest_mset_pb2.TestMessageSetExtension2 + extension1 = extension_message1.message_set_extension + extension2 = extension_message2.message_set_extension + proto.Extensions[extension1].i = 123 + proto.Extensions[extension2].str = 'foo' + + # Serialize using the MessageSet wire format (this is specified in the + # .proto file). + serialized = proto.SerializeToString() + + raw = unittest_mset_pb2.RawMessageSet() + self.assertEqual(False, + raw.DESCRIPTOR.GetOptions().message_set_wire_format) + raw.MergeFromString(serialized) + self.assertEqual(2, len(raw.item)) + + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.MergeFromString(raw.item[0].message) + self.assertEqual(123, message1.i) + + message2 = unittest_mset_pb2.TestMessageSetExtension2() + message2.MergeFromString(raw.item[1].message) + self.assertEqual('foo', message2.str) + + # Deserialize using the MessageSet wire format. + proto2 = unittest_mset_pb2.TestMessageSet() + proto2.MergeFromString(serialized) + self.assertEqual(123, proto2.Extensions[extension1].i) + self.assertEqual('foo', proto2.Extensions[extension2].str) + + # Check byte size. + self.assertEqual(proto2.ByteSize(), len(serialized)) + self.assertEqual(proto.ByteSize(), len(serialized)) + + def testMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an item. + item = raw.item.add() + item.type_id = 1545008 + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + # Add a second, unknown extension. + item = raw.item.add() + item.type_id = 1545009 + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12346 + item.message = message1.SerializeToString() + + # Add another unknown extension. + item = raw.item.add() + item.type_id = 1545010 + message1 = unittest_mset_pb2.TestMessageSetExtension2() + message1.str = 'foo' + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Check that the message parsed well. + extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 + extension1 = extension_message1.message_set_extension + self.assertEquals(12345, proto.Extensions[extension1].i) + + def testUnknownFields(self): + proto = unittest_pb2.TestAllTypes() + test_util.SetAllFields(proto) + + serialized = proto.SerializeToString() + + # The empty message should be parsable with all of the fields + # unknown. + proto2 = unittest_pb2.TestEmptyMessage() + + # Parsing this message should succeed. + proto2.MergeFromString(serialized) + + +class OptionsTest(unittest.TestCase): + + def testMessageOptions(self): + proto = unittest_mset_pb2.TestMessageSet() + self.assertEqual(True, + proto.DESCRIPTOR.GetOptions().message_set_wire_format) + proto = unittest_pb2.TestAllTypes() + self.assertEqual(False, + proto.DESCRIPTOR.GetOptions().message_set_wire_format) + + +class UtilityTest(unittest.TestCase): + + def testImergeSorted(self): + ImergeSorted = reflection._ImergeSorted + # Various types of emptiness. + self.assertEqual([], list(ImergeSorted())) + self.assertEqual([], list(ImergeSorted([]))) + self.assertEqual([], list(ImergeSorted([], []))) + + # One nonempty list. + self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3]))) + self.assertEqual([1, 2, 3], list(ImergeSorted([1, 2, 3], []))) + self.assertEqual([1, 2, 3], list(ImergeSorted([], [1, 2, 3]))) + + # Merging some nonempty lists together. + self.assertEqual([1, 2, 3], list(ImergeSorted([1, 3], [2]))) + self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2]))) + self.assertEqual([1, 2, 3], list(ImergeSorted([1], [3], [2], []))) + + # Elements repeated across component iterators. + self.assertEqual([1, 2, 2, 3, 3], + list(ImergeSorted([1, 2], [3], [2, 3]))) + + # Elements repeated within an iterator. + self.assertEqual([1, 2, 2, 3, 3], + list(ImergeSorted([1, 2, 2], [3], [3]))) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py new file mode 100755 index 00000000..895d24d3 --- /dev/null +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -0,0 +1,98 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for google.protobuf.internal.service_reflection.""" + +__author__ = 'petar@google.com (Petar Petrov)' + +import unittest +from google.protobuf import unittest_pb2 +from google.protobuf import service_reflection +from google.protobuf import service + + +class FooUnitTest(unittest.TestCase): + + def testService(self): + class MockRpcChannel(service.RpcChannel): + def CallMethod(self, method, controller, request, response, callback): + self.method = method + self.controller = controller + self.request = request + callback(response) + + class MockRpcController(service.RpcController): + def SetFailed(self, msg): + self.failure_message = msg + + self.callback_response = None + + class MyService(unittest_pb2.TestService): + pass + + self.callback_response = None + + def MyCallback(response): + self.callback_response = response + + rpc_controller = MockRpcController() + channel = MockRpcChannel() + srvc = MyService() + srvc.Foo(rpc_controller, unittest_pb2.FooRequest(), MyCallback) + self.assertEqual('Method Foo not implemented.', + rpc_controller.failure_message) + self.assertEqual(None, self.callback_response) + + rpc_controller.failure_message = None + + service_descriptor = unittest_pb2.TestService.DESCRIPTOR + srvc.CallMethod(service_descriptor.methods[1], rpc_controller, + unittest_pb2.BarRequest(), MyCallback) + self.assertEqual('Method Bar not implemented.', + rpc_controller.failure_message) + self.assertEqual(None, self.callback_response) + + def testServiceStub(self): + class MockRpcChannel(service.RpcChannel): + def CallMethod(self, method, controller, request, + response_class, callback): + self.method = method + self.controller = controller + self.request = request + callback(response_class()) + + self.callback_response = None + + def MyCallback(response): + self.callback_response = response + + channel = MockRpcChannel() + stub = unittest_pb2.TestService_Stub(channel) + rpc_controller = 'controller' + request = 'request' + + # Invoke method. + stub.Foo(rpc_controller, request, MyCallback) + + self.assertTrue(isinstance(self.callback_response, + unittest_pb2.FooResponse)) + self.assertEqual(request, channel.request) + self.assertEqual(rpc_controller, channel.controller) + self.assertEqual(stub.GetDescriptor().methods[0], channel.method) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py new file mode 100755 index 00000000..d9106421 --- /dev/null +++ b/python/google/protobuf/internal/test_util.py @@ -0,0 +1,354 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utilities for Python proto2 tests. + +This is intentionally modeled on C++ code in +//net/proto2/internal/test_util.*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import os.path + +from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_pb2 + + +def SetAllFields(message): + """Sets every field in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllTypes instance. + """ + + # + # Optional fields. + # + + message.optional_int32 = 101 + message.optional_int64 = 102 + message.optional_uint32 = 103 + message.optional_uint64 = 104 + message.optional_sint32 = 105 + message.optional_sint64 = 106 + message.optional_fixed32 = 107 + message.optional_fixed64 = 108 + message.optional_sfixed32 = 109 + message.optional_sfixed64 = 110 + message.optional_float = 111 + message.optional_double = 112 + message.optional_bool = True + # TODO(robinson): Firmly spec out and test how + # protos interact with unicode. One specific example: + # what happens if we change the literal below to + # u'115'? What *should* happen? Still some discussion + # to finish with Kenton about bytes vs. strings + # and forcing everything to be utf8. :-/ + message.optional_string = '115' + message.optional_bytes = '116' + + message.optionalgroup.a = 117 + message.optional_nested_message.bb = 118 + message.optional_foreign_message.c = 119 + message.optional_import_message.d = 120 + + message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ + message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ + message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ + + message.optional_string_piece = '124' + message.optional_cord = '125' + + # + # Repeated fields. + # + + message.repeated_int32.append(201) + message.repeated_int64.append(202) + message.repeated_uint32.append(203) + message.repeated_uint64.append(204) + message.repeated_sint32.append(205) + message.repeated_sint64.append(206) + message.repeated_fixed32.append(207) + message.repeated_fixed64.append(208) + message.repeated_sfixed32.append(209) + message.repeated_sfixed64.append(210) + message.repeated_float.append(211) + message.repeated_double.append(212) + message.repeated_bool.append(True) + message.repeated_string.append('215') + message.repeated_bytes.append('216') + + message.repeatedgroup.add().a = 217 + message.repeated_nested_message.add().bb = 218 + message.repeated_foreign_message.add().c = 219 + message.repeated_import_message.add().d = 220 + + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) + + message.repeated_string_piece.append('224') + message.repeated_cord.append('225') + + # Add a second one of each field. + message.repeated_int32.append(301) + message.repeated_int64.append(302) + message.repeated_uint32.append(303) + message.repeated_uint64.append(304) + message.repeated_sint32.append(305) + message.repeated_sint64.append(306) + message.repeated_fixed32.append(307) + message.repeated_fixed64.append(308) + message.repeated_sfixed32.append(309) + message.repeated_sfixed64.append(310) + message.repeated_float.append(311) + message.repeated_double.append(312) + message.repeated_bool.append(False) + message.repeated_string.append('315') + message.repeated_bytes.append('316') + + message.repeatedgroup.add().a = 317 + message.repeated_nested_message.add().bb = 318 + message.repeated_foreign_message.add().c = 319 + message.repeated_import_message.add().d = 320 + + message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) + message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) + + message.repeated_string_piece.append('324') + message.repeated_cord.append('325') + + # + # Fields that have defaults. + # + + message.default_int32 = 401 + message.default_int64 = 402 + message.default_uint32 = 403 + message.default_uint64 = 404 + message.default_sint32 = 405 + message.default_sint64 = 406 + message.default_fixed32 = 407 + message.default_fixed64 = 408 + message.default_sfixed32 = 409 + message.default_sfixed64 = 410 + message.default_float = 411 + message.default_double = 412 + message.default_bool = False + message.default_string = '415' + message.default_bytes = '416' + + message.default_nested_enum = unittest_pb2.TestAllTypes.FOO + message.default_foreign_enum = unittest_pb2.FOREIGN_FOO + message.default_import_enum = unittest_import_pb2.IMPORT_FOO + + message.default_string_piece = '424' + message.default_cord = '425' + + +def SetAllExtensions(message): + """Sets every extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllExtensions instance. + """ + + extensions = message.Extensions + pb2 = unittest_pb2 + import_pb2 = unittest_import_pb2 + + # + # Optional fields. + # + + extensions[pb2.optional_int32_extension] = 101 + extensions[pb2.optional_int64_extension] = 102 + extensions[pb2.optional_uint32_extension] = 103 + extensions[pb2.optional_uint64_extension] = 104 + extensions[pb2.optional_sint32_extension] = 105 + extensions[pb2.optional_sint64_extension] = 106 + extensions[pb2.optional_fixed32_extension] = 107 + extensions[pb2.optional_fixed64_extension] = 108 + extensions[pb2.optional_sfixed32_extension] = 109 + extensions[pb2.optional_sfixed64_extension] = 110 + extensions[pb2.optional_float_extension] = 111 + extensions[pb2.optional_double_extension] = 112 + extensions[pb2.optional_bool_extension] = True + extensions[pb2.optional_string_extension] = '115' + extensions[pb2.optional_bytes_extension] = '116' + + extensions[pb2.optionalgroup_extension].a = 117 + extensions[pb2.optional_nested_message_extension].bb = 118 + extensions[pb2.optional_foreign_message_extension].c = 119 + extensions[pb2.optional_import_message_extension].d = 120 + + extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ + extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ + extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ + extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ + + extensions[pb2.optional_string_piece_extension] = '124' + extensions[pb2.optional_cord_extension] = '125' + + # + # Repeated fields. + # + + extensions[pb2.repeated_int32_extension].append(201) + extensions[pb2.repeated_int64_extension].append(202) + extensions[pb2.repeated_uint32_extension].append(203) + extensions[pb2.repeated_uint64_extension].append(204) + extensions[pb2.repeated_sint32_extension].append(205) + extensions[pb2.repeated_sint64_extension].append(206) + extensions[pb2.repeated_fixed32_extension].append(207) + extensions[pb2.repeated_fixed64_extension].append(208) + extensions[pb2.repeated_sfixed32_extension].append(209) + extensions[pb2.repeated_sfixed64_extension].append(210) + extensions[pb2.repeated_float_extension].append(211) + extensions[pb2.repeated_double_extension].append(212) + extensions[pb2.repeated_bool_extension].append(True) + extensions[pb2.repeated_string_extension].append('215') + extensions[pb2.repeated_bytes_extension].append('216') + + extensions[pb2.repeatedgroup_extension].add().a = 217 + extensions[pb2.repeated_nested_message_extension].add().bb = 218 + extensions[pb2.repeated_foreign_message_extension].add().c = 219 + extensions[pb2.repeated_import_message_extension].add().d = 220 + + extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) + extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) + extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) + + extensions[pb2.repeated_string_piece_extension].append('224') + extensions[pb2.repeated_cord_extension].append('225') + + # Append a second one of each field. + extensions[pb2.repeated_int32_extension].append(301) + extensions[pb2.repeated_int64_extension].append(302) + extensions[pb2.repeated_uint32_extension].append(303) + extensions[pb2.repeated_uint64_extension].append(304) + extensions[pb2.repeated_sint32_extension].append(305) + extensions[pb2.repeated_sint64_extension].append(306) + extensions[pb2.repeated_fixed32_extension].append(307) + extensions[pb2.repeated_fixed64_extension].append(308) + extensions[pb2.repeated_sfixed32_extension].append(309) + extensions[pb2.repeated_sfixed64_extension].append(310) + extensions[pb2.repeated_float_extension].append(311) + extensions[pb2.repeated_double_extension].append(312) + extensions[pb2.repeated_bool_extension].append(False) + extensions[pb2.repeated_string_extension].append('315') + extensions[pb2.repeated_bytes_extension].append('316') + + extensions[pb2.repeatedgroup_extension].add().a = 317 + extensions[pb2.repeated_nested_message_extension].add().bb = 318 + extensions[pb2.repeated_foreign_message_extension].add().c = 319 + extensions[pb2.repeated_import_message_extension].add().d = 320 + + extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) + extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) + extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) + + extensions[pb2.repeated_string_piece_extension].append('324') + extensions[pb2.repeated_cord_extension].append('325') + + # + # Fields with defaults. + # + + extensions[pb2.default_int32_extension] = 401 + extensions[pb2.default_int64_extension] = 402 + extensions[pb2.default_uint32_extension] = 403 + extensions[pb2.default_uint64_extension] = 404 + extensions[pb2.default_sint32_extension] = 405 + extensions[pb2.default_sint64_extension] = 406 + extensions[pb2.default_fixed32_extension] = 407 + extensions[pb2.default_fixed64_extension] = 408 + extensions[pb2.default_sfixed32_extension] = 409 + extensions[pb2.default_sfixed64_extension] = 410 + extensions[pb2.default_float_extension] = 411 + extensions[pb2.default_double_extension] = 412 + extensions[pb2.default_bool_extension] = False + extensions[pb2.default_string_extension] = '415' + extensions[pb2.default_bytes_extension] = '416' + + extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO + extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO + extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO + + extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_cord_extension] = '425' + + +def SetAllFieldsAndExtensions(message): + """Sets every field and extension in the message to a unique value. + + Args: + message: A unittest_pb2.TestAllExtensions message. + """ + message.my_int = 1 + message.my_string = 'foo' + message.my_float = 1.0 + message.Extensions[unittest_pb2.my_extension_int] = 23 + message.Extensions[unittest_pb2.my_extension_string] = 'bar' + + +def ExpectAllFieldsAndExtensionsInOrder(serialized): + """Ensures that serialized is the serialization we expect for a message + filled with SetAllFieldsAndExtensions(). (Specifically, ensures that the + serialization is in canonical, tag-number order). + """ + my_extension_int = unittest_pb2.my_extension_int + my_extension_string = unittest_pb2.my_extension_string + expected_strings = [] + message = unittest_pb2.TestFieldOrderings() + message.my_int = 1 # Field 1. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.Extensions[my_extension_int] = 23 # Field 5. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.my_string = 'foo' # Field 11. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.Extensions[my_extension_string] = 'bar' # Field 50. + expected_strings.append(message.SerializeToString()) + message.Clear() + message.my_float = 1.0 + expected_strings.append(message.SerializeToString()) + message.Clear() + expected = ''.join(expected_strings) + + if expected != serialized: + raise ValueError('Expected %r, found %r' % (expected, serialized)) + +def GoldenFile(filename): + """Finds the given golden file and returns a file object representing it.""" + + # Search up the directory tree looking for the C++ protobuf source code. + path = '.' + while os.path.exists(path): + if os.path.exists(os.path.join(path, 'src/google/protobuf')): + # Found it. Load the golden file from the testdata directory. + return file(os.path.join(path, 'src/google/protobuf/testdata', filename)) + path = os.path.join(path, '..') + + raise RuntimeError( + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py new file mode 100755 index 00000000..c2074db5 --- /dev/null +++ b/python/google/protobuf/internal/text_format_test.py @@ -0,0 +1,97 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for google.protobuf.text_format.""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import difflib + +import unittest +from google.protobuf import text_format +from google.protobuf.internal import test_util +from google.protobuf import unittest_pb2 +from google.protobuf import unittest_mset_pb2 + +class TextFormatTest(unittest.TestCase): + def CompareToGoldenFile(self, text, golden_filename): + f = test_util.GoldenFile(golden_filename) + golden_lines = f.readlines() + f.close() + self.CompareToGoldenLines(text, golden_lines) + + def CompareToGoldenText(self, text, golden_text): + self.CompareToGoldenLines(text, golden_text.splitlines(1)) + + def CompareToGoldenLines(self, text, golden_lines): + actual_lines = text.splitlines(1) + self.assertEqual(golden_lines, actual_lines, + "Text doesn't match golden. Diff:\n" + + ''.join(difflib.ndiff(golden_lines, actual_lines))) + + def testPrintAllFields(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile(text_format.MessageToString(message), + 'text_format_unittest_data.txt') + + def testPrintAllExtensions(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile(text_format.MessageToString(message), + 'text_format_unittest_extensions_data.txt') + + def testPrintMessageSet(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText(text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') + + def testPrintExotic(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808); + message.repeated_uint64.append(18446744073709551615); + message.repeated_double.append(123.456); + message.repeated_double.append(1.23e22); + message.repeated_double.append(1.23e-18); + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"'); + self.CompareToGoldenText(text_format.MessageToString(message), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string: ' + '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + + def testMessageToString(self): + message = unittest_pb2.ForeignMessage() + message.c = 123 + self.assertEqual('c: 123\n', str(message)) + +if __name__ == '__main__': + unittest.main() + diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py new file mode 100755 index 00000000..69aa4abf --- /dev/null +++ b/python/google/protobuf/internal/wire_format.py @@ -0,0 +1,222 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Constants and static functions to support protocol buffer wire format.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import struct +from google.protobuf import message + + +TAG_TYPE_BITS = 3 # Number of bits used to hold type info in a proto tag. +_TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1 # 0x7 + +# These numbers identify the wire type of a protocol buffer value. +# We use the least-significant TAG_TYPE_BITS bits of the varint-encoded +# tag-and-type to store one of these WIRETYPE_* constants. +# These values must match WireType enum in //net/proto2/public/wire_format.h. +WIRETYPE_VARINT = 0 +WIRETYPE_FIXED64 = 1 +WIRETYPE_LENGTH_DELIMITED = 2 +WIRETYPE_START_GROUP = 3 +WIRETYPE_END_GROUP = 4 +WIRETYPE_FIXED32 = 5 +_WIRETYPE_MAX = 5 + + +# Bounds for various integer types. +INT32_MAX = int((1 << 31) - 1) +INT32_MIN = int(-(1 << 31)) +UINT32_MAX = (1 << 32) - 1 + +INT64_MAX = (1 << 63) - 1 +INT64_MIN = -(1 << 63) +UINT64_MAX = (1 << 64) - 1 + +# "struct" format strings that will encode/decode the specified formats. +FORMAT_UINT32_LITTLE_ENDIAN = '> TAG_TYPE_BITS), (tag & _TAG_TYPE_MASK) + + +def ZigZagEncode(value): + """ZigZag Transform: Encodes signed integers so that they can be + effectively used with varint encoding. See wire_format.h for + more details. + """ + if value >= 0: + return value << 1 + return ((value << 1) ^ (~0)) | 0x1 + + +def ZigZagDecode(value): + """Inverse of ZigZagEncode().""" + if not value & 0x1: + return value >> 1 + return (value >> 1) ^ (~0) + + + +# The *ByteSize() functions below return the number of bytes required to +# serialize "field number + type" information and then serialize the value. + + +def Int32ByteSize(field_number, int32): + return Int64ByteSize(field_number, int32) + + +def Int64ByteSize(field_number, int64): + # Have to convert to uint before calling UInt64ByteSize(). + return UInt64ByteSize(field_number, 0xffffffffffffffff & int64) + + +def UInt32ByteSize(field_number, uint32): + return UInt64ByteSize(field_number, uint32) + + +def UInt64ByteSize(field_number, uint64): + return _TagByteSize(field_number) + _VarUInt64ByteSizeNoTag(uint64) + + +def SInt32ByteSize(field_number, int32): + return UInt32ByteSize(field_number, ZigZagEncode(int32)) + + +def SInt64ByteSize(field_number, int64): + return UInt64ByteSize(field_number, ZigZagEncode(int64)) + + +def Fixed32ByteSize(field_number, fixed32): + return _TagByteSize(field_number) + 4 + + +def Fixed64ByteSize(field_number, fixed64): + return _TagByteSize(field_number) + 8 + + +def SFixed32ByteSize(field_number, sfixed32): + return _TagByteSize(field_number) + 4 + + +def SFixed64ByteSize(field_number, sfixed64): + return _TagByteSize(field_number) + 8 + + +def FloatByteSize(field_number, flt): + return _TagByteSize(field_number) + 4 + + +def DoubleByteSize(field_number, double): + return _TagByteSize(field_number) + 8 + + +def BoolByteSize(field_number, b): + return _TagByteSize(field_number) + 1 + + +def EnumByteSize(field_number, enum): + return UInt32ByteSize(field_number, enum) + + +def StringByteSize(field_number, string): + return (_TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(len(string)) + + len(string)) + + +def BytesByteSize(field_number, b): + return StringByteSize(field_number, b) + + +def GroupByteSize(field_number, message): + return (2 * _TagByteSize(field_number) # START and END group. + + message.ByteSize()) + + +def MessageByteSize(field_number, message): + return (_TagByteSize(field_number) + + _VarUInt64ByteSizeNoTag(message.ByteSize()) + + message.ByteSize()) + + +def MessageSetItemByteSize(field_number, msg): + # First compute the sizes of the tags. + # There are 2 tags for the beginning and ending of the repeated group, that + # is field number 1, one with field number 2 (type_id) and one with field + # number 3 (message). + total_size = (2 * _TagByteSize(1) + _TagByteSize(2) + _TagByteSize(3)) + + # Add the number of bytes for type_id. + total_size += _VarUInt64ByteSizeNoTag(field_number) + + message_size = msg.ByteSize() + + # The number of bytes for encoding the length of the message. + total_size += _VarUInt64ByteSizeNoTag(message_size) + + # The size of the message. + total_size += message_size + return total_size + + +# Private helper functions for the *ByteSize() functions above. + + +def _TagByteSize(field_number): + """Returns the bytes required to serialize a tag with this field number.""" + # Just pass in type 0, since the type won't affect the tag+type size. + return _VarUInt64ByteSizeNoTag(PackTag(field_number, 0)) + + +def _VarUInt64ByteSizeNoTag(uint64): + """Returns the bytes required to serialize a single varint. + uint64 must be unsigned. + """ + if uint64 > UINT64_MAX: + raise message.EncodeError('Value out of range: %d' % uint64) + bytes = 1 + while uint64 > 0x7f: + bytes += 1 + uint64 >>= 7 + return bytes diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py new file mode 100755 index 00000000..87e0ddf5 --- /dev/null +++ b/python/google/protobuf/internal/wire_format_test.py @@ -0,0 +1,232 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for google.protobuf.internal.wire_format.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import unittest +from google.protobuf import message +from google.protobuf.internal import wire_format + + +class WireFormatTest(unittest.TestCase): + + def testPackTag(self): + field_number = 0xabc + tag_type = 2 + self.assertEqual((field_number << 3) | tag_type, + wire_format.PackTag(field_number, tag_type)) + PackTag = wire_format.PackTag + # Number too high. + self.assertRaises(message.EncodeError, PackTag, field_number, 6) + # Number too low. + self.assertRaises(message.EncodeError, PackTag, field_number, -1) + + def testUnpackTag(self): + # Test field numbers that will require various varint sizes. + for expected_field_number in (1, 15, 16, 2047, 2048): + for expected_wire_type in range(6): # Highest-numbered wiretype is 5. + field_number, wire_type = wire_format.UnpackTag( + wire_format.PackTag(expected_field_number, expected_wire_type)) + self.assertEqual(expected_field_number, field_number) + self.assertEqual(expected_wire_type, wire_type) + + self.assertRaises(TypeError, wire_format.UnpackTag, None) + self.assertRaises(TypeError, wire_format.UnpackTag, 'abc') + self.assertRaises(TypeError, wire_format.UnpackTag, 0.0) + self.assertRaises(TypeError, wire_format.UnpackTag, object()) + + def testZigZagEncode(self): + Z = wire_format.ZigZagEncode + self.assertEqual(0, Z(0)) + self.assertEqual(1, Z(-1)) + self.assertEqual(2, Z(1)) + self.assertEqual(3, Z(-2)) + self.assertEqual(4, Z(2)) + self.assertEqual(0xfffffffe, Z(0x7fffffff)) + self.assertEqual(0xffffffff, Z(-0x80000000)) + self.assertEqual(0xfffffffffffffffe, Z(0x7fffffffffffffff)) + self.assertEqual(0xffffffffffffffff, Z(-0x8000000000000000)) + + self.assertRaises(TypeError, Z, None) + self.assertRaises(TypeError, Z, 'abcd') + self.assertRaises(TypeError, Z, 0.0) + self.assertRaises(TypeError, Z, object()) + + def testZigZagDecode(self): + Z = wire_format.ZigZagDecode + self.assertEqual(0, Z(0)) + self.assertEqual(-1, Z(1)) + self.assertEqual(1, Z(2)) + self.assertEqual(-2, Z(3)) + self.assertEqual(2, Z(4)) + self.assertEqual(0x7fffffff, Z(0xfffffffe)) + self.assertEqual(-0x80000000, Z(0xffffffff)) + self.assertEqual(0x7fffffffffffffff, Z(0xfffffffffffffffe)) + self.assertEqual(-0x8000000000000000, Z(0xffffffffffffffff)) + + self.assertRaises(TypeError, Z, None) + self.assertRaises(TypeError, Z, 'abcd') + self.assertRaises(TypeError, Z, 0.0) + self.assertRaises(TypeError, Z, object()) + + def NumericByteSizeTestHelper(self, byte_size_fn, value, expected_value_size): + # Use field numbers that cause various byte sizes for the tag information. + for field_number, tag_bytes in ((15, 1), (16, 2), (2047, 2), (2048, 3)): + expected_size = expected_value_size + tag_bytes + actual_size = byte_size_fn(field_number, value) + self.assertEqual(expected_size, actual_size, + 'byte_size_fn: %s, field_number: %d, value: %r\n' + 'Expected: %d, Actual: %d'% ( + byte_size_fn, field_number, value, expected_size, actual_size)) + + def testByteSizeFunctions(self): + # Test all numeric *ByteSize() functions. + NUMERIC_ARGS = [ + # Int32ByteSize(). + [wire_format.Int32ByteSize, 0, 1], + [wire_format.Int32ByteSize, 127, 1], + [wire_format.Int32ByteSize, 128, 2], + [wire_format.Int32ByteSize, -1, 10], + # Int64ByteSize(). + [wire_format.Int64ByteSize, 0, 1], + [wire_format.Int64ByteSize, 127, 1], + [wire_format.Int64ByteSize, 128, 2], + [wire_format.Int64ByteSize, -1, 10], + # UInt32ByteSize(). + [wire_format.UInt32ByteSize, 0, 1], + [wire_format.UInt32ByteSize, 127, 1], + [wire_format.UInt32ByteSize, 128, 2], + [wire_format.UInt32ByteSize, wire_format.UINT32_MAX, 5], + # UInt64ByteSize(). + [wire_format.UInt64ByteSize, 0, 1], + [wire_format.UInt64ByteSize, 127, 1], + [wire_format.UInt64ByteSize, 128, 2], + [wire_format.UInt64ByteSize, wire_format.UINT64_MAX, 10], + # SInt32ByteSize(). + [wire_format.SInt32ByteSize, 0, 1], + [wire_format.SInt32ByteSize, -1, 1], + [wire_format.SInt32ByteSize, 1, 1], + [wire_format.SInt32ByteSize, -63, 1], + [wire_format.SInt32ByteSize, 63, 1], + [wire_format.SInt32ByteSize, -64, 1], + [wire_format.SInt32ByteSize, 64, 2], + # SInt64ByteSize(). + [wire_format.SInt64ByteSize, 0, 1], + [wire_format.SInt64ByteSize, -1, 1], + [wire_format.SInt64ByteSize, 1, 1], + [wire_format.SInt64ByteSize, -63, 1], + [wire_format.SInt64ByteSize, 63, 1], + [wire_format.SInt64ByteSize, -64, 1], + [wire_format.SInt64ByteSize, 64, 2], + # Fixed32ByteSize(). + [wire_format.Fixed32ByteSize, 0, 4], + [wire_format.Fixed32ByteSize, wire_format.UINT32_MAX, 4], + # Fixed64ByteSize(). + [wire_format.Fixed64ByteSize, 0, 8], + [wire_format.Fixed64ByteSize, wire_format.UINT64_MAX, 8], + # SFixed32ByteSize(). + [wire_format.SFixed32ByteSize, 0, 4], + [wire_format.SFixed32ByteSize, wire_format.INT32_MIN, 4], + [wire_format.SFixed32ByteSize, wire_format.INT32_MAX, 4], + # SFixed64ByteSize(). + [wire_format.SFixed64ByteSize, 0, 8], + [wire_format.SFixed64ByteSize, wire_format.INT64_MIN, 8], + [wire_format.SFixed64ByteSize, wire_format.INT64_MAX, 8], + # FloatByteSize(). + [wire_format.FloatByteSize, 0.0, 4], + [wire_format.FloatByteSize, 1000000000.0, 4], + [wire_format.FloatByteSize, -1000000000.0, 4], + # DoubleByteSize(). + [wire_format.DoubleByteSize, 0.0, 8], + [wire_format.DoubleByteSize, 1000000000.0, 8], + [wire_format.DoubleByteSize, -1000000000.0, 8], + # BoolByteSize(). + [wire_format.BoolByteSize, False, 1], + [wire_format.BoolByteSize, True, 1], + # EnumByteSize(). + [wire_format.EnumByteSize, 0, 1], + [wire_format.EnumByteSize, 127, 1], + [wire_format.EnumByteSize, 128, 2], + [wire_format.EnumByteSize, wire_format.UINT32_MAX, 5], + ] + for args in NUMERIC_ARGS: + self.NumericByteSizeTestHelper(*args) + + # Test strings and bytes. + for byte_size_fn in (wire_format.StringByteSize, wire_format.BytesByteSize): + # 1 byte for tag, 1 byte for length, 3 bytes for contents. + self.assertEqual(5, byte_size_fn(10, 'abc')) + # 2 bytes for tag, 1 byte for length, 3 bytes for contents. + self.assertEqual(6, byte_size_fn(16, 'abc')) + # 2 bytes for tag, 2 bytes for length, 128 bytes for contents. + self.assertEqual(132, byte_size_fn(16, 'a' * 128)) + + class MockMessage(object): + def __init__(self, byte_size): + self.byte_size = byte_size + def ByteSize(self): + return self.byte_size + + message_byte_size = 10 + mock_message = MockMessage(byte_size=message_byte_size) + # Test groups. + # (2 * 1) bytes for begin and end tags, plus message_byte_size. + self.assertEqual(2 + message_byte_size, + wire_format.GroupByteSize(1, mock_message)) + # (2 * 2) bytes for begin and end tags, plus message_byte_size. + self.assertEqual(4 + message_byte_size, + wire_format.GroupByteSize(16, mock_message)) + + # Test messages. + # 1 byte for tag, plus 1 byte for length, plus contents. + self.assertEqual(2 + mock_message.byte_size, + wire_format.MessageByteSize(1, mock_message)) + # 2 bytes for tag, plus 1 byte for length, plus contents. + self.assertEqual(3 + mock_message.byte_size, + wire_format.MessageByteSize(16, mock_message)) + # 2 bytes for tag, plus 2 bytes for length, plus contents. + mock_message.byte_size = 128 + self.assertEqual(4 + mock_message.byte_size, + wire_format.MessageByteSize(16, mock_message)) + + + # Test message set item byte size. + # 4 bytes for tags, plus 1 byte for length, plus 1 byte for type_id, + # plus contents. + mock_message.byte_size = 10 + self.assertEqual(mock_message.byte_size + 6, + wire_format.MessageSetItemByteSize(1, mock_message)) + + # 4 bytes for tags, plus 2 bytes for length, plus 1 byte for type_id, + # plus contents. + mock_message.byte_size = 128 + self.assertEqual(mock_message.byte_size + 7, + wire_format.MessageSetItemByteSize(1, mock_message)) + + # 4 bytes for tags, plus 2 bytes for length, plus 2 byte for type_id, + # plus contents. + self.assertEqual(mock_message.byte_size + 8, + wire_format.MessageSetItemByteSize(128, mock_message)) + + # Too-long varint. + self.assertRaises(message.EncodeError, + wire_format.UInt64ByteSize, 1, 1 << 128) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/google/protobuf/message.py b/python/google/protobuf/message.py new file mode 100755 index 00000000..9b48f889 --- /dev/null +++ b/python/google/protobuf/message.py @@ -0,0 +1,184 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# TODO(robinson): We should just make these methods all "pure-virtual" and move +# all implementation out, into reflection.py for now. + + +"""Contains an abstract base class for protocol messages.""" + +__author__ = 'robinson@google.com (Will Robinson)' + +from google.protobuf import text_format + +class Error(Exception): pass +class DecodeError(Error): pass +class EncodeError(Error): pass + + +class Message(object): + + """Abstract base class for protocol messages. + + Protocol message classes are almost always generated by the protocol + compiler. These generated types subclass Message and implement the methods + shown below. + + TODO(robinson): Link to an HTML document here. + + TODO(robinson): Document that instances of this class will also + have an Extensions attribute with __getitem__ and __setitem__. + Again, not sure how to best convey this. + + TODO(robinson): Document that the class must also have a static + RegisterExtension(extension_field) method. + Not sure how to best express at this point. + """ + + # TODO(robinson): Document these fields and methods. + + __slots__ = [] + + DESCRIPTOR = None + + def __eq__(self, other_msg): + raise NotImplementedError + + def __ne__(self, other_msg): + # Can't just say self != other_msg, since that would infinitely recurse. :) + return not self == other_msg + + def __str__(self): + return text_format.MessageToString(self) + + def MergeFrom(self, other_msg): + raise NotImplementedError + + def CopyFrom(self, other_msg): + raise NotImplementedError + + def Clear(self): + raise NotImplementedError + + def IsInitialized(self): + raise NotImplementedError + + # TODO(robinson): MergeFromString() should probably return None and be + # implemented in terms of a helper that returns the # of bytes read. Our + # deserialization routines would use the helper when recursively + # deserializing, but the end user would almost always just want the no-return + # MergeFromString(). + + def MergeFromString(self, serialized): + """Merges serialized protocol buffer data into this message. + + When we find a field in |serialized| that is already present + in this message: + - If it's a "repeated" field, we append to the end of our list. + - Else, if it's a scalar, we overwrite our field. + - Else, (it's a nonrepeated composite), we recursively merge + into the existing composite. + + TODO(robinson): Document handling of unknown fields. + + Args: + serialized: Any object that allows us to call buffer(serialized) + to access a string of bytes using the buffer interface. + + TODO(robinson): When we switch to a helper, this will return None. + + Returns: + The number of bytes read from |serialized|. + For non-group messages, this will always be len(serialized), + but for messages which are actually groups, this will + generally be less than len(serialized), since we must + stop when we reach an END_GROUP tag. Note that if + we *do* stop because of an END_GROUP tag, the number + of bytes returned does not include the bytes + for the END_GROUP tag information. + """ + raise NotImplementedError + + def ParseFromString(self, serialized): + """Like MergeFromString(), except we clear the object first.""" + self.Clear() + self.MergeFromString(serialized) + + def SerializeToString(self): + raise NotImplementedError + + # TODO(robinson): Decide whether we like these better + # than auto-generated has_foo() and clear_foo() methods + # on the instances themselves. This way is less consistent + # with C++, but it makes reflection-type access easier and + # reduces the number of magically autogenerated things. + # + # TODO(robinson): Be sure to document (and test) exactly + # which field names are accepted here. Are we case-sensitive? + # What do we do with fields that share names with Python keywords + # like 'lambda' and 'yield'? + # + # nnorwitz says: + # """ + # Typically (in python), an underscore is appended to names that are + # keywords. So they would become lambda_ or yield_. + # """ + def ListFields(self, field_name): + """Returns a list of (FieldDescriptor, value) tuples for all + fields in the message which are not empty. A singular field is non-empty + if HasField() would return true, and a repeated field is non-empty if + it contains at least one element. The fields are ordered by field + number""" + raise NotImplementedError + + def HasField(self, field_name): + raise NotImplementedError + + def ClearField(self, field_name): + raise NotImplementedError + + def HasExtension(self, extension_handle): + raise NotImplementedError + + def ClearExtension(self, extension_handle): + raise NotImplementedError + + def ByteSize(self): + """Returns the serialized size of this message. + Recursively calls ByteSize() on all contained messages. + """ + raise NotImplementedError + + def _SetListener(self, message_listener): + """Internal method used by the protocol message implementation. + Clients should not call this directly. + + Sets a listener that this message will call on certain state transitions. + + The purpose of this method is to register back-edges from children to + parents at runtime, for the purpose of setting "has" bits and + byte-size-dirty bits in the parent and ancestor objects whenever a child or + descendant object is modified. + + If the client wants to disconnect this Message from the object tree, she + explicitly sets callback to None. + + If message_listener is None, unregisters any existing listener. Otherwise, + message_listener must implement the MessageListener interface in + internal/message_listener.py, and we discard any listener registered + via a previous _SetListener() call. + """ + raise NotImplementedError diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py new file mode 100755 index 00000000..75202c4e --- /dev/null +++ b/python/google/protobuf/reflection.py @@ -0,0 +1,1734 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import heapq +import threading +import weakref +# We use "as" to avoid name collisions with variables. +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import wire_format +from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message as message_mod + +_FieldDescriptor = descriptor_mod.FieldDescriptor + + +class GeneratedProtocolMessageType(type): + + """Metaclass for protocol message classes created at runtime from Descriptors. + + We add implementations for all methods described in the Message class. We + also create properties to allow getting/setting all fields in the protocol + message. Finally, we create slots to prevent users from accidentally + "setting" nonexistent fields in the protocol message, which then wouldn't get + serialized / deserialized properly. + + The protocol compiler currently uses this metaclass to create protocol + message classes at runtime. Clients can also manually create their own + classes at runtime, as in this example: + + mydescriptor = Descriptor(.....) + class MyProtoClass(Message): + __metaclass__ = GeneratedProtocolMessageType + DESCRIPTOR = mydescriptor + myproto_instance = MyProtoClass() + myproto.foo_field = 23 + ... + """ + + # Must be consistent with the protocol-compiler code in + # proto2/compiler/internal/generator.*. + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __new__(cls, name, bases, dictionary): + """Custom allocation for runtime-generated class types. + + We override __new__ because this is apparently the only place + where we can meaningfully set __slots__ on the class we're creating(?). + (The interplay between metaclasses and slots is not very well-documented). + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + + Returns: + Newly-allocated class. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + _AddSlots(descriptor, dictionary) + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + superclass = super(GeneratedProtocolMessageType, cls) + return superclass.__new__(cls, name, bases, dictionary) + + def __init__(cls, name, bases, dictionary): + """Here we perform the majority of our work on the class. + We add enum getters, an __init__ method, implementations + of all Message methods, and properties for all fields + in the protocol type. + + Args: + name: Name of the class (ignored, but required by the + metaclass protocol). + bases: Base classes of the class we're constructing. + (Should be message.Message). We ignore this field, but + it's required by the metaclass protocol + dictionary: The class dictionary of the class we're + constructing. dictionary[_DESCRIPTOR_KEY] must contain + a Descriptor object describing this protocol message + type. + """ + descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + # We act as a "friend" class of the descriptor, setting + # its _concrete_class attribute the first time we use a + # given descriptor to initialize a concrete protocol message + # class. + concrete_class_attr_name = '_concrete_class' + if not hasattr(descriptor, concrete_class_attr_name): + setattr(descriptor, concrete_class_attr_name, cls) + cls._known_extensions = [] + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(cls) + superclass = super(GeneratedProtocolMessageType, cls) + superclass.__init__(cls, name, bases, dictionary) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + return proto_field_name + + +def _ValueFieldName(proto_field_name): + """Returns the name of the (internal) instance attribute which objects + should use to store the current value for a given protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + return '_value_' + proto_field_name + + +def _HasFieldName(proto_field_name): + """Returns the name of the (internal) instance attribute which + objects should use to store a boolean telling whether this field + is explicitly set or not. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + return '_has_' + proto_field_name + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields] + field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields + if f.label != _FieldDescriptor.LABEL_REPEATED) + field_names.extend(('Extensions', + '_cached_byte_size', + '_cached_byte_size_dirty', + '_called_transition_to_nonempty', + '_listener', + '_lock', '__weakref__')) + dictionary['__slots__'] = field_names + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _DefaultValueForField(message, field): + """Returns a default value for a field. + + Args: + message: Message instance containing this field, or a weakref proxy + of same. + field: FieldDescriptor object for this field. + + Returns: A default value for this field. May refer back to |message| + via a weak reference. + """ + # TODO(robinson): Only the repeated fields need a reference to 'message' (so + # that they can set the 'has' bit on the containing Message when someone + # append()s a value). We could special-case this, and avoid an extra + # function call on __init__() and Clear() for non-repeated fields. + + # TODO(robinson): Find a better place for the default value assertion in this + # function. No need to repeat them every time the client calls Clear('foo'). + # (We should probably just assert these things once and as early as possible, + # by tightening checking in the descriptor classes.) + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + listener = _Listener(message, None) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + return _RepeatedCompositeFieldContainer(listener, field.message_type) + else: + return _RepeatedScalarFieldContainer(listener, + _VALUE_CHECKERS[field.cpp_type]) + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + assert field.default_value is None + + return field.default_value + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + fields = message_descriptor.fields + def init(self): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = False + self._listener = message_listener_mod.NullMessageListener() + self._called_transition_to_nonempty = False + # TODO(robinson): We should only create a lock if we really need one + # in this class. + self._lock = threading.Lock() + for field in fields: + default_value = _DefaultValueForField(self, field) + python_field_name = _ValueFieldName(field.name) + setattr(self, python_field_name, default_value) + if field.label != _FieldDescriptor.LABEL_REPEATED: + setattr(self, _HasFieldName(field.name), False) + self.Extensions = _ExtensionDict(self, cls._known_extensions) + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + + def getter(self): + return getattr(self, python_field_name) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + type_checker = _VALUE_CHECKERS[field.cpp_type] + + def getter(self): + return getattr(self, python_field_name) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + def setter(self, new_value): + type_checker.CheckValue(new_value) + setattr(self, has_field_name, True) + self._MarkByteSizeDirty() + self._MaybeCallTransitionToNonemptyCallback() + setattr(self, python_field_name, new_value) + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + property_name = _PropertyName(proto_field_name) + message_type = field.message_type + + def getter(self): + # TODO(robinson): Appropriately scary note about double-checked locking. + field_value = getattr(self, python_field_name) + if field_value is None: + self._lock.acquire() + try: + field_value = getattr(self, python_field_name) + if field_value is None: + field_class = message_type._concrete_class + field_value = field_class() + field_value._SetListener(_Listener(self, has_field_name)) + setattr(self, python_field_name, field_value) + finally: + self._lock.release() + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + cls._known_extensions.append(extension_handle) + cls.RegisterExtension = staticmethod(RegisterExtension) + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + # Ensure that we always list in ascending field-number order. + # For non-extension fields, we can do the sort once, here, at import-time. + # For extensions, we sort on each ListFields() call, though + # we could do better if we have to. + fields = sorted(message_descriptor.fields, key=lambda f: f.number) + has_field_names = (_HasFieldName(f.name) for f in fields) + value_field_names = (_ValueFieldName(f.name) for f in fields) + triplets = zip(has_field_names, value_field_names, fields) + + def ListFields(self): + # We need to list all extension and non-extension fields + # together, in sorted order by field number. + + # Step 0: Get an iterator over all "set" non-extension fields, + # sorted by field number. + # This iterator yields (field_number, field_descriptor, value) tuples. + def SortedSetFieldsIter(): + # Note that triplets is already sorted by field number. + for has_field_name, value_field_name, field_descriptor in triplets: + if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: + value = getattr(self, _ValueFieldName(field_descriptor.name)) + if len(value) > 0: + yield (field_descriptor.number, field_descriptor, value) + elif getattr(self, _HasFieldName(field_descriptor.name)): + value = getattr(self, _ValueFieldName(field_descriptor.name)) + yield (field_descriptor.number, field_descriptor, value) + sorted_fields = SortedSetFieldsIter() + + # Step 1: Get an iterator over all "set" extension fields, + # sorted by field number. + # This iterator ALSO yields (field_number, field_descriptor, value) tuples. + # TODO(robinson): It's not necessary to repeat this with each + # serialization call. We can do better. + sorted_extension_fields = sorted( + [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()]) + + # Step 2: Create a composite iterator that merges the extension- + # and non-extension fields, and that still yields fields in + # sorted order. + all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields) + + # Step 3: Strip off the field numbers and return. + return [field[1:] for field in all_set_fields] + + cls.ListFields = ListFields + +def _AddHasFieldMethod(cls): + """Helper for _AddMessageMethods().""" + def HasField(self, field_name): + try: + return getattr(self, _HasFieldName(field_name)) + except AttributeError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + cls.HasField = HasField + + +def _AddClearFieldMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + try: + field = self.DESCRIPTOR.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + proto_field_name = field.name + python_field_name = _ValueFieldName(proto_field_name) + has_field_name = _HasFieldName(proto_field_name) + default_value = _DefaultValueForField(self, field) + if field.label == _FieldDescriptor.LABEL_REPEATED: + self._MarkByteSizeDirty() + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + old_field_value = getattr(self, python_field_name) + if old_field_value is not None: + # Snip the old object out of the object tree. + old_field_value._SetListener(None) + if getattr(self, has_field_name): + setattr(self, has_field_name, False) + # Set dirty bit on ourself and parents only if + # we're actually changing state. + self._MarkByteSizeDirty() + setattr(self, python_field_name, default_value) + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + self.Extensions._ClearExtension(extension_handle) + cls.ClearExtension = ClearExtension + + +def _AddClearMethod(cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + fields = self.DESCRIPTOR.fields + for field in fields: + self.ClearField(field.name) + # Clear extensions. + extensions = self.Extensions._ListSetExtensions() + for extension in extensions: + self.ClearExtension(extension[0]) + cls.Clear = Clear + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + return self.Extensions._HasExtension(extension_handle) + cls.HasExtension = HasExtension + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if self is other: + return True + + # Compare all fields contained directly in this message. + for field_descriptor in message_descriptor.fields: + label = field_descriptor.label + property_name = _PropertyName(field_descriptor.name) + # Non-repeated field equality requires matching "has" bits as well + # as having an equal value. + if label != _FieldDescriptor.LABEL_REPEATED: + self_has = self.HasField(property_name) + other_has = other.HasField(property_name) + if self_has != other_has: + return False + if not self_has: + # If the "has" bit for this field is False, we must stop here. + # Otherwise we will recurse forever on recursively-defined protos. + continue + if getattr(self, property_name) != getattr(other, property_name): + return False + + # Compare the extensions present in both messages. + return self.Extensions == other.Extensions + cls.__eq__ = __eq__ + + +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = _TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def BytesForField(message, field, value): + """Returns the number of bytes required to serialize a single field + in message. The field may be repeated or not, composite or not. + + Args: + message: The Message instance containing a field of the given type. + field: A FieldDescriptor describing the field of interest. + value: The value whose byte size we're interested in. + + Returns: The number of bytes required to serialize the current value + of "field" in "message", including space for tags and any other + necessary information. + """ + + if _MessageSetField(field): + return wire_format.MessageSetItemByteSize(field.number, value) + + field_number, field_type = field.number, field.type + + # Repeated fields. + if field.label == _FieldDescriptor.LABEL_REPEATED: + elements = value + else: + elements = [value] + + size = sum(_BytesForNonRepeatedElement(element, field_number, field_type) + for element in elements) + return size + + fields = message_descriptor.fields + has_field_names = (_HasFieldName(f.name) for f in fields) + zipped = zip(has_field_names, fields) + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + # Hardcoded fields first. + for has_field_name, field in zipped: + if (field.label == _FieldDescriptor.LABEL_REPEATED + or getattr(self, has_field_name)): + value = getattr(self, _ValueFieldName(field.name)) + size += BytesForField(self, field, value) + # Extensions next. + for field, value in self.Extensions._ListSetExtensions(): + size += BytesForField(self, field, value) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + return size + cls.ByteSize = ByteSize + + +def _MessageSetField(field_descriptor): + """Checks if a field should be serialized using the message set wire format. + + Args: + field_descriptor: Descriptor of the field. + + Returns: + True if the field should be serialized using the message set wire format, + false otherwise. + """ + return (field_descriptor.is_extension and + field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and + field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + field_descriptor.containing_type.GetOptions().message_set_wire_format) + + +def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): + """Appends the serialization of a single value to encoder. + + Args: + value: Value to serialize. + field_number: Field number of this value. + field_descriptor: Descriptor of the field to serialize. + encoder: encoder.Encoder object to which we should serialize this value. + """ + if _MessageSetField(field_descriptor): + encoder.AppendMessageSetItem(field_number, value) + return + + try: + method = _TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] + method(encoder, field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % + field_descriptor.type) + + +def _ImergeSorted(*streams): + """Merges N sorted iterators into a single sorted iterator. + Each element in streams must be an iterable that yields + its elements in sorted order, and the elements contained + in each stream must all be comparable. + + There may be repeated elements in the component streams or + across the streams; the repeated elements will all be repeated + in the merged iterator as well. + + I believe that the heapq module at HEAD in the Python + sources has a method like this, but for now we roll our own. + """ + iters = [iter(stream) for stream in streams] + heap = [] + for index, it in enumerate(iters): + try: + heap.append((it.next(), index)) + except StopIteration: + pass + heapq.heapify(heap) + + while heap: + smallest_value, idx = heap[0] + yield smallest_value + try: + next_element = iters[idx].next() + heapq.heapreplace(heap, (next_element, idx)) + except StopIteration: + heapq.heappop(heap) + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Encoder = encoder.Encoder + + def SerializeToString(self): + encoder = Encoder() + # We need to serialize all extension and non-extension fields + # together, in sorted order by field number. + + # Step 3: Iterate over all extension and non-extension fields, sorted + # in order of tag number, and serialize each one to the wire. + for field_descriptor, field_value in self.ListFields(): + if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: + repeated_value = field_value + else: + repeated_value = [field_value] + for element in repeated_value: + _SerializeValueToEncoder(element, field_descriptor.number, + field_descriptor, encoder) + return encoder.ToString() + cls.SerializeToString = SerializeToString + + +def _WireTypeForFieldType(field_type): + """Given a field type, returns the expected wire type.""" + try: + return _FIELD_TYPE_TO_WIRE_TYPE[field_type] + except KeyError: + raise message_mod.DecodeError('Unknown field type: %d' % field_type) + + +def _RecursivelyMerge(field_number, field_type, decoder, message): + """Decodes a message from decoder into message. + message is either a group or a nested message within some containing + protocol message. If it's a group, we use the group protocol to + deserialize, and if it's a nested message, we use the nested-message + protocol. + + Args: + field_number: The field number of message in its enclosing protocol buffer. + field_type: The field type of message. Must be either TYPE_MESSAGE + or TYPE_GROUP. + decoder: Decoder to read from. + message: Message to deserialize into. + """ + if field_type == _FieldDescriptor.TYPE_MESSAGE: + decoder.ReadMessageInto(message) + elif field_type == _FieldDescriptor.TYPE_GROUP: + decoder.ReadGroupInto(field_number, message) + else: + raise message_mod.DecodeError('Unexpected field type: %d' % field_type) + + +def _DeserializeScalarFromDecoder(field_type, decoder): + """Deserializes a scalar of the requested type from decoder. field_type must + be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. + """ + try: + method = _TYPE_TO_DESERIALIZE_METHOD[field_type] + return method(decoder) + except KeyError: + raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) + + +def _SkipField(field_number, wire_type, decoder): + """Skips a field with the specified wire type. + + Args: + field_number: Tag number of the field to skip. + wire_type: Wire type of the field to skip. + decoder: Decoder used to deserialize the messsage. It must be positioned + just after reading the the tag and wire type of the field. + """ + if wire_type == wire_format.WIRETYPE_VARINT: + decoder.ReadInt32() + elif wire_type == wire_format.WIRETYPE_FIXED64: + decoder.ReadFixed64() + elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + decoder.SkipBytes(decoder.ReadInt32()) + elif wire_type == wire_format.WIRETYPE_START_GROUP: + _SkipGroup(field_number, decoder) + elif wire_type == wire_format.WIRETYPE_END_GROUP: + pass + elif wire_type == wire_format.WIRETYPE_FIXED32: + decoder.ReadFixed32() + else: + raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type) + + +def _SkipGroup(group_number, decoder): + """Skips a nested group from the decoder. + + Args: + group_number: Tag number of the group to skip. + decoder: Decoder used to deserialize the message. It must be positioned + exactly at the beginning of the message that should be skipped. + """ + while True: + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if (wire_type == wire_format.WIRETYPE_END_GROUP and + field_number == group_number): + return + _SkipField(field_number, wire_type, decoder) + + +def _DeserializeMessageSetItem(message, decoder): + """Deserializes a message using the message set wire format. + + Args: + message: Message to be parsed to. + decoder: The decoder to be used to deserialize encoded data. Note that the + decoder should be positioned just after reading the START_GROUP tag that + began the messageset item. + """ + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + type_id = decoder.ReadInt32() + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + extension_dict = message.Extensions + extensions_by_number = extension_dict._AllExtensionsByNumber() + if type_id not in extensions_by_number: + _SkipField(field_number, wire_type, decoder) + return + + field_descriptor = extensions_by_number[type_id] + value = extension_dict[field_descriptor] + decoder.ReadMessageInto(value) + # Read the END_GROUP tag. + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1: + raise message_mod.DecodeError( + 'Incorrect message set wire format. ' + 'wire_type: %d, field_number: %d' % (wire_type, field_number)) + + +def _DeserializeOneEntity(message_descriptor, message, decoder): + """Deserializes the next wire entity from decoder into message. + The next wire entity is either a scalar or a nested message, + and may also be an element in a repeated field (the wire encoding + is the same). + + Args: + message_descriptor: A Descriptor instance describing all fields + in message. + message: The Message instance into which we're decoding our fields. + decoder: The Decoder we're using to deserialize encoded data. + + Returns: The number of bytes read from decoder during this method. + """ + initial_position = decoder.Position() + field_number, wire_type = decoder.ReadFieldNumberAndWireType() + extension_dict = message.Extensions + extensions_by_number = extension_dict._AllExtensionsByNumber() + if field_number in message_descriptor.fields_by_number: + # Non-extension field. + field_descriptor = message_descriptor.fields_by_number[field_number] + value = getattr(message, _PropertyName(field_descriptor.name)) + def nonextension_setter_fn(scalar): + setattr(message, _PropertyName(field_descriptor.name), scalar) + scalar_setter_fn = nonextension_setter_fn + elif field_number in extensions_by_number: + # Extension field. + field_descriptor = extensions_by_number[field_number] + value = extension_dict[field_descriptor] + def extension_setter_fn(scalar): + extension_dict[field_descriptor] = scalar + scalar_setter_fn = extension_setter_fn + elif wire_type == wire_format.WIRETYPE_END_GROUP: + # We assume we're being parsed as the group that's ended. + return 0 + elif (wire_type == wire_format.WIRETYPE_START_GROUP and + field_number == 1 and + message_descriptor.GetOptions().message_set_wire_format): + # A Message Set item. + _DeserializeMessageSetItem(message, decoder) + return decoder.Position() - initial_position + else: + _SkipField(field_number, wire_type, decoder) + return decoder.Position() - initial_position + + # If we reach this point, we've identified the field as either + # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|, + # and |value| appropriately. Now actually deserialize the thing. + # + # field_descriptor: Describes the field we're deserializing. + # value: The value currently stored in the field to deserialize. + # Used only if the field is composite and/or repeated. + # scalar_setter_fn: A function F such that F(scalar) will + # set a nonrepeated scalar value for this field. Used only + # if this field is a nonrepeated scalar. + + field_number = field_descriptor.number + field_type = field_descriptor.type + expected_wire_type = _WireTypeForFieldType(field_type) + if wire_type != expected_wire_type: + # Need to fill in uninterpreted_bytes. Work for the next CL. + raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') + + property_name = _PropertyName(field_descriptor.name) + label = field_descriptor.label + cpp_type = field_descriptor.cpp_type + + # Nonrepeated scalar. Just set the field directly. + if (label != _FieldDescriptor.LABEL_REPEATED + and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + + # Nonrepeated composite. Recursively deserialize. + if label != _FieldDescriptor.LABEL_REPEATED: + composite = value + _RecursivelyMerge(field_number, field_type, decoder, composite) + return decoder.Position() - initial_position + + # Now we know we're dealing with a repeated field of some kind. + element_list = value + + if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: + # Repeated scalar. + element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) + return decoder.Position() - initial_position + else: + # Repeated composite. + composite = element_list.add() + _RecursivelyMerge(field_number, field_type, decoder, composite) + return decoder.Position() - initial_position + + +def _FieldOrExtensionValues(message, field_or_extension): + """Retrieves the list of values for the specified field or extension. + + The target field or extension can be optional, required or repeated, but it + must have value(s) set. The assumption is that the target field or extension + is set (e.g. _HasFieldOrExtension holds true). + + Args: + message: Message which contains the target field or extension. + field_or_extension: Field or extension for which the list of values is + required. Must be an instance of FieldDescriptor. + + Returns: + A list of values for the specified field or extension. This list will only + contain a single element if the field is non-repeated. + """ + if field_or_extension.is_extension: + value = message.Extensions[field_or_extension] + else: + value = getattr(message, _ValueFieldName(field_or_extension.name)) + if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED: + return [value] + else: + # In this case value is a list or repeated values. + return value + + +def _HasFieldOrExtension(message, field_or_extension): + """Checks if a message has the specified field or extension set. + + The field or extension specified can be optional, required or repeated. If + it is repeated, this function returns True. Otherwise it checks the has bit + of the field or extension. + + Args: + message: Message which contains the target field or extension. + field_or_extension: Field or extension to check. This must be a + FieldDescriptor instance. + + Returns: + True if the message has a value set for the specified field or extension, + or if the field or extension is repeated. + """ + if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED: + return True + if field_or_extension.is_extension: + return message.HasExtension(field_or_extension) + else: + return message.HasField(field_or_extension.name) + + +def _IsFieldOrExtensionInitialized(message, field): + """Checks if a message field or extension is initialized. + + Args: + message: The message which contains the field or extension. + field: Field or extension to check. This must be a FieldDescriptor instance. + + Returns: + True if the field/extension can be considered initialized. + """ + # If the field is required and is not set, it isn't initialized. + if field.label == _FieldDescriptor.LABEL_REQUIRED: + if not _HasFieldOrExtension(message, field): + return False + + # If the field is optional and is not set, or if it + # isn't a submessage then the field is initialized. + if field.label == _FieldDescriptor.LABEL_OPTIONAL: + if not _HasFieldOrExtension(message, field): + return True + if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: + return True + + # The field is set and is either a single or a repeated submessage. + messages = _FieldOrExtensionValues(message, field) + # If all submessages in this field are initialized, the field is + # considered initialized. + for message in messages: + if not message.IsInitialized(): + return False + return True + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + Decoder = decoder.Decoder + def MergeFromString(self, serialized): + decoder = Decoder(serialized) + byte_count = 0 + while not decoder.EndOfStream(): + bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder) + if not bytes_read: + break + byte_count += bytes_read + return byte_count + cls.MergeFromString = MergeFromString + + +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized method to the protocol message class.""" + def IsInitialized(self): + fields_and_extensions = [] + fields_and_extensions.extend(message_descriptor.fields) + fields_and_extensions.extend( + self.Extensions._AllExtensionsByNumber().values()) + for field_or_extension in fields_and_extensions: + if not _IsFieldOrExtensionInitialized(self, field_or_extension): + return False + return True + cls.IsInitialized = IsInitialized + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + + # TODO(robinson): Add support for remaining Message methods. + + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(cls) + _AddClearFieldMethod(cls) + _AddClearExtensionMethod(cls) + _AddClearMethod(cls) + _AddHasExtensionMethod(cls) + _AddEqualsMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(message_descriptor, cls) + + +def _AddPrivateHelperMethods(cls): + """Adds implementation of private helper methods to cls.""" + + def MaybeCallTransitionToNonemptyCallback(self): + """Calls self._listener.TransitionToNonempty() the first time this + method is called. On all subsequent calls, this is a no-op. + """ + if not self._called_transition_to_nonempty: + self._listener.TransitionToNonempty() + self._called_transition_to_nonempty = True + cls._MaybeCallTransitionToNonemptyCallback = ( + MaybeCallTransitionToNonemptyCallback) + + def MarkByteSizeDirty(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener.ByteSizeDirty() + cls._MarkByteSizeDirty = MarkByteSizeDirty + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message, has_field_name): + """Args: + parent_message: The message whose _MaybeCallTransitionToNonemptyCallback() + and _MarkByteSizeDirty() methods we should call when we receive + TransitionToNonempty() and ByteSizeDirty() messages. + has_field_name: The name of the "has" field that we should set in + the parent message when we receive a TransitionToNonempty message, + or None if there's no "has" field to set. (This will be the case + for child objects in "repeated" fields). + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + self._has_field_name = has_field_name + + def TransitionToNonempty(self): + try: + if self._has_field_name is not None: + setattr(self._parent_message_weakref, self._has_field_name, True) + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + def ByteSizeDirty(self): + try: + self._parent_message_weakref._MarkByteSizeDirty() + except ReferenceError: + # Same as above. + pass + + +# TODO(robinson): Move elsewhere? +# TODO(robinson): Provide a clear() method here in addition to ClearField()? +class _RepeatedScalarFieldContainer(object): + + """Simple, type-checked, list-like container for holding repeated scalars. + """ + + def __init__(self, message_listener, type_checker): + """ + Args: + message_listener: A MessageListener implementation. + The _RepeatedScalarFieldContaininer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + type_checker: A _ValueChecker instance to run on elements inserted + into this container. + """ + self._message_listener = message_listener + self._type_checker = type_checker + self._values = [] + + def append(self, elem): + self._type_checker.CheckValue(elem) + self._values.append(elem) + self._message_listener.ByteSizeDirty() + if len(self._values) == 1: + self._message_listener.TransitionToNonempty() + + # List-like __getitem__() support also makes us iterable (via "iter(foo)" + # or implicitly via "for i in mylist:") for free. + def __getitem__(self, key): + return self._values[key] + + def __setitem__(self, key, value): + # No need to call TransitionToNonempty(), since if we're able to + # set the element at this index, we were already nonempty before + # this method was called. + self._message_listener.ByteSizeDirty() + self._type_checker.CheckValue(value) + self._values[key] = value + + def __len__(self): + return len(self._values) + + def __eq__(self, other): + if self is other: + return True + # Special case for the same type which should be common and fast. + if isinstance(other, self.__class__): + return other._values == self._values + # We are presumably comparing against some other sequence type. + return other == self._values + + def __ne__(self, other): + # Can't use != here since it would infinitely recurse. + return not self == other + + +# TODO(robinson): Move elsewhere? +# TODO(robinson): Provide a clear() method here in addition to ClearField()? +# TODO(robinson): Unify common functionality with +# _RepeatedScalarFieldContaininer? +class _RepeatedCompositeFieldContainer(object): + + """Simple, list-like container for holding repeated composite fields. + """ + + def __init__(self, message_listener, message_descriptor): + """Note that we pass in a descriptor instead of the generated directly, + since at the time we construct a _RepeatedCompositeFieldContainer we + haven't yet necessarily initialized the type that will be contained in the + container. + + Args: + message_listener: A MessageListener implementation. + The _RepeatedCompositeFieldContainer will call this object's + TransitionToNonempty() method when it transitions from being empty to + being nonempty. + message_descriptor: A Descriptor instance describing the protocol type + that should be present in this container. We'll use the + _concrete_class field of this descriptor when the client calls add(). + """ + self._message_listener = message_listener + self._message_descriptor = message_descriptor + self._values = [] + + def add(self): + new_element = self._message_descriptor._concrete_class() + new_element._SetListener(self._message_listener) + self._values.append(new_element) + self._message_listener.ByteSizeDirty() + self._message_listener.TransitionToNonempty() + return new_element + + # List-like __getitem__() support also makes us iterable (via "iter(foo)" + # or implicitly via "for i in mylist:") for free. + def __getitem__(self, key): + return self._values[key] + + def __len__(self): + return len(self._values) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + return self._values == other._values + + def __ne__(self, other): + # Can't use != here since it would infinitely recurse. + return not self == other + + # TODO(robinson): Implement, document, and test slicing support. + + +# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): There's so much similarity between the way that +# extensions behave and the way that normal fields behave that it would +# be really nice to unify more code. It's not immediately obvious +# how to do this, though, and I'd rather get the full functionality +# implemented (and, crucially, get all the tests and specs fleshed out +# and passing), and then come back to this thorny unification problem. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for supporting an indexable "Extensions" + field on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + class _ExtensionListener(object): + + """Adapts an _ExtensionDict to behave as a MessageListener.""" + + def __init__(self, extension_dict, handle_id): + self._extension_dict = extension_dict + self._handle_id = handle_id + + def TransitionToNonempty(self): + self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id) + + def ByteSizeDirty(self): + self._extension_dict._SubmessageByteSizeBecameDirty() + + # TODO(robinson): Somewhere, we need to blow up if people + # try to register two extensions with the same field number. + # (And we need a test for this of course). + + def __init__(self, extended_message, known_extensions): + """extended_message: Message instance for which we are the Extensions dict. + known_extensions: Iterable of known extension handles. + These must be FieldDescriptors. + """ + # We keep a weak reference to extended_message, since + # it has a reference to this instance in turn. + self._extended_message = weakref.proxy(extended_message) + # We make a deep copy of known_extensions to avoid any + # thread-safety concerns, since the argument passed in + # is the global (class-level) dict of known extensions for + # this type of message, which could be modified at any time + # via a RegisterExtension() call. + # + # This dict maps from handle id to handle (a FieldDescriptor). + # + # XXX + # TODO(robinson): This isn't good enough. The client could + # instantiate an object in module A, then afterward import + # module B and pass the instance to B.Foo(). If B imports + # an extender of this proto and then tries to use it, B + # will get a KeyError, even though the extension *is* registered + # at the time of use. + # XXX + self._known_extensions = dict((id(e), e) for e in known_extensions) + # Read lock around self._values, which may be modified by multiple + # concurrent readers in the conceptually "const" __getitem__ method. + # So, we grab this lock in every "read-only" method to ensure + # that concurrent read access is safe without external locking. + self._lock = threading.Lock() + # Maps from extension handle ID to current value of that extension. + self._values = {} + # Maps from extension handle ID to a boolean "has" bit, but only + # for non-repeated extension fields. + keys = (id for id, extension in self._known_extensions.iteritems() + if extension.label != _FieldDescriptor.LABEL_REPEATED) + self._has_bits = dict.fromkeys(keys, False) + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + # We don't care as much about keeping critical sections short in the + # extension support, since it's presumably much less of a common case. + self._lock.acquire() + try: + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + if handle_id not in self._values: + self._AddMissingHandle(extension_handle, handle_id) + return self._values[handle_id] + finally: + self._lock.release() + + def __eq__(self, other): + # We have to grab read locks since we're accessing _values + # in a "const" method. See the comment in the constructor. + if self is other: + return True + self._lock.acquire() + try: + other._lock.acquire() + try: + if self._has_bits != other._has_bits: + return False + # If there's a "has" bit, then only compare values where it is true. + for k, v in self._values.iteritems(): + if self._has_bits.get(k, False) and v != other._values[k]: + return False + return True + finally: + other._lock.release() + finally: + self._lock.release() + + def __ne__(self, other): + return not self == other + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call + # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field + # this way, to set any necssary "has" bits in the ancestors of the extended + # message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + field = extension_handle # Just shorten the name. + if (field.label == _FieldDescriptor.LABEL_OPTIONAL + and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = _VALUE_CHECKERS[field.cpp_type] + type_checker.CheckValue(value) + self._values[handle_id] = value + self._has_bits[handle_id] = True + self._extended_message._MarkByteSizeDirty() + self._extended_message._MaybeCallTransitionToNonemptyCallback() + else: + raise TypeError('Extension is repeated and/or a composite type.') + + def _AddMissingHandle(self, extension_handle, handle_id): + """Helper internal to ExtensionDict.""" + # Special handling for non-repeated message extensions, which (like + # normal fields of this kind) are initialized lazily. + # REQUIRES: _lock already held. + cpp_type = extension_handle.cpp_type + label = extension_handle.label + if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE + and label != _FieldDescriptor.LABEL_REPEATED): + self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id) + else: + self._values[handle_id] = _DefaultValueForField( + self._extended_message, extension_handle) + + def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id): + """Helper internal to ExtensionDict.""" + # REQUIRES: _lock already held. + value = extension_handle.message_type._concrete_class() + value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id)) + self._values[handle_id] = value + + def _SubmessageTransitionedToNonempty(self, handle_id): + """Called when a submessage with a given handle id first transitions to + being nonempty. Called by _ExtensionListener. + """ + assert handle_id in self._has_bits + self._has_bits[handle_id] = True + self._extended_message._MaybeCallTransitionToNonemptyCallback() + + def _SubmessageByteSizeBecameDirty(self): + """Called whenever a submessage's cached byte size becomes invalid + (goes from being "clean" to being "dirty"). Called by _ExtensionListener. + """ + self._extended_message._MarkByteSizeDirty() + + # We may wish to widen the public interface of Message.Extensions + # to expose some of this private functionality in the future. + # For now, we make all this functionality module-private and just + # implement what we need for serialization/deserialization, + # HasField()/ClearField(), etc. + + def _HasExtension(self, extension_handle): + """Method for internal use by this module. + Returns true iff we "have" this extension in the sense of the + "has" bit being set. + """ + handle_id = id(extension_handle) + # Note that this is different from the other checks. + if handle_id not in self._has_bits: + raise KeyError('Extension not known to this class, or is repeated field.') + return self._has_bits[handle_id] + + # Intentionally pretty similar to ClearField() above. + def _ClearExtension(self, extension_handle): + """Method for internal use by this module. + Clears the specified extension, unsetting its "has" bit. + """ + handle_id = id(extension_handle) + if handle_id not in self._known_extensions: + raise KeyError('Extension not known to this class') + default_value = _DefaultValueForField(self._extended_message, + extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + self._extended_message._MarkByteSizeDirty() + else: + cpp_type = extension_handle.cpp_type + if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if handle_id in self._values: + # Future modifications to this object shouldn't set any + # "has" bits here. + self._values[handle_id]._SetListener(None) + if self._has_bits[handle_id]: + self._has_bits[handle_id] = False + self._extended_message._MarkByteSizeDirty() + if handle_id in self._values: + del self._values[handle_id] + + def _ListSetExtensions(self): + """Method for internal use by this module. + + Returns an sequence of all extensions that are currently "set" + in this extension dict. A "set" extension is a repeated extension, + or a non-repeated extension with its "has" bit set. + + The returned sequence contains (field_descriptor, value) pairs, + where value is the current value of the extension with the given + field descriptor. + + The sequence values are in arbitrary order. + """ + self._lock.acquire() # Read-only methods must lock around self._values. + try: + set_extensions = [] + for handle_id, value in self._values.iteritems(): + handle = self._known_extensions[handle_id] + if (handle.label == _FieldDescriptor.LABEL_REPEATED + or self._has_bits[handle_id]): + set_extensions.append((handle, value)) + return set_extensions + finally: + self._lock.release() + + def _AllExtensionsByNumber(self): + """Method for internal use by this module. + + Returns: A dict mapping field_number to (handle, field_descriptor), + for *all* registered extensions for this dict. + """ + # TODO(robinson): Precompute and store this away. Note that we'll have to + # be careful when we move away from having _known_extensions as a + # deep-copied member of this object. + return dict((f.number, f) for f in self._known_extensions.itervalues()) + + +# None of the typecheckers below make any attempt to guard against people +# subclassing builtin types and doing weird things. We're not trying to +# protect against malicious clients here, just people accidentally shooting +# themselves in the foot in obvious ways. + +class _TypeChecker(object): + + """Type checker used to catch type errors as early as possible + when the client is setting scalar fields in protocol messages. + """ + + def __init__(self, *acceptable_types): + self._acceptable_types = acceptable_types + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, self._acceptable_types): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), self._acceptable_types)) + raise TypeError(message) + + +# _IntValueChecker and its subclasses perform integer type-checks +# and bounds-checks. +class _IntValueChecker(object): + + """Checker used for integer fields. Performs type-check and range check.""" + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if not self._MIN <= proposed_value <= self._MAX: + raise ValueError('Value out of range: %d' % proposed_value) + +class _Int32ValueChecker(_IntValueChecker): + # We're sure to use ints instead of longs here since comparison may be more + # efficient. + _MIN = -2147483648 + _MAX = 2147483647 + +class _Uint32ValueChecker(_IntValueChecker): + _MIN = 0 + _MAX = (1 << 32) - 1 + +class _Int64ValueChecker(_IntValueChecker): + _MIN = -(1 << 63) + _MAX = (1 << 63) - 1 + +class _Uint64ValueChecker(_IntValueChecker): + _MIN = 0 + _MAX = (1 << 64) - 1 + + +# Type-checkers for all scalar CPPTYPEs. +_VALUE_CHECKERS = { + _FieldDescriptor.CPPTYPE_INT32: _Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_INT64: _Int64ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT32: _Uint32ValueChecker(), + _FieldDescriptor.CPPTYPE_UINT64: _Uint64ValueChecker(), + _FieldDescriptor.CPPTYPE_DOUBLE: _TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_FLOAT: _TypeChecker( + float, int, long), + _FieldDescriptor.CPPTYPE_BOOL: _TypeChecker(bool, int), + _FieldDescriptor.CPPTYPE_ENUM: _Int32ValueChecker(), + _FieldDescriptor.CPPTYPE_STRING: _TypeChecker(str), + } + + +# Map from field type to a function F, such that F(field_num, value) +# gives the total byte size for a value of the given type. This +# byte size includes tag information and any other additional space +# associated with serializing "value". +_TYPE_TO_BYTE_SIZE_FN = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.DoubleByteSize, + _FieldDescriptor.TYPE_FLOAT: wire_format.FloatByteSize, + _FieldDescriptor.TYPE_INT64: wire_format.Int64ByteSize, + _FieldDescriptor.TYPE_UINT64: wire_format.UInt64ByteSize, + _FieldDescriptor.TYPE_INT32: wire_format.Int32ByteSize, + _FieldDescriptor.TYPE_FIXED64: wire_format.Fixed64ByteSize, + _FieldDescriptor.TYPE_FIXED32: wire_format.Fixed32ByteSize, + _FieldDescriptor.TYPE_BOOL: wire_format.BoolByteSize, + _FieldDescriptor.TYPE_STRING: wire_format.StringByteSize, + _FieldDescriptor.TYPE_GROUP: wire_format.GroupByteSize, + _FieldDescriptor.TYPE_MESSAGE: wire_format.MessageByteSize, + _FieldDescriptor.TYPE_BYTES: wire_format.BytesByteSize, + _FieldDescriptor.TYPE_UINT32: wire_format.UInt32ByteSize, + _FieldDescriptor.TYPE_ENUM: wire_format.EnumByteSize, + _FieldDescriptor.TYPE_SFIXED32: wire_format.SFixed32ByteSize, + _FieldDescriptor.TYPE_SFIXED64: wire_format.SFixed64ByteSize, + _FieldDescriptor.TYPE_SINT32: wire_format.SInt32ByteSize, + _FieldDescriptor.TYPE_SINT64: wire_format.SInt64ByteSize + } + +# Maps from field type to an unbound Encoder method F, such that +# F(encoder, field_number, value) will append the serialization +# of a value of this type to the encoder. +_Encoder = encoder.Encoder +_TYPE_TO_SERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Encoder.AppendDouble, + _FieldDescriptor.TYPE_FLOAT: _Encoder.AppendFloat, + _FieldDescriptor.TYPE_INT64: _Encoder.AppendInt64, + _FieldDescriptor.TYPE_UINT64: _Encoder.AppendUInt64, + _FieldDescriptor.TYPE_INT32: _Encoder.AppendInt32, + _FieldDescriptor.TYPE_FIXED64: _Encoder.AppendFixed64, + _FieldDescriptor.TYPE_FIXED32: _Encoder.AppendFixed32, + _FieldDescriptor.TYPE_BOOL: _Encoder.AppendBool, + _FieldDescriptor.TYPE_STRING: _Encoder.AppendString, + _FieldDescriptor.TYPE_GROUP: _Encoder.AppendGroup, + _FieldDescriptor.TYPE_MESSAGE: _Encoder.AppendMessage, + _FieldDescriptor.TYPE_BYTES: _Encoder.AppendBytes, + _FieldDescriptor.TYPE_UINT32: _Encoder.AppendUInt32, + _FieldDescriptor.TYPE_ENUM: _Encoder.AppendEnum, + _FieldDescriptor.TYPE_SFIXED32: _Encoder.AppendSFixed32, + _FieldDescriptor.TYPE_SFIXED64: _Encoder.AppendSFixed64, + _FieldDescriptor.TYPE_SINT32: _Encoder.AppendSInt32, + _FieldDescriptor.TYPE_SINT64: _Encoder.AppendSInt64, + } + +# Maps from field type to expected wiretype. +_FIELD_TYPE_TO_WIRE_TYPE = { + _FieldDescriptor.TYPE_DOUBLE: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FLOAT: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_INT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_UINT64: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_INT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_FIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_FIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_BOOL: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_STRING: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_GROUP: wire_format.WIRETYPE_START_GROUP, + _FieldDescriptor.TYPE_MESSAGE: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_BYTES: + wire_format.WIRETYPE_LENGTH_DELIMITED, + _FieldDescriptor.TYPE_UINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_ENUM: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SFIXED32: wire_format.WIRETYPE_FIXED32, + _FieldDescriptor.TYPE_SFIXED64: wire_format.WIRETYPE_FIXED64, + _FieldDescriptor.TYPE_SINT32: wire_format.WIRETYPE_VARINT, + _FieldDescriptor.TYPE_SINT64: wire_format.WIRETYPE_VARINT, + } + +# Maps from field type to an unbound Decoder method F, +# such that F(decoder) will read a field of the requested type. +# +# Note that Message and Group are intentionally missing here. +# They're handled by _RecursivelyMerge(). +_Decoder = decoder.Decoder +_TYPE_TO_DESERIALIZE_METHOD = { + _FieldDescriptor.TYPE_DOUBLE: _Decoder.ReadDouble, + _FieldDescriptor.TYPE_FLOAT: _Decoder.ReadFloat, + _FieldDescriptor.TYPE_INT64: _Decoder.ReadInt64, + _FieldDescriptor.TYPE_UINT64: _Decoder.ReadUInt64, + _FieldDescriptor.TYPE_INT32: _Decoder.ReadInt32, + _FieldDescriptor.TYPE_FIXED64: _Decoder.ReadFixed64, + _FieldDescriptor.TYPE_FIXED32: _Decoder.ReadFixed32, + _FieldDescriptor.TYPE_BOOL: _Decoder.ReadBool, + _FieldDescriptor.TYPE_STRING: _Decoder.ReadString, + _FieldDescriptor.TYPE_BYTES: _Decoder.ReadBytes, + _FieldDescriptor.TYPE_UINT32: _Decoder.ReadUInt32, + _FieldDescriptor.TYPE_ENUM: _Decoder.ReadEnum, + _FieldDescriptor.TYPE_SFIXED32: _Decoder.ReadSFixed32, + _FieldDescriptor.TYPE_SFIXED64: _Decoder.ReadSFixed64, + _FieldDescriptor.TYPE_SINT32: _Decoder.ReadSInt32, + _FieldDescriptor.TYPE_SINT64: _Decoder.ReadSInt64, + } diff --git a/python/google/protobuf/service.py b/python/google/protobuf/service.py new file mode 100755 index 00000000..461031b7 --- /dev/null +++ b/python/google/protobuf/service.py @@ -0,0 +1,194 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Declares the RPC service interfaces. + +This module declares the abstract interfaces underlying proto2 RPC +services. These are intented to be independent of any particular RPC +implementation, so that proto2 services can be used on top of a variety +of implementations. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class Service(object): + + """Abstract base interface for protocol-buffer-based RPC services. + + Services themselves are abstract classes (implemented either by servers or as + stubs), but they subclass this base interface. The methods of this + interface can be used to call the methods of the service without knowing + its exact type at compile time (analogous to the Message interface). + """ + + def GetDescriptor(self): + """Retrieves this service's descriptor.""" + raise NotImplementedError + + def CallMethod(self, method_descriptor, rpc_controller, + request, done): + """Calls a method of the service specified by method_descriptor. + + Preconditions: + * method_descriptor.service == GetDescriptor + * request is of the exact same classes as returned by + GetRequestClass(method). + * After the call has started, the request must not be modified. + * "rpc_controller" is of the correct type for the RPC implementation being + used by this Service. For stubs, the "correct type" depends on the + RpcChannel which the stub is using. + + Postconditions: + * "done" will be called when the method is complete. This may be + before CallMethod() returns or it may be at some point in the future. + """ + raise NotImplementedError + + def GetRequestClass(self, method_descriptor): + """Returns the class of the request message for the specified method. + + CallMethod() requires that the request is of a particular subclass of + Message. GetRequestClass() gets the default instance of this required + type. + + Example: + method = service.GetDescriptor().FindMethodByName("Foo") + request = stub.GetRequestClass(method)() + request.ParseFromString(input) + service.CallMethod(method, request, callback) + """ + raise NotImplementedError + + def GetResponseClass(self, method_descriptor): + """Returns the class of the response message for the specified method. + + This method isn't really needed, as the RpcChannel's CallMethod constructs + the response protocol message. It's provided anyway in case it is useful + for the caller to know the response type in advance. + """ + raise NotImplementedError + + +class RpcController(object): + + """Abstract interface for an RPC channel. + + An RpcChannel represents a communication line to a service which can be used + to call that service's methods. The service may be running on another + machine. Normally, you should not use an RpcChannel directly, but instead + construct a stub {@link Service} wrapping it. Example: + + Example: + RpcChannel channel = rpcImpl.Channel("remotehost.example.com:1234") + RpcController controller = rpcImpl.Controller() + MyService service = MyService_Stub(channel) + service.MyMethod(controller, request, callback) + """ + + # Client-side methods below + + def Reset(self): + """Resets the RpcController to its initial state. + + After the RpcController has been reset, it may be reused in + a new call. Must not be called while an RPC is in progress. + """ + raise NotImplementedError + + def Failed(self): + """Returns true if the call failed. + + After a call has finished, returns true if the call failed. The possible + reasons for failure depend on the RPC implementation. Failed() must not + be called before a call has finished. If Failed() returns true, the + contents of the response message are undefined. + """ + raise NotImplementedError + + def ErrorText(self): + """If Failed is true, returns a human-readable description of the error.""" + raise NotImplementedError + + def StartCancel(self): + """Initiate cancellation. + + Advises the RPC system that the caller desires that the RPC call be + canceled. The RPC system may cancel it immediately, may wait awhile and + then cancel it, or may not even cancel the call at all. If the call is + canceled, the "done" callback will still be called and the RpcController + will indicate that the call failed at that time. + """ + raise NotImplementedError + + # Server-side methods below + + def SetFailed(self, reason): + """Sets a failure reason. + + Causes Failed() to return true on the client side. "reason" will be + incorporated into the message returned by ErrorText(). If you find + you need to return machine-readable information about failures, you + should incorporate it into your response protocol buffer and should + NOT call SetFailed(). + """ + raise NotImplementedError + + def IsCanceled(self): + """Checks if the client cancelled the RPC. + + If true, indicates that the client canceled the RPC, so the server may + as well give up on replying to it. The server should still call the + final "done" callback. + """ + raise NotImplementedError + + def NotifyOnCancel(self, callback): + """Sets a callback to invoke on cancel. + + Asks that the given callback be called when the RPC is canceled. The + callback will always be called exactly once. If the RPC completes without + being canceled, the callback will be called after completion. If the RPC + has already been canceled when NotifyOnCancel() is called, the callback + will be called immediately. + + NotifyOnCancel() must be called no more than once per request. + """ + raise NotImplementedError + + +class RpcChannel(object): + + """An RpcController mediates a single method call. + + The primary purpose of the controller is to provide a way to manipulate + settings specific to the RPC implementation and to find out about RPC-level + errors. The methods provided by the RpcController interface are intended + to be a "least common denominator" set of features which we expect all + implementations to support. Specific implementations may provide more + advanced features (e.g. deadline propagation). + """ + + def CallMethod(self, method_descriptor, rpc_controller, + request, response_class, done): + """Calls the method identified by the descriptor. + + Call the given method of the remote service. The signature of this + procedure looks the same as Service.CallMethod(), but the requirements + are less strict in one important way: the request object doesn't have to + be of any specific class as long as its descriptor is method.input_type. + """ + raise NotImplementedError diff --git a/python/google/protobuf/service_reflection.py b/python/google/protobuf/service_reflection.py new file mode 100755 index 00000000..6e3bf14e --- /dev/null +++ b/python/google/protobuf/service_reflection.py @@ -0,0 +1,275 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains metaclasses used to create protocol service and service stub +classes from ServiceDescriptor objects at runtime. + +The GeneratedServiceType and GeneratedServiceStubType metaclasses are used to +inject all useful functionality into the classes output by the protocol +compiler at compile-time. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + + +class GeneratedServiceType(type): + + """Metaclass for service classes created at runtime from ServiceDescriptors. + + Implementations for all methods described in the Service class are added here + by this class. We also create properties to allow getting/setting all fields + in the protocol message. + + The protocol compiler currently uses this metaclass to create protocol service + classes at runtime. Clients can also manually create their own classes at + runtime, as in this example: + + mydescriptor = ServiceDescriptor(.....) + class MyProtoService(service.Service): + __metaclass__ = GeneratedServiceType + DESCRIPTOR = mydescriptor + myservice_instance = MyProtoService() + ... + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service class. + + Args: + name: Name of the class (ignored, but required by the metaclass + protocol). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service class is subclassed. + if GeneratedServiceType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceType._DESCRIPTOR_KEY] + service_builder = _ServiceBuilder(descriptor) + service_builder.BuildService(cls) + + +class GeneratedServiceStubType(GeneratedServiceType): + + """Metaclass for service stubs created at runtime from ServiceDescriptors. + + This class has similar responsibilities as GeneratedServiceType, except that + it creates the service stub classes. + """ + + _DESCRIPTOR_KEY = 'DESCRIPTOR' + + def __init__(cls, name, bases, dictionary): + """Creates a message service stub class. + + Args: + name: Name of the class (ignored, here). + bases: Base classes of the class being constructed. + dictionary: The class dictionary of the class being constructed. + dictionary[_DESCRIPTOR_KEY] must contain a ServiceDescriptor object + describing this protocol service type. + """ + super(GeneratedServiceStubType, cls).__init__(name, bases, dictionary) + # Don't do anything if this class doesn't have a descriptor. This happens + # when a service stub is subclassed. + if GeneratedServiceStubType._DESCRIPTOR_KEY not in dictionary: + return + descriptor = dictionary[GeneratedServiceStubType._DESCRIPTOR_KEY] + service_stub_builder = _ServiceStubBuilder(descriptor) + service_stub_builder.BuildServiceStub(cls) + + +class _ServiceBuilder(object): + + """This class constructs a protocol service class using a service descriptor. + + Given a service descriptor, this class constructs a class that represents + the specified service descriptor. One service builder instance constructs + exactly one service class. That means all instances of that class share the + same builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + service class. + """ + self.descriptor = service_descriptor + + def BuildService(self, cls): + """Constructs the service class. + + Args: + cls: The class that will be constructed. + """ + + # CallMethod needs to operate with an instance of the Service class. This + # internal wrapper function exists only to be able to pass the service + # instance to the method that does the real CallMethod work. + def _WrapCallMethod(srvc, method_descriptor, + rpc_controller, request, callback): + self._CallMethod(srvc, method_descriptor, + rpc_controller, request, callback) + self.cls = cls + cls.CallMethod = _WrapCallMethod + cls.GetDescriptor = self._GetDescriptor + cls.GetRequestClass = self._GetRequestClass + cls.GetResponseClass = self._GetResponseClass + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateNonImplementedMethod(method)) + + def _GetDescriptor(self): + """Retrieves the service descriptor. + + Returns: + The descriptor of the service (of type ServiceDescriptor). + """ + return self.descriptor + + def _CallMethod(self, srvc, method_descriptor, + rpc_controller, request, callback): + """Calls the method described by a given method descriptor. + + Args: + srvc: Instance of the service for which this method is called. + method_descriptor: Descriptor that represent the method to call. + rpc_controller: RPC controller to use for this method's execution. + request: Request protocol message. + callback: A callback to invoke after the method has completed. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'CallMethod() given method descriptor for wrong service type.') + method = getattr(self.cls, method_descriptor.name) + method(srvc, rpc_controller, request, callback) + + def _GetRequestClass(self, method_descriptor): + """Returns the class of the request protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + request protocol message class. + + Returns: + A class that represents the input protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetRequestClass() given method descriptor for wrong service type.') + return method_descriptor.input_type._concrete_class + + def _GetResponseClass(self, method_descriptor): + """Returns the class of the response protocol message. + + Args: + method_descriptor: Descriptor of the method for which to return the + response protocol message class. + + Returns: + A class that represents the output protocol message of the specified + method. + """ + if method_descriptor.containing_service != self.descriptor: + raise RuntimeError( + 'GetResponseClass() given method descriptor for wrong service type.') + return method_descriptor.output_type._concrete_class + + def _GenerateNonImplementedMethod(self, method): + """Generates and returns a method that can be set for a service methods. + + Args: + method: Descriptor of the service method for which a method is to be + generated. + + Returns: + A method that can be added to the service class. + """ + return lambda inst, rpc_controller, request, callback: ( + self._NonImplementedMethod(method.name, rpc_controller, callback)) + + def _NonImplementedMethod(self, method_name, rpc_controller, callback): + """The body of all methods in the generated service class. + + Args: + method_name: Name of the method being executed. + rpc_controller: RPC controller used to execute this method. + callback: A callback which will be invoked when the method finishes. + """ + rpc_controller.SetFailed('Method %s not implemented.' % method_name) + callback(None) + + +class _ServiceStubBuilder(object): + + """Constructs a protocol service stub class using a service descriptor. + + Given a service descriptor, this class constructs a suitable stub class. + A stub is just a type-safe wrapper around an RpcChannel which emulates a + local implementation of the service. + + One service stub builder instance constructs exactly one class. It means all + instances of that class share the same service stub builder. + """ + + def __init__(self, service_descriptor): + """Initializes an instance of the service stub class builder. + + Args: + service_descriptor: ServiceDescriptor to use when constructing the + stub class. + """ + self.descriptor = service_descriptor + + def BuildServiceStub(self, cls): + """Constructs the stub class. + + Args: + cls: The class that will be constructed. + """ + + def _ServiceStubInit(stub, rpc_channel): + stub.rpc_channel = rpc_channel + self.cls = cls + cls.__init__ = _ServiceStubInit + for method in self.descriptor.methods: + setattr(cls, method.name, self._GenerateStubMethod(method)) + + def _GenerateStubMethod(self, method): + return lambda inst, rpc_controller, request, callback: self._StubMethod( + inst, method, rpc_controller, request, callback) + + def _StubMethod(self, stub, method_descriptor, + rpc_controller, request, callback): + """The body of all service methods in the generated stub class. + + Args: + stub: Stub instance. + method_descriptor: Descriptor of the invoked method. + rpc_controller: Rpc controller to execute the method. + request: Request protocol message. + callback: A callback to execute when the method finishes. + """ + stub.rpc_channel.CallMethod( + method_descriptor, rpc_controller, request, + method_descriptor.output_type._concrete_class, callback) diff --git a/python/google/protobuf/text_format.py b/python/google/protobuf/text_format.py new file mode 100755 index 00000000..428b4c0a --- /dev/null +++ b/python/google/protobuf/text_format.py @@ -0,0 +1,111 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. +# http://code.google.com/p/protobuf/ +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Contains routines for printing protocol messages in text format.""" + +__author__ = 'kenton@google.com (Kenton Varda)' + +import cStringIO + +from google.protobuf import descriptor + +__all__ = [ 'MessageToString', 'PrintMessage', 'PrintField', 'PrintFieldValue' ] + +def MessageToString(message): + out = cStringIO.StringIO() + PrintMessage(message, out) + result = out.getvalue() + out.close() + return result + +def PrintMessage(message, out, indent = 0): + for field, value in message.ListFields(): + if field.label == descriptor.FieldDescriptor.LABEL_REPEATED: + for element in value: + PrintField(field, element, out, indent) + else: + PrintField(field, value, out, indent) + +def PrintField(field, value, out, indent = 0): + """Print a single field name/value pair. For repeated fields, the value + should be a single element.""" + + out.write(' ' * indent); + if field.is_extension: + out.write('[') + if (field.containing_type.GetOptions().message_set_wire_format and + field.type == descriptor.FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == descriptor.FieldDescriptor.LABEL_OPTIONAL): + out.write(field.message_type.full_name) + else: + out.write(field.full_name) + out.write(']') + elif field.type == descriptor.FieldDescriptor.TYPE_GROUP: + # For groups, use the capitalized name. + out.write(field.message_type.name) + else: + out.write(field.name) + + if field.cpp_type != descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + # The colon is optional in this case, but our cross-language golden files + # don't include it. + out.write(': ') + + PrintFieldValue(field, value, out, indent) + out.write('\n') + +def PrintFieldValue(field, value, out, indent = 0): + """Print a single field value (not including name). For repeated fields, + the value should be a single element.""" + + if field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_MESSAGE: + out.write(' {\n') + PrintMessage(value, out, indent + 2) + out.write(' ' * indent + '}') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_ENUM: + out.write(field.enum_type.values_by_number[value].name) + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_STRING: + out.write('\"') + out.write(_CEscape(value)) + out.write('\"') + elif field.cpp_type == descriptor.FieldDescriptor.CPPTYPE_BOOL: + if value: + out.write("true") + else: + out.write("false") + else: + out.write(str(value)) + +# text.encode('string_escape') does not seem to satisfy our needs as it +# encodes unprintable characters using two-digit hex escapes whereas our +# C++ unescaping function allows hex escapes to be any length. So, +# "\0011".encode('string_escape') ends up being "\\x011", which will be +# decoded in C++ as a single-character string with char code 0x11. +def _CEscape(text): + def escape(c): + o = ord(c) + if o == 10: return r"\n" # optional escape + if o == 13: return r"\r" # optional escape + if o == 9: return r"\t" # optional escape + if o == 39: return r"\'" # optional escape + + if o == 34: return r'\"' # necessary escape + if o == 92: return r"\\" # necessary escape + + if o >= 127 or o < 32: return "\\%03o" % o # necessary escapes + return c + return "".join([escape(c) for c in text]) -- cgit v1.2.3