diff options
author | Burak Yavuz <brkyvz@gmail.com> | 2015-05-08 17:24:32 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-08 17:24:39 -0700 |
commit | 85cab34828b92eba51f8ca43e5cc4ba87752b169 (patch) | |
tree | f13c0cba1707581ca602788b6943f0d31f5dc6b1 /mllib/src/main | |
parent | 45b62151da13c165a82d64c24bf69f242690bc5d (diff) | |
download | spark-85cab34828b92eba51f8ca43e5cc4ba87752b169.tar.gz spark-85cab34828b92eba51f8ca43e5cc4ba87752b169.tar.bz2 spark-85cab34828b92eba51f8ca43e5cc4ba87752b169.zip |
[SPARK-7488] [ML] Feature Parity in PySpark for ml.recommendation
Adds Python Api for `ALS` under `ml.recommendation` in PySpark. Also adds seed as a settable parameter in the Scala Implementation of ALS.
Author: Burak Yavuz <brkyvz@gmail.com>
Closes #6015 from brkyvz/ml-rec and squashes the following commits:
be6e931 [Burak Yavuz] addressed comments
eaed879 [Burak Yavuz] readd numFeatures
0bd66b1 [Burak Yavuz] fixed seed
7f6d964 [Burak Yavuz] merged master
52e2bda [Burak Yavuz] added ALS
(cherry picked from commit 84bf931f36edf1f319c9116f7f326959a6118991)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala | 12 |
1 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 6cf4b40075..d7cbffc3be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom * Common params for ALS. */ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam - with HasPredictionCol with HasCheckpointInterval { + with HasPredictionCol with HasCheckpointInterval with HasSeed { /** * Param for rank of the matrix factorization (>= 1). @@ -147,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10, implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item", - ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10) + ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L) /** * Validates and transforms the input schema. @@ -278,6 +278,9 @@ class ALS extends Estimator[ALSModel] with ALSParams { /** @group setParam */ def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + /** * Sets both numUserBlocks and numItemBlocks to the specific value. * @group setParam @@ -290,7 +293,8 @@ class ALS extends Estimator[ALSModel] with ALSParams { override def fit(dataset: DataFrame): ALSModel = { val ratings = dataset - .select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType)) + .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), + col($(ratingCol)).cast(FloatType)) .map { row => Rating(row.getInt(0), row.getInt(1), row.getFloat(2)) } @@ -298,7 +302,7 @@ class ALS extends Estimator[ALSModel] with ALSParams { numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks), maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs), alpha = $(alpha), nonnegative = $(nonnegative), - checkpointInterval = $(checkpointInterval)) + checkpointInterval = $(checkpointInterval), seed = $(seed)) copyValues(new ALSModel(this, $(rank), userFactors, itemFactors)) } |