diff options
Diffstat (limited to 'python/google/protobuf/internal/decoder.py')
-rwxr-xr-x | python/google/protobuf/internal/decoder.py | 196 |
1 files changed, 179 insertions, 17 deletions
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 52b64915..5a540184 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -81,12 +81,17 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct - +import sys import six +_UCS2_MAXUNICODE = 65535 if six.PY3: long = int +else: + import re # pylint: disable=g-import-not-at-top + _SURROGATE_PATTERN = re.compile(six.u(r'[\ud800-\udfff]')) +from google.protobuf.internal import containers from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message @@ -167,7 +172,7 @@ _DecodeSignedVarint32 = _SignedVarintDecoder(32, int) def ReadTag(buffer, pos): - """Read a tag from the buffer, and return a (tag_bytes, new_pos) tuple. + """Read a tag from the memoryview, and return a (tag_bytes, new_pos) tuple. We return the raw bytes of the tag rather than decoding them. The raw bytes can then be used to look up the proper decoder. This effectively allows @@ -175,13 +180,21 @@ def ReadTag(buffer, pos): for work that is done in C (searching for a byte string in a hash table). In a low-level language it would be much cheaper to decode the varint and use that, but not in Python. - """ + Args: + buffer: memoryview object of the encoded bytes + pos: int of the current position to start from + + Returns: + Tuple[bytes, int] of the tag data and new position. + """ start = pos while six.indexbytes(buffer, pos) & 0x80: pos += 1 pos += 1 - return (six.binary_type(buffer[start:pos]), pos) + + tag_bytes = buffer[start:pos].tobytes() + return tag_bytes, pos # -------------------------------------------------------------------- @@ -295,10 +308,20 @@ def _FloatDecoder(): local_unpack = struct.unpack def InnerDecode(buffer, pos): + """Decode serialized float to a float and new position. + + Args: + buffer: memoryview of the serialized bytes + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the deserialized float value and new position + in the serialized data. + """ # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. new_pos = pos + 4 - float_bytes = buffer[pos:new_pos] + float_bytes = buffer[pos:new_pos].tobytes() # If this value has all its exponent bits set, then it's non-finite. # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. @@ -329,10 +352,20 @@ def _DoubleDecoder(): local_unpack = struct.unpack def InnerDecode(buffer, pos): + """Decode serialized double to a double and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + + Returns: + Tuple[float, int] of the decoded double value and new position + in the serialized data. + """ # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. new_pos = pos + 8 - double_bytes = buffer[pos:new_pos] + double_bytes = buffer[pos:new_pos].tobytes() # If this value has all its exponent bits set and at least one significand # bit set, it's not a number. In Python 2.4, struct.unpack will treat it @@ -355,6 +388,18 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): if is_packed: local_DecodeVarint = _DecodeVarint def DecodePackedField(buffer, pos, end, message, field_dict): + """Decode serialized packed enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) @@ -365,6 +410,7 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): while pos < endpoint: value_start_pos = pos (element, pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: @@ -372,8 +418,10 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): message._unknown_fields = [] tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + # pylint: enable=protected-access if pos > endpoint: if element in enum_type.values_by_number: del value[-1] # Discard corrupt value. @@ -386,18 +434,32 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) tag_len = len(tag_bytes) def DecodeRepeatedField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value = field_dict.get(key) if value is None: value = field_dict.setdefault(key, new_default(message)) while 1: (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + # pylint: disable=protected-access if element in enum_type.values_by_number: value.append(element) else: if not message._unknown_fields: message._unknown_fields = [] message._unknown_fields.append( - (tag_bytes, buffer[pos:new_pos])) + (tag_bytes, buffer[pos:new_pos].tobytes())) + # pylint: enable=protected-access # Predict that the next tag is another copy of the same repeated # field. pos = new_pos + tag_len @@ -409,10 +471,23 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): return DecodeRepeatedField else: def DecodeField(buffer, pos, end, message, field_dict): + """Decode serialized repeated enum to its value and a new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ value_start_pos = pos (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) if pos > end: raise _DecodeError('Truncated message.') + # pylint: disable=protected-access if enum_value in enum_type.values_by_number: field_dict[key] = enum_value else: @@ -421,7 +496,8 @@ def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) message._unknown_fields.append( - (tag_bytes, buffer[value_start_pos:pos])) + (tag_bytes, buffer[value_start_pos:pos].tobytes())) + # pylint: enable=protected-access return pos return DecodeField @@ -458,20 +534,34 @@ BoolDecoder = _ModifiedDecoder( wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) -def StringDecoder(field_number, is_repeated, is_packed, key, new_default): +def StringDecoder(field_number, is_repeated, is_packed, key, new_default, + is_strict_utf8=False): """Returns a decoder for a string field.""" local_DecodeVarint = _DecodeVarint local_unicode = six.text_type - def _ConvertToUnicode(byte_str): + def _ConvertToUnicode(memview): + """Convert byte to unicode.""" + byte_str = memview.tobytes() try: - return local_unicode(byte_str, 'utf-8') + value = local_unicode(byte_str, 'utf-8') except UnicodeDecodeError as e: # add more information to the error message and re-raise it. e.reason = '%s in field: %s' % (e, key.full_name) raise + if is_strict_utf8 and six.PY2 and sys.maxunicode > _UCS2_MAXUNICODE: + # Only do the check for python2 ucs4 when is_strict_utf8 enabled + if _SURROGATE_PATTERN.search(value): + reason = ('String field %s contains invalid UTF-8 data when parsing' + 'a protocol buffer: surrogates not allowed. Use' + 'the bytes type if you intend to send raw bytes.') % ( + key.full_name) + raise message.DecodeError(reason) + + return value + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -523,7 +613,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(buffer[pos:new_pos]) + value.append(buffer[pos:new_pos].tobytes()) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -536,7 +626,7 @@ def BytesDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = buffer[pos:new_pos] + field_dict[key] = buffer[pos:new_pos].tobytes() return new_pos return DecodeField @@ -665,6 +755,18 @@ def MessageSetItemDecoder(descriptor): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + """Decode serialized message set to its value and new position. + + Args: + buffer: memoryview of the serialized bytes. + pos: int, position in the memory view to start at. + end: int, end position of serialized data + message: Message object to store unknown fields in + field_dict: Map[Descriptor, Any] to store decoded values in. + + Returns: + int, new position in serialized data. + """ message_set_item_start = pos type_id = -1 message_start = -1 @@ -695,6 +797,7 @@ def MessageSetItemDecoder(descriptor): raise _DecodeError('MessageSet item missing message.') extension = message.Extensions._FindExtensionByNumber(type_id) + # pylint: disable=protected-access if extension is not None: value = field_dict.get(extension) if value is None: @@ -707,8 +810,9 @@ def MessageSetItemDecoder(descriptor): else: if not message._unknown_fields: message._unknown_fields = [] - message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, - buffer[message_set_item_start:pos])) + message._unknown_fields.append( + (MESSAGE_SET_ITEM_TAG, buffer[message_set_item_start:pos].tobytes())) + # pylint: enable=protected-access return pos @@ -767,7 +871,7 @@ def _SkipVarint(buffer, pos, end): # Previously ord(buffer[pos]) raised IndexError when pos is out of range. # With this code, ord(b'') raises TypeError. Both are handled in # python_message.py to generate a 'Truncated message' error. - while ord(buffer[pos:pos+1]) & 0x80: + while ord(buffer[pos:pos+1].tobytes()) & 0x80: pos += 1 pos += 1 if pos > end: @@ -782,6 +886,13 @@ def _SkipFixed64(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + +def _DecodeFixed64(buffer, pos): + """Decode a fixed64.""" + new_pos = pos + 8 + return (struct.unpack('<Q', buffer[pos:new_pos])[0], new_pos) + + def _SkipLengthDelimited(buffer, pos, end): """Skip a length-delimited value. Returns the new position.""" @@ -791,6 +902,7 @@ def _SkipLengthDelimited(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + def _SkipGroup(buffer, pos, end): """Skip sub-group. Returns the new position.""" @@ -801,11 +913,53 @@ def _SkipGroup(buffer, pos, end): return pos pos = new_pos + +def _DecodeGroup(buffer, pos): + """Decode group. Returns the UnknownFieldSet and new position.""" + + unknown_field_set = containers.UnknownFieldSet() + while 1: + (tag_bytes, pos) = ReadTag(buffer, pos) + (tag, _) = _DecodeVarint(tag_bytes, 0) + field_number, wire_type = wire_format.UnpackTag(tag) + if wire_type == wire_format.WIRETYPE_END_GROUP: + break + (data, pos) = _DecodeUnknownField(buffer, pos, wire_type) + # pylint: disable=protected-access + unknown_field_set._add(field_number, wire_type, data) + + return (unknown_field_set, pos) + + +def _DecodeUnknownField(buffer, pos, wire_type): + """Decode a unknown field. Returns the UnknownField and new position.""" + + if wire_type == wire_format.WIRETYPE_VARINT: + (data, pos) = _DecodeVarint(buffer, pos) + elif wire_type == wire_format.WIRETYPE_FIXED64: + (data, pos) = _DecodeFixed64(buffer, pos) + elif wire_type == wire_format.WIRETYPE_FIXED32: + (data, pos) = _DecodeFixed32(buffer, pos) + elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: + (size, pos) = _DecodeVarint(buffer, pos) + data = buffer[pos:pos+size] + pos += size + elif wire_type == wire_format.WIRETYPE_START_GROUP: + (data, pos) = _DecodeGroup(buffer, pos) + elif wire_type == wire_format.WIRETYPE_END_GROUP: + return (0, -1) + else: + raise _DecodeError('Wrong wire type in tag.') + + return (data, pos) + + def _EndGroup(buffer, pos, end): """Skipping an END_GROUP tag returns -1 to tell the parent loop to break.""" return -1 + def _SkipFixed32(buffer, pos, end): """Skip a fixed32 value. Returns the new position.""" @@ -814,6 +968,14 @@ def _SkipFixed32(buffer, pos, end): raise _DecodeError('Truncated message.') return pos + +def _DecodeFixed32(buffer, pos): + """Decode a fixed32.""" + + new_pos = pos + 4 + return (struct.unpack('<I', buffer[pos:new_pos])[0], new_pos) + + def _RaiseInvalidWireType(buffer, pos, end): """Skip function for unknown wire types. Raises an exception.""" |