aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala8
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index cdea90ec1a..995780bf64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -457,10 +457,12 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel]
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}
+
val instr = Instrumentation.create(this, ratings)
- instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha,
- userCol, itemCol, ratingCol, predictionCol, maxIter,
- regParam, nonnegative, checkpointInterval, seed)
+ instr.logParams(rank, numUserBlocks, numItemBlocks, implicitPrefs, alpha, userCol,
+ itemCol, ratingCol, predictionCol, maxIter, regParam, nonnegative, checkpointInterval,
+ seed, intermediateStorageLevel, finalStorageLevel)
+
val (userFactors, itemFactors) = ALS.train(ratings, rank = $(rank),
numUserBlocks = $(numUserBlocks), numItemBlocks = $(numItemBlocks),
maxIter = $(maxIter), regParam = $(regParam), implicitPrefs = $(implicitPrefs),