aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala12
1 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 6cf4b40075..d7cbffc3be 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -49,7 +49,7 @@ import org.apache.spark.util.random.XORShiftRandom
* Common params for ALS.
*/
private[recommendation] trait ALSParams extends Params with HasMaxIter with HasRegParam
- with HasPredictionCol with HasCheckpointInterval {
+ with HasPredictionCol with HasCheckpointInterval with HasSeed {
/**
* Param for rank of the matrix factorization (>= 1).
@@ -147,7 +147,7 @@ private[recommendation] trait ALSParams extends Params with HasMaxIter with HasR
setDefault(rank -> 10, maxIter -> 10, regParam -> 0.1, numUserBlocks -> 10, numItemBlocks -> 10,
implicitPrefs -> false, alpha -> 1.0, userCol -> "user", itemCol -> "item",
- ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10)
+ ratingCol -> "rating", nonnegative -> false, checkpointInterval -> 10, seed -> 0L)
/**
* Validates and transforms the input schema.
@@ -278,6 +278,9 @@ class ALS extends Estimator[ALSModel] with ALSParams {
/** @group setParam */
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
/**
* Sets both numUserBlocks and numItemBlocks to the specific value.
* @group setParam
@@ -290,7 +293,8 @@ class ALS extends Estimator[ALSModel] with ALSParams {
override def fit(dataset: DataFrame): ALSModel = {
val ratings = dataset
- .select(col($(userCol)), col($(itemCol)), col($(ratingCol)).cast(FloatType))
+ .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
+ col($(ratingCol)).cast(FloatType))
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
@@ -298,7 +302,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),
alpha = $(alpha), nonnegative = $(nonnegative),
- checkpointInterval = $(checkpointInterval))
+ checkpointInterval = $(checkpointInterval), seed = $(seed))
copyValues(new ALSModel(this, $(rank), userFactors, itemFactors))
}