From 8d0c2f7399ebf7a38346a60cf84d7020c0b1dba1 Mon Sep 17 00:00:00 2001 From: Hossein Falaki Date: Sat, 4 Jan 2014 16:23:17 -0800 Subject: Added python binding for bulk recommendation --- python/pyspark/mllib/_common.py | 10 ++++++++++ python/pyspark/mllib/recommendation.py | 10 +++++++++- 2 files changed, 19 insertions(+), 1 deletion(-) (limited to 'python') diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e74ba0fabc..c818fc4d97 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -213,6 +213,16 @@ def _serialize_rating(r): intpart[0], intpart[1], doublepart[0] = r return ba +def _deserialize_rating(ba): + ar = ndarray(shape=(3, ), buffer=ba, dtype="float64", order='C') + return ar.copy() + +def _serialize_tuple(t): + ba = bytearray(8) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + intpart[0], intpart[1] = t + return ba + def _test(): import doctest globs = globals().copy() diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 14d06cba21..c81b482a87 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -20,7 +20,10 @@ from pyspark.mllib._common import \ _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ _serialize_double_matrix, _deserialize_double_matrix, \ _serialize_double_vector, _deserialize_double_vector, \ - _get_initial_weights, _serialize_rating, _regression_train_wrapper + _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ + _serialize_tuple, _deserialize_rating +from pyspark.serializers import BatchedSerializer +from pyspark.rdd import RDD class MatrixFactorizationModel(object): """A matrix factorisation model trained by regularized alternating @@ -45,6 +48,11 @@ class MatrixFactorizationModel(object): def predict(self, user, product): return self._java_model.predict(user, product) + def predictAll(self, usersProducts): + usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) + return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd), + self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize)) + class ALS(object): @classmethod def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): -- cgit v1.2.3