aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala9
-rw-r--r--python/pyspark/mllib/_common.py21
2 files changed, 26 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 2c3e828300..443fc5de5b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -67,7 +67,14 @@ class MatrixFactorizationModel(
}
}
- def predictJavaRDD(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
+ /**
+ * Predict the rating of many users for many products.
+ * This is a Java stub for python predictAll()
+ *
+ * @param usersProductsJRDD A JavaRDD with serialized tuples (user, product)
+ * @return JavaRDD of serialized Rating objects.
+ */
+ def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
val pythonAPI = new PythonMLLibAPI()
val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index c818fc4d97..769d88dfb9 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -18,6 +18,9 @@
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext
+from pyspark.serializers import Serializer
+import struct
+
# Double vector format:
#
# [8-byte 1] [8-byte length] [length*8 bytes of data]
@@ -213,9 +216,21 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba
-def _deserialize_rating(ba):
- ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C')
- return ar.copy()
+class RatingDeserializer(Serializer):
+ def loads(self, stream):
+ length = struct.unpack("!i", stream.read(4))[0]
+ ba = stream.read(length)
+ res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4)
+ return int(res[0]), int(res[1]), res[2]
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def _serialize_tuple(t):
ba = bytearray(8)