diff options
Diffstat (limited to 'python/google/protobuf/internal')
18 files changed, 1407 insertions, 211 deletions
diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py index e69de29b..7d2e571a 100755 --- a/python/google/protobuf/internal/__init__.py +++ b/python/google/protobuf/internal/__init__.py @@ -0,0 +1,30 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py index ab9e7812..23cc2c0a 100755 --- a/python/google/protobuf/internal/api_implementation.py +++ b/python/google/protobuf/internal/api_implementation.py @@ -145,29 +145,3 @@ def Version(): # For internal use only def IsPythonDefaultSerializationDeterministic(): return _python_deterministic_proto_serialization - -# DO NOT USE: For migration and testing only. Will be removed when Proto3 -# defaults to preserve unknowns. -if _implementation_type == 'cpp': - try: - # pylint: disable=g-import-not-at-top - from google.protobuf.pyext import _message - - def GetPythonProto3PreserveUnknownsDefault(): - return _message.GetPythonProto3PreserveUnknownsDefault() - - def SetPythonProto3PreserveUnknownsDefault(preserve): - _message.SetPythonProto3PreserveUnknownsDefault(preserve) - except ImportError: - # Unrecognized cpp implementation. Skipping the unknown fields APIs. - pass -else: - _python_proto3_preserve_unknowns_default = True - - def GetPythonProto3PreserveUnknownsDefault(): - return _python_proto3_preserve_unknowns_default - - def SetPythonProto3PreserveUnknownsDefault(preserve): - global _python_proto3_preserve_unknowns_default - _python_proto3_preserve_unknowns_default = preserve - diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index c6a3692a..182cac99 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -628,3 +628,130 @@ class MessageMap(MutableMapping): def GetEntryClass(self): return self._entry_descriptor._concrete_class + + +class _UnknownField(object): + + """A parsed unknown field.""" + + # Disallows assignment to other attributes. + __slots__ = ['_field_number', '_wire_type', '_data'] + + def __init__(self, field_number, wire_type, data): + self._field_number = field_number + self._wire_type = wire_type + self._data = data + return + + def __lt__(self, other): + # pylint: disable=protected-access + return self._field_number < other._field_number + + def __eq__(self, other): + if self is other: + return True + # pylint: disable=protected-access + return (self._field_number == other._field_number and + self._wire_type == other._wire_type and + self._data == other._data) + + +class UnknownFieldRef(object): + + def __init__(self, parent, index): + self._parent = parent + self._index = index + return + + def _check_valid(self): + if not self._parent: + raise ValueError('UnknownField does not exist. ' + 'The parent message might be cleared.') + if self._index >= len(self._parent): + raise ValueError('UnknownField does not exist. ' + 'The parent message might be cleared.') + + @property + def field_number(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._field_number + + @property + def wire_type(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._wire_type + + @property + def data(self): + self._check_valid() + # pylint: disable=protected-access + return self._parent._internal_get(self._index)._data + + +class UnknownFieldSet(object): + + """UnknownField container""" + + # Disallows assignment to other attributes. + __slots__ = ['_values'] + + def __init__(self): + self._values = [] + + def __getitem__(self, index): + if self._values is None: + raise ValueError('UnknownFields does not exist. ' + 'The parent message might be cleared.') + size = len(self._values) + if index < 0: + index += size + if index < 0 or index >= size: + raise IndexError('index %d out of range'.index) + + return UnknownFieldRef(self, index) + + def _internal_get(self, index): + return self._values[index] + + def __len__(self): + if self._values is None: + raise ValueError('UnknownFields does not exist. ' + 'The parent message might be cleared.') + return len(self._values) + + def _add(self, field_number, wire_type, data): + unknown_field = _UnknownField(field_number, wire_type, data) + self._values.append(unknown_field) + return unknown_field + + def __iter__(self): + for i in range(len(self)): + yield UnknownFieldRef(self, i) + + def _extend(self, other): + if other is None: + return + # pylint: disable=protected-access + self._values.extend(other._values) + + def __eq__(self, other): + if self is other: + return True + # Sort unknown fields because their order shouldn't + # affect equality test. + values = list(self._values) + if other is None: + return not values + values.sort() + # pylint: disable=protected-access + other_values = sorted(other._values) + return values == other_values + + def _clear(self): + for value in self._values: + # pylint: disable=protected-access + if isinstance(value._data, UnknownFieldSet): + value._data._clear() # pylint: disable=protected-access + self._values = None diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 52b64915..5a540184 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -81,12 +81,17 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct - +import sys import six +_UCS2_MAXUNICODE = 65535 if six.PY3: long = int +else: + import re # pylint: disable=g-import-not-at-top + _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) +from google.protobuf.internal import containers from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -167,7 +172,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int) def ReadTag(buffer, pos): - """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. + """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. We return the raw bytes of the tag rather than decoding them. The raw bytes can then be used to look up the proper decoder. This effectively allows @@ -175,13 +180,21 @@ def ReadTag(buffer, pos): for work that is done in C (searching for a byte string in a hash table). In a low-level language it would be much cheaper to decode the varint and use that, but not in Python. - """ + Args: + buffer: memoryview object of the encoded bytes + pos: int of the current position to start from + + Returns: + Tuple[bytes, int] of the tag data and new position. + """ start = pos while six.indexbytes(buffer, pos) & 0x80: pos += 1 pos += 1 - return (six.binary_type(buffer[start:pos]), pos) + + tag_bytes = buffer[start:pos].tobytes() + return tag_bytes, pos # -------------------------------------------------------------------- @@ -295,10 +308,20 @@ def _FloatDecoder(): local_unpack = struct.unpack def InnerDecode(buffer, pos): + """Decode serialized float to a float and new position. + + Args: + buffer: memoryview of the serialized bytes + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the deserialized float value and new position + in the serialized data. + """ # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. new_pos = pos + 4 - float_bytes = buffer[pos:new_pos] + float_bytes = buffer[pos:new_pos].tobytes() # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. @@ -329,10 +352,20 @@ def _DoubleDecoder(): local_unpack = struct.unpack def InnerDecode(buffer, pos): + """Decode serialized double to a double and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the decoded double value and new position + in the serialized data. + """ # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. new_pos = pos + 8 - double_bytes = buffer[pos:new_pos] + double_bytes = buffer[pos:new_pos].tobytes() # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it @@ -355,6 +388,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): + """Decode serialized packed enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -365,6 +410,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: @@ -372,8 +418,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + # pylint: enable=protected-access if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. @@ -386,18 +434,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( - (tag_bytes, buffer[pos:new_pos])) + (tag_bytes, buffer[pos:new_pos].tobytes())) + # pylint: enable=protected-access # Predict that the next tag is another copy of the same repeated # field. pos = new_pos + tag_len @@ -409,10 +471,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') + # pylint: disable=protected-access if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: @@ -421,7 +496,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + # pylint: enable=protected-access return pos return DecodeField @@ -458,20 +534,34 @@ BoolDecoder = _ModifiedDecoder( wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) -def StringDecoder(field_number, is_repeated, is_packed, key, new_default): +def StringDecoder(field_number, is_repeated, is_packed, key, new_default, + is_strict_utf8=False): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint local_unicode = six.text_type - def _ConvertToUnicode(byte_str): + def _ConvertToUnicode(memview): + """Convert byte to unicode.""" + byte_str = memview.tobytes() try: - return local_unicode(byte_str, 'utf-8') + value = local_unicode(byte_str, 'utf-8') except UnicodeDecodeError as e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise + if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: + # Only do the check for python2 ucs4 when is_strict_utf8 enabled + if _SURROGATE_PATTERN.search(value): + reason = ('String field %s contains invalid UTF-8 data when parsing' + 'a protocol buffer: surrogates not allowed. Use' + 'the bytes type if you intend to send raw bytes.') % ( + key.full_name) + raise message.DecodeError(reason) + + return value + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -523,7 +613,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(buffer[pos:new_pos]) + value.append(buffer[pos:new_pos].tobytes()) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -536,7 +626,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = buffer[pos:new_pos] + field_dict[key] = buffer[pos:new_pos].tobytes() return new_pos return DecodeField @@ -665,6 +755,18 @@ def MessageSetItemDecoder(descriptor): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + """Decode serialized message set to its value and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ message_set_item_start = pos type_id = -1 message_start = -1 @@ -695,6 +797,7 @@ def MessageSetItemDecoder(descriptor): raise _DecodeError('MessageSet item missing message.') extension = message.Extensions._FindExtensionByNumber(type_id) + # pylint: disable=protected-access if extension is not None: value = field_dict.get(extension) if value is None: @@ -707,8 +810,9 @@ def MessageSetItemDecoder(descriptor): else: if not message._unknown_fields: message._unknown_fields = [] - message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, - buffer[message_set_item_start:pos])) + message._unknown_fields.append( + (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) + # pylint: enable=protected-access return pos @@ -767,7 +871,7 @@ def _SkipVarint(buffer, pos, end): # Previously ord(buffer[pos]) raised IndexError when pos is out of range. # With this code, ord(b'') raises TypeError. Both are handled in # python_message.py to generate a 'Truncated message' error. - while ord(buffer[pos:pos+1]) & 0x80: + while ord(buffer[pos:pos+1].tobytes()) & 0x80: pos += 1 pos += 1 if pos > end: @@ -782,6 +886,13 @@ def _SkipFixed64(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + +def _DecodeFixed64(buffer, pos): + """Decode a fixed64.""" + new_pos = pos + 8 + return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) + + def _SkipLengthDelimited(buffer, pos, end): """Skip a length-delimited value. Returns the new position.""" @@ -791,6 +902,7 @@ def _SkipLengthDelimited(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + def _SkipGroup(buffer, pos, end): """Skip sub-group. Returns the new position.""" @@ -801,11 +913,53 @@ def _SkipGroup(buffer, pos, end): return pos pos = new_pos + +def _DecodeGroup(buffer, pos): + """Decode group. Returns the UnknownFieldSet and new position.""" + + unknown_field_set = containers.UnknownFieldSet() + while 1: + (tag_bytes, pos) = ReadTag(buffer, pos) + (tag, _) = _DecodeVarint(tag_bytes, 0) + field_number, wire_type = wire_format.UnpackTag(tag) + if wire_type == wire_format.WIRETYPE_END_GROUP: + break + (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) + # pylint: disable=protected-access + unknown_field_set._add(field_number, wire_type, data) + + return (unknown_field_set, pos) + + +def _DecodeUnknownField(buffer, pos, wire_type): + """Decode a unknown field. Returns the UnknownField and new position.""" + + if wire_type == wire_format.WIRETYPE_VARINT: + (data, pos) = _DecodeVarint(buffer, pos) + elif wire_type == wire_format.WIRETYPE_FIXED64: + (data, pos) = _DecodeFixed64(buffer, pos) + elif wire_type == wire_format.WIRETYPE_FIXED32: + (data, pos) = _DecodeFixed32(buffer, pos) + elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + (size, pos) = _DecodeVarint(buffer, pos) + data = buffer[pos:pos+size] + pos += size + elif wire_type == wire_format.WIRETYPE_START_GROUP: + (data, pos) = _DecodeGroup(buffer, pos) + elif wire_type == wire_format.WIRETYPE_END_GROUP: + return (0, -1) + else: + raise _DecodeError('Wrong wire type in tag.') + + return (data, pos) + + def _EndGroup(buffer, pos, end): """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" return -1 + def _SkipFixed32(buffer, pos, end): """Skip a fixed32 value. Returns the new position.""" @@ -814,6 +968,14 @@ def _SkipFixed32(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + +def _DecodeFixed32(buffer, pos): + """Decode a fixed32.""" + + new_pos = pos + 4 + return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) + + def _RaiseInvalidWireType(buffer, pos, end): """Skip function for unknown wire types. Raises an exception.""" diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py index f97477b3..da5dbd92 100644 --- a/python/google/protobuf/internal/descriptor_database_test.py +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -43,6 +43,7 @@ import warnings from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 from google.protobuf.internal import factory_test2_pb2 +from google.protobuf.internal import no_package_pb2 from google.protobuf import descriptor_database @@ -52,7 +53,10 @@ class DescriptorDatabaseTest(unittest.TestCase): db = descriptor_database.DescriptorDatabase() file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( factory_test2_pb2.DESCRIPTOR.serialized_pb) + file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString( + no_package_pb2.DESCRIPTOR.serialized_pb) db.Add(file_desc_proto) + db.Add(file_desc_proto2) self.assertEqual(file_desc_proto, db.FindFileByName( 'google/protobuf/internal/factory_test2.proto')) @@ -76,6 +80,10 @@ class DescriptorDatabaseTest(unittest.TestCase): # Can find enum value. self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0')) + self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.FACTORY_2_VALUE_0')) + self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol( + '.NO_PACKAGE_VALUE_0')) # Can find top level extension. self.assertEqual(file_desc_proto, db.FindFileContainingSymbol( 'google.protobuf.python.internal.another_field')) @@ -95,9 +103,8 @@ class DescriptorDatabaseTest(unittest.TestCase): self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol( 'protobuf_unittest.TestAllTypes.none_field')) - self.assertRaises(KeyError, - db.FindFileContainingSymbol, - 'protobuf_unittest.NoneMessage') + with self.assertRaisesRegexp(KeyError, r'\'protobuf_unittest\.NoneMessage\''): + db.FindFileContainingSymbol('protobuf_unittest.NoneMessage') def testConflictRegister(self): db = descriptor_database.DescriptorDatabase() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py index 2cbf7813..1b72b0b9 100644 --- a/python/google/protobuf/internal/descriptor_pool_test.py +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -36,7 +36,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)' import copy import os -import sys import warnings try: @@ -55,6 +54,7 @@ from google.protobuf.internal import factory_test1_pb2 from google.protobuf.internal import factory_test2_pb2 from google.protobuf.internal import file_options_test_pb2 from google.protobuf.internal import more_messages_pb2 +from google.protobuf.internal import no_package_pb2 from google.protobuf import descriptor from google.protobuf import descriptor_database from google.protobuf import descriptor_pool @@ -120,7 +120,6 @@ class DescriptorPoolTestBase(object): self.assertIsInstance(file_desc5, descriptor.FileDescriptor) self.assertEqual('google/protobuf/unittest.proto', file_desc5.name) - # Tests the generated pool. assert descriptor_pool.Default().FindFileContainingSymbol( 'google.protobuf.python.internal.Factory2Message.one_more_field') @@ -129,6 +128,32 @@ class DescriptorPoolTestBase(object): assert descriptor_pool.Default().FindFileContainingSymbol( 'protobuf_unittest.TestService') + # Can find field. + file_desc6 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory1Message.list_value') + self.assertIsInstance(file_desc6, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/internal/factory_test1.proto', + file_desc6.name) + + # Can find top level Enum value. + file_desc7 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.FACTORY_1_VALUE_0') + self.assertIsInstance(file_desc7, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/internal/factory_test1.proto', + file_desc7.name) + + # Can find nested Enum value. + file_desc8 = self.pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes.FOO') + self.assertIsInstance(file_desc8, descriptor.FileDescriptor) + self.assertEqual('google/protobuf/unittest.proto', + file_desc8.name) + + # TODO(jieluo): Add tests for no package when b/13860351 is fixed. + + self.assertRaises(KeyError, self.pool.FindFileContainingSymbol, + 'google.protobuf.python.internal.Factory1Message.none_field') + def testFindFileContainingSymbolFailure(self): with self.assertRaises(KeyError): self.pool.FindFileContainingSymbol('Does not exist') @@ -217,11 +242,10 @@ class DescriptorPoolTestBase(object): def testFindTypeErrors(self): self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '') + self.assertRaises(KeyError, self.pool.FindMethodByName, '') # TODO(jieluo): Fix python to raise correct errors. if api_implementation.Type() == 'cpp': - self.assertRaises(TypeError, self.pool.FindMethodByName, 0) - self.assertRaises(KeyError, self.pool.FindMethodByName, '') error_type = TypeError else: error_type = AttributeError @@ -231,6 +255,7 @@ class DescriptorPoolTestBase(object): self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0) self.assertRaises(error_type, self.pool.FindOneofByName, 0) self.assertRaises(error_type, self.pool.FindServiceByName, 0) + self.assertRaises(error_type, self.pool.FindMethodByName, 0) self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0) if api_implementation.Type() == 'python': error_type = KeyError @@ -275,11 +300,6 @@ class DescriptorPoolTestBase(object): self.pool.FindEnumTypeByName('Does not exist') def testFindFieldByName(self): - if isinstance(self, SecondaryDescriptorFromDescriptorDB): - if api_implementation.Type() == 'cpp': - # TODO(jieluo): Fix cpp extension to find field correctly - # when descriptor pool is using an underlying database. - return field = self.pool.FindFieldByName( 'google.protobuf.python.internal.Factory1Message.list_value') self.assertEqual(field.name, 'list_value') @@ -290,11 +310,6 @@ class DescriptorPoolTestBase(object): self.pool.FindFieldByName('Does not exist') def testFindOneofByName(self): - if isinstance(self, SecondaryDescriptorFromDescriptorDB): - if api_implementation.Type() == 'cpp': - # TODO(jieluo): Fix cpp extension to find oneof correctly - # when descriptor pool is using an underlying database. - return oneof = self.pool.FindOneofByName( 'google.protobuf.python.internal.Factory2Message.oneof_field') self.assertEqual(oneof.name, 'oneof_field') @@ -302,11 +317,6 @@ class DescriptorPoolTestBase(object): self.pool.FindOneofByName('Does not exist') def testFindExtensionByName(self): - if isinstance(self, SecondaryDescriptorFromDescriptorDB): - if api_implementation.Type() == 'cpp': - # TODO(jieluo): Fix cpp extension to find extension correctly - # when descriptor pool is using an underlying database. - return # An extension defined in a message. extension = self.pool.FindExtensionByName( 'google.protobuf.python.internal.Factory2Message.one_more_field') @@ -382,6 +392,11 @@ class DescriptorPoolTestBase(object): with self.assertRaises(KeyError): self.pool.FindServiceByName('Does not exist') + method = self.pool.FindMethodByName('protobuf_unittest.TestService.Foo') + self.assertIs(method.containing_service, service) + with self.assertRaises(KeyError): + self.pool.FindMethodByName('protobuf_unittest.TestService.Doesnotexist') + def testUserDefinedDB(self): db = descriptor_database.DescriptorDatabase() self.pool = descriptor_pool.DescriptorPool(db) @@ -601,6 +616,8 @@ class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase): unittest_import_pb2.DESCRIPTOR.serialized_pb)) self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( unittest_pb2.DESCRIPTOR.serialized_pb)) + self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString( + no_package_pb2.DESCRIPTOR.serialized_pb)) class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase, @@ -620,6 +637,8 @@ class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase, unittest_import_pb2.DESCRIPTOR.serialized_pb)) db.Add(descriptor_pb2.FileDescriptorProto.FromString( unittest_pb2.DESCRIPTOR.serialized_pb)) + db.Add(descriptor_pb2.FileDescriptorProto.FromString( + no_package_pb2.DESCRIPTOR.serialized_pb)) self.pool = descriptor_pool.DescriptorPool(descriptor_db=db) @@ -746,11 +765,7 @@ class MessageField(object): test.assertEqual(msg_desc, field_desc.containing_type) test.assertEqual(field_type_desc, field_desc.message_type) test.assertEqual(file_desc, field_desc.file) - # TODO(jieluo): Fix python and cpp extension diff for message field - # default value. - if api_implementation.Type() == 'cpp': - test.assertRaises( - NotImplementedError, getattr, field_desc, 'default_value') + test.assertEqual(field_desc.default_value, None) class StringField(object): diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 02a43d15..af6bece1 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -452,6 +452,17 @@ class DescriptorTest(unittest.TestCase): self.assertEqual('attribute is not writable: has_options', str(e.exception)) + def testDefault(self): + message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + field = message_descriptor.fields_by_name['repeated_int32'] + self.assertEqual(field.default_value, []) + field = message_descriptor.fields_by_name['repeated_nested_message'] + self.assertEqual(field.default_value, []) + field = message_descriptor.fields_by_name['optionalgroup'] + self.assertEqual(field.default_value, None) + field = message_descriptor.fields_by_name['optional_nested_message'] + self.assertEqual(field.default_value, None) + class NewDescriptorTest(DescriptorTest): """Redo the same tests as above, but with a separate DescriptorPool.""" diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto index d2fbbeec..f5bd0383 100644 --- a/python/google/protobuf/internal/factory_test1.proto +++ b/python/google/protobuf/internal/factory_test1.proto @@ -56,3 +56,17 @@ message Factory1Message { extensions 1000 to max; } + +message Factory1MethodRequest { + optional string argument = 1; +} + +message Factory1MethodResponse { + optional string result = 1; +} + +service Factory1Service { + // Dummy method for this dummy service. + rpc Factory1Method(Factory1MethodRequest) returns (Factory1MethodResponse) { + } +} diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py index 6df52ed2..b97e3f65 100644 --- a/python/google/protobuf/internal/message_factory_test.py +++ b/python/google/protobuf/internal/message_factory_test.py @@ -142,10 +142,8 @@ class MessageFactoryTest(unittest.TestCase): self.assertEqual('test2', msg1.Extensions[ext2]) self.assertEqual(None, msg1.Extensions._FindExtensionByNumber(12321)) + self.assertRaises(TypeError, len, msg1.Extensions) if api_implementation.Type() == 'cpp': - # TODO(jieluo): Fix len to return the correct value. - # self.assertEqual(2, len(msg1.Extensions)) - self.assertEqual(len(msg1.Extensions), len(msg1.Extensions)) self.assertRaises(TypeError, msg1.Extensions._FindExtensionByName, 0) self.assertRaises(TypeError, diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 61a56a67..4dd1104a 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -1,4 +1,5 @@ #! /usr/bin/env python +# -*- coding: utf-8 -*- # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. @@ -49,6 +50,7 @@ import copy import math import operator import pickle +import pydoc import six import sys import warnings @@ -72,12 +74,14 @@ from google.protobuf import message_factory from google.protobuf import text_format from google.protobuf.internal import api_implementation from google.protobuf.internal import encoder +from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import packed_field_test_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks from google.protobuf import message from google.protobuf.internal import _parameterized +UCS2_MAXUNICODE = 65535 if six.PY3: long = int @@ -415,6 +419,37 @@ class MessageTest(BaseTestCase): empty.ParseFromString(populated.SerializeToString()) self.assertEqual(str(empty), '') + def testMergeFromRepeatedField(self, message_module): + msg = message_module.TestAllTypes() + msg.repeated_int32.append(1) + msg.repeated_int32.append(3) + msg.repeated_nested_message.add(bb=1) + msg.repeated_nested_message.add(bb=2) + other_msg = message_module.TestAllTypes() + other_msg.repeated_nested_message.add(bb=3) + other_msg.repeated_nested_message.add(bb=4) + other_msg.repeated_int32.append(5) + other_msg.repeated_int32.append(7) + + msg.repeated_int32.MergeFrom(other_msg.repeated_int32) + self.assertEqual(4, len(msg.repeated_int32)) + + msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message) + self.assertEqual([1, 2, 3, 4], + [m.bb for m in msg.repeated_nested_message]) + + def testAddWrongRepeatedNestedField(self, message_module): + msg = message_module.TestAllTypes() + try: + msg.repeated_nested_message.add('wrong') + except TypeError: + pass + try: + msg.repeated_nested_message.add(value_field='wrong') + except ValueError: + pass + self.assertEqual(len(msg.repeated_nested_message), 0) + def testRepeatedNestedFieldIteration(self, message_module): msg = message_module.TestAllTypes() msg.repeated_nested_message.add(bb=1) @@ -645,6 +680,82 @@ class MessageTest(BaseTestCase): m.payload.repeated_int32.extend([]) self.assertTrue(m.HasField('payload')) + def testMergeFrom(self, message_module): + m1 = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, m1.optional_nested_message.bb) + m2.optional_nested_message.bb = 1 + # Make sure cmessage pointing to a mutable message after merge instead of + # the lazily created message. + m1.MergeFrom(m2) + self.assertEqual(1, m1.optional_nested_message.bb) + + # Test more nested sub message. + msg1 = message_module.NestedTestAllTypes() + msg2 = message_module.NestedTestAllTypes() + self.assertEqual(0, msg1.child.payload.optional_nested_message.bb) + msg2.child.payload.optional_nested_message.bb = 1 + msg1.MergeFrom(msg2) + self.assertEqual(1, msg1.child.payload.optional_nested_message.bb) + + # Test repeated field. + self.assertEqual(msg1.payload.repeated_nested_message, + msg1.payload.repeated_nested_message) + msg2.payload.repeated_nested_message.add().bb = 1 + msg1.MergeFrom(msg2) + self.assertEqual(1, len(msg1.payload.repeated_nested_message)) + self.assertEqual(1, msg1.payload.repeated_nested_message[0].bb) + + def testMergeFromString(self, message_module): + m1 = message_module.TestAllTypes() + m2 = message_module.TestAllTypes() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, m1.optional_nested_message.bb) + m2.optional_nested_message.bb = 1 + # Make sure cmessage pointing to a mutable message after merge instead of + # the lazily created message. + m1.MergeFromString(m2.SerializeToString()) + self.assertEqual(1, m1.optional_nested_message.bb) + + @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2') + def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module): + m2 = message_module.TestAllTypes() + m2.optional_string = 'scalar string' + m2.repeated_string.append('repeated string') + m2.optional_bytes = b'scalar bytes' + m2.repeated_bytes.append(b'repeated bytes') + + serialized = m2.SerializeToString() + memview = memoryview(serialized) + m1 = message_module.TestAllTypes.FromString(memview) + + self.assertEqual(m1.optional_bytes, b'scalar bytes') + self.assertEqual(m1.repeated_bytes, [b'repeated bytes']) + self.assertEqual(m1.optional_string, 'scalar string') + self.assertEqual(m1.repeated_string, ['repeated string']) + # Make sure that the memoryview was correctly converted to bytes, and + # that a sub-sliced memoryview is not being used. + self.assertIsInstance(m1.optional_bytes, bytes) + self.assertIsInstance(m1.repeated_bytes[0], bytes) + self.assertIsInstance(m1.optional_string, six.text_type) + self.assertIsInstance(m1.repeated_string[0], six.text_type) + + @unittest.skipIf(six.PY3, 'memoryview is supported by py3') + def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module): + memview = memoryview(b'') + with self.assertRaises(TypeError): + message_module.TestAllTypes.FromString(memview) + + def testMergeFromEmpty(self, message_module): + m1 = message_module.TestAllTypes() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, m1.optional_nested_message.bb) + self.assertFalse(m1.HasField('optional_nested_message')) + # Make sure the sub message is still immutable after merge from empty. + m1.MergeFromString(b'') # field state should not change + self.assertFalse(m1.HasField('optional_nested_message')) + def ensureNestedMessageExists(self, msg, attribute): """Make sure that a nested message object exists. @@ -1067,14 +1178,8 @@ class MessageTest(BaseTestCase): with self.assertRaises(AttributeError): m.repeated_int32 = [] m.repeated_int32.append(1) - if api_implementation.Type() == 'cpp': - # For test coverage: cpp has a different path if composite - # field is in cache - with self.assertRaises(TypeError): - m.repeated_int32 = [] - else: - with self.assertRaises(AttributeError): - m.repeated_int32 = [] + with self.assertRaises(AttributeError): + m.repeated_int32 = [] # Class to test proto2-only features (required, extensions, etc.) @@ -1112,13 +1217,13 @@ class Proto2Test(BaseTestCase): message.optional_bool = True message.optional_nested_message.bb = 15 - self.assertTrue(message.HasField("optional_int32")) + self.assertTrue(message.HasField(u"optional_int32")) self.assertTrue(message.HasField("optional_bool")) self.assertTrue(message.HasField("optional_nested_message")) # Clearing the fields unsets them and resets their value to default. message.ClearField("optional_int32") - message.ClearField("optional_bool") + message.ClearField(u"optional_bool") message.ClearField("optional_nested_message") self.assertFalse(message.HasField("optional_int32")) @@ -1169,6 +1274,21 @@ class Proto2Test(BaseTestCase): msg = unittest_pb2.TestAllTypes() self.assertRaises(AttributeError, getattr, msg, 'Extensions') + def testMergeFromExtensions(self): + msg1 = more_extensions_pb2.TopLevelMessage() + msg2 = more_extensions_pb2.TopLevelMessage() + # Cpp extension will lazily create a sub message which is immutable. + self.assertEqual(0, msg1.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + self.assertFalse(msg1.HasField('submessage')) + msg2.submessage.Extensions[ + more_extensions_pb2.optional_int_extension] = 123 + # Make sure cmessage and extensions pointing to a mutable message + # after merge instead of the lazily created message. + msg1.MergeFrom(msg2) + self.assertEqual(123, msg1.submessage.Extensions[ + more_extensions_pb2.optional_int_extension]) + def testGoldenExtensions(self): golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllExtensions() @@ -1315,6 +1435,25 @@ class Proto2Test(BaseTestCase): with self.assertRaises(ValueError): unittest_pb2.TestAllTypes(repeated_nested_enum='FOO') + def testPythonicInitWithDict(self): + # Both string/unicode field name keys should work. + kwargs = { + 'optional_int32': 100, + u'optional_fixed32': 200, + } + msg = unittest_pb2.TestAllTypes(**kwargs) + self.assertEqual(100, msg.optional_int32) + self.assertEqual(200, msg.optional_fixed32) + + + def test_documentation(self): + # Also used by the interactive help() function. + doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message') + self.assertIn('class TestAllTypes', doc) + self.assertIn('SerializePartialToString', doc) + self.assertIn('repeated_float', doc) + base = unittest_pb2.TestAllTypes.__bases__[0] + self.assertRaises(AttributeError, getattr, base, '_extensions_by_name') # Class to test proto3-only features/behavior (updated field presence & enums) @@ -1539,10 +1678,8 @@ class Proto3Test(BaseTestCase): self.assertEqual(True, msg2.map_bool_bool[True]) self.assertEqual(2, msg2.map_int32_enum[888]) self.assertEqual(456, msg2.map_int32_enum[123]) - # TODO(jieluo): Add cpp extension support. - if api_implementation.Type() == 'python': - self.assertEqual('{-123: -456}', - str(msg2.map_int32_int32)) + self.assertEqual('{-123: -456}', + str(msg2.map_int32_int32)) def testMapEntryAlwaysSerialized(self): msg = map_unittest_pb2.TestMap() @@ -1603,11 +1740,10 @@ class Proto3Test(BaseTestCase): self.assertIn(123, msg2.map_int32_foreign_message) self.assertIn(-456, msg2.map_int32_foreign_message) self.assertEqual(2, len(msg2.map_int32_foreign_message)) + msg2.map_int32_foreign_message[123].c = 1 # TODO(jieluo): Fix text format for message map. - # TODO(jieluo): Add cpp extension support. - if api_implementation.Type() == 'python': - self.assertEqual(15, - len(str(msg2.map_int32_foreign_message))) + self.assertIn(str(msg2.map_int32_foreign_message), + ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }')) def testNestedMessageMapItemDelete(self): msg = map_unittest_pb2.TestMap() @@ -1721,6 +1857,15 @@ class Proto3Test(BaseTestCase): self.assertEqual(10, msg2.map_int32_foreign_message[222].c) self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d')) + # Test when cpp extension cache a map. + m1 = map_unittest_pb2.TestMap() + m2 = map_unittest_pb2.TestMap() + self.assertEqual(m1.map_int32_foreign_message, + m1.map_int32_foreign_message) + m2.map_int32_foreign_message[123].c = 10 + m1.MergeFrom(m2) + self.assertEqual(10, m2.map_int32_foreign_message[123].c) + def testMergeFromBadType(self): msg = map_unittest_pb2.TestMap() with self.assertRaisesRegexp( @@ -1972,7 +2117,7 @@ class Proto3Test(BaseTestCase): def testMapValidAfterFieldCleared(self): # Map needs to work even if field is cleared. # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() + # MapContainer::Release() msg = map_unittest_pb2.TestMap() int32_map = msg.map_int32_int32 @@ -1988,7 +2133,7 @@ class Proto3Test(BaseTestCase): def testMessageMapValidAfterFieldCleared(self): # Map needs to work even if field is cleared. # For the C++ implementation this tests the correctness of - # ScalarMapContainer::Release() + # MapContainer::Release() msg = map_unittest_pb2.TestMap() int32_foreign_message = msg.map_int32_foreign_message @@ -1998,6 +2143,24 @@ class Proto3Test(BaseTestCase): self.assertEqual(b'', msg.SerializeToString()) self.assertTrue(2 in int32_foreign_message.keys()) + def testMessageMapItemValidAfterTopMessageCleared(self): + # Message map item needs to work even if it is cleared. + # For the C++ implementation this tests the correctness of + # MapContainer::Release() + msg = map_unittest_pb2.TestMap() + msg.map_int32_all_types[2].optional_string = 'bar' + + if api_implementation.Type() == 'cpp': + # Need to keep the map reference because of b/27942626. + # TODO(jieluo): Remove it. + unused_map = msg.map_int32_all_types # pylint: disable=unused-variable + msg_value = msg.map_int32_all_types[2] + msg.Clear() + + # Reset to trigger sync between repeated field and map in c++. + msg.map_int32_all_types[3].optional_string = 'foo' + self.assertEqual(msg_value.optional_string, 'bar') + def testMapIterInvalidatedByClearField(self): # Map iterator is invalidated when field is cleared. # But this case does need to not crash the interpreter. @@ -2058,6 +2221,82 @@ class Proto3Test(BaseTestCase): msg.map_string_foreign_message['foo'].c = 5 self.assertEqual(0, len(msg.FindInitializationErrors())) + @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2') + def testStrictUtf8Check(self): + # Test u'\ud801' is rejected at parser in both python2 and python3. + serialized = (b'r\x03\xed\xa0\x81') + msg = unittest_proto3_arena_pb2.TestAllTypes() + with self.assertRaises(Exception) as context: + msg.MergeFromString(serialized) + if api_implementation.Type() == 'python': + self.assertIn('optional_string', str(context.exception)) + else: + self.assertIn('Error parsing message', str(context.exception)) + + # Test optional_string=u'😍' is accepted. + serialized = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'😍').SerializeToString() + msg2 = unittest_proto3_arena_pb2.TestAllTypes() + msg2.MergeFromString(serialized) + self.assertEqual(msg2.optional_string, u'😍') + + msg = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud001') + self.assertEqual(msg.optional_string, u'\ud001') + + @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2') + def testSurrogatesInPython3(self): + # Surrogates like U+D83D is an invalid unicode character, it is + # supported by Python2 only because in some builds, unicode strings + # use 2-bytes code units. Since Python 3.3, we don't have this problem. + # + # Surrogates are utf16 code units, in a unicode string they are invalid + # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf + # Python3 reject such cases at setters and parsers. Python2 accpect it + # to keep same features with the language itself. 'Unpaired pairs' + # like u'\ud801' are rejected at parsers when strict utf8 check is enabled + # in proto3 to keep same behavior with c extension. + + # Surrogates are rejected at setters in Python3. + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\udc01') + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=b'\xed\xa0\x81') + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801') + with self.assertRaises(ValueError): + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\ud801') + + @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE, + 'Surrogates are rejected at setters in Python3') + def testSurrogatesInPython2(self): + # Test optional_string=u'\ud801\udc01'. + # surrogate pair is acceptable in python2. + msg = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\udc01') + # TODO(jieluo): Change pure python to have same behavior with c extension. + # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are + # equal, some are not equal. + if api_implementation.Type() == 'python': + self.assertEqual(msg.optional_string, u'\ud801\udc01') + else: + self.assertEqual(msg.optional_string, u'\U00010401') + serialized = msg.SerializeToString() + msg2 = unittest_proto3_arena_pb2.TestAllTypes() + msg2.MergeFromString(serialized) + self.assertEqual(msg2.optional_string, u'\U00010401') + + # Python2 does not reject surrogates at setters. + msg = unittest_proto3_arena_pb2.TestAllTypes( + optional_string=b'\xed\xa0\x81') + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801') + unittest_proto3_arena_pb2.TestAllTypes( + optional_string=u'\ud801\ud801') class ValidTypeNamesTest(BaseTestCase): diff --git a/python/google/protobuf/internal/no_package.proto b/python/google/protobuf/internal/no_package.proto index 3546dcc3..49eda959 100644 --- a/python/google/protobuf/internal/no_package.proto +++ b/python/google/protobuf/internal/no_package.proto @@ -1,3 +1,33 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + syntax = "proto2"; enum NoPackageEnum { diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py index 975e3b4d..4e0f545c 100755 --- a/python/google/protobuf/internal/python_message.py +++ b/python/google/protobuf/internal/python_message.py @@ -56,6 +56,7 @@ import sys import weakref import six +from six.moves import range # We use "as" to avoid name collisions with variables. from google.protobuf.internal import api_implementation @@ -124,6 +125,21 @@ class GeneratedProtocolMessageType(type): Newly-allocated class. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + # If a concrete class already exists for this descriptor, don't try to + # create another. Doing so will break any messages that already exist with + # the existing class. + # + # The C++ implementation appears to have its own internal `PyMessageFactory` + # to achieve similar results. + # + # This most commonly happens in `text_format.py` when using descriptors from + # a custom pool; it calls symbol_database.Global().getPrototype() on a + # descriptor which already has an existing concrete class. + new_class = getattr(descriptor, '_concrete_class', None) + if new_class: + return new_class + if descriptor.full_name in well_known_types.WKTBASES: bases += (well_known_types.WKTBASES[descriptor.full_name],) _AddClassAttributesForNestedExtensions(descriptor, dictionary) @@ -151,6 +167,16 @@ class GeneratedProtocolMessageType(type): type. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + # If this is an _existing_ class looked up via `_concrete_class` in the + # __new__ method above, then we don't need to re-initialize anything. + existing_class = getattr(descriptor, '_concrete_class', None) + if existing_class: + assert existing_class is cls, ( + 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r' + % (descriptor.full_name)) + return + cls._decoders_by_tag = {} if (descriptor.has_options and descriptor.GetOptions().message_set_wire_format): @@ -245,6 +271,7 @@ def _AddSlots(message_descriptor, dictionary): '_cached_byte_size_dirty', '_fields', '_unknown_fields', + '_unknown_field_set', '_is_present_in_parent', '_listener', '_listener_for_children', @@ -271,6 +298,13 @@ def _IsMessageMapField(field): return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE +def _IsStrictUtf8Check(field): + if field.containing_type.syntax != 'proto3': + return False + enforce_utf8 = True + return enforce_utf8 + + def _AttachFieldHelpers(cls, field_descriptor): is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) is_packable = (is_repeated and @@ -322,10 +356,16 @@ def _AttachFieldHelpers(cls, field_descriptor): field_decoder = decoder.MapDecoder( field_descriptor, _GetInitializeDefaultForMap(field_descriptor), is_message_map) + elif decode_type == _FieldDescriptor.TYPE_STRING: + is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor) + field_decoder = decoder.StringDecoder( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor, + is_strict_utf8_check) else: field_decoder = type_checkers.TYPE_TO_DECODER[decode_type]( - field_descriptor.number, is_repeated, is_packed, - field_descriptor, field_descriptor._default_constructor) + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor) cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor) @@ -422,6 +462,9 @@ def _DefaultValueConstructorForField(field): # _concrete_class may not yet be initialized. message_type = field.message_type def MakeSubMessageDefault(message): + assert getattr(message_type, '_concrete_class', None), ( + 'Uninitialized concrete class found for field %r (message type %r)' + % (field.full_name, message_type.full_name)) result = message_type._concrete_class() result._SetListener( _OneofListener(message, field) @@ -477,6 +520,9 @@ def _AddInitMethod(message_descriptor, cls): # _unknown_fields is () when empty for efficiency, and will be turned into # a list if fields are added. self._unknown_fields = () + # _unknown_field_set is None when empty for efficiency, and will be + # turned into UnknownFieldSet struct if fields are added. + self._unknown_field_set = None # pylint: disable=protected-access self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() self._listener_for_children = _Listener(self) @@ -584,6 +630,14 @@ def _AddPropertiesForField(field, cls): _AddPropertiesForNonRepeatedScalarField(field, cls) +class _FieldProperty(property): + __slots__ = ('DESCRIPTOR',) + + def __init__(self, descriptor, getter, setter, doc): + property.__init__(self, getter, setter, doc=doc) + self.DESCRIPTOR = descriptor + + def _AddPropertiesForRepeatedField(field, cls): """Adds a public property for a "repeated" protocol message field. Clients can use this property to get the value of the field, which will be either a @@ -625,7 +679,7 @@ def _AddPropertiesForRepeatedField(field, cls): '"%s" in protocol message object.' % proto_field_name) doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) def _AddPropertiesForNonRepeatedScalarField(field, cls): @@ -681,7 +735,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): # Add a property to encapsulate the getter/setter. doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) def _AddPropertiesForNonRepeatedCompositeField(field, cls): @@ -725,7 +779,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # Add a property to encapsulate the getter. doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name - setattr(cls, property_name, property(getter, setter, doc=doc)) + setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc)) def _AddPropertiesForExtensions(descriptor, cls): @@ -949,12 +1003,12 @@ def _AddEqualsMethod(message_descriptor, cls): if not self.ListFields() == other.ListFields(): return False - # Sort unknown fields because their order shouldn't affect equality test. + # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions, + # then use it for the comparison. unknown_fields = list(self._unknown_fields) unknown_fields.sort() other_unknown_fields = list(other._unknown_fields) other_unknown_fields.sort() - return unknown_fields == other_unknown_fields cls.__eq__ = __eq__ @@ -1078,6 +1132,13 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls): def _AddMergeFromStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def MergeFromString(self, serialized): + if isinstance(serialized, memoryview) and six.PY2: + raise TypeError( + 'memoryview not supported in Python 2 with the pure Python proto ' + 'implementation: this is to maintain compatibility with the C++ ' + 'implementation') + + serialized = memoryview(serialized) length = len(serialized) try: if self._InternalParse(serialized, 0, length) != length: @@ -1095,26 +1156,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls): local_ReadTag = decoder.ReadTag local_SkipField = decoder.SkipField decoders_by_tag = cls._decoders_by_tag - is_proto3 = message_descriptor.syntax == "proto3" def InternalParse(self, buffer, pos, end): + """Create a message from serialized bytes. + + Args: + self: Message, instance of the proto message object. + buffer: memoryview of the serialized data. + pos: int, position to start in the serialized data. + end: int, end position of the serialized data. + + Returns: + Message object. + """ + # Guard against internal misuse, since this function is called internally + # quite extensively, and its easy to accidentally pass bytes. + assert isinstance(buffer, memoryview) self._Modified() field_dict = self._fields - unknown_field_list = self._unknown_fields + # pylint: disable=protected-access + unknown_field_set = self._unknown_field_set while pos != end: (tag_bytes, new_pos) = local_ReadTag(buffer, pos) field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) if field_decoder is None: - value_start_pos = new_pos - new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if not self._unknown_fields: # pylint: disable=protected-access + self._unknown_fields = [] # pylint: disable=protected-access + if unknown_field_set is None: + # pylint: disable=protected-access + self._unknown_field_set = containers.UnknownFieldSet() + # pylint: disable=protected-access + unknown_field_set = self._unknown_field_set + # pylint: disable=protected-access + (tag, _) = decoder._DecodeVarint(tag_bytes, 0) + field_number, wire_type = wire_format.UnpackTag(tag) + # TODO(jieluo): remove old_pos. + old_pos = new_pos + (data, new_pos) = decoder._DecodeUnknownField( + buffer, new_pos, wire_type) # pylint: disable=protected-access if new_pos == -1: return pos - if (not is_proto3 or - api_implementation.GetPythonProto3PreserveUnknownsDefault()): - if not unknown_field_list: - unknown_field_list = self._unknown_fields = [] - unknown_field_list.append( - (tag_bytes, buffer[value_start_pos:new_pos])) + # pylint: disable=protected-access + unknown_field_set._add(field_number, wire_type, data) + # TODO(jieluo): remove _unknown_fields. + new_pos = local_SkipField(buffer, old_pos, end, tag_bytes) + if new_pos == -1: + return pos + self._unknown_fields.append( + (tag_bytes, buffer[old_pos:new_pos].tobytes())) pos = new_pos else: pos = field_decoder(buffer, new_pos, end, self, field_dict) @@ -1259,6 +1348,10 @@ def _AddMergeFromMethod(cls): if not self._unknown_fields: self._unknown_fields = [] self._unknown_fields.extend(msg._unknown_fields) + # pylint: disable=protected-access + if self._unknown_field_set is None: + self._unknown_field_set = containers.UnknownFieldSet() + self._unknown_field_set._extend(msg._unknown_field_set) cls.MergeFrom = MergeFrom @@ -1291,12 +1384,25 @@ def _Clear(self): # Clear fields. self._fields = {} self._unknown_fields = () + # pylint: disable=protected-access + if self._unknown_field_set is not None: + self._unknown_field_set._clear() + self._unknown_field_set = None + self._oneofs = {} self._Modified() +def _UnknownFields(self): + if self._unknown_field_set is None: # pylint: disable=protected-access + # pylint: disable=protected-access + self._unknown_field_set = containers.UnknownFieldSet() + return self._unknown_field_set # pylint: disable=protected-access + + def _DiscardUnknownFields(self): self._unknown_fields = [] + self._unknown_field_set = None # pylint: disable=protected-access for field, value in self.ListFields(): if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: if field.label == _FieldDescriptor.LABEL_REPEATED: @@ -1335,6 +1441,7 @@ def _AddMessageMethods(message_descriptor, cls): _AddReduceMethod(cls) # Adds methods which do not depend on cls. cls.Clear = _Clear + cls.UnknownFields = _UnknownFields cls.DiscardUnknownFields = _DiscardUnknownFields cls._SetListener = _SetListener @@ -1471,6 +1578,10 @@ class _ExtensionDict(object): if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: result = extension_handle._default_constructor(self._extended_message) elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + assert getattr(extension_handle.message_type, '_concrete_class', None), ( + 'Uninitialized concrete class found for field %r (message type %r)' + % (extension_handle.full_name, + extension_handle.message_type.full_name)) result = extension_handle.message_type._concrete_class() try: result._SetListener(self._extended_message._listener_for_children) diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 0306ff46..90d2fe3c 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -64,6 +64,10 @@ from google.protobuf.internal import testing_refleaks from google.protobuf.internal import decoder +if six.PY3: + long = int # pylint: disable=redefined-builtin,invalid-name + + BaseTestCase = testing_refleaks.BaseTestCase @@ -647,10 +651,7 @@ class ReflectionTest(BaseTestCase): TestGetAndDeserialize('optional_int32', 1, int) TestGetAndDeserialize('optional_int32', 1 << 30, int) TestGetAndDeserialize('optional_uint32', 1 << 30, int) - try: - integer_64 = long - except NameError: # Python3 - integer_64 = int + integer_64 = long if struct.calcsize('L') == 4: # Python only has signed ints, so 32-bit python can't fit an uint32 # in an int. @@ -1103,6 +1104,7 @@ class ReflectionTest(BaseTestCase): self.assertEqual(23, myproto_instance.foo_field) self.assertTrue(myproto_instance.HasField('foo_field')) + @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable') def testDescriptorProtoSupport(self): # Hand written descriptors/reflection are only supported by the pure-Python # implementation of the API. @@ -1141,7 +1143,8 @@ class ReflectionTest(BaseTestCase): self.assertTrue('price' in desc.fields_by_name) self.assertTrue('owners' in desc.fields_by_name) - class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)): + class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, + message.Message)): DESCRIPTOR = desc prius = CarMessage() @@ -1576,6 +1579,8 @@ class ReflectionTest(BaseTestCase): proto1.repeated_int32.append(3) container = copy.deepcopy(proto1.repeated_int32) self.assertEqual([2, 3], container) + container.remove(container[0]) + self.assertEqual([3], container) message1 = proto1.repeated_nested_message.add() message1.bb = 1 @@ -1583,6 +1588,8 @@ class ReflectionTest(BaseTestCase): self.assertEqual(proto1.repeated_nested_message, messages) message1.bb = 2 self.assertNotEqual(proto1.repeated_nested_message, messages) + messages.remove(messages[0]) + self.assertEqual(len(messages), 0) # TODO(anuraag): Implement deepcopy for extension dict @@ -2435,7 +2442,7 @@ class SerializationTest(BaseTestCase): first_proto = unittest_pb2.TestAllTypes() test_util.SetAllFields(first_proto) - serialized = first_proto.SerializeToString() + serialized = memoryview(first_proto.SerializeToString()) for truncation_point in range(len(serialized) + 1): try: @@ -2857,6 +2864,38 @@ class SerializationTest(BaseTestCase): self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER, 51) + def testFieldProperties(self): + cls = unittest_pb2.TestAllTypes + self.assertIs(cls.optional_int32.DESCRIPTOR, + cls.DESCRIPTOR.fields_by_name['optional_int32']) + self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER, + cls.optional_int32.DESCRIPTOR.number) + self.assertIs(cls.optional_nested_message.DESCRIPTOR, + cls.DESCRIPTOR.fields_by_name['optional_nested_message']) + self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER, + cls.optional_nested_message.DESCRIPTOR.number) + self.assertIs(cls.repeated_int32.DESCRIPTOR, + cls.DESCRIPTOR.fields_by_name['repeated_int32']) + self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER, + cls.repeated_int32.DESCRIPTOR.number) + + def testFieldDataDescriptor(self): + msg = unittest_pb2.TestAllTypes() + msg.optional_int32 = 42 + self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42) + unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25) + self.assertEqual(msg.optional_int32, 25) + with self.assertRaises(AttributeError): + del msg.optional_int32 + try: + unittest_pb2.ForeignMessage.c.__get__(msg) + except TypeError: + pass # The cpp implementation cannot mix fields from other messages. + # This test exercises a specific check that avoids a crash. + else: + pass # The python implementation allows fields from other messages. + # This is useless, but works. + def testInitKwargs(self): proto = unittest_pb2.TestAllTypes( optional_int32=1, @@ -2963,6 +3002,7 @@ class ClassAPITest(BaseTestCase): @unittest.skipIf( api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, 'C++ implementation requires a call to MakeDescriptor()') + @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable') def testMakeClassWithNestedDescriptor(self): leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', containing_type=None, fields=[], @@ -2980,10 +3020,7 @@ class ClassAPITest(BaseTestCase): containing_type=None, fields=[], nested_types=[child_desc, sibling_desc], enum_types=[], extensions=[]) - message_class = reflection.MakeClass(parent_desc) - self.assertIn('child', message_class.__dict__) - self.assertIn('sibling', message_class.__dict__) - self.assertIn('leaf', message_class.child.__dict__) + reflection.MakeClass(parent_desc) def _GetSerializedFileDescriptor(self, name): """Get a serialized representation of a test FileDescriptorProto. diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index 237a2d50..ccf8ac16 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -33,20 +33,19 @@ """Test for google.protobuf.text_format.""" -__author__ = 'kenton@google.com (Kenton Varda)' - - +import io import math import re -import six import string +import textwrap + +import six +# pylint: disable=g-import-not-at-top try: - import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top + import unittest2 as unittest # PY26 except ImportError: - import unittest # pylint: disable=g-import-not-at-top - -from google.protobuf.internal import _parameterized + import unittest from google.protobuf import any_pb2 from google.protobuf import any_test_pb2 @@ -54,12 +53,13 @@ from google.protobuf import map_unittest_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 from google.protobuf import unittest_proto3_arena_pb2 -from google.protobuf.internal import api_implementation from google.protobuf.internal import any_test_pb2 as test_extend_any from google.protobuf.internal import message_set_extensions_pb2 from google.protobuf.internal import test_util from google.protobuf import descriptor_pool from google.protobuf import text_format +from google.protobuf.internal import _parameterized +# pylint: enable=g-import-not-at-top # Low-level nuts-n-bolts tests. @@ -100,8 +100,8 @@ class TextFormatBase(unittest.TestCase): return text -@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2)) -class TextFormatTest(TextFormatBase): +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatMessageToStringTests(TextFormatBase): def testPrintExotic(self, message_module): message = message_module.TestAllTypes() @@ -154,6 +154,40 @@ class TextFormatTest(TextFormatBase): 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' 'repeated_string: "Google" repeated_string: "Zurich"') + def VerifyPrintShortFormatRepeatedFields(self, message_module, as_one_line): + message = message_module.TestAllTypes() + message.repeated_int32.append(1) + message.repeated_string.append('Google') + message.repeated_string.append('Hello,World') + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_FOO) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) + message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) + message.optional_nested_message.bb = 3 + for i in (21, 32): + msg = message.repeated_nested_message.add() + msg.bb = i + expected_ascii = ( + 'optional_nested_message {\n bb: 3\n}\n' + 'repeated_int32: [1]\n' + 'repeated_string: "Google"\n' + 'repeated_string: "Hello,World"\n' + 'repeated_nested_message {\n bb: 21\n}\n' + 'repeated_nested_message {\n bb: 32\n}\n' + 'repeated_foreign_enum: [FOREIGN_FOO, FOREIGN_BAR, FOREIGN_BAZ]\n') + if as_one_line: + expected_ascii = expected_ascii.replace('\n ', '').replace('\n', '') + actual_ascii = text_format.MessageToString( + message, use_short_repeated_primitives=True, + as_one_line=as_one_line) + self.CompareToGoldenText(actual_ascii, expected_ascii) + parsed_message = message_module.TestAllTypes() + text_format.Parse(actual_ascii, parsed_message) + self.assertEqual(parsed_message, message) + + def tesPrintShortFormatRepeatedFields(self, message_module, as_one_line): + self.VerifyPrintShortFormatRepeatedFields(message_module, False) + self.VerifyPrintShortFormatRepeatedFields(message_module, True) + def testPrintNestedNewLineInStringAsOneLine(self, message_module): message = message_module.TestAllTypes() message.optional_string = 'a\nnew\nline' @@ -213,13 +247,18 @@ class TextFormatTest(TextFormatBase): def testPrintRawUtf8String(self, message_module): message = message_module.TestAllTypes() - message.repeated_string.append(u'\u00fc\ua71f') + message.repeated_string.append(u'\u00fc\t\ua71f') text = text_format.MessageToString(message, as_utf8=True) - self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') + golden_unicode = u'repeated_string: "\u00fc\\t\ua71f"\n' + golden_text = golden_unicode if six.PY3 else golden_unicode.encode('utf-8') + # MessageToString always returns a native str. + self.CompareToGoldenText(text, golden_text) parsed_message = message_module.TestAllTypes() text_format.Parse(text, parsed_message) - self.assertEqual(message, parsed_message, - '\n%s != %s' % (message, parsed_message)) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) def testPrintFloatFormat(self, message_module): # Check that float_format argument is passed to sub-message formatting. @@ -259,6 +298,36 @@ class TextFormatTest(TextFormatBase): message.c = 123 self.assertEqual('c: 123\n', str(message)) + def testMessageToStringUnicode(self, message_module): + golden_unicode = u'Á short desçription and a 🍌.' + golden_bytes = golden_unicode.encode('utf-8') + message = message_module.TestAllTypes() + message.optional_string = golden_unicode + message.optional_bytes = golden_bytes + text = text_format.MessageToString(message, as_utf8=True) + golden_message = textwrap.dedent( + 'optional_string: "Á short desçription and a 🍌."\n' + 'optional_bytes: ' + r'"\303\201 short des\303\247ription and a \360\237\215\214."' + '\n') + self.CompareToGoldenText(text, golden_message) + + def testMessageToStringASCII(self, message_module): + golden_unicode = u'Á short desçription and a 🍌.' + golden_bytes = golden_unicode.encode('utf-8') + message = message_module.TestAllTypes() + message.optional_string = golden_unicode + message.optional_bytes = golden_bytes + text = text_format.MessageToString(message, as_utf8=False) # ASCII + golden_message = ( + 'optional_string: ' + r'"\303\201 short des\303\247ription and a \360\237\215\214."' + '\n' + 'optional_bytes: ' + r'"\303\201 short des\303\247ription and a \360\237\215\214."' + '\n') + self.CompareToGoldenText(text, golden_message) + def testPrintField(self, message_module): message = message_module.TestAllTypes() field = message.DESCRIPTOR.fields_by_name['optional_float'] @@ -289,6 +358,45 @@ class TextFormatTest(TextFormatBase): self.assertEqual('0.0', out.getvalue()) out.close() + +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatMessageToTextBytesTests(TextFormatBase): + + def testMessageToBytes(self, message_module): + message = message_module.ForeignMessage() + message.c = 123 + self.assertEqual(b'c: 123\n', text_format.MessageToBytes(message)) + + def testRawUtf8RoundTrip(self, message_module): + message = message_module.TestAllTypes() + message.repeated_string.append(u'\u00fc\t\ua71f') + utf8_text = text_format.MessageToBytes(message, as_utf8=True) + golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n' + self.CompareToGoldenText(utf8_text, golden_bytes) + parsed_message = message_module.TestAllTypes() + text_format.Parse(utf8_text, parsed_message) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) + + def testEscapedUtf8ASCIIRoundTrip(self, message_module): + message = message_module.TestAllTypes() + message.repeated_string.append(u'\u00fc\t\ua71f') + ascii_text = text_format.MessageToBytes(message) # as_utf8=False default + golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n' + self.CompareToGoldenText(ascii_text, golden_bytes) + parsed_message = message_module.TestAllTypes() + text_format.Parse(ascii_text, parsed_message) + self.assertEqual( + message, parsed_message, '\n%s != %s (%s != %s)' % + (message, parsed_message, message.repeated_string[0], + parsed_message.repeated_string[0])) + + +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatParserTests(TextFormatBase): + def testParseAllFields(self, message_module): message = message_module.TestAllTypes() test_util.SetAllFields(message) @@ -318,14 +426,14 @@ class TextFormatTest(TextFormatBase): if message_module is unittest_pb2: test_util.ExpectAllFieldsSet(self, message) - if six.PY2: - msg2 = message_module.TestAllTypes() - text = (u'optional_string: "café"') - text_format.Merge(text, msg2) - self.assertEqual(msg2.optional_string, u'café') - msg2.Clear() - text_format.Parse(text, msg2) - self.assertEqual(msg2.optional_string, u'café') + msg2 = message_module.TestAllTypes() + text = (u'optional_string: "café"') + text_format.Merge(text, msg2) + self.assertEqual(msg2.optional_string, u'café') + msg2.Clear() + self.assertEqual(msg2.optional_string, u'') + text_format.Parse(text, msg2) + self.assertEqual(msg2.optional_string, u'café') def testParseExotic(self, message_module): message = message_module.TestAllTypes() @@ -425,7 +533,8 @@ class TextFormatTest(TextFormatBase): message = message_module.TestAllTypes() text = 'optional_nested_enum: BARR' six.assertRaisesRegex(self, text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + (r'1:23 : \'optional_nested_enum: BARR\': ' + r'Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value named BARR.'), text_format.Parse, text, message) @@ -433,7 +542,8 @@ class TextFormatTest(TextFormatBase): message = message_module.TestAllTypes() text = 'optional_int32: bork' six.assertRaisesRegex(self, text_format.ParseError, - ('1:17 : Couldn\'t parse integer: bork'), + ('1:17 : \'optional_int32: bork\': ' + 'Couldn\'t parse integer: bork'), text_format.Parse, text, message) def testParseStringFieldUnescape(self, message_module): @@ -457,6 +567,96 @@ class TextFormatTest(TextFormatBase): message.repeated_string[4]) self.assertEqual(SLASH + 'x20', message.repeated_string[5]) + def testParseOneof(self, message_module): + m = message_module.TestAllTypes() + m.oneof_uint32 = 11 + m2 = message_module.TestAllTypes() + text_format.Parse(text_format.MessageToString(m), m2) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + def testParseMultipleOneof(self, message_module): + m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) + m2 = message_module.TestAllTypes() + with six.assertRaisesRegex(self, text_format.ParseError, + ' is specified along with field '): + text_format.Parse(m_string, m2) + + # This example contains non-ASCII codepoint unicode data as literals + # which should come through as utf-8 for bytes, and as the unicode + # itself for string fields. It also demonstrates escaped binary data. + # The ur"" string prefix is unfortunately missing from Python 3 + # so we resort to double escaping our \s so that they come through. + _UNICODE_SAMPLE = u""" + optional_bytes: 'Á short desçription' + optional_string: 'Á short desçription' + repeated_bytes: '\\303\\201 short des\\303\\247ription' + repeated_bytes: '\\x12\\x34\\x56\\x78\\x90\\xab\\xcd\\xef' + repeated_string: '\\xd0\\x9f\\xd1\\x80\\xd0\\xb8\\xd0\\xb2\\xd0\\xb5\\xd1\\x82' + """ + _BYTES_SAMPLE = _UNICODE_SAMPLE.encode('utf-8') + _GOLDEN_UNICODE = u'Á short desçription' + _GOLDEN_BYTES = _GOLDEN_UNICODE.encode('utf-8') + _GOLDEN_BYTES_1 = b'\x12\x34\x56\x78\x90\xab\xcd\xef' + _GOLDEN_STR_0 = u'Привет' + + def testParseUnicode(self, message_module): + m = message_module.TestAllTypes() + text_format.Parse(self._UNICODE_SAMPLE, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data. + self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1) + # repeated_string[0] contained \ escaped data representing the UTF-8 + # representation of _GOLDEN_STR_0 - it needs to decode as such. + self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0) + + def testParseBytes(self, message_module): + m = message_module.TestAllTypes() + text_format.Parse(self._BYTES_SAMPLE, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data. + self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1) + # repeated_string[0] contained \ escaped data representing the UTF-8 + # representation of _GOLDEN_STR_0 - it needs to decode as such. + self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0) + + def testFromBytesFile(self, message_module): + m = message_module.TestAllTypes() + f = io.BytesIO(self._BYTES_SAMPLE) + text_format.ParseLines(f, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + def testFromUnicodeFile(self, message_module): + m = message_module.TestAllTypes() + f = io.StringIO(self._UNICODE_SAMPLE) + text_format.ParseLines(f, m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + def testFromBytesLines(self, message_module): + m = message_module.TestAllTypes() + text_format.ParseLines(self._BYTES_SAMPLE.split(b'\n'), m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + def testFromUnicodeLines(self, message_module): + m = message_module.TestAllTypes() + text_format.ParseLines(self._UNICODE_SAMPLE.split(u'\n'), m) + self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES) + self.assertEqual(m.optional_string, self._GOLDEN_UNICODE) + self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES) + + +@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2) +class TextFormatMergeTests(TextFormatBase): + def testMergeDuplicateScalars(self, message_module): message = message_module.TestAllTypes() text = ('optional_int32: 42 ' 'optional_int32: 67') @@ -472,26 +672,12 @@ class TextFormatTest(TextFormatBase): self.assertTrue(r is message) self.assertEqual(2, message.optional_nested_message.bb) - def testParseOneof(self, message_module): - m = message_module.TestAllTypes() - m.oneof_uint32 = 11 - m2 = message_module.TestAllTypes() - text_format.Parse(text_format.MessageToString(m), m2) - self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) - def testMergeMultipleOneof(self, message_module): m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) m2 = message_module.TestAllTypes() text_format.Merge(m_string, m2) self.assertEqual('oneof_string', m2.WhichOneof('oneof_field')) - def testParseMultipleOneof(self, message_module): - m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"']) - m2 = message_module.TestAllTypes() - with self.assertRaisesRegexp(text_format.ParseError, - ' is specified along with field '): - text_format.Parse(m_string, m2) - # These are tests that aren't fundamentally specific to proto2, but are at # the moment because of differences between the proto2 and proto3 test schemas. @@ -649,6 +835,29 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase): ' }\n' '}\n') + # In cpp implementation, __str__ calls the cpp implementation of text format. + def testPrintMapUsingCppImplementation(self): + message = map_unittest_pb2.TestMap() + inner_msg = message.map_int32_foreign_message[111] + inner_msg.c = 1 + self.assertEqual( + str(message), + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 1\n' + ' }\n' + '}\n') + inner_msg.c = 2 + self.assertEqual( + str(message), + 'map_int32_foreign_message {\n' + ' key: 111\n' + ' value {\n' + ' c: 2\n' + ' }\n' + '}\n') + def testMapOrderEnforcement(self): message = map_unittest_pb2.TestMap() for letter in string.ascii_uppercase[13:26]: @@ -938,7 +1147,7 @@ class Proto2Tests(TextFormatBase): '}\n') six.assertRaisesRegex(self, text_format.ParseError, - '5:1 : Expected ">".', + '5:1 : \'}\': Expected ">".', text_format.Parse, malformed, message, @@ -981,7 +1190,8 @@ class Proto2Tests(TextFormatBase): with self.assertRaises(text_format.ParseError) as e: text_format.Parse(text, message) self.assertEqual(str(e.exception), - '1:27 : Expected identifier or number, got "bb".') + '1:27 : \'optional_nested_message { "bb": 1 }\': ' + 'Expected identifier or number, got "bb".') def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() @@ -998,7 +1208,8 @@ class Proto2Tests(TextFormatBase): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' six.assertRaisesRegex(self, text_format.ParseError, - (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" ' + (r'1:23 : \'optional_nested_enum: 100\': ' + r'Enum type "\w+.TestAllTypes.NestedEnum" ' r'has no value with number 100.'), text_format.Parse, text, message) @@ -1209,6 +1420,24 @@ class Proto3Tests(unittest.TestCase): ' < data: "string" > ' '>') + def testPrintAndParseMessageInvalidAny(self): + packed_message = unittest_pb2.OneString() + packed_message.data = 'string' + message = any_test_pb2.TestAny() + message.any_value.Pack(packed_message) + # Only include string after last '/' in type_url. + message.any_value.type_url = message.any_value.TypeName() + text = text_format.MessageToString(message) + self.assertEqual( + text, 'any_value {\n' + ' type_url: "protobuf_unittest.OneString"\n' + ' value: "\\n\\006string"\n' + '}\n') + + parsed_message = any_test_pb2.TestAny() + text_format.Parse(text, parsed_message) + self.assertEqual(message, parsed_message) + def testUnknownEnums(self): message = unittest_proto3_arena_pb2.TestAllTypes() message2 = unittest_proto3_arena_pb2.TestAllTypes() @@ -1448,6 +1677,26 @@ class TokenizerTest(unittest.TestCase): self.assertEqual(0, text_format._ConsumeUint64(tokenizer)) self.assertTrue(tokenizer.AtEnd()) + def testConsumeOctalIntegers(self): + """Test support for C style octal integers.""" + text = '00 -00 04 0755 -010 007 -0033 08 -09 01' + tokenizer = text_format.Tokenizer(text.splitlines()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(0, tokenizer.ConsumeInteger()) + self.assertEqual(4, tokenizer.ConsumeInteger()) + self.assertEqual(0o755, tokenizer.ConsumeInteger()) + self.assertEqual(-0o10, tokenizer.ConsumeInteger()) + self.assertEqual(7, tokenizer.ConsumeInteger()) + self.assertEqual(-0o033, tokenizer.ConsumeInteger()) + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() # 08 + tokenizer.NextToken() + with self.assertRaises(text_format.ParseError): + tokenizer.ConsumeInteger() # -09 + tokenizer.NextToken() + self.assertEqual(1, tokenizer.ConsumeInteger()) + self.assertTrue(tokenizer.AtEnd()) + def testConsumeByteString(self): text = '"string1\'' tokenizer = text_format.Tokenizer(text.splitlines()) @@ -1556,6 +1805,12 @@ class TokenizerTest(unittest.TestCase): tokenizer.ConsumeCommentOrTrailingComment()) self.assertTrue(tokenizer.AtEnd()) + def testHugeString(self): + # With pathologic backtracking, fails with Forge OOM. + text = '"' + 'a' * (10 * 1024 * 1024) + '"' + tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False) + tokenizer.ConsumeString() + # Tests for pretty printer functionality. @_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2)) @@ -1652,5 +1907,64 @@ class PrettyPrinterTest(TextFormatBase): 'repeated_nested_message { My lucky number is 42 } ' 'repeated_nested_message { My lucky number is 99 }')) + +class WhitespaceTest(TextFormatBase): + + def setUp(self): + self.out = text_format.TextWriter(False) + self.addCleanup(self.out.close) + self.message = unittest_pb2.NestedTestAllTypes() + self.message.child.payload.optional_string = 'value' + self.field = self.message.DESCRIPTOR.fields_by_name['child'] + self.value = self.message.child + + def testMessageToString(self): + self.CompareToGoldenText( + text_format.MessageToString(self.message), + textwrap.dedent("""\ + child { + payload { + optional_string: "value" + } + } + """)) + + def testPrintMessage(self): + text_format.PrintMessage(self.message, self.out) + self.CompareToGoldenText( + self.out.getvalue(), + textwrap.dedent("""\ + child { + payload { + optional_string: "value" + } + } + """)) + + def testPrintField(self): + text_format.PrintField(self.field, self.value, self.out) + self.CompareToGoldenText( + self.out.getvalue(), + textwrap.dedent("""\ + child { + payload { + optional_string: "value" + } + } + """)) + + def testPrintFieldValue(self): + text_format.PrintFieldValue( + self.field, self.value, self.out) + self.CompareToGoldenText( + self.out.getvalue(), + textwrap.dedent("""\ + { + payload { + optional_string: "value" + } + }""")) + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 4a76cd4e..0807e7f7 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -185,6 +185,14 @@ class UnicodeValueChecker(object): 'encoding. Non-UTF-8 strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) + else: + try: + proposed_value.encode('utf8') + except UnicodeEncodeError: + raise ValueError('%.1024r isn\'t a valid unicode string and ' + 'can\'t be encoded in UTF-8.'% + (proposed_value)) + return proposed_value def DefaultValue(self): diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py index 8b7de2e7..fceadf71 100755 --- a/python/google/protobuf/internal/unknown_fields_test.py +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2 from google.protobuf.internal import test_util from google.protobuf.internal import testing_refleaks from google.protobuf.internal import type_checkers +from google.protobuf import descriptor BaseTestCase = testing_refleaks.BaseTestCase -# CheckUnknownField() cannot be used by the C++ implementation because -# some protect members are called. It is not a behavior difference -# for python and C++ implementation. -def SkipCheckUnknownFieldIfCppImplementation(func): - return unittest.skipIf( - api_implementation.Type() == 'cpp' and api_implementation.Version() == 2, - 'Addtional test for pure python involved protect members')(func) - - class UnknownFieldsTest(BaseTestCase): def setUp(self): @@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase): # stdout. self.assertTrue(data == self.all_fields_data) - def expectSerializeProto3(self, preserve): + def testSerializeProto3(self): + # Verify proto3 unknown fields behavior. message = unittest_proto3_arena_pb2.TestEmptyMessage() message.ParseFromString(self.all_fields_data) - if preserve: - self.assertEqual(self.all_fields_data, message.SerializeToString()) - else: - self.assertEqual(0, len(message.SerializeToString())) - - def testSerializeProto3(self): - # Verify that proto3 unknown fields behavior. - default_preserve = (api_implementation - .GetPythonProto3PreserveUnknownsDefault()) - self.expectSerializeProto3(default_preserve) - api_implementation.SetPythonProto3PreserveUnknownsDefault( - not default_preserve) - self.expectSerializeProto3(not default_preserve) - api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve) + self.assertEqual(self.all_fields_data, message.SerializeToString()) def testByteSize(self): self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) @@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.empty_message = unittest_pb2.TestEmptyMessage() self.empty_message.ParseFromString(self.all_fields_data) - # CheckUnknownField() is an additional Pure Python check which checks + # InternalCheckUnknownField() is an additional Pure Python check which checks # a detail of unknown fields. It cannot be used by the C++ # implementation because some protect members are called. # The test is added for historical reasons. It is not necessary as # serialized string is checked. - - def CheckUnknownField(self, name, expected_value): + # TODO(jieluo): Remove message._unknown_fields. + def InternalCheckUnknownField(self, name, expected_value): + if api_implementation.Type() == 'cpp': + return field_descriptor = self.descriptor.fields_by_name[name] wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] field_tag = encoder.TagBytes(field_descriptor.number, wire_type) @@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase): for tag_bytes, value in self.empty_message._unknown_fields: if tag_bytes == field_tag: decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] - decoder(value, 0, len(value), self.all_fields, result_dict) + decoder(memoryview(value), 0, len(value), self.all_fields, result_dict) self.assertEqual(expected_value, result_dict[field_descriptor]) - @SkipCheckUnknownFieldIfCppImplementation + def CheckUnknownField(self, name, unknown_fields, expected_value): + field_descriptor = self.descriptor.fields_by_name[name] + expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[ + field_descriptor.type] + for unknown_field in unknown_fields: + if unknown_field.field_number == field_descriptor.number: + self.assertEqual(expected_type, unknown_field.wire_type) + if expected_type == 3: + # Check group + self.assertEqual(expected_value[0], + unknown_field.data[0].field_number) + self.assertEqual(expected_value[1], unknown_field.data[0].wire_type) + self.assertEqual(expected_value[2], unknown_field.data[0].data) + continue + if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: + self.assertIn(unknown_field.data, expected_value) + else: + self.assertEqual(expected_value, unknown_field.data) + def testCheckUnknownFieldValue(self): + unknown_fields = self.empty_message.UnknownFields() # Test enum. self.CheckUnknownField('optional_nested_enum', + unknown_fields, self.all_fields.optional_nested_enum) + self.InternalCheckUnknownField('optional_nested_enum', + self.all_fields.optional_nested_enum) + # Test repeated enum. self.CheckUnknownField('repeated_nested_enum', + unknown_fields, self.all_fields.repeated_nested_enum) + self.InternalCheckUnknownField('repeated_nested_enum', + self.all_fields.repeated_nested_enum) # Test varint. self.CheckUnknownField('optional_int32', + unknown_fields, self.all_fields.optional_int32) + self.InternalCheckUnknownField('optional_int32', + self.all_fields.optional_int32) + # Test fixed32. self.CheckUnknownField('optional_fixed32', + unknown_fields, self.all_fields.optional_fixed32) + self.InternalCheckUnknownField('optional_fixed32', + self.all_fields.optional_fixed32) # Test fixed64. self.CheckUnknownField('optional_fixed64', + unknown_fields, self.all_fields.optional_fixed64) + self.InternalCheckUnknownField('optional_fixed64', + self.all_fields.optional_fixed64) # Test lengthd elimited. self.CheckUnknownField('optional_string', - self.all_fields.optional_string) + unknown_fields, + self.all_fields.optional_string.encode('utf-8')) + self.InternalCheckUnknownField('optional_string', + self.all_fields.optional_string) # Test group. self.CheckUnknownField('optionalgroup', - self.all_fields.optionalgroup) + unknown_fields, + (17, 0, 117)) + self.InternalCheckUnknownField('optionalgroup', + self.all_fields.optionalgroup) + + self.assertEqual(97, len(unknown_fields)) def testCopyFrom(self): message = unittest_pb2.TestEmptyMessage() @@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase): message.optional_int64 = 3 message.optional_uint32 = 4 destination = unittest_pb2.TestEmptyMessage() + unknown_fields = destination.UnknownFields() + self.assertEqual(0, len(unknown_fields)) destination.ParseFromString(message.SerializeToString()) - + # ParseFromString clears the message thus unknown fields is invalid. + with self.assertRaises(ValueError) as context: + len(unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + unknown_fields = destination.UnknownFields() + self.assertEqual(2, len(unknown_fields)) destination.MergeFrom(source) + self.assertEqual(4, len(unknown_fields)) # Check that the fields where correctly merged, even stored in the unknown # fields set. message.ParseFromString(destination.SerializeToString()) @@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase): self.assertEqual(message.optional_int64, 3) def testClear(self): + unknown_fields = self.empty_message.UnknownFields() self.empty_message.Clear() # All cleared, even unknown fields. self.assertEqual(self.empty_message.SerializeToString(), b'') + with self.assertRaises(ValueError) as context: + len(unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + + def testSubUnknownFields(self): + message = unittest_pb2.TestAllTypes() + message.optionalgroup.a = 123 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + sub_unknown_fields = destination.UnknownFields()[0].data + self.assertEqual(1, len(sub_unknown_fields)) + self.assertEqual(sub_unknown_fields[0].data, 123) + destination.Clear() + with self.assertRaises(ValueError) as context: + len(sub_unknown_fields) + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + with self.assertRaises(ValueError) as context: + # pylint: disable=pointless-statement + sub_unknown_fields[0] + self.assertIn('UnknownFields does not exist.', + str(context.exception)) + message.Clear() + message.optional_uint32 = 456 + nested_message = unittest_pb2.NestedTestAllTypes() + nested_message.payload.optional_nested_message.ParseFromString( + message.SerializeToString()) + unknown_fields = ( + nested_message.payload.optional_nested_message.UnknownFields()) + self.assertEqual(unknown_fields[0].data, 456) + nested_message.ClearField('payload') + self.assertEqual(unknown_fields[0].data, 456) + unknown_fields = ( + nested_message.payload.optional_nested_message.UnknownFields()) + self.assertEqual(0, len(unknown_fields)) + + def testUnknownField(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 123 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_field = destination.UnknownFields()[0] + destination.Clear() + with self.assertRaises(ValueError) as context: + unknown_field.data # pylint: disable=pointless-statement + self.assertIn('The parent message might be cleared.', + str(context.exception)) def testUnknownExtensions(self): message = unittest_pb2.TestEmptyMessageWithExtensions() @@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase): def CheckUnknownField(self, name, expected_value): field_descriptor = self.descriptor.fields_by_name[name] - wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] - field_tag = encoder.TagBytes(field_descriptor.number, wire_type) - result_dict = {} - for tag_bytes, value in self.missing_message._unknown_fields: - if tag_bytes == field_tag: - decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ - tag_bytes][0] - decoder(value, 0, len(value), self.message, result_dict) - self.assertEqual(expected_value, result_dict[field_descriptor]) + unknown_fields = self.missing_message.UnknownFields() + for field in unknown_fields: + if field.field_number == field_descriptor.number: + if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED: + self.assertIn(field.data, expected_value) + else: + self.assertEqual(expected_value, field.data) def testUnknownParseMismatchEnumValue(self): just_string = missing_enum_values_pb2.JustString() @@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase): def testUnknownPackedEnumValue(self): self.assertEqual([], self.missing_message.packed_nested_enum) - @SkipCheckUnknownFieldIfCppImplementation def testCheckUnknownFieldValueForEnum(self): self.CheckUnknownField('optional_nested_enum', self.message.optional_nested_enum) diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py index 37a65cfa..95c5615f 100644 --- a/python/google/protobuf/internal/well_known_types.py +++ b/python/google/protobuf/internal/well_known_types.py @@ -40,6 +40,7 @@ This files defines well known classes which need extra maintenance including: __author__ = 'jieluo@google.com (Jie Luo)' +import calendar import collections from datetime import datetime from datetime import timedelta @@ -92,7 +93,7 @@ class Any(object): def Is(self, descriptor): """Checks if this Any represents the given protobuf type.""" - return self.TypeName() == descriptor.full_name + return '/' in self.type_url and self.TypeName() == descriptor.full_name class Timestamp(object): @@ -233,9 +234,15 @@ class Timestamp(object): def FromDatetime(self, dt): """Converts datetime to Timestamp.""" - td = dt - datetime(1970, 1, 1) - self.seconds = td.seconds + td.days * _SECONDS_PER_DAY - self.nanos = td.microseconds * _NANOS_PER_MICROSECOND + # Using this guide: http://wiki.python.org/moin/WorkingWithTime + # And this conversion guide: http://docs.python.org/library/time.html + + # Turn the date parameter into a tuple (struct_time) that can then be + # manipulated into a long value of seconds. During the conversion from + # struct_time to long, the source date in UTC, and so it follows that the + # correct transformation is calendar.timegm() + self.seconds = calendar.timegm(dt.utctimetuple()) + self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND class Duration(object): diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py index 965940b2..4dc2ae4f 100644 --- a/python/google/protobuf/internal/well_known_types_test.py +++ b/python/google/protobuf/internal/well_known_types_test.py @@ -35,7 +35,7 @@ __author__ = 'jieluo@google.com (Jie Luo)' import collections -from datetime import datetime +import datetime try: import unittest2 as unittest #PY26 @@ -240,14 +240,34 @@ class TimeUtilTest(TimeUtilTestBase): def testDatetimeConverison(self): message = timestamp_pb2.Timestamp() - dt = datetime(1970, 1, 1) + dt = datetime.datetime(1970, 1, 1) message.FromDatetime(dt) self.assertEqual(dt, message.ToDatetime()) message.FromMilliseconds(1999) - self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000), + self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 1, 999000), message.ToDatetime()) + def testDatetimeConversionWithTimezone(self): + class TZ(datetime.tzinfo): + + def utcoffset(self, _): + return datetime.timedelta(hours=1) + + def dst(self, _): + return datetime.timedelta(0) + + def tzname(self, _): + return 'UTC+1' + + message1 = timestamp_pb2.Timestamp() + dt = datetime.datetime(1970, 1, 1, 1, tzinfo=TZ()) + message1.FromDatetime(dt) + message2 = timestamp_pb2.Timestamp() + dt = datetime.datetime(1970, 1, 1, 0) + message2.FromDatetime(dt) + self.assertEqual(message1, message2) + def testTimedeltaConversion(self): message = duration_pb2.Duration() message.FromNanoseconds(1999999999) @@ -879,6 +899,17 @@ class AnyTest(unittest.TestCase): raise AttributeError('%s should not have Pack method.' % msg_descriptor.full_name) + def testUnpackWithNoSlashInTypeUrl(self): + msg = any_test_pb2.TestAny() + all_types = unittest_pb2.TestAllTypes() + all_descriptor = all_types.DESCRIPTOR + msg.value.Pack(all_types) + # Reset type_url to part of type_url after '/' + msg.value.type_url = msg.value.TypeName() + self.assertFalse(msg.value.Is(all_descriptor)) + unpacked_message = unittest_pb2.TestAllTypes() + self.assertFalse(msg.value.Unpack(unpacked_message)) + def testMessageName(self): # Creates and sets message. submessage = any_test_pb2.TestAny() |