aboutsummaryrefslogtreecommitdiff
path: root/python/google/protobuf/internal/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/decoder.py')
-rwxr-xr-xpython/google/protobuf/internal/decoder.py196
1 files changed, 179 insertions, 17 deletions
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 52b64915..5a540184 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -81,12 +81,17 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
-
+import sys
import six
+_UCS2_MAXUNICODE = 65535
if six.PY3:
long = int
+else:
+ import re # pylint: disable=g-import-not-at-top
+ _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]'))
+from google.protobuf.internal import containers
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
@@ -167,7 +172,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int)
def ReadTag(buffer, pos):
- """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple.
+ """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple.
We return the raw bytes of the tag rather than decoding them. The raw
bytes can then be used to look up the proper decoder. This effectively allows
@@ -175,13 +180,21 @@ def ReadTag(buffer, pos):
for work that is done in C (searching for a byte string in a hash table).
In a low-level language it would be much cheaper to decode the varint and
use that, but not in Python.
- """
+ Args:
+ buffer: memoryview object of the encoded bytes
+ pos: int of the current position to start from
+
+ Returns:
+ Tuple[bytes, int] of the tag data and new position.
+ """
start = pos
while six.indexbytes(buffer, pos) & 0x80:
pos += 1
pos += 1
- return (six.binary_type(buffer[start:pos]), pos)
+
+ tag_bytes = buffer[start:pos].tobytes()
+ return tag_bytes, pos
# --------------------------------------------------------------------
@@ -295,10 +308,20 @@ def _FloatDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
+ """Decode serialized float to a float and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes
+ pos: int, position in the memory view to start at.
+
+ Returns:
+ Tuple[float, int] of the deserialized float value and new position
+ in the serialized data.
+ """
# We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
new_pos = pos + 4
- float_bytes = buffer[pos:new_pos]
+ float_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set, then it's non-finite.
# In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
@@ -329,10 +352,20 @@ def _DoubleDecoder():
local_unpack = struct.unpack
def InnerDecode(buffer, pos):
+ """Decode serialized double to a double and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+
+ Returns:
+ Tuple[float, int] of the decoded double value and new position
+ in the serialized data.
+ """
# We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
# bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
new_pos = pos + 8
- double_bytes = buffer[pos:new_pos]
+ double_bytes = buffer[pos:new_pos].tobytes()
# If this value has all its exponent bits set and at least one significand
# bit set, it's not a number. In Python 2.4, struct.unpack will treat it
@@ -355,6 +388,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
if is_packed:
local_DecodeVarint = _DecodeVarint
def DecodePackedField(buffer, pos, end, message, field_dict):
+ """Decode serialized packed enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
@@ -365,6 +410,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
while pos < endpoint:
value_start_pos = pos
(element, pos) = _DecodeSignedVarint32(buffer, pos)
+ # pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
@@ -372,8 +418,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
message._unknown_fields = []
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
+
message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
+ (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+ # pylint: enable=protected-access
if pos > endpoint:
if element in enum_type.values_by_number:
del value[-1] # Discard corrupt value.
@@ -386,18 +434,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
tag_len = len(tag_bytes)
def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ """Decode serialized repeated enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value = field_dict.get(key)
if value is None:
value = field_dict.setdefault(key, new_default(message))
while 1:
(element, new_pos) = _DecodeSignedVarint32(buffer, pos)
+ # pylint: disable=protected-access
if element in enum_type.values_by_number:
value.append(element)
else:
if not message._unknown_fields:
message._unknown_fields = []
message._unknown_fields.append(
- (tag_bytes, buffer[pos:new_pos]))
+ (tag_bytes, buffer[pos:new_pos].tobytes()))
+ # pylint: enable=protected-access
# Predict that the next tag is another copy of the same repeated
# field.
pos = new_pos + tag_len
@@ -409,10 +471,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
return DecodeRepeatedField
else:
def DecodeField(buffer, pos, end, message, field_dict):
+ """Decode serialized repeated enum to its value and a new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
value_start_pos = pos
(enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
if pos > end:
raise _DecodeError('Truncated message.')
+ # pylint: disable=protected-access
if enum_value in enum_type.values_by_number:
field_dict[key] = enum_value
else:
@@ -421,7 +496,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
tag_bytes = encoder.TagBytes(field_number,
wire_format.WIRETYPE_VARINT)
message._unknown_fields.append(
- (tag_bytes, buffer[value_start_pos:pos]))
+ (tag_bytes, buffer[value_start_pos:pos].tobytes()))
+ # pylint: enable=protected-access
return pos
return DecodeField
@@ -458,20 +534,34 @@ BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
-def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
+def StringDecoder(field_number, is_repeated, is_packed, key, new_default,
+ is_strict_utf8=False):
"""Returns a decoder for a string field."""
local_DecodeVarint = _DecodeVarint
local_unicode = six.text_type
- def _ConvertToUnicode(byte_str):
+ def _ConvertToUnicode(memview):
+ """Convert byte to unicode."""
+ byte_str = memview.tobytes()
try:
- return local_unicode(byte_str, 'utf-8')
+ value = local_unicode(byte_str, 'utf-8')
except UnicodeDecodeError as e:
# add more information to the error message and re-raise it.
e.reason = '%s in field: %s' % (e, key.full_name)
raise
+ if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE:
+ # Only do the check for python2 ucs4 when is_strict_utf8 enabled
+ if _SURROGATE_PATTERN.search(value):
+ reason = ('String field %s contains invalid UTF-8 data when parsing'
+ 'a protocol buffer: surrogates not allowed. Use'
+ 'the bytes type if you intend to send raw bytes.') % (
+ key.full_name)
+ raise message.DecodeError(reason)
+
+ return value
+
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
@@ -523,7 +613,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- value.append(buffer[pos:new_pos])
+ value.append(buffer[pos:new_pos].tobytes())
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -536,7 +626,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- field_dict[key] = buffer[pos:new_pos]
+ field_dict[key] = buffer[pos:new_pos].tobytes()
return new_pos
return DecodeField
@@ -665,6 +755,18 @@ def MessageSetItemDecoder(descriptor):
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
+ """Decode serialized message set to its value and new position.
+
+ Args:
+ buffer: memoryview of the serialized bytes.
+ pos: int, position in the memory view to start at.
+ end: int, end position of serialized data
+ message: Message object to store unknown fields in
+ field_dict: Map[Descriptor, Any] to store decoded values in.
+
+ Returns:
+ int, new position in serialized data.
+ """
message_set_item_start = pos
type_id = -1
message_start = -1
@@ -695,6 +797,7 @@ def MessageSetItemDecoder(descriptor):
raise _DecodeError('MessageSet item missing message.')
extension = message.Extensions._FindExtensionByNumber(type_id)
+ # pylint: disable=protected-access
if extension is not None:
value = field_dict.get(extension)
if value is None:
@@ -707,8 +810,9 @@ def MessageSetItemDecoder(descriptor):
else:
if not message._unknown_fields:
message._unknown_fields = []
- message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
- buffer[message_set_item_start:pos]))
+ message._unknown_fields.append(
+ (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes()))
+ # pylint: enable=protected-access
return pos
@@ -767,7 +871,7 @@ def _SkipVarint(buffer, pos, end):
# Previously ord(buffer[pos]) raised IndexError when pos is out of range.
# With this code, ord(b'') raises TypeError. Both are handled in
# python_message.py to generate a 'Truncated message' error.
- while ord(buffer[pos:pos+1]) & 0x80:
+ while ord(buffer[pos:pos+1].tobytes()) & 0x80:
pos += 1
pos += 1
if pos > end:
@@ -782,6 +886,13 @@ def _SkipFixed64(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
+def _DecodeFixed64(buffer, pos):
+ """Decode a fixed64."""
+ new_pos = pos + 8
+ return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos)
+
+
def _SkipLengthDelimited(buffer, pos, end):
"""Skip a length-delimited value. Returns the new position."""
@@ -791,6 +902,7 @@ def _SkipLengthDelimited(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
def _SkipGroup(buffer, pos, end):
"""Skip sub-group. Returns the new position."""
@@ -801,11 +913,53 @@ def _SkipGroup(buffer, pos, end):
return pos
pos = new_pos
+
+def _DecodeGroup(buffer, pos):
+ """Decode group. Returns the UnknownFieldSet and new position."""
+
+ unknown_field_set = containers.UnknownFieldSet()
+ while 1:
+ (tag_bytes, pos) = ReadTag(buffer, pos)
+ (tag, _) = _DecodeVarint(tag_bytes, 0)
+ field_number, wire_type = wire_format.UnpackTag(tag)
+ if wire_type == wire_format.WIRETYPE_END_GROUP:
+ break
+ (data, pos) = _DecodeUnknownField(buffer, pos, wire_type)
+ # pylint: disable=protected-access
+ unknown_field_set._add(field_number, wire_type, data)
+
+ return (unknown_field_set, pos)
+
+
+def _DecodeUnknownField(buffer, pos, wire_type):
+ """Decode a unknown field. Returns the UnknownField and new position."""
+
+ if wire_type == wire_format.WIRETYPE_VARINT:
+ (data, pos) = _DecodeVarint(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_FIXED64:
+ (data, pos) = _DecodeFixed64(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_FIXED32:
+ (data, pos) = _DecodeFixed32(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED:
+ (size, pos) = _DecodeVarint(buffer, pos)
+ data = buffer[pos:pos+size]
+ pos += size
+ elif wire_type == wire_format.WIRETYPE_START_GROUP:
+ (data, pos) = _DecodeGroup(buffer, pos)
+ elif wire_type == wire_format.WIRETYPE_END_GROUP:
+ return (0, -1)
+ else:
+ raise _DecodeError('Wrong wire type in tag.')
+
+ return (data, pos)
+
+
def _EndGroup(buffer, pos, end):
"""Skipping an END_GROUP tag returns -1 to tell the parent loop to break."""
return -1
+
def _SkipFixed32(buffer, pos, end):
"""Skip a fixed32 value. Returns the new position."""
@@ -814,6 +968,14 @@ def _SkipFixed32(buffer, pos, end):
raise _DecodeError('Truncated message.')
return pos
+
+def _DecodeFixed32(buffer, pos):
+ """Decode a fixed32."""
+
+ new_pos = pos + 4
+ return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos)
+
+
def _RaiseInvalidWireType(buffer, pos, end):
"""Skip function for unknown wire types. Raises an exception."""