diff options
author | Michelangelo D'Agostino <mdagostino@civisanalytics.com> | 2014-11-07 22:53:01 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-07 22:53:01 -0800 |
commit | 7e9d975676d56ace0e84c2200137e4cd4eba074a (patch) | |
tree | 59f03936200f7a6a5502a5bf70e94631c03c63c7 /mllib | |
parent | 7779109796c90d789464ab0be35917f963bbe867 (diff) | |
download | spark-7e9d975676d56ace0e84c2200137e4cd4eba074a.tar.gz spark-7e9d975676d56ace0e84c2200137e4cd4eba074a.tar.bz2 spark-7e9d975676d56ace0e84c2200137e4cd4eba074a.zip |
[MLLIB] [PYTHON] SPARK-4221: Expose nonnegative ALS in the python API
SPARK-1553 added alternating nonnegative least squares to MLLib, however it's not possible to access it via the python API. This pull request resolves that.
Author: Michelangelo D'Agostino <mdagostino@civisanalytics.com>
Closes #3095 from mdagost/python_nmf and squashes the following commits:
a6743ad [Michelangelo D'Agostino] Use setters instead of static methods in PythonMLLibAPI. Remove the new static methods I added. Set seed in tests. Change ratings to ratingsRDD in both train and trainImplicit for consistency.
7cffd39 [Michelangelo D'Agostino] Swapped nonnegative and seed in a few more places.
3fdc851 [Michelangelo D'Agostino] Moved seed to the end of the python parameter list.
bdcc154 [Michelangelo D'Agostino] Change seed type to java.lang.Long so that it can handle null.
cedf043 [Michelangelo D'Agostino] Added in ability to set the seed from python and made that play nice with the nonnegative changes. Also made the python ALS tests more exact.
a72fdc9 [Michelangelo D'Agostino] Expose nonnegative ALS in the python API.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala | 39 |
1 files changed, 33 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index d832ae34b5..70d7138e30 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -275,12 +275,25 @@ class PythonMLLibAPI extends Serializable { * the Py4J documentation. */ def trainALSModel( - ratings: JavaRDD[Rating], + ratingsJRDD: JavaRDD[Rating], rank: Int, iterations: Int, lambda: Double, - blocks: Int): MatrixFactorizationModel = { - new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks)) + blocks: Int, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** @@ -295,9 +308,23 @@ class PythonMLLibAPI extends Serializable { iterations: Int, lambda: Double, blocks: Int, - alpha: Double): MatrixFactorizationModel = { - new MatrixFactorizationModelWrapper( - ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)) + alpha: Double, + nonnegative: Boolean, + seed: java.lang.Long): MatrixFactorizationModel = { + + val als = new ALS() + .setImplicitPrefs(true) + .setRank(rank) + .setIterations(iterations) + .setLambda(lambda) + .setBlocks(blocks) + .setAlpha(alpha) + .setNonnegative(nonnegative) + + if (seed != null) als.setSeed(seed) + + val model = als.run(ratingsJRDD.rdd) + new MatrixFactorizationModelWrapper(model) } /** |