aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/internal
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal')
-rwxr-xr-xpython/google/protobuf/internal/__init__.py30
-rwxr-xr-xpython/google/protobuf/internal/api_implementation.py26
-rwxr-xr-xpython/google/protobuf/internal/containers.py127
-rwxr-xr-xpython/google/protobuf/internal/decoder.py196
-rw-r--r--python/google/protobuf/internal/descriptor_database_test.py13
-rw-r--r--python/google/protobuf/internal/descriptor_pool_test.py63
-rwxr-xr-xpython/google/protobuf/internal/descriptor_test.py11
-rw-r--r--python/google/protobuf/internal/factory_test1.proto14
-rw-r--r--python/google/protobuf/internal/message_factory_test.py4
-rwxr-xr-xpython/google/protobuf/internal/message_test.py279
-rw-r--r--python/google/protobuf/internal/no_package.proto30
-rwxr-xr-xpython/google/protobuf/internal/python_message.py145
-rwxr-xr-xpython/google/protobuf/internal/reflection_test.py57
-rwxr-xr-xpython/google/protobuf/internal/text_format_test.py398
-rwxr-xr-xpython/google/protobuf/internal/type_checkers.py8
-rwxr-xr-xpython/google/protobuf/internal/unknown_fields_test.py165
-rw-r--r--python/google/protobuf/internal/well_known_types.py15
-rw-r--r--python/google/protobuf/internal/well_known_types_test.py37
18 files changed, 1407 insertions, 211 deletions
diff --git a/python/google/protobuf/internal/__init__.py b/python/google/protobuf/internal/__init__.py
index e69de29b..7d2e571a 100755
--- a/python/google/protobuf/internal/__init__.py
+++ b/python/google/protobuf/internal/__init__.py
@@ -0,0 +1,30 @@
+# Protocol Buffers - Google's data interchange format
+# Copyright 2008 Google Inc. All rights reserved.
+# https://developers.google.com/protocol-buffers/
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are
+# met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+# * Redistributions in binary form must reproduce the above
+# copyright notice, this list of conditions and the following disclaimer
+# in the documentation and/or other materials provided with the
+# distribution.
+# * Neither the name of Google Inc. nor the names of its
+# contributors may be used to endorse or promote products derived from
+# this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py
index ab9e7812..23cc2c0a 100755
--- a/python/google/protobuf/internal/api_implementation.py
+++ b/python/google/protobuf/internal/api_implementation.py
@@ -145,29 +145,3 @@ def Version():
# For internal use only
def IsPythonDefaultSerializationDeterministic():
return _python_deterministic_proto_serialization
-
-# DO NOT USE: For migration and testing only. Will be removed when Proto3
-# defaults to preserve unknowns.
-if _implementation_type == 'cpp':
- try:
- # pylint: disable=g-import-not-at-top
- from google.protobuf.pyext import _message
-
- def GetPythonProto3PreserveUnknownsDefault():
- return _message.GetPythonProto3PreserveUnknownsDefault()
-
- def SetPythonProto3PreserveUnknownsDefault(preserve):
- _message.SetPythonProto3PreserveUnknownsDefault(preserve)
- except ImportError:
- # Unrecognized cpp implementation. Skipping the unknown fields APIs.
- pass
-else:
- _python_proto3_preserve_unknowns_default = True
-
- def GetPythonProto3PreserveUnknownsDefault():
- return _python_proto3_preserve_unknowns_default
-
- def SetPythonProto3PreserveUnknownsDefault(preserve):
- global _python_proto3_preserve_unknowns_default
- _python_proto3_preserve_unknowns_default = preserve
-
diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py
index c6a3692a..182cac99 100755
--- a/python/google/protobuf/internal/containers.py
+++ b/python/google/protobuf/internal/containers.py
@@ -628,3 +628,130 @@ class MessageMap(MutableMapping):
def GetEntryClass(self):
return self._entry_descriptor._concrete_class
+
+
+class _UnknownField(object):
+
+ """A parsed unknown field."""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_field_number', '_wire_type', '_data']
+
+ def __init__(self, field_number, wire_type, data):
+ self._field_number = field_number
+ self._wire_type = wire_type
+ self._data = data
+ return
+
+ def __lt__(self, other):
+ # pylint: disable=protected-access
+ return self._field_number < other._field_number
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ # pylint: disable=protected-access
+ return (self._field_number == other._field_number and
+ self._wire_type == other._wire_type and
+ self._data == other._data)
+
+
+class UnknownFieldRef(object):
+
+ def __init__(self, parent, index):
+ self._parent = parent
+ self._index = index
+ return
+
+ def _check_valid(self):
+ if not self._parent:
+ raise ValueError('UnknownField does not exist. '
+ 'The parent message might be cleared.')
+ if self._index >= len(self._parent):
+ raise ValueError('UnknownField does not exist. '
+ 'The parent message might be cleared.')
+
+ @property
+ def field_number(self):
+ self._check_valid()
+ # pylint: disable=protected-access
+ return self._parent._internal_get(self._index)._field_number
+
+ @property
+ def wire_type(self):
+ self._check_valid()
+ # pylint: disable=protected-access
+ return self._parent._internal_get(self._index)._wire_type
+
+ @property
+ def data(self):
+ self._check_valid()
+ # pylint: disable=protected-access
+ return self._parent._internal_get(self._index)._data
+
+
+class UnknownFieldSet(object):
+
+ """UnknownField container"""
+
+ # Disallows assignment to other attributes.
+ __slots__ = ['_values']
+
+ def __init__(self):
+ self._values = []
+
+ def __getitem__(self, index):
+ if self._values is None:
+ raise ValueError('UnknownFields does not exist. '
+ 'The parent message might be cleared.')
+ size = len(self._values)
+ if index < 0:
+ index += size
+ if index < 0 or index >= size:
+ raise IndexError('index %d out of range'.index)
+
+ return UnknownFieldRef(self, index)
+
+ def _internal_get(self, index):
+ return self._values[index]
+
+ def __len__(self):
+ if self._values is None:
+ raise ValueError('UnknownFields does not exist. '
+ 'The parent message might be cleared.')
+ return len(self._values)
+
+ def _add(self, field_number, wire_type, data):
+ unknown_field = _UnknownField(field_number, wire_type, data)
+ self._values.append(unknown_field)
+ return unknown_field
+
+ def __iter__(self):
+ for i in range(len(self)):
+ yield UnknownFieldRef(self, i)
+
+ def _extend(self, other):
+ if other is None:
+ return
+ # pylint: disable=protected-access
+ self._values.extend(other._values)
+
+ def __eq__(self, other):
+ if self is other:
+ return True
+ # Sort unknown fields because their order shouldn't
+ # affect equality test.
+ values = list(self._values)
+ if other is None:
+ return not values
+ values.sort()
+ # pylint: disable=protected-access
+ other_values = sorted(other._values)
+ return values == other_values
+
+ def _clear(self):
+ for value in self._values:
+ # pylint: disable=protected-access
+ if isinstance(value._data, UnknownFieldSet):
+ value._data._clear() # pylint: disable=protected-access
+ self._values = None
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 52b64915..5a540184 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -81,12 +81,17 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-
+import sys
import six
+_UCS2_MAXUNICODE = 65535
if six.PY3:
long = int
+else:
+ import re # pylint: disable=g-import-not-at-top
+ _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
+from google.protobuf.internal import containers
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
@@ -167,7 +172,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
def ReadTag(buffer, pos):
- """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
+ """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
We return the raw bytes of the tag rather than decoding them. The raw
bytes can then be used to look up the proper decoder. This effectively allows
@@ -175,13 +180,21 @@ def ReadTag(buffer, pos):
for work that is done in C (searching for a byte string in a hash table).
In a low-level language it would be much cheaper to decode the varint and
use that, but not in Python.
- """
+ Args:
+ buffer: memoryview object of the encoded bytes
+ pos: int of the current position to start from
+
+ Returns:
+ Tuple[bytes, int] of the tag data and new position.
+ """
start = pos
while six.indexbytes(buffer, pos) & 0x80:
pos += 1
pos += 1
- return (six.binary_type(buffer[start:pos]), pos)
+
+ tag_bytes = buffer[start:pos].tobytes()
+ return tag_bytes, pos
# --------------------------------------------------------------------
@@ -295,10 +308,20 @@ def _FloatDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
+ """Decode serialized float to a float and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes
+ pos: int, position in the memory view to start at.
+
+ Returns:
+ Tuple[float, int] of the deserialized float value and new position
+ in the serialized data.
+ """
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
new_pos = pos + 4
- float_bytes = buffer[pos:new_pos]
+ float_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
@@ -329,10 +352,20 @@ def _DoubleDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
+ """Decode serialized double to a double and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+
+ Returns:
+ Tuple[float, int] of the decoded double value and new position
+ in the serialized data.
+ """
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
new_pos = pos + 8
- double_bytes = buffer[pos:new_pos]
+ double_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
@@ -355,6 +388,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
if is_packed:
local_DecodeVarint = _DecodeVarint
def DecodePackedField(buffer, pos, end, message, field_dict):
+ """Decode serialized packed enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -365,6 +410,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
while pos < endpoint:
value_start_pos = pos
(element, pos) = _DecodeSignedVarint32(buffer, pos)
+ # pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
@@ -372,8 +418,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
message._unknown_fields = []
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
+
message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
+ (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+ # pylint: enable=protected-access
if pos > endpoint:
if element in enum_type.values_by_number:
del value[-1] # Discard corrupt value.
@@ -386,18 +434,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ """Decode serialized repeated enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(element, new_pos) = _DecodeSignedVarint32(buffer, pos)
+ # pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
if not message._unknown_fields:
message._unknown_fields = []
message._unknown_fields.append(
- (tag_bytes, buffer[pos:new_pos]))
+ (tag_bytes, buffer[pos:new_pos].tobytes()))
+ # pylint: enable=protected-access
# Predict that the next tag is another copy of the same repeated
# field.
pos = new_pos + tag_len
@@ -409,10 +471,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
+ """Decode serialized repeated enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value_start_pos = pos
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
+ # pylint: disable=protected-access
if enum_value in enum_type.values_by_number:
field_dict[key] = enum_value
else:
@@ -421,7 +496,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
+ (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+ # pylint: enable=protected-access
return pos
return DecodeField
@@ -458,20 +534,34 @@ BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
-def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
+def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
+ is_strict_utf8=False):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
local_unicode = six.text_type
- def _ConvertToUnicode(byte_str):
+ def _ConvertToUnicode(memview):
+ """Convert byte to unicode."""
+ byte_str = memview.tobytes()
try:
- return local_unicode(byte_str, 'utf-8')
+ value = local_unicode(byte_str, 'utf-8')
except UnicodeDecodeError as e:
# add more information to the error message and re-raise it.
e.reason = '%s in field: %s' % (e, key.full_name)
raise
+ if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE:
+ # Only do the check for python2 ucs4 when is_strict_utf8 enabled
+ if _SURROGATE_PATTERN.search(value):
+ reason = ('String field %s contains invalid UTF-8 data when parsing'
+ 'a protocol buffer: surrogates not allowed. Use'
+ 'the bytes type if you intend to send raw bytes.') % (
+ key.full_name)
+ raise message.DecodeError(reason)
+
+ return value
+
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
@@ -523,7 +613,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- value.append(buffer[pos:new_pos])
+ value.append(buffer[pos:new_pos].tobytes())
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -536,7 +626,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- field_dict[key] = buffer[pos:new_pos]
+ field_dict[key] = buffer[pos:new_pos].tobytes()
return new_pos
return DecodeField
@@ -665,6 +755,18 @@ def MessageSetItemDecoder(descriptor):
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
+ """Decode serialized message set to its value and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
message_set_item_start = pos
type_id = -1
message_start = -1
@@ -695,6 +797,7 @@ def MessageSetItemDecoder(descriptor):
raise _DecodeError('MessageSet item missing message.')
extension = message.Extensions._FindExtensionByNumber(type_id)
+ # pylint: disable=protected-access
if extension is not None:
value = field_dict.get(extension)
if value is None:
@@ -707,8 +810,9 @@ def MessageSetItemDecoder(descriptor):
else:
if not message._unknown_fields:
message._unknown_fields = []
- message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
- buffer[message_set_item_start:pos]))
+ message._unknown_fields.append(
+ (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
+ # pylint: enable=protected-access
return pos
@@ -767,7 +871,7 @@ def _SkipVarint(buffer, pos, end):
# Previously ord(buffer[pos]) raised IndexError when pos is out of range.
# With this code, ord(b'') raises TypeError. Both are handled in
# python_message.py to generate a 'Truncated message' error.
- while ord(buffer[pos:pos+1]) & 0x80:
+ while ord(buffer[pos:pos+1].tobytes()) & 0x80:
pos += 1
pos += 1
if pos > end:
@@ -782,6 +886,13 @@ def _SkipFixed64(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
+def _DecodeFixed64(buffer, pos):
+ """Decode a fixed64."""
+ new_pos = pos + 8
+ return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
+
+
def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""
@@ -791,6 +902,7 @@ def _SkipLengthDelimited(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
@@ -801,11 +913,53 @@ def _SkipGroup(buffer, pos, end):
return pos
pos = new_pos
+
+def _DecodeGroup(buffer, pos):
+ """Decode group. Returns the UnknownFieldSet and new position."""
+
+ unknown_field_set = containers.UnknownFieldSet()
+ while 1:
+ (tag_bytes, pos) = ReadTag(buffer, pos)
+ (tag, _) = _DecodeVarint(tag_bytes, 0)
+ field_number, wire_type = wire_format.UnpackTag(tag)
+ if wire_type == wire_format.WIRETYPE_END_GROUP:
+ break
+ (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+ # pylint: disable=protected-access
+ unknown_field_set._add(field_number, wire_type, data)
+
+ return (unknown_field_set, pos)
+
+
+def _DecodeUnknownField(buffer, pos, wire_type):
+ """Decode a unknown field. Returns the UnknownField and new position."""
+
+ if wire_type == wire_format.WIRETYPE_VARINT:
+ (data, pos) = _DecodeVarint(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_FIXED64:
+ (data, pos) = _DecodeFixed64(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_FIXED32:
+ (data, pos) = _DecodeFixed32(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+ (size, pos) = _DecodeVarint(buffer, pos)
+ data = buffer[pos:pos+size]
+ pos += size
+ elif wire_type == wire_format.WIRETYPE_START_GROUP:
+ (data, pos) = _DecodeGroup(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_END_GROUP:
+ return (0, -1)
+ else:
+ raise _DecodeError('Wrong wire type in tag.')
+
+ return (data, pos)
+
+
def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
return -1
+
def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""
@@ -814,6 +968,14 @@ def _SkipFixed32(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
+def _DecodeFixed32(buffer, pos):
+ """Decode a fixed32."""
+
+ new_pos = pos + 4
+ return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
+
+
def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""
diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py
index f97477b3..da5dbd92 100644
--- a/python/google/protobuf/internal/descriptor_database_test.py
+++ b/python/google/protobuf/internal/descriptor_database_test.py
@@ -43,6 +43,7 @@ import warnings
from google.protobuf import unittest_pb2
from google.protobuf import descriptor_pb2
from google.protobuf.internal import factory_test2_pb2
+from google.protobuf.internal import no_package_pb2
from google.protobuf import descriptor_database
@@ -52,7 +53,10 @@ class DescriptorDatabaseTest(unittest.TestCase):
db = descriptor_database.DescriptorDatabase()
file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString(
factory_test2_pb2.DESCRIPTOR.serialized_pb)
+ file_desc_proto2 = descriptor_pb2.FileDescriptorProto.FromString(
+ no_package_pb2.DESCRIPTOR.serialized_pb)
db.Add(file_desc_proto)
+ db.Add(file_desc_proto2)
self.assertEqual(file_desc_proto, db.FindFileByName(
'google/protobuf/internal/factory_test2.proto'))
@@ -76,6 +80,10 @@ class DescriptorDatabaseTest(unittest.TestCase):
# Can find enum value.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Enum.FACTORY_2_VALUE_0'))
+ self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.FACTORY_2_VALUE_0'))
+ self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
+ '.NO_PACKAGE_VALUE_0'))
# Can find top level extension.
self.assertEqual(file_desc_proto, db.FindFileContainingSymbol(
'google.protobuf.python.internal.another_field'))
@@ -95,9 +103,8 @@ class DescriptorDatabaseTest(unittest.TestCase):
self.assertEqual(file_desc_proto2, db.FindFileContainingSymbol(
'protobuf_unittest.TestAllTypes.none_field'))
- self.assertRaises(KeyError,
- db.FindFileContainingSymbol,
- 'protobuf_unittest.NoneMessage')
+ with self.assertRaisesRegexp(KeyError, r'\'protobuf_unittest\.NoneMessage\''):
+ db.FindFileContainingSymbol('protobuf_unittest.NoneMessage')
def testConflictRegister(self):
db = descriptor_database.DescriptorDatabase()
diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py
index 2cbf7813..1b72b0b9 100644
--- a/python/google/protobuf/internal/descriptor_pool_test.py
+++ b/python/google/protobuf/internal/descriptor_pool_test.py
@@ -36,7 +36,6 @@ __author__ = 'matthewtoia@google.com (Matt Toia)'
import copy
import os
-import sys
import warnings
try:
@@ -55,6 +54,7 @@ from google.protobuf.internal import factory_test1_pb2
from google.protobuf.internal import factory_test2_pb2
from google.protobuf.internal import file_options_test_pb2
from google.protobuf.internal import more_messages_pb2
+from google.protobuf.internal import no_package_pb2
from google.protobuf import descriptor
from google.protobuf import descriptor_database
from google.protobuf import descriptor_pool
@@ -120,7 +120,6 @@ class DescriptorPoolTestBase(object):
self.assertIsInstance(file_desc5, descriptor.FileDescriptor)
self.assertEqual('google/protobuf/unittest.proto',
file_desc5.name)
-
# Tests the generated pool.
assert descriptor_pool.Default().FindFileContainingSymbol(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@@ -129,6 +128,32 @@ class DescriptorPoolTestBase(object):
assert descriptor_pool.Default().FindFileContainingSymbol(
'protobuf_unittest.TestService')
+ # Can find field.
+ file_desc6 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.Factory1Message.list_value')
+ self.assertIsInstance(file_desc6, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/internal/factory_test1.proto',
+ file_desc6.name)
+
+ # Can find top level Enum value.
+ file_desc7 = self.pool.FindFileContainingSymbol(
+ 'google.protobuf.python.internal.FACTORY_1_VALUE_0')
+ self.assertIsInstance(file_desc7, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/internal/factory_test1.proto',
+ file_desc7.name)
+
+ # Can find nested Enum value.
+ file_desc8 = self.pool.FindFileContainingSymbol(
+ 'protobuf_unittest.TestAllTypes.FOO')
+ self.assertIsInstance(file_desc8, descriptor.FileDescriptor)
+ self.assertEqual('google/protobuf/unittest.proto',
+ file_desc8.name)
+
+ # TODO(jieluo): Add tests for no package when b/13860351 is fixed.
+
+ self.assertRaises(KeyError, self.pool.FindFileContainingSymbol,
+ 'google.protobuf.python.internal.Factory1Message.none_field')
+
def testFindFileContainingSymbolFailure(self):
with self.assertRaises(KeyError):
self.pool.FindFileContainingSymbol('Does not exist')
@@ -217,11 +242,10 @@ class DescriptorPoolTestBase(object):
def testFindTypeErrors(self):
self.assertRaises(TypeError, self.pool.FindExtensionByNumber, '')
+ self.assertRaises(KeyError, self.pool.FindMethodByName, '')
# TODO(jieluo): Fix python to raise correct errors.
if api_implementation.Type() == 'cpp':
- self.assertRaises(TypeError, self.pool.FindMethodByName, 0)
- self.assertRaises(KeyError, self.pool.FindMethodByName, '')
error_type = TypeError
else:
error_type = AttributeError
@@ -231,6 +255,7 @@ class DescriptorPoolTestBase(object):
self.assertRaises(error_type, self.pool.FindEnumTypeByName, 0)
self.assertRaises(error_type, self.pool.FindOneofByName, 0)
self.assertRaises(error_type, self.pool.FindServiceByName, 0)
+ self.assertRaises(error_type, self.pool.FindMethodByName, 0)
self.assertRaises(error_type, self.pool.FindFileContainingSymbol, 0)
if api_implementation.Type() == 'python':
error_type = KeyError
@@ -275,11 +300,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindEnumTypeByName('Does not exist')
def testFindFieldByName(self):
- if isinstance(self, SecondaryDescriptorFromDescriptorDB):
- if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix cpp extension to find field correctly
- # when descriptor pool is using an underlying database.
- return
field = self.pool.FindFieldByName(
'google.protobuf.python.internal.Factory1Message.list_value')
self.assertEqual(field.name, 'list_value')
@@ -290,11 +310,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindFieldByName('Does not exist')
def testFindOneofByName(self):
- if isinstance(self, SecondaryDescriptorFromDescriptorDB):
- if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix cpp extension to find oneof correctly
- # when descriptor pool is using an underlying database.
- return
oneof = self.pool.FindOneofByName(
'google.protobuf.python.internal.Factory2Message.oneof_field')
self.assertEqual(oneof.name, 'oneof_field')
@@ -302,11 +317,6 @@ class DescriptorPoolTestBase(object):
self.pool.FindOneofByName('Does not exist')
def testFindExtensionByName(self):
- if isinstance(self, SecondaryDescriptorFromDescriptorDB):
- if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix cpp extension to find extension correctly
- # when descriptor pool is using an underlying database.
- return
# An extension defined in a message.
extension = self.pool.FindExtensionByName(
'google.protobuf.python.internal.Factory2Message.one_more_field')
@@ -382,6 +392,11 @@ class DescriptorPoolTestBase(object):
with self.assertRaises(KeyError):
self.pool.FindServiceByName('Does not exist')
+ method = self.pool.FindMethodByName('protobuf_unittest.TestService.Foo')
+ self.assertIs(method.containing_service, service)
+ with self.assertRaises(KeyError):
+ self.pool.FindMethodByName('protobuf_unittest.TestService.Doesnotexist')
+
def testUserDefinedDB(self):
db = descriptor_database.DescriptorDatabase()
self.pool = descriptor_pool.DescriptorPool(db)
@@ -601,6 +616,8 @@ class CreateDescriptorPoolTest(DescriptorPoolTestBase, unittest.TestCase):
unittest_import_pb2.DESCRIPTOR.serialized_pb))
self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb))
+ self.pool.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ no_package_pb2.DESCRIPTOR.serialized_pb))
class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
@@ -620,6 +637,8 @@ class SecondaryDescriptorFromDescriptorDB(DescriptorPoolTestBase,
unittest_import_pb2.DESCRIPTOR.serialized_pb))
db.Add(descriptor_pb2.FileDescriptorProto.FromString(
unittest_pb2.DESCRIPTOR.serialized_pb))
+ db.Add(descriptor_pb2.FileDescriptorProto.FromString(
+ no_package_pb2.DESCRIPTOR.serialized_pb))
self.pool = descriptor_pool.DescriptorPool(descriptor_db=db)
@@ -746,11 +765,7 @@ class MessageField(object):
test.assertEqual(msg_desc, field_desc.containing_type)
test.assertEqual(field_type_desc, field_desc.message_type)
test.assertEqual(file_desc, field_desc.file)
- # TODO(jieluo): Fix python and cpp extension diff for message field
- # default value.
- if api_implementation.Type() == 'cpp':
- test.assertRaises(
- NotImplementedError, getattr, field_desc, 'default_value')
+ test.assertEqual(field_desc.default_value, None)
class StringField(object):
diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py
index 02a43d15..af6bece1 100755
--- a/python/google/protobuf/internal/descriptor_test.py
+++ b/python/google/protobuf/internal/descriptor_test.py
@@ -452,6 +452,17 @@ class DescriptorTest(unittest.TestCase):
self.assertEqual('attribute is not writable: has_options',
str(e.exception))
+ def testDefault(self):
+ message_descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR
+ field = message_descriptor.fields_by_name['repeated_int32']
+ self.assertEqual(field.default_value, [])
+ field = message_descriptor.fields_by_name['repeated_nested_message']
+ self.assertEqual(field.default_value, [])
+ field = message_descriptor.fields_by_name['optionalgroup']
+ self.assertEqual(field.default_value, None)
+ field = message_descriptor.fields_by_name['optional_nested_message']
+ self.assertEqual(field.default_value, None)
+
class NewDescriptorTest(DescriptorTest):
"""Redo the same tests as above, but with a separate DescriptorPool."""
diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto
index d2fbbeec..f5bd0383 100644
--- a/python/google/protobuf/internal/factory_test1.proto
+++ b/python/google/protobuf/internal/factory_test1.proto
@@ -56,3 +56,17 @@ message Factory1Message {
extensions 1000 to max;
}
+
+message Factory1MethodRequest {
+ optional string argument = 1;
+}
+
+message Factory1MethodResponse {
+ optional string result = 1;
+}
+
+service Factory1Service {
+ // Dummy method for this dummy service.
+ rpc Factory1Method(Factory1MethodRequest) returns (Factory1MethodResponse) {
+ }
+}
diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py
index 6df52ed2..b97e3f65 100644
--- a/python/google/protobuf/internal/message_factory_test.py
+++ b/python/google/protobuf/internal/message_factory_test.py
@@ -142,10 +142,8 @@ class MessageFactoryTest(unittest.TestCase):
self.assertEqual('test2', msg1.Extensions[ext2])
self.assertEqual(None,
msg1.Extensions._FindExtensionByNumber(12321))
+ self.assertRaises(TypeError, len, msg1.Extensions)
if api_implementation.Type() == 'cpp':
- # TODO(jieluo): Fix len to return the correct value.
- # self.assertEqual(2, len(msg1.Extensions))
- self.assertEqual(len(msg1.Extensions), len(msg1.Extensions))
self.assertRaises(TypeError,
msg1.Extensions._FindExtensionByName, 0)
self.assertRaises(TypeError,
diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py
index 61a56a67..4dd1104a 100755
--- a/python/google/protobuf/internal/message_test.py
+++ b/python/google/protobuf/internal/message_test.py
@@ -1,4 +1,5 @@
#! /usr/bin/env python
+# -*- coding: utf-8 -*-
#
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
@@ -49,6 +50,7 @@ import copy
import math
import operator
import pickle
+import pydoc
import six
import sys
import warnings
@@ -72,12 +74,14 @@ from google.protobuf import message_factory
from google.protobuf import text_format
from google.protobuf.internal import api_implementation
from google.protobuf.internal import encoder
+from google.protobuf.internal import more_extensions_pb2
from google.protobuf.internal import packed_field_test_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf import message
from google.protobuf.internal import _parameterized
+UCS2_MAXUNICODE = 65535
if six.PY3:
long = int
@@ -415,6 +419,37 @@ class MessageTest(BaseTestCase):
empty.ParseFromString(populated.SerializeToString())
self.assertEqual(str(empty), '')
+ def testMergeFromRepeatedField(self, message_module):
+ msg = message_module.TestAllTypes()
+ msg.repeated_int32.append(1)
+ msg.repeated_int32.append(3)
+ msg.repeated_nested_message.add(bb=1)
+ msg.repeated_nested_message.add(bb=2)
+ other_msg = message_module.TestAllTypes()
+ other_msg.repeated_nested_message.add(bb=3)
+ other_msg.repeated_nested_message.add(bb=4)
+ other_msg.repeated_int32.append(5)
+ other_msg.repeated_int32.append(7)
+
+ msg.repeated_int32.MergeFrom(other_msg.repeated_int32)
+ self.assertEqual(4, len(msg.repeated_int32))
+
+ msg.repeated_nested_message.MergeFrom(other_msg.repeated_nested_message)
+ self.assertEqual([1, 2, 3, 4],
+ [m.bb for m in msg.repeated_nested_message])
+
+ def testAddWrongRepeatedNestedField(self, message_module):
+ msg = message_module.TestAllTypes()
+ try:
+ msg.repeated_nested_message.add('wrong')
+ except TypeError:
+ pass
+ try:
+ msg.repeated_nested_message.add(value_field='wrong')
+ except ValueError:
+ pass
+ self.assertEqual(len(msg.repeated_nested_message), 0)
+
def testRepeatedNestedFieldIteration(self, message_module):
msg = message_module.TestAllTypes()
msg.repeated_nested_message.add(bb=1)
@@ -645,6 +680,82 @@ class MessageTest(BaseTestCase):
m.payload.repeated_int32.extend([])
self.assertTrue(m.HasField('payload'))
+ def testMergeFrom(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, m1.optional_nested_message.bb)
+ m2.optional_nested_message.bb = 1
+ # Make sure cmessage pointing to a mutable message after merge instead of
+ # the lazily created message.
+ m1.MergeFrom(m2)
+ self.assertEqual(1, m1.optional_nested_message.bb)
+
+ # Test more nested sub message.
+ msg1 = message_module.NestedTestAllTypes()
+ msg2 = message_module.NestedTestAllTypes()
+ self.assertEqual(0, msg1.child.payload.optional_nested_message.bb)
+ msg2.child.payload.optional_nested_message.bb = 1
+ msg1.MergeFrom(msg2)
+ self.assertEqual(1, msg1.child.payload.optional_nested_message.bb)
+
+ # Test repeated field.
+ self.assertEqual(msg1.payload.repeated_nested_message,
+ msg1.payload.repeated_nested_message)
+ msg2.payload.repeated_nested_message.add().bb = 1
+ msg1.MergeFrom(msg2)
+ self.assertEqual(1, len(msg1.payload.repeated_nested_message))
+ self.assertEqual(1, msg1.payload.repeated_nested_message[0].bb)
+
+ def testMergeFromString(self, message_module):
+ m1 = message_module.TestAllTypes()
+ m2 = message_module.TestAllTypes()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, m1.optional_nested_message.bb)
+ m2.optional_nested_message.bb = 1
+ # Make sure cmessage pointing to a mutable message after merge instead of
+ # the lazily created message.
+ m1.MergeFromString(m2.SerializeToString())
+ self.assertEqual(1, m1.optional_nested_message.bb)
+
+ @unittest.skipIf(six.PY2, 'memoryview objects are not supported on py2')
+ def testMergeFromStringUsingMemoryViewWorksInPy3(self, message_module):
+ m2 = message_module.TestAllTypes()
+ m2.optional_string = 'scalar string'
+ m2.repeated_string.append('repeated string')
+ m2.optional_bytes = b'scalar bytes'
+ m2.repeated_bytes.append(b'repeated bytes')
+
+ serialized = m2.SerializeToString()
+ memview = memoryview(serialized)
+ m1 = message_module.TestAllTypes.FromString(memview)
+
+ self.assertEqual(m1.optional_bytes, b'scalar bytes')
+ self.assertEqual(m1.repeated_bytes, [b'repeated bytes'])
+ self.assertEqual(m1.optional_string, 'scalar string')
+ self.assertEqual(m1.repeated_string, ['repeated string'])
+ # Make sure that the memoryview was correctly converted to bytes, and
+ # that a sub-sliced memoryview is not being used.
+ self.assertIsInstance(m1.optional_bytes, bytes)
+ self.assertIsInstance(m1.repeated_bytes[0], bytes)
+ self.assertIsInstance(m1.optional_string, six.text_type)
+ self.assertIsInstance(m1.repeated_string[0], six.text_type)
+
+ @unittest.skipIf(six.PY3, 'memoryview is supported by py3')
+ def testMergeFromStringUsingMemoryViewIsPy2Error(self, message_module):
+ memview = memoryview(b'')
+ with self.assertRaises(TypeError):
+ message_module.TestAllTypes.FromString(memview)
+
+ def testMergeFromEmpty(self, message_module):
+ m1 = message_module.TestAllTypes()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, m1.optional_nested_message.bb)
+ self.assertFalse(m1.HasField('optional_nested_message'))
+ # Make sure the sub message is still immutable after merge from empty.
+ m1.MergeFromString(b'') # field state should not change
+ self.assertFalse(m1.HasField('optional_nested_message'))
+
def ensureNestedMessageExists(self, msg, attribute):
"""Make sure that a nested message object exists.
@@ -1067,14 +1178,8 @@ class MessageTest(BaseTestCase):
with self.assertRaises(AttributeError):
m.repeated_int32 = []
m.repeated_int32.append(1)
- if api_implementation.Type() == 'cpp':
- # For test coverage: cpp has a different path if composite
- # field is in cache
- with self.assertRaises(TypeError):
- m.repeated_int32 = []
- else:
- with self.assertRaises(AttributeError):
- m.repeated_int32 = []
+ with self.assertRaises(AttributeError):
+ m.repeated_int32 = []
# Class to test proto2-only features (required, extensions, etc.)
@@ -1112,13 +1217,13 @@ class Proto2Test(BaseTestCase):
message.optional_bool = True
message.optional_nested_message.bb = 15
- self.assertTrue(message.HasField("optional_int32"))
+ self.assertTrue(message.HasField(u"optional_int32"))
self.assertTrue(message.HasField("optional_bool"))
self.assertTrue(message.HasField("optional_nested_message"))
# Clearing the fields unsets them and resets their value to default.
message.ClearField("optional_int32")
- message.ClearField("optional_bool")
+ message.ClearField(u"optional_bool")
message.ClearField("optional_nested_message")
self.assertFalse(message.HasField("optional_int32"))
@@ -1169,6 +1274,21 @@ class Proto2Test(BaseTestCase):
msg = unittest_pb2.TestAllTypes()
self.assertRaises(AttributeError, getattr, msg, 'Extensions')
+ def testMergeFromExtensions(self):
+ msg1 = more_extensions_pb2.TopLevelMessage()
+ msg2 = more_extensions_pb2.TopLevelMessage()
+ # Cpp extension will lazily create a sub message which is immutable.
+ self.assertEqual(0, msg1.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension])
+ self.assertFalse(msg1.HasField('submessage'))
+ msg2.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension] = 123
+ # Make sure cmessage and extensions pointing to a mutable message
+ # after merge instead of the lazily created message.
+ msg1.MergeFrom(msg2)
+ self.assertEqual(123, msg1.submessage.Extensions[
+ more_extensions_pb2.optional_int_extension])
+
def testGoldenExtensions(self):
golden_data = test_util.GoldenFileData('golden_message')
golden_message = unittest_pb2.TestAllExtensions()
@@ -1315,6 +1435,25 @@ class Proto2Test(BaseTestCase):
with self.assertRaises(ValueError):
unittest_pb2.TestAllTypes(repeated_nested_enum='FOO')
+ def testPythonicInitWithDict(self):
+ # Both string/unicode field name keys should work.
+ kwargs = {
+ 'optional_int32': 100,
+ u'optional_fixed32': 200,
+ }
+ msg = unittest_pb2.TestAllTypes(**kwargs)
+ self.assertEqual(100, msg.optional_int32)
+ self.assertEqual(200, msg.optional_fixed32)
+
+
+ def test_documentation(self):
+ # Also used by the interactive help() function.
+ doc = pydoc.html.document(unittest_pb2.TestAllTypes, 'message')
+ self.assertIn('class TestAllTypes', doc)
+ self.assertIn('SerializePartialToString', doc)
+ self.assertIn('repeated_float', doc)
+ base = unittest_pb2.TestAllTypes.__bases__[0]
+ self.assertRaises(AttributeError, getattr, base, '_extensions_by_name')
# Class to test proto3-only features/behavior (updated field presence & enums)
@@ -1539,10 +1678,8 @@ class Proto3Test(BaseTestCase):
self.assertEqual(True, msg2.map_bool_bool[True])
self.assertEqual(2, msg2.map_int32_enum[888])
self.assertEqual(456, msg2.map_int32_enum[123])
- # TODO(jieluo): Add cpp extension support.
- if api_implementation.Type() == 'python':
- self.assertEqual('{-123: -456}',
- str(msg2.map_int32_int32))
+ self.assertEqual('{-123: -456}',
+ str(msg2.map_int32_int32))
def testMapEntryAlwaysSerialized(self):
msg = map_unittest_pb2.TestMap()
@@ -1603,11 +1740,10 @@ class Proto3Test(BaseTestCase):
self.assertIn(123, msg2.map_int32_foreign_message)
self.assertIn(-456, msg2.map_int32_foreign_message)
self.assertEqual(2, len(msg2.map_int32_foreign_message))
+ msg2.map_int32_foreign_message[123].c = 1
# TODO(jieluo): Fix text format for message map.
- # TODO(jieluo): Add cpp extension support.
- if api_implementation.Type() == 'python':
- self.assertEqual(15,
- len(str(msg2.map_int32_foreign_message)))
+ self.assertIn(str(msg2.map_int32_foreign_message),
+ ('{-456: , 123: c: 1\n}', '{123: c: 1\n, -456: }'))
def testNestedMessageMapItemDelete(self):
msg = map_unittest_pb2.TestMap()
@@ -1721,6 +1857,15 @@ class Proto3Test(BaseTestCase):
self.assertEqual(10, msg2.map_int32_foreign_message[222].c)
self.assertFalse(msg2.map_int32_foreign_message[222].HasField('d'))
+ # Test when cpp extension cache a map.
+ m1 = map_unittest_pb2.TestMap()
+ m2 = map_unittest_pb2.TestMap()
+ self.assertEqual(m1.map_int32_foreign_message,
+ m1.map_int32_foreign_message)
+ m2.map_int32_foreign_message[123].c = 10
+ m1.MergeFrom(m2)
+ self.assertEqual(10, m2.map_int32_foreign_message[123].c)
+
def testMergeFromBadType(self):
msg = map_unittest_pb2.TestMap()
with self.assertRaisesRegexp(
@@ -1972,7 +2117,7 @@ class Proto3Test(BaseTestCase):
def testMapValidAfterFieldCleared(self):
# Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
- # ScalarMapContainer::Release()
+ # MapContainer::Release()
msg = map_unittest_pb2.TestMap()
int32_map = msg.map_int32_int32
@@ -1988,7 +2133,7 @@ class Proto3Test(BaseTestCase):
def testMessageMapValidAfterFieldCleared(self):
# Map needs to work even if field is cleared.
# For the C++ implementation this tests the correctness of
- # ScalarMapContainer::Release()
+ # MapContainer::Release()
msg = map_unittest_pb2.TestMap()
int32_foreign_message = msg.map_int32_foreign_message
@@ -1998,6 +2143,24 @@ class Proto3Test(BaseTestCase):
self.assertEqual(b'', msg.SerializeToString())
self.assertTrue(2 in int32_foreign_message.keys())
+ def testMessageMapItemValidAfterTopMessageCleared(self):
+ # Message map item needs to work even if it is cleared.
+ # For the C++ implementation this tests the correctness of
+ # MapContainer::Release()
+ msg = map_unittest_pb2.TestMap()
+ msg.map_int32_all_types[2].optional_string = 'bar'
+
+ if api_implementation.Type() == 'cpp':
+ # Need to keep the map reference because of b/27942626.
+ # TODO(jieluo): Remove it.
+ unused_map = msg.map_int32_all_types # pylint: disable=unused-variable
+ msg_value = msg.map_int32_all_types[2]
+ msg.Clear()
+
+ # Reset to trigger sync between repeated field and map in c++.
+ msg.map_int32_all_types[3].optional_string = 'foo'
+ self.assertEqual(msg_value.optional_string, 'bar')
+
def testMapIterInvalidatedByClearField(self):
# Map iterator is invalidated when field is cleared.
# But this case does need to not crash the interpreter.
@@ -2058,6 +2221,82 @@ class Proto3Test(BaseTestCase):
msg.map_string_foreign_message['foo'].c = 5
self.assertEqual(0, len(msg.FindInitializationErrors()))
+ @unittest.skipIf(sys.maxunicode == UCS2_MAXUNICODE, 'Skip for ucs2')
+ def testStrictUtf8Check(self):
+ # Test u'\ud801' is rejected at parser in both python2 and python3.
+ serialized = (b'r\x03\xed\xa0\x81')
+ msg = unittest_proto3_arena_pb2.TestAllTypes()
+ with self.assertRaises(Exception) as context:
+ msg.MergeFromString(serialized)
+ if api_implementation.Type() == 'python':
+ self.assertIn('optional_string', str(context.exception))
+ else:
+ self.assertIn('Error parsing message', str(context.exception))
+
+ # Test optional_string=u'😍' is accepted.
+ serialized = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'😍').SerializeToString()
+ msg2 = unittest_proto3_arena_pb2.TestAllTypes()
+ msg2.MergeFromString(serialized)
+ self.assertEqual(msg2.optional_string, u'😍')
+
+ msg = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud001')
+ self.assertEqual(msg.optional_string, u'\ud001')
+
+ @unittest.skipIf(six.PY2, 'Surrogates are acceptable in python2')
+ def testSurrogatesInPython3(self):
+ # Surrogates like U+D83D is an invalid unicode character, it is
+ # supported by Python2 only because in some builds, unicode strings
+ # use 2-bytes code units. Since Python 3.3, we don't have this problem.
+ #
+ # Surrogates are utf16 code units, in a unicode string they are invalid
+ # characters even when they appear in pairs like u'\ud801\udc01'. Protobuf
+ # Python3 reject such cases at setters and parsers. Python2 accpect it
+ # to keep same features with the language itself. 'Unpaired pairs'
+ # like u'\ud801' are rejected at parsers when strict utf8 check is enabled
+ # in proto3 to keep same behavior with c extension.
+
+ # Surrogates are rejected at setters in Python3.
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\udc01')
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=b'\xed\xa0\x81')
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801')
+ with self.assertRaises(ValueError):
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\ud801')
+
+ @unittest.skipIf(six.PY3 or sys.maxunicode == UCS2_MAXUNICODE,
+ 'Surrogates are rejected at setters in Python3')
+ def testSurrogatesInPython2(self):
+ # Test optional_string=u'\ud801\udc01'.
+ # surrogate pair is acceptable in python2.
+ msg = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\udc01')
+ # TODO(jieluo): Change pure python to have same behavior with c extension.
+ # Some build in python2 consider u'\ud801\udc01' and u'\U00010401' are
+ # equal, some are not equal.
+ if api_implementation.Type() == 'python':
+ self.assertEqual(msg.optional_string, u'\ud801\udc01')
+ else:
+ self.assertEqual(msg.optional_string, u'\U00010401')
+ serialized = msg.SerializeToString()
+ msg2 = unittest_proto3_arena_pb2.TestAllTypes()
+ msg2.MergeFromString(serialized)
+ self.assertEqual(msg2.optional_string, u'\U00010401')
+
+ # Python2 does not reject surrogates at setters.
+ msg = unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=b'\xed\xa0\x81')
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801')
+ unittest_proto3_arena_pb2.TestAllTypes(
+ optional_string=u'\ud801\ud801')
class ValidTypeNamesTest(BaseTestCase):
diff --git a/python/google/protobuf/internal/no_package.proto b/python/google/protobuf/internal/no_package.proto
index 3546dcc3..49eda959 100644
--- a/python/google/protobuf/internal/no_package.proto
+++ b/python/google/protobuf/internal/no_package.proto
@@ -1,3 +1,33 @@
+// Protocol Buffers - Google's data interchange format
+// Copyright 2008 Google Inc. All rights reserved.
+// https://developers.google.com/protocol-buffers/
+//
+// Redistribution and use in source and binary forms, with or without
+// modification, are permitted provided that the following conditions are
+// met:
+//
+// * Redistributions of source code must retain the above copyright
+// notice, this list of conditions and the following disclaimer.
+// * Redistributions in binary form must reproduce the above
+// copyright notice, this list of conditions and the following disclaimer
+// in the documentation and/or other materials provided with the
+// distribution.
+// * Neither the name of Google Inc. nor the names of its
+// contributors may be used to endorse or promote products derived from
+// this software without specific prior written permission.
+//
+// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
+// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
+// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
+// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
+// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
+// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
+// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
+// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
+// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+
syntax = "proto2";
enum NoPackageEnum {
diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py
index 975e3b4d..4e0f545c 100755
--- a/python/google/protobuf/internal/python_message.py
+++ b/python/google/protobuf/internal/python_message.py
@@ -56,6 +56,7 @@ import sys
import weakref
import six
+from six.moves import range
# We use "as" to avoid name collisions with variables.
from google.protobuf.internal import api_implementation
@@ -124,6 +125,21 @@ class GeneratedProtocolMessageType(type):
Newly-allocated class.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ # If a concrete class already exists for this descriptor, don't try to
+ # create another. Doing so will break any messages that already exist with
+ # the existing class.
+ #
+ # The C++ implementation appears to have its own internal `PyMessageFactory`
+ # to achieve similar results.
+ #
+ # This most commonly happens in `text_format.py` when using descriptors from
+ # a custom pool; it calls symbol_database.Global().getPrototype() on a
+ # descriptor which already has an existing concrete class.
+ new_class = getattr(descriptor, '_concrete_class', None)
+ if new_class:
+ return new_class
+
if descriptor.full_name in well_known_types.WKTBASES:
bases += (well_known_types.WKTBASES[descriptor.full_name],)
_AddClassAttributesForNestedExtensions(descriptor, dictionary)
@@ -151,6 +167,16 @@ class GeneratedProtocolMessageType(type):
type.
"""
descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY]
+
+ # If this is an _existing_ class looked up via `_concrete_class` in the
+ # __new__ method above, then we don't need to re-initialize anything.
+ existing_class = getattr(descriptor, '_concrete_class', None)
+ if existing_class:
+ assert existing_class is cls, (
+ 'Duplicate `GeneratedProtocolMessageType` created for descriptor %r'
+ % (descriptor.full_name))
+ return
+
cls._decoders_by_tag = {}
if (descriptor.has_options and
descriptor.GetOptions().message_set_wire_format):
@@ -245,6 +271,7 @@ def _AddSlots(message_descriptor, dictionary):
'_cached_byte_size_dirty',
'_fields',
'_unknown_fields',
+ '_unknown_field_set',
'_is_present_in_parent',
'_listener',
'_listener_for_children',
@@ -271,6 +298,13 @@ def _IsMessageMapField(field):
return value_type.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE
+def _IsStrictUtf8Check(field):
+ if field.containing_type.syntax != 'proto3':
+ return False
+ enforce_utf8 = True
+ return enforce_utf8
+
+
def _AttachFieldHelpers(cls, field_descriptor):
is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED)
is_packable = (is_repeated and
@@ -322,10 +356,16 @@ def _AttachFieldHelpers(cls, field_descriptor):
field_decoder = decoder.MapDecoder(
field_descriptor, _GetInitializeDefaultForMap(field_descriptor),
is_message_map)
+ elif decode_type == _FieldDescriptor.TYPE_STRING:
+ is_strict_utf8_check = _IsStrictUtf8Check(field_descriptor)
+ field_decoder = decoder.StringDecoder(
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor,
+ is_strict_utf8_check)
else:
field_decoder = type_checkers.TYPE_TO_DECODER[decode_type](
- field_descriptor.number, is_repeated, is_packed,
- field_descriptor, field_descriptor._default_constructor)
+ field_descriptor.number, is_repeated, is_packed,
+ field_descriptor, field_descriptor._default_constructor)
cls._decoders_by_tag[tag_bytes] = (field_decoder, oneof_descriptor)
@@ -422,6 +462,9 @@ def _DefaultValueConstructorForField(field):
# _concrete_class may not yet be initialized.
message_type = field.message_type
def MakeSubMessageDefault(message):
+ assert getattr(message_type, '_concrete_class', None), (
+ 'Uninitialized concrete class found for field %r (message type %r)'
+ % (field.full_name, message_type.full_name))
result = message_type._concrete_class()
result._SetListener(
_OneofListener(message, field)
@@ -477,6 +520,9 @@ def _AddInitMethod(message_descriptor, cls):
# _unknown_fields is () when empty for efficiency, and will be turned into
# a list if fields are added.
self._unknown_fields = ()
+ # _unknown_field_set is None when empty for efficiency, and will be
+ # turned into UnknownFieldSet struct if fields are added.
+ self._unknown_field_set = None # pylint: disable=protected-access
self._is_present_in_parent = False
self._listener = message_listener_mod.NullMessageListener()
self._listener_for_children = _Listener(self)
@@ -584,6 +630,14 @@ def _AddPropertiesForField(field, cls):
_AddPropertiesForNonRepeatedScalarField(field, cls)
+class _FieldProperty(property):
+ __slots__ = ('DESCRIPTOR',)
+
+ def __init__(self, descriptor, getter, setter, doc):
+ property.__init__(self, getter, setter, doc=doc)
+ self.DESCRIPTOR = descriptor
+
+
def _AddPropertiesForRepeatedField(field, cls):
"""Adds a public property for a "repeated" protocol message field. Clients
can use this property to get the value of the field, which will be either a
@@ -625,7 +679,7 @@ def _AddPropertiesForRepeatedField(field, cls):
'"%s" in protocol message object.' % proto_field_name)
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedScalarField(field, cls):
@@ -681,7 +735,7 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls):
# Add a property to encapsulate the getter/setter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForNonRepeatedCompositeField(field, cls):
@@ -725,7 +779,7 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls):
# Add a property to encapsulate the getter.
doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name
- setattr(cls, property_name, property(getter, setter, doc=doc))
+ setattr(cls, property_name, _FieldProperty(field, getter, setter, doc=doc))
def _AddPropertiesForExtensions(descriptor, cls):
@@ -949,12 +1003,12 @@ def _AddEqualsMethod(message_descriptor, cls):
if not self.ListFields() == other.ListFields():
return False
- # Sort unknown fields because their order shouldn't affect equality test.
+ # TODO(jieluo): Fix UnknownFieldSet to consider MessageSet extensions,
+ # then use it for the comparison.
unknown_fields = list(self._unknown_fields)
unknown_fields.sort()
other_unknown_fields = list(other._unknown_fields)
other_unknown_fields.sort()
-
return unknown_fields == other_unknown_fields
cls.__eq__ = __eq__
@@ -1078,6 +1132,13 @@ def _AddSerializePartialToStringMethod(message_descriptor, cls):
def _AddMergeFromStringMethod(message_descriptor, cls):
"""Helper for _AddMessageMethods()."""
def MergeFromString(self, serialized):
+ if isinstance(serialized, memoryview) and six.PY2:
+ raise TypeError(
+ 'memoryview not supported in Python 2 with the pure Python proto '
+ 'implementation: this is to maintain compatibility with the C++ '
+ 'implementation')
+
+ serialized = memoryview(serialized)
length = len(serialized)
try:
if self._InternalParse(serialized, 0, length) != length:
@@ -1095,26 +1156,54 @@ def _AddMergeFromStringMethod(message_descriptor, cls):
local_ReadTag = decoder.ReadTag
local_SkipField = decoder.SkipField
decoders_by_tag = cls._decoders_by_tag
- is_proto3 = message_descriptor.syntax == "proto3"
def InternalParse(self, buffer, pos, end):
+ """Create a message from serialized bytes.
+
+ Args:
+ self: Message, instance of the proto message object.
+ buffer: memoryview of the serialized data.
+ pos: int, position to start in the serialized data.
+ end: int, end position of the serialized data.
+
+ Returns:
+ Message object.
+ """
+ # Guard against internal misuse, since this function is called internally
+ # quite extensively, and its easy to accidentally pass bytes.
+ assert isinstance(buffer, memoryview)
self._Modified()
field_dict = self._fields
- unknown_field_list = self._unknown_fields
+ # pylint: disable=protected-access
+ unknown_field_set = self._unknown_field_set
while pos != end:
(tag_bytes, new_pos) = local_ReadTag(buffer, pos)
field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None))
if field_decoder is None:
- value_start_pos = new_pos
- new_pos = local_SkipField(buffer, new_pos, end, tag_bytes)
+ if not self._unknown_fields: # pylint: disable=protected-access
+ self._unknown_fields = [] # pylint: disable=protected-access
+ if unknown_field_set is None:
+ # pylint: disable=protected-access
+ self._unknown_field_set = containers.UnknownFieldSet()
+ # pylint: disable=protected-access
+ unknown_field_set = self._unknown_field_set
+ # pylint: disable=protected-access
+ (tag, _) = decoder._DecodeVarint(tag_bytes, 0)
+ field_number, wire_type = wire_format.UnpackTag(tag)
+ # TODO(jieluo): remove old_pos.
+ old_pos = new_pos
+ (data, new_pos) = decoder._DecodeUnknownField(
+ buffer, new_pos, wire_type) # pylint: disable=protected-access
if new_pos == -1:
return pos
- if (not is_proto3 or
- api_implementation.GetPythonProto3PreserveUnknownsDefault()):
- if not unknown_field_list:
- unknown_field_list = self._unknown_fields = []
- unknown_field_list.append(
- (tag_bytes, buffer[value_start_pos:new_pos]))
+ # pylint: disable=protected-access
+ unknown_field_set._add(field_number, wire_type, data)
+ # TODO(jieluo): remove _unknown_fields.
+ new_pos = local_SkipField(buffer, old_pos, end, tag_bytes)
+ if new_pos == -1:
+ return pos
+ self._unknown_fields.append(
+ (tag_bytes, buffer[old_pos:new_pos].tobytes()))
pos = new_pos
else:
pos = field_decoder(buffer, new_pos, end, self, field_dict)
@@ -1259,6 +1348,10 @@ def _AddMergeFromMethod(cls):
if not self._unknown_fields:
self._unknown_fields = []
self._unknown_fields.extend(msg._unknown_fields)
+ # pylint: disable=protected-access
+ if self._unknown_field_set is None:
+ self._unknown_field_set = containers.UnknownFieldSet()
+ self._unknown_field_set._extend(msg._unknown_field_set)
cls.MergeFrom = MergeFrom
@@ -1291,12 +1384,25 @@ def _Clear(self):
# Clear fields.
self._fields = {}
self._unknown_fields = ()
+ # pylint: disable=protected-access
+ if self._unknown_field_set is not None:
+ self._unknown_field_set._clear()
+ self._unknown_field_set = None
+
self._oneofs = {}
self._Modified()
+def _UnknownFields(self):
+ if self._unknown_field_set is None: # pylint: disable=protected-access
+ # pylint: disable=protected-access
+ self._unknown_field_set = containers.UnknownFieldSet()
+ return self._unknown_field_set # pylint: disable=protected-access
+
+
def _DiscardUnknownFields(self):
self._unknown_fields = []
+ self._unknown_field_set = None # pylint: disable=protected-access
for field, value in self.ListFields():
if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
if field.label == _FieldDescriptor.LABEL_REPEATED:
@@ -1335,6 +1441,7 @@ def _AddMessageMethods(message_descriptor, cls):
_AddReduceMethod(cls)
# Adds methods which do not depend on cls.
cls.Clear = _Clear
+ cls.UnknownFields = _UnknownFields
cls.DiscardUnknownFields = _DiscardUnknownFields
cls._SetListener = _SetListener
@@ -1471,6 +1578,10 @@ class _ExtensionDict(object):
if extension_handle.label == _FieldDescriptor.LABEL_REPEATED:
result = extension_handle._default_constructor(self._extended_message)
elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
+ assert getattr(extension_handle.message_type, '_concrete_class', None), (
+ 'Uninitialized concrete class found for field %r (message type %r)'
+ % (extension_handle.full_name,
+ extension_handle.message_type.full_name))
result = extension_handle.message_type._concrete_class()
try:
result._SetListener(self._extended_message._listener_for_children)
diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py
index 0306ff46..90d2fe3c 100755
--- a/python/google/protobuf/internal/reflection_test.py
+++ b/python/google/protobuf/internal/reflection_test.py
@@ -64,6 +64,10 @@ from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import decoder
+if six.PY3:
+ long = int # pylint: disable=redefined-builtin,invalid-name
+
+
BaseTestCase = testing_refleaks.BaseTestCase
@@ -647,10 +651,7 @@ class ReflectionTest(BaseTestCase):
TestGetAndDeserialize('optional_int32', 1, int)
TestGetAndDeserialize('optional_int32', 1 << 30, int)
TestGetAndDeserialize('optional_uint32', 1 << 30, int)
- try:
- integer_64 = long
- except NameError: # Python3
- integer_64 = int
+ integer_64 = long
if struct.calcsize('L') == 4:
# Python only has signed ints, so 32-bit python can't fit an uint32
# in an int.
@@ -1103,6 +1104,7 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(23, myproto_instance.foo_field)
self.assertTrue(myproto_instance.HasField('foo_field'))
+ @testing_refleaks.SkipReferenceLeakChecker('MakeDescriptor is not repeatable')
def testDescriptorProtoSupport(self):
# Hand written descriptors/reflection are only supported by the pure-Python
# implementation of the API.
@@ -1141,7 +1143,8 @@ class ReflectionTest(BaseTestCase):
self.assertTrue('price' in desc.fields_by_name)
self.assertTrue('owners' in desc.fields_by_name)
- class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType, message.Message)):
+ class CarMessage(six.with_metaclass(reflection.GeneratedProtocolMessageType,
+ message.Message)):
DESCRIPTOR = desc
prius = CarMessage()
@@ -1576,6 +1579,8 @@ class ReflectionTest(BaseTestCase):
proto1.repeated_int32.append(3)
container = copy.deepcopy(proto1.repeated_int32)
self.assertEqual([2, 3], container)
+ container.remove(container[0])
+ self.assertEqual([3], container)
message1 = proto1.repeated_nested_message.add()
message1.bb = 1
@@ -1583,6 +1588,8 @@ class ReflectionTest(BaseTestCase):
self.assertEqual(proto1.repeated_nested_message, messages)
message1.bb = 2
self.assertNotEqual(proto1.repeated_nested_message, messages)
+ messages.remove(messages[0])
+ self.assertEqual(len(messages), 0)
# TODO(anuraag): Implement deepcopy for extension dict
@@ -2435,7 +2442,7 @@ class SerializationTest(BaseTestCase):
first_proto = unittest_pb2.TestAllTypes()
test_util.SetAllFields(first_proto)
- serialized = first_proto.SerializeToString()
+ serialized = memoryview(first_proto.SerializeToString())
for truncation_point in range(len(serialized) + 1):
try:
@@ -2857,6 +2864,38 @@ class SerializationTest(BaseTestCase):
self.assertEqual(unittest_pb2.REPEATED_NESTED_ENUM_EXTENSION_FIELD_NUMBER,
51)
+ def testFieldProperties(self):
+ cls = unittest_pb2.TestAllTypes
+ self.assertIs(cls.optional_int32.DESCRIPTOR,
+ cls.DESCRIPTOR.fields_by_name['optional_int32'])
+ self.assertEqual(cls.OPTIONAL_INT32_FIELD_NUMBER,
+ cls.optional_int32.DESCRIPTOR.number)
+ self.assertIs(cls.optional_nested_message.DESCRIPTOR,
+ cls.DESCRIPTOR.fields_by_name['optional_nested_message'])
+ self.assertEqual(cls.OPTIONAL_NESTED_MESSAGE_FIELD_NUMBER,
+ cls.optional_nested_message.DESCRIPTOR.number)
+ self.assertIs(cls.repeated_int32.DESCRIPTOR,
+ cls.DESCRIPTOR.fields_by_name['repeated_int32'])
+ self.assertEqual(cls.REPEATED_INT32_FIELD_NUMBER,
+ cls.repeated_int32.DESCRIPTOR.number)
+
+ def testFieldDataDescriptor(self):
+ msg = unittest_pb2.TestAllTypes()
+ msg.optional_int32 = 42
+ self.assertEqual(unittest_pb2.TestAllTypes.optional_int32.__get__(msg), 42)
+ unittest_pb2.TestAllTypes.optional_int32.__set__(msg, 25)
+ self.assertEqual(msg.optional_int32, 25)
+ with self.assertRaises(AttributeError):
+ del msg.optional_int32
+ try:
+ unittest_pb2.ForeignMessage.c.__get__(msg)
+ except TypeError:
+ pass # The cpp implementation cannot mix fields from other messages.
+ # This test exercises a specific check that avoids a crash.
+ else:
+ pass # The python implementation allows fields from other messages.
+ # This is useless, but works.
+
def testInitKwargs(self):
proto = unittest_pb2.TestAllTypes(
optional_int32=1,
@@ -2963,6 +3002,7 @@ class ClassAPITest(BaseTestCase):
@unittest.skipIf(
api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
'C++ implementation requires a call to MakeDescriptor()')
+ @testing_refleaks.SkipReferenceLeakChecker('MakeClass is not repeatable')
def testMakeClassWithNestedDescriptor(self):
leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '',
containing_type=None, fields=[],
@@ -2980,10 +3020,7 @@ class ClassAPITest(BaseTestCase):
containing_type=None, fields=[],
nested_types=[child_desc, sibling_desc],
enum_types=[], extensions=[])
- message_class = reflection.MakeClass(parent_desc)
- self.assertIn('child', message_class.__dict__)
- self.assertIn('sibling', message_class.__dict__)
- self.assertIn('leaf', message_class.child.__dict__)
+ reflection.MakeClass(parent_desc)
def _GetSerializedFileDescriptor(self, name):
"""Get a serialized representation of a test FileDescriptorProto.
diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py
index 237a2d50..ccf8ac16 100755
--- a/python/google/protobuf/internal/text_format_test.py
+++ b/python/google/protobuf/internal/text_format_test.py
@@ -33,20 +33,19 @@
"""Test for google.protobuf.text_format."""
-__author__ = 'kenton@google.com (Kenton Varda)'
-
-
+import io
import math
import re
-import six
import string
+import textwrap
+
+import six
+# pylint: disable=g-import-not-at-top
try:
- import unittest2 as unittest # PY26, pylint: disable=g-import-not-at-top
+ import unittest2 as unittest # PY26
except ImportError:
- import unittest # pylint: disable=g-import-not-at-top
-
-from google.protobuf.internal import _parameterized
+ import unittest
from google.protobuf import any_pb2
from google.protobuf import any_test_pb2
@@ -54,12 +53,13 @@ from google.protobuf import map_unittest_pb2
from google.protobuf import unittest_mset_pb2
from google.protobuf import unittest_pb2
from google.protobuf import unittest_proto3_arena_pb2
-from google.protobuf.internal import api_implementation
from google.protobuf.internal import any_test_pb2 as test_extend_any
from google.protobuf.internal import message_set_extensions_pb2
from google.protobuf.internal import test_util
from google.protobuf import descriptor_pool
from google.protobuf import text_format
+from google.protobuf.internal import _parameterized
+# pylint: enable=g-import-not-at-top
# Low-level nuts-n-bolts tests.
@@ -100,8 +100,8 @@ class TextFormatBase(unittest.TestCase):
return text
-@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
-class TextFormatTest(TextFormatBase):
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatMessageToStringTests(TextFormatBase):
def testPrintExotic(self, message_module):
message = message_module.TestAllTypes()
@@ -154,6 +154,40 @@ class TextFormatTest(TextFormatBase):
'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 '
'repeated_string: "Google" repeated_string: "Zurich"')
+ def VerifyPrintShortFormatRepeatedFields(self, message_module, as_one_line):
+ message = message_module.TestAllTypes()
+ message.repeated_int32.append(1)
+ message.repeated_string.append('Google')
+ message.repeated_string.append('Hello,World')
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_FOO)
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR)
+ message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ)
+ message.optional_nested_message.bb = 3
+ for i in (21, 32):
+ msg = message.repeated_nested_message.add()
+ msg.bb = i
+ expected_ascii = (
+ 'optional_nested_message {\n bb: 3\n}\n'
+ 'repeated_int32: [1]\n'
+ 'repeated_string: "Google"\n'
+ 'repeated_string: "Hello,World"\n'
+ 'repeated_nested_message {\n bb: 21\n}\n'
+ 'repeated_nested_message {\n bb: 32\n}\n'
+ 'repeated_foreign_enum: [FOREIGN_FOO, FOREIGN_BAR, FOREIGN_BAZ]\n')
+ if as_one_line:
+ expected_ascii = expected_ascii.replace('\n ', '').replace('\n', '')
+ actual_ascii = text_format.MessageToString(
+ message, use_short_repeated_primitives=True,
+ as_one_line=as_one_line)
+ self.CompareToGoldenText(actual_ascii, expected_ascii)
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(actual_ascii, parsed_message)
+ self.assertEqual(parsed_message, message)
+
+ def tesPrintShortFormatRepeatedFields(self, message_module, as_one_line):
+ self.VerifyPrintShortFormatRepeatedFields(message_module, False)
+ self.VerifyPrintShortFormatRepeatedFields(message_module, True)
+
def testPrintNestedNewLineInStringAsOneLine(self, message_module):
message = message_module.TestAllTypes()
message.optional_string = 'a\nnew\nline'
@@ -213,13 +247,18 @@ class TextFormatTest(TextFormatBase):
def testPrintRawUtf8String(self, message_module):
message = message_module.TestAllTypes()
- message.repeated_string.append(u'\u00fc\ua71f')
+ message.repeated_string.append(u'\u00fc\t\ua71f')
text = text_format.MessageToString(message, as_utf8=True)
- self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n')
+ golden_unicode = u'repeated_string: "\u00fc\\t\ua71f"\n'
+ golden_text = golden_unicode if six.PY3 else golden_unicode.encode('utf-8')
+ # MessageToString always returns a native str.
+ self.CompareToGoldenText(text, golden_text)
parsed_message = message_module.TestAllTypes()
text_format.Parse(text, parsed_message)
- self.assertEqual(message, parsed_message,
- '\n%s != %s' % (message, parsed_message))
+ self.assertEqual(
+ message, parsed_message, '\n%s != %s (%s != %s)' %
+ (message, parsed_message, message.repeated_string[0],
+ parsed_message.repeated_string[0]))
def testPrintFloatFormat(self, message_module):
# Check that float_format argument is passed to sub-message formatting.
@@ -259,6 +298,36 @@ class TextFormatTest(TextFormatBase):
message.c = 123
self.assertEqual('c: 123\n', str(message))
+ def testMessageToStringUnicode(self, message_module):
+ golden_unicode = u'Á short desçription and a 🍌.'
+ golden_bytes = golden_unicode.encode('utf-8')
+ message = message_module.TestAllTypes()
+ message.optional_string = golden_unicode
+ message.optional_bytes = golden_bytes
+ text = text_format.MessageToString(message, as_utf8=True)
+ golden_message = textwrap.dedent(
+ 'optional_string: "Á short desçription and a 🍌."\n'
+ 'optional_bytes: '
+ r'"\303\201 short des\303\247ription and a \360\237\215\214."'
+ '\n')
+ self.CompareToGoldenText(text, golden_message)
+
+ def testMessageToStringASCII(self, message_module):
+ golden_unicode = u'Á short desçription and a 🍌.'
+ golden_bytes = golden_unicode.encode('utf-8')
+ message = message_module.TestAllTypes()
+ message.optional_string = golden_unicode
+ message.optional_bytes = golden_bytes
+ text = text_format.MessageToString(message, as_utf8=False) # ASCII
+ golden_message = (
+ 'optional_string: '
+ r'"\303\201 short des\303\247ription and a \360\237\215\214."'
+ '\n'
+ 'optional_bytes: '
+ r'"\303\201 short des\303\247ription and a \360\237\215\214."'
+ '\n')
+ self.CompareToGoldenText(text, golden_message)
+
def testPrintField(self, message_module):
message = message_module.TestAllTypes()
field = message.DESCRIPTOR.fields_by_name['optional_float']
@@ -289,6 +358,45 @@ class TextFormatTest(TextFormatBase):
self.assertEqual('0.0', out.getvalue())
out.close()
+
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatMessageToTextBytesTests(TextFormatBase):
+
+ def testMessageToBytes(self, message_module):
+ message = message_module.ForeignMessage()
+ message.c = 123
+ self.assertEqual(b'c: 123\n', text_format.MessageToBytes(message))
+
+ def testRawUtf8RoundTrip(self, message_module):
+ message = message_module.TestAllTypes()
+ message.repeated_string.append(u'\u00fc\t\ua71f')
+ utf8_text = text_format.MessageToBytes(message, as_utf8=True)
+ golden_bytes = b'repeated_string: "\xc3\xbc\\t\xea\x9c\x9f"\n'
+ self.CompareToGoldenText(utf8_text, golden_bytes)
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(utf8_text, parsed_message)
+ self.assertEqual(
+ message, parsed_message, '\n%s != %s (%s != %s)' %
+ (message, parsed_message, message.repeated_string[0],
+ parsed_message.repeated_string[0]))
+
+ def testEscapedUtf8ASCIIRoundTrip(self, message_module):
+ message = message_module.TestAllTypes()
+ message.repeated_string.append(u'\u00fc\t\ua71f')
+ ascii_text = text_format.MessageToBytes(message) # as_utf8=False default
+ golden_bytes = b'repeated_string: "\\303\\274\\t\\352\\234\\237"\n'
+ self.CompareToGoldenText(ascii_text, golden_bytes)
+ parsed_message = message_module.TestAllTypes()
+ text_format.Parse(ascii_text, parsed_message)
+ self.assertEqual(
+ message, parsed_message, '\n%s != %s (%s != %s)' %
+ (message, parsed_message, message.repeated_string[0],
+ parsed_message.repeated_string[0]))
+
+
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatParserTests(TextFormatBase):
+
def testParseAllFields(self, message_module):
message = message_module.TestAllTypes()
test_util.SetAllFields(message)
@@ -318,14 +426,14 @@ class TextFormatTest(TextFormatBase):
if message_module is unittest_pb2:
test_util.ExpectAllFieldsSet(self, message)
- if six.PY2:
- msg2 = message_module.TestAllTypes()
- text = (u'optional_string: "café"')
- text_format.Merge(text, msg2)
- self.assertEqual(msg2.optional_string, u'café')
- msg2.Clear()
- text_format.Parse(text, msg2)
- self.assertEqual(msg2.optional_string, u'café')
+ msg2 = message_module.TestAllTypes()
+ text = (u'optional_string: "café"')
+ text_format.Merge(text, msg2)
+ self.assertEqual(msg2.optional_string, u'café')
+ msg2.Clear()
+ self.assertEqual(msg2.optional_string, u'')
+ text_format.Parse(text, msg2)
+ self.assertEqual(msg2.optional_string, u'café')
def testParseExotic(self, message_module):
message = message_module.TestAllTypes()
@@ -425,7 +533,8 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_nested_enum: BARR'
six.assertRaisesRegex(self, text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ (r'1:23 : \'optional_nested_enum: BARR\': '
+ r'Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value named BARR.'), text_format.Parse,
text, message)
@@ -433,7 +542,8 @@ class TextFormatTest(TextFormatBase):
message = message_module.TestAllTypes()
text = 'optional_int32: bork'
six.assertRaisesRegex(self, text_format.ParseError,
- ('1:17 : Couldn\'t parse integer: bork'),
+ ('1:17 : \'optional_int32: bork\': '
+ 'Couldn\'t parse integer: bork'),
text_format.Parse, text, message)
def testParseStringFieldUnescape(self, message_module):
@@ -457,6 +567,96 @@ class TextFormatTest(TextFormatBase):
message.repeated_string[4])
self.assertEqual(SLASH + 'x20', message.repeated_string[5])
+ def testParseOneof(self, message_module):
+ m = message_module.TestAllTypes()
+ m.oneof_uint32 = 11
+ m2 = message_module.TestAllTypes()
+ text_format.Parse(text_format.MessageToString(m), m2)
+ self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
+
+ def testParseMultipleOneof(self, message_module):
+ m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
+ m2 = message_module.TestAllTypes()
+ with six.assertRaisesRegex(self, text_format.ParseError,
+ ' is specified along with field '):
+ text_format.Parse(m_string, m2)
+
+ # This example contains non-ASCII codepoint unicode data as literals
+ # which should come through as utf-8 for bytes, and as the unicode
+ # itself for string fields. It also demonstrates escaped binary data.
+ # The ur"" string prefix is unfortunately missing from Python 3
+ # so we resort to double escaping our \s so that they come through.
+ _UNICODE_SAMPLE = u"""
+ optional_bytes: 'Á short desçription'
+ optional_string: 'Á short desçription'
+ repeated_bytes: '\\303\\201 short des\\303\\247ription'
+ repeated_bytes: '\\x12\\x34\\x56\\x78\\x90\\xab\\xcd\\xef'
+ repeated_string: '\\xd0\\x9f\\xd1\\x80\\xd0\\xb8\\xd0\\xb2\\xd0\\xb5\\xd1\\x82'
+ """
+ _BYTES_SAMPLE = _UNICODE_SAMPLE.encode('utf-8')
+ _GOLDEN_UNICODE = u'Á short desçription'
+ _GOLDEN_BYTES = _GOLDEN_UNICODE.encode('utf-8')
+ _GOLDEN_BYTES_1 = b'\x12\x34\x56\x78\x90\xab\xcd\xef'
+ _GOLDEN_STR_0 = u'Привет'
+
+ def testParseUnicode(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.Parse(self._UNICODE_SAMPLE, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+ # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data.
+ self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1)
+ # repeated_string[0] contained \ escaped data representing the UTF-8
+ # representation of _GOLDEN_STR_0 - it needs to decode as such.
+ self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0)
+
+ def testParseBytes(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.Parse(self._BYTES_SAMPLE, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+ # repeated_bytes[1] contained simple \ escaped non-UTF-8 raw binary data.
+ self.assertEqual(m.repeated_bytes[1], self._GOLDEN_BYTES_1)
+ # repeated_string[0] contained \ escaped data representing the UTF-8
+ # representation of _GOLDEN_STR_0 - it needs to decode as such.
+ self.assertEqual(m.repeated_string[0], self._GOLDEN_STR_0)
+
+ def testFromBytesFile(self, message_module):
+ m = message_module.TestAllTypes()
+ f = io.BytesIO(self._BYTES_SAMPLE)
+ text_format.ParseLines(f, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+ def testFromUnicodeFile(self, message_module):
+ m = message_module.TestAllTypes()
+ f = io.StringIO(self._UNICODE_SAMPLE)
+ text_format.ParseLines(f, m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+ def testFromBytesLines(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.ParseLines(self._BYTES_SAMPLE.split(b'\n'), m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+ def testFromUnicodeLines(self, message_module):
+ m = message_module.TestAllTypes()
+ text_format.ParseLines(self._UNICODE_SAMPLE.split(u'\n'), m)
+ self.assertEqual(m.optional_bytes, self._GOLDEN_BYTES)
+ self.assertEqual(m.optional_string, self._GOLDEN_UNICODE)
+ self.assertEqual(m.repeated_bytes[0], self._GOLDEN_BYTES)
+
+
+@_parameterized.parameters(unittest_pb2, unittest_proto3_arena_pb2)
+class TextFormatMergeTests(TextFormatBase):
+
def testMergeDuplicateScalars(self, message_module):
message = message_module.TestAllTypes()
text = ('optional_int32: 42 ' 'optional_int32: 67')
@@ -472,26 +672,12 @@ class TextFormatTest(TextFormatBase):
self.assertTrue(r is message)
self.assertEqual(2, message.optional_nested_message.bb)
- def testParseOneof(self, message_module):
- m = message_module.TestAllTypes()
- m.oneof_uint32 = 11
- m2 = message_module.TestAllTypes()
- text_format.Parse(text_format.MessageToString(m), m2)
- self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field'))
-
def testMergeMultipleOneof(self, message_module):
m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
m2 = message_module.TestAllTypes()
text_format.Merge(m_string, m2)
self.assertEqual('oneof_string', m2.WhichOneof('oneof_field'))
- def testParseMultipleOneof(self, message_module):
- m_string = '\n'.join(['oneof_uint32: 11', 'oneof_string: "foo"'])
- m2 = message_module.TestAllTypes()
- with self.assertRaisesRegexp(text_format.ParseError,
- ' is specified along with field '):
- text_format.Parse(m_string, m2)
-
# These are tests that aren't fundamentally specific to proto2, but are at
# the moment because of differences between the proto2 and proto3 test schemas.
@@ -649,6 +835,29 @@ class OnlyWorksWithProto2RightNowTests(TextFormatBase):
' }\n'
'}\n')
+ # In cpp implementation, __str__ calls the cpp implementation of text format.
+ def testPrintMapUsingCppImplementation(self):
+ message = map_unittest_pb2.TestMap()
+ inner_msg = message.map_int32_foreign_message[111]
+ inner_msg.c = 1
+ self.assertEqual(
+ str(message),
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 1\n'
+ ' }\n'
+ '}\n')
+ inner_msg.c = 2
+ self.assertEqual(
+ str(message),
+ 'map_int32_foreign_message {\n'
+ ' key: 111\n'
+ ' value {\n'
+ ' c: 2\n'
+ ' }\n'
+ '}\n')
+
def testMapOrderEnforcement(self):
message = map_unittest_pb2.TestMap()
for letter in string.ascii_uppercase[13:26]:
@@ -938,7 +1147,7 @@ class Proto2Tests(TextFormatBase):
'}\n')
six.assertRaisesRegex(self,
text_format.ParseError,
- '5:1 : Expected ">".',
+ '5:1 : \'}\': Expected ">".',
text_format.Parse,
malformed,
message,
@@ -981,7 +1190,8 @@ class Proto2Tests(TextFormatBase):
with self.assertRaises(text_format.ParseError) as e:
text_format.Parse(text, message)
self.assertEqual(str(e.exception),
- '1:27 : Expected identifier or number, got "bb".')
+ '1:27 : \'optional_nested_message { "bb": 1 }\': '
+ 'Expected identifier or number, got "bb".')
def testParseBadExtension(self):
message = unittest_pb2.TestAllExtensions()
@@ -998,7 +1208,8 @@ class Proto2Tests(TextFormatBase):
message = unittest_pb2.TestAllTypes()
text = 'optional_nested_enum: 100'
six.assertRaisesRegex(self, text_format.ParseError,
- (r'1:23 : Enum type "\w+.TestAllTypes.NestedEnum" '
+ (r'1:23 : \'optional_nested_enum: 100\': '
+ r'Enum type "\w+.TestAllTypes.NestedEnum" '
r'has no value with number 100.'), text_format.Parse,
text, message)
@@ -1209,6 +1420,24 @@ class Proto3Tests(unittest.TestCase):
' < data: "string" > '
'>')
+ def testPrintAndParseMessageInvalidAny(self):
+ packed_message = unittest_pb2.OneString()
+ packed_message.data = 'string'
+ message = any_test_pb2.TestAny()
+ message.any_value.Pack(packed_message)
+ # Only include string after last '/' in type_url.
+ message.any_value.type_url = message.any_value.TypeName()
+ text = text_format.MessageToString(message)
+ self.assertEqual(
+ text, 'any_value {\n'
+ ' type_url: "protobuf_unittest.OneString"\n'
+ ' value: "\\n\\006string"\n'
+ '}\n')
+
+ parsed_message = any_test_pb2.TestAny()
+ text_format.Parse(text, parsed_message)
+ self.assertEqual(message, parsed_message)
+
def testUnknownEnums(self):
message = unittest_proto3_arena_pb2.TestAllTypes()
message2 = unittest_proto3_arena_pb2.TestAllTypes()
@@ -1448,6 +1677,26 @@ class TokenizerTest(unittest.TestCase):
self.assertEqual(0, text_format._ConsumeUint64(tokenizer))
self.assertTrue(tokenizer.AtEnd())
+ def testConsumeOctalIntegers(self):
+ """Test support for C style octal integers."""
+ text = '00 -00 04 0755 -010 007 -0033 08 -09 01'
+ tokenizer = text_format.Tokenizer(text.splitlines())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(0, tokenizer.ConsumeInteger())
+ self.assertEqual(4, tokenizer.ConsumeInteger())
+ self.assertEqual(0o755, tokenizer.ConsumeInteger())
+ self.assertEqual(-0o10, tokenizer.ConsumeInteger())
+ self.assertEqual(7, tokenizer.ConsumeInteger())
+ self.assertEqual(-0o033, tokenizer.ConsumeInteger())
+ with self.assertRaises(text_format.ParseError):
+ tokenizer.ConsumeInteger() # 08
+ tokenizer.NextToken()
+ with self.assertRaises(text_format.ParseError):
+ tokenizer.ConsumeInteger() # -09
+ tokenizer.NextToken()
+ self.assertEqual(1, tokenizer.ConsumeInteger())
+ self.assertTrue(tokenizer.AtEnd())
+
def testConsumeByteString(self):
text = '"string1\''
tokenizer = text_format.Tokenizer(text.splitlines())
@@ -1556,6 +1805,12 @@ class TokenizerTest(unittest.TestCase):
tokenizer.ConsumeCommentOrTrailingComment())
self.assertTrue(tokenizer.AtEnd())
+ def testHugeString(self):
+ # With pathologic backtracking, fails with Forge OOM.
+ text = '"' + 'a' * (10 * 1024 * 1024) + '"'
+ tokenizer = text_format.Tokenizer(text.splitlines(), skip_comments=False)
+ tokenizer.ConsumeString()
+
# Tests for pretty printer functionality.
@_parameterized.parameters((unittest_pb2), (unittest_proto3_arena_pb2))
@@ -1652,5 +1907,64 @@ class PrettyPrinterTest(TextFormatBase):
'repeated_nested_message { My lucky number is 42 } '
'repeated_nested_message { My lucky number is 99 }'))
+
+class WhitespaceTest(TextFormatBase):
+
+ def setUp(self):
+ self.out = text_format.TextWriter(False)
+ self.addCleanup(self.out.close)
+ self.message = unittest_pb2.NestedTestAllTypes()
+ self.message.child.payload.optional_string = 'value'
+ self.field = self.message.DESCRIPTOR.fields_by_name['child']
+ self.value = self.message.child
+
+ def testMessageToString(self):
+ self.CompareToGoldenText(
+ text_format.MessageToString(self.message),
+ textwrap.dedent("""\
+ child {
+ payload {
+ optional_string: "value"
+ }
+ }
+ """))
+
+ def testPrintMessage(self):
+ text_format.PrintMessage(self.message, self.out)
+ self.CompareToGoldenText(
+ self.out.getvalue(),
+ textwrap.dedent("""\
+ child {
+ payload {
+ optional_string: "value"
+ }
+ }
+ """))
+
+ def testPrintField(self):
+ text_format.PrintField(self.field, self.value, self.out)
+ self.CompareToGoldenText(
+ self.out.getvalue(),
+ textwrap.dedent("""\
+ child {
+ payload {
+ optional_string: "value"
+ }
+ }
+ """))
+
+ def testPrintFieldValue(self):
+ text_format.PrintFieldValue(
+ self.field, self.value, self.out)
+ self.CompareToGoldenText(
+ self.out.getvalue(),
+ textwrap.dedent("""\
+ {
+ payload {
+ optional_string: "value"
+ }
+ }"""))
+
+
if __name__ == '__main__':
unittest.main()
diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py
index 4a76cd4e..0807e7f7 100755
--- a/python/google/protobuf/internal/type_checkers.py
+++ b/python/google/protobuf/internal/type_checkers.py
@@ -185,6 +185,14 @@ class UnicodeValueChecker(object):
'encoding. Non-UTF-8 strings must be converted to '
'unicode objects before being added.' %
(proposed_value))
+ else:
+ try:
+ proposed_value.encode('utf8')
+ except UnicodeEncodeError:
+ raise ValueError('%.1024r isn\'t a valid unicode string and '
+ 'can\'t be encoded in UTF-8.'%
+ (proposed_value))
+
return proposed_value
def DefaultValue(self):
diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py
index 8b7de2e7..fceadf71 100755
--- a/python/google/protobuf/internal/unknown_fields_test.py
+++ b/python/google/protobuf/internal/unknown_fields_test.py
@@ -49,20 +49,12 @@ from google.protobuf.internal import missing_enum_values_pb2
from google.protobuf.internal import test_util
from google.protobuf.internal import testing_refleaks
from google.protobuf.internal import type_checkers
+from google.protobuf import descriptor
BaseTestCase = testing_refleaks.BaseTestCase
-# CheckUnknownField() cannot be used by the C++ implementation because
-# some protect members are called. It is not a behavior difference
-# for python and C++ implementation.
-def SkipCheckUnknownFieldIfCppImplementation(func):
- return unittest.skipIf(
- api_implementation.Type() == 'cpp' and api_implementation.Version() == 2,
- 'Addtional test for pure python involved protect members')(func)
-
-
class UnknownFieldsTest(BaseTestCase):
def setUp(self):
@@ -80,23 +72,11 @@ class UnknownFieldsTest(BaseTestCase):
# stdout.
self.assertTrue(data == self.all_fields_data)
- def expectSerializeProto3(self, preserve):
+ def testSerializeProto3(self):
+ # Verify proto3 unknown fields behavior.
message = unittest_proto3_arena_pb2.TestEmptyMessage()
message.ParseFromString(self.all_fields_data)
- if preserve:
- self.assertEqual(self.all_fields_data, message.SerializeToString())
- else:
- self.assertEqual(0, len(message.SerializeToString()))
-
- def testSerializeProto3(self):
- # Verify that proto3 unknown fields behavior.
- default_preserve = (api_implementation
- .GetPythonProto3PreserveUnknownsDefault())
- self.expectSerializeProto3(default_preserve)
- api_implementation.SetPythonProto3PreserveUnknownsDefault(
- not default_preserve)
- self.expectSerializeProto3(not default_preserve)
- api_implementation.SetPythonProto3PreserveUnknownsDefault(default_preserve)
+ self.assertEqual(self.all_fields_data, message.SerializeToString())
def testByteSize(self):
self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize())
@@ -169,13 +149,15 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.empty_message = unittest_pb2.TestEmptyMessage()
self.empty_message.ParseFromString(self.all_fields_data)
- # CheckUnknownField() is an additional Pure Python check which checks
+ # InternalCheckUnknownField() is an additional Pure Python check which checks
# a detail of unknown fields. It cannot be used by the C++
# implementation because some protect members are called.
# The test is added for historical reasons. It is not necessary as
# serialized string is checked.
-
- def CheckUnknownField(self, name, expected_value):
+ # TODO(jieluo): Remove message._unknown_fields.
+ def InternalCheckUnknownField(self, name, expected_value):
+ if api_implementation.Type() == 'cpp':
+ return
field_descriptor = self.descriptor.fields_by_name[name]
wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
@@ -183,36 +165,80 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
for tag_bytes, value in self.empty_message._unknown_fields:
if tag_bytes == field_tag:
decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0]
- decoder(value, 0, len(value), self.all_fields, result_dict)
+ decoder(memoryview(value), 0, len(value), self.all_fields, result_dict)
self.assertEqual(expected_value, result_dict[field_descriptor])
- @SkipCheckUnknownFieldIfCppImplementation
+ def CheckUnknownField(self, name, unknown_fields, expected_value):
+ field_descriptor = self.descriptor.fields_by_name[name]
+ expected_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[
+ field_descriptor.type]
+ for unknown_field in unknown_fields:
+ if unknown_field.field_number == field_descriptor.number:
+ self.assertEqual(expected_type, unknown_field.wire_type)
+ if expected_type == 3:
+ # Check group
+ self.assertEqual(expected_value[0],
+ unknown_field.data[0].field_number)
+ self.assertEqual(expected_value[1], unknown_field.data[0].wire_type)
+ self.assertEqual(expected_value[2], unknown_field.data[0].data)
+ continue
+ if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ self.assertIn(unknown_field.data, expected_value)
+ else:
+ self.assertEqual(expected_value, unknown_field.data)
+
def testCheckUnknownFieldValue(self):
+ unknown_fields = self.empty_message.UnknownFields()
# Test enum.
self.CheckUnknownField('optional_nested_enum',
+ unknown_fields,
self.all_fields.optional_nested_enum)
+ self.InternalCheckUnknownField('optional_nested_enum',
+ self.all_fields.optional_nested_enum)
+
# Test repeated enum.
self.CheckUnknownField('repeated_nested_enum',
+ unknown_fields,
self.all_fields.repeated_nested_enum)
+ self.InternalCheckUnknownField('repeated_nested_enum',
+ self.all_fields.repeated_nested_enum)
# Test varint.
self.CheckUnknownField('optional_int32',
+ unknown_fields,
self.all_fields.optional_int32)
+ self.InternalCheckUnknownField('optional_int32',
+ self.all_fields.optional_int32)
+
# Test fixed32.
self.CheckUnknownField('optional_fixed32',
+ unknown_fields,
self.all_fields.optional_fixed32)
+ self.InternalCheckUnknownField('optional_fixed32',
+ self.all_fields.optional_fixed32)
# Test fixed64.
self.CheckUnknownField('optional_fixed64',
+ unknown_fields,
self.all_fields.optional_fixed64)
+ self.InternalCheckUnknownField('optional_fixed64',
+ self.all_fields.optional_fixed64)
# Test lengthd elimited.
self.CheckUnknownField('optional_string',
- self.all_fields.optional_string)
+ unknown_fields,
+ self.all_fields.optional_string.encode('utf-8'))
+ self.InternalCheckUnknownField('optional_string',
+ self.all_fields.optional_string)
# Test group.
self.CheckUnknownField('optionalgroup',
- self.all_fields.optionalgroup)
+ unknown_fields,
+ (17, 0, 117))
+ self.InternalCheckUnknownField('optionalgroup',
+ self.all_fields.optionalgroup)
+
+ self.assertEqual(97, len(unknown_fields))
def testCopyFrom(self):
message = unittest_pb2.TestEmptyMessage()
@@ -230,9 +256,18 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
message.optional_int64 = 3
message.optional_uint32 = 4
destination = unittest_pb2.TestEmptyMessage()
+ unknown_fields = destination.UnknownFields()
+ self.assertEqual(0, len(unknown_fields))
destination.ParseFromString(message.SerializeToString())
-
+ # ParseFromString clears the message thus unknown fields is invalid.
+ with self.assertRaises(ValueError) as context:
+ len(unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ unknown_fields = destination.UnknownFields()
+ self.assertEqual(2, len(unknown_fields))
destination.MergeFrom(source)
+ self.assertEqual(4, len(unknown_fields))
# Check that the fields where correctly merged, even stored in the unknown
# fields set.
message.ParseFromString(destination.SerializeToString())
@@ -241,9 +276,58 @@ class UnknownFieldsAccessorsTest(BaseTestCase):
self.assertEqual(message.optional_int64, 3)
def testClear(self):
+ unknown_fields = self.empty_message.UnknownFields()
self.empty_message.Clear()
# All cleared, even unknown fields.
self.assertEqual(self.empty_message.SerializeToString(), b'')
+ with self.assertRaises(ValueError) as context:
+ len(unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+
+ def testSubUnknownFields(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optionalgroup.a = 123
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ sub_unknown_fields = destination.UnknownFields()[0].data
+ self.assertEqual(1, len(sub_unknown_fields))
+ self.assertEqual(sub_unknown_fields[0].data, 123)
+ destination.Clear()
+ with self.assertRaises(ValueError) as context:
+ len(sub_unknown_fields)
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ with self.assertRaises(ValueError) as context:
+ # pylint: disable=pointless-statement
+ sub_unknown_fields[0]
+ self.assertIn('UnknownFields does not exist.',
+ str(context.exception))
+ message.Clear()
+ message.optional_uint32 = 456
+ nested_message = unittest_pb2.NestedTestAllTypes()
+ nested_message.payload.optional_nested_message.ParseFromString(
+ message.SerializeToString())
+ unknown_fields = (
+ nested_message.payload.optional_nested_message.UnknownFields())
+ self.assertEqual(unknown_fields[0].data, 456)
+ nested_message.ClearField('payload')
+ self.assertEqual(unknown_fields[0].data, 456)
+ unknown_fields = (
+ nested_message.payload.optional_nested_message.UnknownFields())
+ self.assertEqual(0, len(unknown_fields))
+
+ def testUnknownField(self):
+ message = unittest_pb2.TestAllTypes()
+ message.optional_int32 = 123
+ destination = unittest_pb2.TestEmptyMessage()
+ destination.ParseFromString(message.SerializeToString())
+ unknown_field = destination.UnknownFields()[0]
+ destination.Clear()
+ with self.assertRaises(ValueError) as context:
+ unknown_field.data # pylint: disable=pointless-statement
+ self.assertIn('The parent message might be cleared.',
+ str(context.exception))
def testUnknownExtensions(self):
message = unittest_pb2.TestEmptyMessageWithExtensions()
@@ -280,15 +364,13 @@ class UnknownEnumValuesTest(BaseTestCase):
def CheckUnknownField(self, name, expected_value):
field_descriptor = self.descriptor.fields_by_name[name]
- wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type]
- field_tag = encoder.TagBytes(field_descriptor.number, wire_type)
- result_dict = {}
- for tag_bytes, value in self.missing_message._unknown_fields:
- if tag_bytes == field_tag:
- decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[
- tag_bytes][0]
- decoder(value, 0, len(value), self.message, result_dict)
- self.assertEqual(expected_value, result_dict[field_descriptor])
+ unknown_fields = self.missing_message.UnknownFields()
+ for field in unknown_fields:
+ if field.field_number == field_descriptor.number:
+ if field_descriptor.label == descriptor.FieldDescriptor.LABEL_REPEATED:
+ self.assertIn(field.data, expected_value)
+ else:
+ self.assertEqual(expected_value, field.data)
def testUnknownParseMismatchEnumValue(self):
just_string = missing_enum_values_pb2.JustString()
@@ -317,7 +399,6 @@ class UnknownEnumValuesTest(BaseTestCase):
def testUnknownPackedEnumValue(self):
self.assertEqual([], self.missing_message.packed_nested_enum)
- @SkipCheckUnknownFieldIfCppImplementation
def testCheckUnknownFieldValueForEnum(self):
self.CheckUnknownField('optional_nested_enum',
self.message.optional_nested_enum)
diff --git a/python/google/protobuf/internal/well_known_types.py b/python/google/protobuf/internal/well_known_types.py
index 37a65cfa..95c5615f 100644
--- a/python/google/protobuf/internal/well_known_types.py
+++ b/python/google/protobuf/internal/well_known_types.py
@@ -40,6 +40,7 @@ This files defines well known classes which need extra maintenance including:
__author__ = 'jieluo@google.com (Jie Luo)'
+import calendar
import collections
from datetime import datetime
from datetime import timedelta
@@ -92,7 +93,7 @@ class Any(object):
def Is(self, descriptor):
"""Checks if this Any represents the given protobuf type."""
- return self.TypeName() == descriptor.full_name
+ return '/' in self.type_url and self.TypeName() == descriptor.full_name
class Timestamp(object):
@@ -233,9 +234,15 @@ class Timestamp(object):
def FromDatetime(self, dt):
"""Converts datetime to Timestamp."""
- td = dt - datetime(1970, 1, 1)
- self.seconds = td.seconds + td.days * _SECONDS_PER_DAY
- self.nanos = td.microseconds * _NANOS_PER_MICROSECOND
+ # Using this guide: http://wiki.python.org/moin/WorkingWithTime
+ # And this conversion guide: http://docs.python.org/library/time.html
+
+ # Turn the date parameter into a tuple (struct_time) that can then be
+ # manipulated into a long value of seconds. During the conversion from
+ # struct_time to long, the source date in UTC, and so it follows that the
+ # correct transformation is calendar.timegm()
+ self.seconds = calendar.timegm(dt.utctimetuple())
+ self.nanos = dt.microsecond * _NANOS_PER_MICROSECOND
class Duration(object):
diff --git a/python/google/protobuf/internal/well_known_types_test.py b/python/google/protobuf/internal/well_known_types_test.py
index 965940b2..4dc2ae4f 100644
--- a/python/google/protobuf/internal/well_known_types_test.py
+++ b/python/google/protobuf/internal/well_known_types_test.py
@@ -35,7 +35,7 @@
__author__ = 'jieluo@google.com (Jie Luo)'
import collections
-from datetime import datetime
+import datetime
try:
import unittest2 as unittest #PY26
@@ -240,14 +240,34 @@ class TimeUtilTest(TimeUtilTestBase):
def testDatetimeConverison(self):
message = timestamp_pb2.Timestamp()
- dt = datetime(1970, 1, 1)
+ dt = datetime.datetime(1970, 1, 1)
message.FromDatetime(dt)
self.assertEqual(dt, message.ToDatetime())
message.FromMilliseconds(1999)
- self.assertEqual(datetime(1970, 1, 1, 0, 0, 1, 999000),
+ self.assertEqual(datetime.datetime(1970, 1, 1, 0, 0, 1, 999000),
message.ToDatetime())
+ def testDatetimeConversionWithTimezone(self):
+ class TZ(datetime.tzinfo):
+
+ def utcoffset(self, _):
+ return datetime.timedelta(hours=1)
+
+ def dst(self, _):
+ return datetime.timedelta(0)
+
+ def tzname(self, _):
+ return 'UTC+1'
+
+ message1 = timestamp_pb2.Timestamp()
+ dt = datetime.datetime(1970, 1, 1, 1, tzinfo=TZ())
+ message1.FromDatetime(dt)
+ message2 = timestamp_pb2.Timestamp()
+ dt = datetime.datetime(1970, 1, 1, 0)
+ message2.FromDatetime(dt)
+ self.assertEqual(message1, message2)
+
def testTimedeltaConversion(self):
message = duration_pb2.Duration()
message.FromNanoseconds(1999999999)
@@ -879,6 +899,17 @@ class AnyTest(unittest.TestCase):
raise AttributeError('%s should not have Pack method.' %
msg_descriptor.full_name)
+ def testUnpackWithNoSlashInTypeUrl(self):
+ msg = any_test_pb2.TestAny()
+ all_types = unittest_pb2.TestAllTypes()
+ all_descriptor = all_types.DESCRIPTOR
+ msg.value.Pack(all_types)
+ # Reset type_url to part of type_url after '/'
+ msg.value.type_url = msg.value.TypeName()
+ self.assertFalse(msg.value.Is(all_descriptor))
+ unpacked_message = unittest_pb2.TestAllTypes()
+ self.assertFalse(msg.value.Unpack(unpacked_message))
+
def testMessageName(self):
# Creates and sets message.
submessage = any_test_pb2.TestAny()