aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/_common.py25
-rw-r--r--python/pyspark/mllib/recommendation.py12
2 files changed, 36 insertions, 1 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index e74ba0fabc..769d88dfb9 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -18,6 +18,9 @@
from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext
+from pyspark.serializers import Serializer
+import struct
+
# Double vector format:
#
# [8-byte 1] [8-byte length] [length*8 bytes of data]
@@ -213,6 +216,28 @@ def _serialize_rating(r):
intpart[0], intpart[1], doublepart[0] = r
return ba
+class RatingDeserializer(Serializer):
+ def loads(self, stream):
+ length = struct.unpack("!i", stream.read(4))[0]
+ ba = stream.read(length)
+ res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4)
+ return int(res[0]), int(res[1]), res[2]
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
+
+def _serialize_tuple(t):
+ ba = bytearray(8)
+ intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+ intpart[0], intpart[1] = t
+ return ba
+
def _test():
import doctest
globs = globals().copy()
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 14d06cba21..0eeb5bb66b 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -20,7 +20,9 @@ from pyspark.mllib._common import \
_get_unmangled_rdd, _get_unmangled_double_vector_rdd, \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
- _get_initial_weights, _serialize_rating, _regression_train_wrapper
+ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \
+ _serialize_tuple, RatingDeserializer
+from pyspark.rdd import RDD
class MatrixFactorizationModel(object):
"""A matrix factorisation model trained by regularized alternating
@@ -33,6 +35,9 @@ class MatrixFactorizationModel(object):
>>> model = ALS.trainImplicit(sc, ratings, 1)
>>> model.predict(2,2) is not None
True
+ >>> testset = sc.parallelize([(1, 2), (1, 1)])
+ >>> model.predictAll(testset).count == 2
+ True
"""
def __init__(self, sc, java_model):
@@ -45,6 +50,11 @@ class MatrixFactorizationModel(object):
def predict(self, user, product):
return self._java_model.predict(user, product)
+ def predictAll(self, usersProducts):
+ usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
+ return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
+ self._context, RatingDeserializer())
+
class ALS(object):
@classmethod
def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):