diff options
author | Hossein Falaki <falaki@gmail.com> | 2014-01-06 12:19:43 -0800 |
---|---|---|
committer | Hossein Falaki <falaki@gmail.com> | 2014-01-06 12:19:43 -0800 |
commit | 754f5300a1e0a214b62cbd6db2398dea4dfbceb4 (patch) | |
tree | e63bb33ef2ca98fac640bbd2f258367ec635f48b | |
parent | 04132ea9b20a95cd68482605d4022f692bb556e5 (diff) | |
download | spark-754f5300a1e0a214b62cbd6db2398dea4dfbceb4.tar.gz spark-754f5300a1e0a214b62cbd6db2398dea4dfbceb4.tar.bz2 spark-754f5300a1e0a214b62cbd6db2398dea4dfbceb4.zip |
Added predictAll python function to MatrixFactorizationModel
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index c81b482a87..0eeb5bb66b 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -21,8 +21,7 @@ from pyspark.mllib._common import \ _serialize_double_matrix, _deserialize_double_matrix, \ _serialize_double_vector, _deserialize_double_vector, \ _get_initial_weights, _serialize_rating, _regression_train_wrapper, \ - _serialize_tuple, _deserialize_rating -from pyspark.serializers import BatchedSerializer + _serialize_tuple, RatingDeserializer from pyspark.rdd import RDD class MatrixFactorizationModel(object): @@ -36,6 +35,9 @@ class MatrixFactorizationModel(object): >>> model = ALS.trainImplicit(sc, ratings, 1) >>> model.predict(2,2) is not None True + >>> testset = sc.parallelize([(1, 2), (1, 1)]) + >>> model.predictAll(testset).count == 2 + True """ def __init__(self, sc, java_model): @@ -50,8 +52,8 @@ class MatrixFactorizationModel(object): def predictAll(self, usersProducts): usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple) - return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd), - self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize)) + return RDD(self._java_model.predict(usersProductsJRDD._jrdd), + self._context, RatingDeserializer()) class ALS(object): @classmethod |