diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 5 | ||||
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 31 |
2 files changed, 36 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index 9a100170b7..b478c21537 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -673,6 +673,11 @@ private[spark] object SerDe extends Serializable { rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int])) } + /* convert RDD[Tuple2[,]] to RDD[Array[Any]] */ + def fromTuple2RDD(rdd: RDD[Tuple2[Any, Any]]): RDD[Array[Any]] = { + rdd.map(x => Array(x._1, x._2)) + } + /** * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by * PySpark. diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 17f96b8700..22872dbbe3 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -53,6 +53,23 @@ class MatrixFactorizationModel(object): >>> model = ALS.train(ratings, 1) >>> model.predictAll(testset).count() == 2 True + + >>> model = ALS.train(ratings, 4) + >>> model.userFeatures().count() == 2 + True + + >>> first_user = model.userFeatures().take(1)[0] + >>> latents = first_user[1] + >>> len(latents) == 4 + True + + >>> model.productFeatures().count() == 2 + True + + >>> first_product = model.productFeatures().take(1)[0] + >>> latents = first_product[1] + >>> len(latents) == 4 + True """ def __init__(self, sc, java_model): @@ -83,6 +100,20 @@ class MatrixFactorizationModel(object): return RDD(sc._jvm.SerDe.javaToPython(jresult), sc, AutoBatchedSerializer(PickleSerializer())) + def userFeatures(self): + sc = self._context + juf = self._java_model.userFeatures() + juf = sc._jvm.SerDe.fromTuple2RDD(juf).toJavaRDD() + return RDD(sc._jvm.PythonRDD.javaToPython(juf), sc, + AutoBatchedSerializer(PickleSerializer())) + + def productFeatures(self): + sc = self._context + jpf = self._java_model.productFeatures() + jpf = sc._jvm.SerDe.fromTuple2RDD(jpf).toJavaRDD() + return RDD(sc._jvm.PythonRDD.javaToPython(jpf), sc, + AutoBatchedSerializer(PickleSerializer())) + class ALS(object): |