aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala5
-rw-r--r--python/pyspark/mllib/recommendation.py31
2 files changed, 36 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 9a100170b7..b478c21537 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -673,6 +673,11 @@ private[spark] object SerDe extends Serializable {
rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
+ /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */
+ def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = {
+ rdd.map(x => Array(x._1, x._2))
+ }
+
/**
* Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
* PySpark.
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 17f96b8700..22872dbbe3 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -53,6 +53,23 @@ class MatrixFactorizationModel(object):
>>> model = ALS.train(ratings, 1)
>>> model.predictAll(testset).count() == 2
True
+
+ >>> model = ALS.train(ratings, 4)
+ >>> model.userFeatures().count() == 2
+ True
+
+ >>> first_user = model.userFeatures().take(1)[0]
+ >>> latents = first_user[1]
+ >>> len(latents) == 4
+ True
+
+ >>> model.productFeatures().count() == 2
+ True
+
+ >>> first_product = model.productFeatures().take(1)[0]
+ >>> latents = first_product[1]
+ >>> len(latents) == 4
+ True
"""
def __init__(self, sc, java_model):
@@ -83,6 +100,20 @@ class MatrixFactorizationModel(object):
return RDD(sc._jvm.SerDe.javaToPython(jresult), sc,
AutoBatchedSerializer(PickleSerializer()))
+ def userFeatures(self):
+ sc = self._context
+ juf = self._java_model.userFeatures()
+ juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
+ def productFeatures(self):
+ sc = self._context
+ jpf = self._java_model.productFeatures()
+ jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD()
+ return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc,
+ AutoBatchedSerializer(PickleSerializer()))
+
class ALS(object):