aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-21 16:44:52 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-21 16:44:52 -0700
commit686dd742e11f6ad0078b7ff9b30b83a118fd8109 (patch)
treed64a986edf001ddea1bc189f48ae1420edfdb145 /python/pyspark
parent7fe6142cd3c39ec79899878c3deca9d5130d05b1 (diff)
downloadspark-686dd742e11f6ad0078b7ff9b30b83a118fd8109.tar.gz
spark-686dd742e11f6ad0078b7ff9b30b83a118fd8109.tar.bz2
spark-686dd742e11f6ad0078b7ff9b30b83a118fd8109.zip
[SPARK-7036][MLLIB] ALS.train should support DataFrames in PySpark
SchemaRDD works with ALS.train in 1.2, so we should continue support DataFrames for compatibility. coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #5619 from mengxr/SPARK-7036 and squashes the following commits: dfcaf5a [Xiangrui Meng] ALS.train should support DataFrames in PySpark
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/recommendation.py36
1 files changed, 26 insertions, 10 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 80e0a356bb..4b7d17d64e 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -22,6 +22,7 @@ from pyspark import SparkContext
from pyspark.rdd import RDD
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
from pyspark.mllib.util import JavaLoader, JavaSaveable
+from pyspark.sql import DataFrame
__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
@@ -78,18 +79,23 @@ class MatrixFactorizationModel(JavaModelWrapper, JavaSaveable, JavaLoader):
True
>>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
- >>> model.predict(2,2)
+ >>> model.predict(2, 2)
+ 3.8...
+
+ >>> df = sqlContext.createDataFrame([Rating(1, 1, 1.0), Rating(1, 2, 2.0), Rating(2, 1, 2.0)])
+ >>> model = ALS.train(df, 1, nonnegative=True, seed=10)
+ >>> model.predict(2, 2)
3.8...
>>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
- >>> model.predict(2,2)
+ >>> model.predict(2, 2)
0.4...
>>> import os, tempfile
>>> path = tempfile.mkdtemp()
>>> model.save(sc, path)
>>> sameModel = MatrixFactorizationModel.load(sc, path)
- >>> sameModel.predict(2,2)
+ >>> sameModel.predict(2, 2)
0.4...
>>> sameModel.predictAll(testset).collect()
[Rating(...
@@ -125,13 +131,20 @@ class ALS(object):
@classmethod
def _prepare(cls, ratings):
- assert isinstance(ratings, RDD), "ratings should be RDD"
+ if isinstance(ratings, RDD):
+ pass
+ elif isinstance(ratings, DataFrame):
+ ratings = ratings.rdd
+ else:
+ raise TypeError("Ratings should be represented by either an RDD or a DataFrame, "
+ "but got %s." % type(ratings))
first = ratings.first()
- if not isinstance(first, Rating):
- if isinstance(first, (tuple, list)):
- ratings = ratings.map(lambda x: Rating(*x))
- else:
- raise ValueError("rating should be RDD of Rating or tuple/list")
+ if isinstance(first, Rating):
+ pass
+ elif isinstance(first, (tuple, list)):
+ ratings = ratings.map(lambda x: Rating(*x))
+ else:
+ raise TypeError("Expect a Rating or a tuple/list, but got %s." % type(first))
return ratings
@classmethod
@@ -152,8 +165,11 @@ class ALS(object):
def _test():
import doctest
import pyspark.mllib.recommendation
+ from pyspark.sql import SQLContext
globs = pyspark.mllib.recommendation.__dict__.copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest')
+ sc = SparkContext('local[4]', 'PythonTest')
+ globs['sc'] = sc
+ globs['sqlContext'] = SQLContext(sc)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count: