aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/ml/recommendation.py30
-rw-r--r--python/pyspark/mllib/common.py5
2 files changed, 32 insertions, 3 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index b3e0dd7abf..b06099ac0a 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
indicated user preferences rather than explicit ratings given to
items.
+ >>> df = sqlContext.createDataFrame(
+ ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
+ ... ["user", "item", "rating"])
>>> als = ALS(rank=10, maxIter=5)
>>> model = als.fit(df)
+ >>> model.rank
+ 10
+ >>> model.userFactors.orderBy("id").collect()
+ [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
>>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0]
@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
Model fitted by ALS.
"""
+ @property
+ def rank(self):
+ """rank of the matrix factorization model"""
+ return self._call_java("rank")
+
+ @property
+ def userFactors(self):
+ """
+ a DataFrame that stores user factors in two columns: `id` and
+ `features`
+ """
+ return self._call_java("userFactors")
+
+ @property
+ def itemFactors(self):
+ """
+ a DataFrame that stores item factors in two columns: `id` and
+ `features`
+ """
+ return self._call_java("itemFactors")
+
if __name__ == "__main__":
import doctest
@@ -272,8 +300,6 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
- (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index ba60589788..855e85f571 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -27,7 +27,7 @@ from py4j.java_collections import ListConverter, JavaArray, JavaList
from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-
+from pyspark.sql import DataFrame, SQLContext
# Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode
@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc)
+ if clsName == 'DataFrame':
+ return DataFrame(r, SQLContext(sc))
+
if clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):