aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-06 12:19:08 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-06 12:19:08 -0800
commit04132ea9b20a95cd68482605d4022f692bb556e5 (patch)
tree189eb8160ce60fda452b9aea36b20074af742037 /python
parent11a93fb5a8fafa940db27b652e4c21f6713ed8d1 (diff)
downloadspark-04132ea9b20a95cd68482605d4022f692bb556e5.tar.gz
spark-04132ea9b20a95cd68482605d4022f692bb556e5.tar.bz2
spark-04132ea9b20a95cd68482605d4022f692bb556e5.zip
Added Rating deserializer
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/_common.py21
1 files changed, 18 insertions, 3 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index c818fc4d97..769d88dfb9 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -18,6 +18,9 @@
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext
+from pyspark.serializers import Serializer
+import struct
+
# Double vector format:
#
# [8-byte 1] [8-byte length] [length*8 bytes of data]
@@ -213,9 +216,21 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba
-def _deserialize_rating(ba):
- ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
- return ar.copy()
+class RatingDeserializer(Serializer):
+ def loads(self, stream):
+ length = struct.unpack("!i", stream.read(4))[0]
+ ba = stream.read(length)
+ res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4)
+ return int(res[0]), int(res[1]), res[2]
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def _serialize_tuple(t):
ba = bytearray(8)