diff options
Diffstat (limited to 'python/pyspark/mllib/_common.py')
-rw-r--r-- | python/pyspark/mllib/_common.py | 25 |
1 files changed, 25 insertions, 0 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e74ba0fabc..769d88dfb9 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -18,6 +18,9 @@ from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext +from pyspark.serializers import Serializer +import struct + # Double vector format: # # [8-byte 1] [8-byte length] [length*8 bytes of data] @@ -213,6 +216,28 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba +class RatingDeserializer(Serializer): + def loads(self, stream): + length = struct.unpack("!i", stream.read(4))[0] + ba = stream.read(length) + res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4) + return int(res[0]), int(res[1]), res[2] + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return + +def _serialize_tuple(t): + ba = bytearray(8) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + intpart[0], intpart[1] = t + return ba + def _test(): import doctest globs = globals().copy() |