aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-05-08 17:24:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-08 17:24:32 -0700
commit84bf931f36edf1f319c9116f7f326959a6118991 (patch)
treeed7930e9cc5fe6855026689f5d0e4ebdfebb4832 /mllib
parent54e6fa0563ffa8788ec2fd1b8740445ef3c2ce5a (diff)
downloadspark-84bf931f36edf1f319c9116f7f326959a6118991.tar.gz
spark-84bf931f36edf1f319c9116f7f326959a6118991.tar.bz2
spark-84bf931f36edf1f319c9116f7f326959a6118991.zip
[SPARK-7488] [ML] Feature Parity in PySpark for ml.recommendation
Adds Python Api for `ALS` under `ml.recommendation` in PySpark. Also adds seed as a settable parameter in the Scala Implementation of ALS. Author: Burak Yavuz <brkyvz@gmail.com> Closes #6015 from brkyvz/ml-rec and squashes the following commits: be6e931 [Burak Yavuz] addressed comments eaed879 [Burak Yavuz] readd numFeatures 0bd66b1 [Burak Yavuz] fixed seed 7f6d964 [Burak Yavuz] merged master 52e2bda [Burak Yavuz] added ALS
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala12
1 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 6cf4b40075..d7cbffc3be 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
- with HasPredictionCol with HasCheckpointInterval {
+ with HasPredictionCol with HasCheckpointInterval with HasSeed {
/**
* Param for rank of the matrix factorization (>= 1).
@@ -147,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
- ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
+ ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L)
/**
* Validates and transforms the input schema.
@@ -278,6 +278,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
@@ -290,7 +293,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
override def fit(dataset: DataFrame): ALSModel = {
val ratings = dataset
- .select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType))
+ .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
+ col($(ratingCol)).cast(FloatType))
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
@@ -298,7 +302,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
- checkpointInterval = $(checkpointInterval))
+ checkpointInterval = $(checkpointInterval), seed = $(seed))
copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
}