aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorMichelangelo D'Agostino <mdagostino@civisanalytics.com>2014-11-07 22:53:01 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-07 22:53:01 -0800
commit7e9d975676d56ace0e84c2200137e4cd4eba074a (patch)
tree59f03936200f7a6a5502a5bf70e94631c03c63c7 /mllib
parent7779109796c90d789464ab0be35917f963bbe867 (diff)
downloadspark-7e9d975676d56ace0e84c2200137e4cd4eba074a.tar.gz
spark-7e9d975676d56ace0e84c2200137e4cd4eba074a.tar.bz2
spark-7e9d975676d56ace0e84c2200137e4cd4eba074a.zip
[MLLIB] [PYTHON] SPARK-4221: Expose nonnegative ALS in the python API
SPARK-1553 added alternating nonnegative least squares to MLLib, however it's not possible to access it via the python API. This pull request resolves that. Author: Michelangelo D'Agostino <mdagostino@civisanalytics.com> Closes #3095 from mdagost/python_nmf and squashes the following commits: a6743ad [Michelangelo D'Agostino] Use setters instead of static methods in PythonMLLibAPI. Remove the new static methods I added. Set seed in tests. Change ratings to ratingsRDD in both train and trainImplicit for consistency. 7cffd39 [Michelangelo D'Agostino] Swapped nonnegative and seed in a few more places. 3fdc851 [Michelangelo D'Agostino] Moved seed to the end of the python parameter list. bdcc154 [Michelangelo D'Agostino] Change seed type to java.lang.Long so that it can handle null. cedf043 [Michelangelo D'Agostino] Added in ability to set the seed from python and made that play nice with the nonnegative changes. Also made the python ALS tests more exact. a72fdc9 [Michelangelo D'Agostino] Expose nonnegative ALS in the python API.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala39
1 files changed, 33 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index d832ae34b5..70d7138e30 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -275,12 +275,25 @@ class PythonMLLibAPI extends Serializable {
* the Py4J documentation.
*/
def trainALSModel(
- ratings: JavaRDD[Rating],
+ ratingsJRDD: JavaRDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
- blocks: Int): MatrixFactorizationModel = {
- new MatrixFactorizationModelWrapper(ALS.train(ratings.rdd, rank, iterations, lambda, blocks))
+ blocks: Int,
+ nonnegative: Boolean,
+ seed: java.lang.Long): MatrixFactorizationModel = {
+
+ val als = new ALS()
+ .setRank(rank)
+ .setIterations(iterations)
+ .setLambda(lambda)
+ .setBlocks(blocks)
+ .setNonnegative(nonnegative)
+
+ if (seed != null) als.setSeed(seed)
+
+ val model = als.run(ratingsJRDD.rdd)
+ new MatrixFactorizationModelWrapper(model)
}
/**
@@ -295,9 +308,23 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int,
- alpha: Double): MatrixFactorizationModel = {
- new MatrixFactorizationModelWrapper(
- ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha))
+ alpha: Double,
+ nonnegative: Boolean,
+ seed: java.lang.Long): MatrixFactorizationModel = {
+
+ val als = new ALS()
+ .setImplicitPrefs(true)
+ .setRank(rank)
+ .setIterations(iterations)
+ .setLambda(lambda)
+ .setBlocks(blocks)
+ .setAlpha(alpha)
+ .setNonnegative(nonnegative)
+
+ if (seed != null) als.setSeed(seed)
+
+ val model = als.run(ratingsJRDD.rdd)
+ new MatrixFactorizationModelWrapper(model)
}
/**