aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/recommendation.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-10 22:26:16 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-10 22:26:16 -0800
commit65083e93ddd552b7d3e4eb09f87c091ef2ae83a2 (patch)
tree38c82101a74f8dd099f062b574e6b7781b4801d8 /python/pyspark/mllib/recommendation.py
parent3c07b8f08240bafcdff5d174989fb433f4bc80b6 (diff)
downloadspark-65083e93ddd552b7d3e4eb09f87c091ef2ae83a2.tar.gz
spark-65083e93ddd552b7d3e4eb09f87c091ef2ae83a2.tar.bz2
spark-65083e93ddd552b7d3e4eb09f87c091ef2ae83a2.zip
[SPARK-4324] [PySpark] [MLlib] support numpy.array for all MLlib API
This PR check all of the existing Python MLlib API to make sure that numpy.array is supported as Vector (also RDD of numpy.array). It also improve some docstring and doctest. cc mateiz mengxr Author: Davies Liu <davies@databricks.com> Closes #3189 from davies/numpy and squashes the following commits: d5057c4 [Davies Liu] fix tests 6987611 [Davies Liu] support numpy.array for all MLlib API
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r--python/pyspark/mllib/recommendation.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index e26b152e0c..41bbd9a779 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -32,7 +32,7 @@ class Rating(object):
return Rating, (self.user, self.product, self.rating)
def __repr__(self):
- return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating)
+ return "Rating(%d, %d, %s)" % (self.user, self.product, self.rating)
class MatrixFactorizationModel(JavaModelWrapper):
@@ -51,7 +51,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> testset = sc.parallelize([(1, 2), (1, 1)])
>>> model = ALS.train(ratings, 1, seed=10)
>>> model.predictAll(testset).collect()
- [Rating(1, 1, 1), Rating(1, 2, 1)]
+ [Rating(1, 1, 1.0471...), Rating(1, 2, 1.9679...)]
>>> model = ALS.train(ratings, 4, seed=10)
>>> model.userFeatures().collect()
@@ -79,7 +79,7 @@ class MatrixFactorizationModel(JavaModelWrapper):
0.4473...
"""
def predict(self, user, product):
- return self._java_model.predict(user, product)
+ return self._java_model.predict(int(user), int(product))
def predictAll(self, user_product):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"