aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-28 22:38:38 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-28 22:38:38 -0700
commitdb9513789756da4f16bb1fe8cf1d19500f231f54 (patch)
treeaaef83386cdad3975181b554d68527abf41407cb /python
parentcd3d9a5c0c3e77098a72c85dffe4a27737009ae7 (diff)
downloadspark-db9513789756da4f16bb1fe8cf1d19500f231f54.tar.gz
spark-db9513789756da4f16bb1fe8cf1d19500f231f54.tar.bz2
spark-db9513789756da4f16bb1fe8cf1d19500f231f54.zip
[SPARK-7922] [MLLIB] use DataFrames for user/item factors in ALSModel
Expose user/item factors in DataFrames. This is to be more consistent with the pipeline API. It also helps maintain consistent APIs across languages. This PR also removed fitting params from `ALSModel`. coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #6468 from mengxr/SPARK-7922 and squashes the following commits: 7bfb1d5 [Xiangrui Meng] update ALSModel in PySpark 1ba5607 [Xiangrui Meng] use DataFrames for user/item factors in ALS
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/recommendation.py30
-rw-r--r--python/pyspark/mllib/common.py5
2 files changed, 32 insertions, 3 deletions
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index b3e0dd7abf..b06099ac0a 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -63,8 +63,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
indicated user preferences rather than explicit ratings given to
items.
+ >>> df = sqlContext.createDataFrame(
+ ... [(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0), (2, 1, 1.0), (2, 2, 5.0)],
+ ... ["user", "item", "rating"])
>>> als = ALS(rank=10, maxIter=5)
>>> model = als.fit(df)
+ >>> model.rank
+ 10
+ >>> model.userFactors.orderBy("id").collect()
+ [Row(id=0, features=[...]), Row(id=1, ...), Row(id=2, ...)]
>>> test = sqlContext.createDataFrame([(0, 2), (1, 0), (2, 0)], ["user", "item"])
>>> predictions = sorted(model.transform(test).collect(), key=lambda r: r[0])
>>> predictions[0]
@@ -260,6 +267,27 @@ class ALSModel(JavaModel):
Model fitted by ALS.
"""
+ @property
+ def rank(self):
+ """rank of the matrix factorization model"""
+ return self._call_java("rank")
+
+ @property
+ def userFactors(self):
+ """
+ a DataFrame that stores user factors in two columns: `id` and
+ `features`
+ """
+ return self._call_java("userFactors")
+
+ @property
+ def itemFactors(self):
+ """
+ a DataFrame that stores item factors in two columns: `id` and
+ `features`
+ """
+ return self._call_java("itemFactors")
+
if __name__ == "__main__":
import doctest
@@ -272,8 +300,6 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- globs['df'] = sqlContext.createDataFrame([(0, 0, 4.0), (0, 1, 2.0), (1, 1, 3.0), (1, 2, 4.0),
- (2, 1, 1.0), (2, 2, 5.0)], ["user", "item", "rating"])
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index ba60589788..855e85f571 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -27,7 +27,7 @@ from py4j.java_collections import ListConverter, JavaArray, JavaList
from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
-
+from pyspark.sql import DataFrame, SQLContext
# Hack for support float('inf') in Py4j
_old_smart_decode = py4j.protocol.smart_decode
@@ -99,6 +99,9 @@ def _java2py(sc, r, encoding="bytes"):
jrdd = sc._jvm.SerDe.javaToPython(r)
return RDD(jrdd, sc)
+ if clsName == 'DataFrame':
+ return DataFrame(r, SQLContext(sc))
+
if clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r)
elif isinstance(r, (JavaArray, JavaList)):