diff options
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r-- | python/pyspark/mllib/_common.py | 25 | ||||
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 12 |
2 files changed, 36 insertions, 1 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e74ba0fabc..769d88dfb9 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -18,6 +18,9 @@ from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext +from pyspark.serializers import Serializer +import struct + # Double vector format: # # [8-byte 1] [8-byte length] [length*8 bytes of data] @@ -213,6 +216,28 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba +class RatingDeserializer(Serializer): + def loads(self, stream): + length = struct.unpack("!i", stream.read(4))[0] + ba = stream.read(length) + res = ndarray(shape=(3, ), buffer=ba, dtype="float64", offset=4) + return int(res[0]), int(res[1]), res[2] + + def load_stream(self, stream): + while True: + try: + yield self.loads(stream) + except struct.error: + return + except EOFError: + return + +def _serialize_tuple(t): + ba = bytearray(8) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + intpart[0], intpart[1] = t + return ba + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 14d06cba21..0eeb5bb66b 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -20,7 +20,9 @@ from pyspark.mllib._common import \ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ _serialize_double_matrix, _deserialize_double_matrix, \ _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + _serialize_tuple, RatingDeserializer +from pyspark.rdd import RDD class MatrixFactorizationModel(object): """A matrix factorisation model trained by regularized alternating @@ -33,6 +35,9 @@ class MatrixFactorizationModel(object): >>> model = ALS.trainImplicit(sc, ratings, 1) >>> model.predict(2,2) is not None True + >>> testset = sc.parallelize([(1, 2), (1, 1)]) + >>> model.predictAll(testset).count == 2 + True """ def __init__(self, sc, java_model): @@ -45,6 +50,11 @@ class MatrixFactorizationModel(object): def predict(self, user, product): return self._java_model.predict(user, product) + def predictAll(self, usersProducts): + usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) + return RDD(self._java_model.predict(usersProductsJRDD._jrdd), + self._context, RatingDeserializer()) + class ALS(object): @classmethod def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): |