aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMichelangelo D'Agostino <mdagostino@civisanalytics.com>2014-10-21 11:49:39 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-21 11:49:39 -0700
commit1a623b2e163da3a9112cb9b68bda22b9e398ed5c (patch)
tree41b232df4b481db742a24928a45e43a6deccaad9 /python
parent61ca7742d21dd66f5a7b3bb826e3aaca6f049b68 (diff)
downloadspark-1a623b2e163da3a9112cb9b68bda22b9e398ed5c.tar.gz
spark-1a623b2e163da3a9112cb9b68bda22b9e398ed5c.tar.bz2
spark-1a623b2e163da3a9112cb9b68bda22b9e398ed5c.zip
SPARK-3770: Make userFeatures accessible from python
https://issues.apache.org/jira/browse/SPARK-3770 We need access to the underlying latent user features from python. However, the userFeatures RDD from the MatrixFactorizationModel isn't accessible from the python bindings. I've added a method to the underlying scala class to turn the RDD[(Int, Array[Double])] to an RDD[String]. This is then accessed from the python recommendation.py Author: Michelangelo D'Agostino <mdagostino@civisanalytics.com> Closes #2636 from mdagost/mf_user_features and squashes the following commits: c98f9e2 [Michelangelo D'Agostino] Added unit tests for userFeatures and productFeatures and merged master. d5eadf8 [Michelangelo D'Agostino] Merge branch 'master' into mf_user_features 2481a2a [Michelangelo D'Agostino] Merged master and resolved conflict. a6ffb96 [Michelangelo D'Agostino] Eliminated a function from our first approach to this problem that is no longer needed now that we added the fromTuple2RDD function. 2aa1bf8 [Michelangelo D'Agostino] Implemented a function called fromTuple2RDD in PythonMLLibAPI and used it to expose the MF userFeatures and productFeatures in python. 34cb2a2 [Michelangelo D'Agostino] A couple of lint cleanups and a comment. cdd98e3 [Michelangelo D'Agostino] It's working now. e1fbe5e [Michelangelo D'Agostino] Added scala function to stringify userFeatures for access in python.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/recommendation.py31
1 files changed, 31 insertions, 0 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 17f96b8700..22872dbbe3 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -53,6 +53,23 @@ class MatrixFactorizationModel(object):
>>> model = ALS.train(ratings, 1)
>>> model.predictAll(testset).count() == 2
True
+
+ >>> model = ALS.train(ratings, 4)
+ >>> model.userFeatures().count() == 2
+ True
+
+ >>> first_user = model.userFeatures().take(1)[0]
+ >>> latents = first_user[1]
+ >>> len(latents) == 4
+ True
+
+ >>> model.productFeatures().count() == 2
+ True
+
+ >>> first_product = model.productFeatures().take(1)[0]
+ >>> latents = first_product[1]
+ >>> len(latents) == 4
+ True
"""
def __init__(self, sc, java_model):
@@ -83,6 +100,20 @@ class MatrixFactorizationModel(object):
return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
AutoBatchedSerializer(PickleSerializer()))
+ def userFeatures(self):
+ sc = self._context
+ juf = self._java_model.userFeatures()
+ juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
+ def productFeatures(self):
+ sc = self._context
+ jpf = self._java_model.productFeatures()
+ jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
class ALS(object):