diff options
author | Hossein Falaki <falaki@gmail.com> | 2014-01-06 12:19:08 -0800 |
---|---|---|
committer | Hossein Falaki <falaki@gmail.com> | 2014-01-06 12:19:08 -0800 |
commit | 04132ea9b20a95cd68482605d4022f692bb556e5 (patch) | |
tree | 189eb8160ce60fda452b9aea36b20074af742037 /python/pyspark | |
parent | 11a93fb5a8fafa940db27b652e4c21f6713ed8d1 (diff) | |
download | spark-04132ea9b20a95cd68482605d4022f692bb556e5.tar.gz spark-04132ea9b20a95cd68482605d4022f692bb556e5.tar.bz2 spark-04132ea9b20a95cd68482605d4022f692bb556e5.zip |
Added Rating deserializer
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/mllib/_common.py | 21 |
1 files changed, 18 insertions, 3 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index c818fc4d97..769d88dfb9 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -18,6 +18,9 @@ from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext +from pyspark.serializers import Serializer +import struct + # Double vector format: # # [8-byte 1] [8-byte length] [length*8 bytes of data] @@ -213,9 +216,21 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba -def _deserialize_rating(ba): - ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C') - return ar.copy() +class RatingDeserializer(Serializer): + def loads(self, stream): + length = struct.unpack("!i", stream.read(4))[0] + ba = stream.read(length) + res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4) + return int(res[0]), int(res[1]), res[2] + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return def _serialize_tuple(t): ba = bytearray(8) |