From d37d46dfbcedadeb439ad0367f8afcf8867dca43 Mon Sep 17 00:00:00 2001 From: "kenton@google.com" Date: Sat, 25 Apr 2009 02:53:47 +0000 Subject: Integrate recent changes from Google-internal code tree. See CHANGES.txt for details. --- python/google/protobuf/internal/input_stream.py | 115 ++++++++++++++++++++- .../google/protobuf/internal/input_stream_test.py | 21 +++- python/google/protobuf/internal/reflection_test.py | 9 ++ python/google/protobuf/reflection.py | 4 + 4 files changed, 147 insertions(+), 2 deletions(-) (limited to 'python/google') diff --git a/python/google/protobuf/internal/input_stream.py b/python/google/protobuf/internal/input_stream.py index 2cff93db..7bda17e3 100755 --- a/python/google/protobuf/internal/input_stream.py +++ b/python/google/protobuf/internal/input_stream.py @@ -36,6 +36,7 @@ the InputStream primitives provided here. __author__ = 'robinson@google.com (Will Robinson)' +import array import struct from google.protobuf import message from google.protobuf.internal import wire_format @@ -46,7 +47,7 @@ from google.protobuf.internal import wire_format # proto2 implementation. -class InputStream(object): +class InputStreamBuffer(object): """Contains all logic for reading bits, and dealing with stream position. @@ -223,3 +224,115 @@ class InputStream(object): shift += 7 if not (b & 0x80): return result + + +class InputStreamArray(object): + + """Contains all logic for reading bits, and dealing with stream position. + + If an InputStream method ever raises an exception, the stream is left + in an indeterminate state and is not safe for further use. + + This alternative to InputStreamBuffer is used in environments where buffer() + is unavailble, such as Google App Engine. + """ + + def __init__(self, s): + self._buffer = array.array('B', s) + self._pos = 0 + + def EndOfStream(self): + return self._pos >= len(self._buffer) + + def Position(self): + return self._pos + + def GetSubBuffer(self, size=None): + if size is None: + return self._buffer[self._pos : ].tostring() + else: + if size < 0: + raise message.DecodeError('Negative size %d' % size) + return self._buffer[self._pos : self._pos + size].tostring() + + def SkipBytes(self, num_bytes): + if num_bytes < 0: + raise message.DecodeError('Negative num_bytes %d' % num_bytes) + self._pos += num_bytes + self._pos = min(self._pos, len(self._buffer)) + + def ReadBytes(self, size): + if size < 0: + raise message.DecodeError('Negative size %d' % size) + s = self._buffer[self._pos : self._pos + size].tostring() + self._pos += len(s) # Only advance by the number of bytes actually read. + return s + + def ReadLittleEndian32(self): + try: + i = struct.unpack(wire_format.FORMAT_UINT32_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 4]) + self._pos += 4 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadLittleEndian64(self): + try: + i = struct.unpack(wire_format.FORMAT_UINT64_LITTLE_ENDIAN, + self._buffer[self._pos : self._pos + 8]) + self._pos += 8 + return i[0] # unpack() result is a 1-element tuple. + except struct.error, e: + raise message.DecodeError(e) + + def ReadVarint32(self): + i = self.ReadVarint64() + if not wire_format.INT32_MIN <= i <= wire_format.INT32_MAX: + raise message.DecodeError('Value out of range for int32: %d' % i) + return int(i) + + def ReadVarUInt32(self): + i = self.ReadVarUInt64() + if i > wire_format.UINT32_MAX: + raise message.DecodeError('Value out of range for uint32: %d' % i) + return i + + def ReadVarint64(self): + i = self.ReadVarUInt64() + if i > wire_format.INT64_MAX: + i -= (1 << 64) + return i + + def ReadVarUInt64(self): + i = self._ReadVarintHelper() + if not 0 <= i <= wire_format.UINT64_MAX: + raise message.DecodeError('Value out of range for uint64: %d' % i) + return i + + def _ReadVarintHelper(self): + result = 0 + shift = 0 + while 1: + if shift >= 64: + raise message.DecodeError('Too many bytes when decoding varint.') + try: + b = self._buffer[self._pos] + except IndexError: + raise message.DecodeError('Truncated varint.') + self._pos += 1 + result |= ((b & 0x7f) << shift) + shift += 7 + if not (b & 0x80): + return result + + +try: + buffer('') + InputStream = InputStreamBuffer +except NotImplementedError: + # Google App Engine: dev_appserver.py + InputStream = InputStreamArray +except RuntimeError: + # Google App Engine: production + InputStream = InputStreamArray diff --git a/python/google/protobuf/internal/input_stream_test.py b/python/google/protobuf/internal/input_stream_test.py index 8cc1d126..ecec7f7d 100755 --- a/python/google/protobuf/internal/input_stream_test.py +++ b/python/google/protobuf/internal/input_stream_test.py @@ -40,7 +40,14 @@ from google.protobuf.internal import wire_format from google.protobuf.internal import input_stream -class InputStreamTest(unittest.TestCase): +class InputStreamBufferTest(unittest.TestCase): + + def setUp(self): + self.__original_input_stream = input_stream.InputStream + input_stream.InputStream = input_stream.InputStreamBuffer + + def tearDown(self): + input_stream.InputStream = self.__original_input_stream def testEndOfStream(self): stream = input_stream.InputStream('abcd') @@ -291,5 +298,17 @@ class InputStreamTest(unittest.TestCase): stream = input_stream.InputStream(s) self.assertRaises(message.DecodeError, stream.ReadVarUInt64) + +class InputStreamArrayTest(InputStreamBufferTest): + + def setUp(self): + # Test InputStreamArray against the same tests in InputStreamBuffer + self.__original_input_stream = input_stream.InputStream + input_stream.InputStream = input_stream.InputStreamArray + + def tearDown(self): + input_stream.InputStream = self.__original_input_stream + + if __name__ == '__main__': unittest.main() diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 1d88c1cc..e2da769a 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -1102,6 +1102,15 @@ class FullProtosEqualityTest(unittest.TestCase): test_util.SetAllFields(self.first_proto) test_util.SetAllFields(self.second_proto) + def testNoneNotEqual(self): + self.assertNotEqual(self.first_proto, None) + self.assertNotEqual(None, self.second_proto) + + def testNotEqualToOtherMessage(self): + third_proto = unittest_pb2.TestRequired() + self.assertNotEqual(self.first_proto, third_proto) + self.assertNotEqual(third_proto, self.second_proto) + def testAllFieldsFilledEquality(self): self.assertEqual(self.first_proto, self.second_proto) diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index f345067a..5ab7a1b1 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -599,6 +599,10 @@ def _AddHasExtensionMethod(cls): def _AddEqualsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def __eq__(self, other): + if (not isinstance(other, message_mod.Message) or + other.DESCRIPTOR != self.DESCRIPTOR): + return False + if self is other: return True -- cgit v1.2.3