aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHossein Falaki <falaki@gmail.com>2014-01-06 12:19:43 -0800
committerHossein Falaki <falaki@gmail.com>2014-01-06 12:19:43 -0800
commit754f5300a1e0a214b62cbd6db2398dea4dfbceb4 (patch)
treee63bb33ef2ca98fac640bbd2f258367ec635f48b /python
parent04132ea9b20a95cd68482605d4022f692bb556e5 (diff)
downloadspark-754f5300a1e0a214b62cbd6db2398dea4dfbceb4.tar.gz
spark-754f5300a1e0a214b62cbd6db2398dea4dfbceb4.tar.bz2
spark-754f5300a1e0a214b62cbd6db2398dea4dfbceb4.zip
Added predictAll python function to MatrixFactorizationModel
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/recommendation.py10
1 files changed, 6 insertions, 4 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index c81b482a87..0eeb5bb66b 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -21,8 +21,7 @@ from pyspark.mllib._common import \
_serialize_double_matrix, _deserialize_double_matrix, \
_serialize_double_vector, _deserialize_double_vector, \
_get_initial_weights, _serialize_rating, _regression_train_wrapper, \
- _serialize_tuple, _deserialize_rating
-from pyspark.serializers import BatchedSerializer
+ _serialize_tuple, RatingDeserializer
from pyspark.rdd import RDD
class MatrixFactorizationModel(object):
@@ -36,6 +35,9 @@ class MatrixFactorizationModel(object):
>>> model = ALS.trainImplicit(sc, ratings, 1)
>>> model.predict(2,2) is not None
True
+ >>> testset = sc.parallelize([(1, 2), (1, 1)])
+ >>> model.predictAll(testset).count == 2
+ True
"""
def __init__(self, sc, java_model):
@@ -50,8 +52,8 @@ class MatrixFactorizationModel(object):
def predictAll(self, usersProducts):
usersProductsJRDD = _get_unmangled_rdd(usersProducts, _serialize_tuple)
- return RDD(self._java_model.predictJavaRDD(usersProductsJRDD._jrdd),
- self._context, BatchedSerializer(_deserialize_rating, self._context._batchSize))
+ return RDD(self._java_model.predict(usersProductsJRDD._jrdd),
+ self._context, RatingDeserializer())
class ALS(object):
@classmethod