aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-29 23:58:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-29 23:58:32 -0700
commit2931e89f0c54248d87f1f84c81137a5a91e142e9 (patch)
tree3fc009b833c65073d2e9c74247a4809a539cdeba /mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
parent4d5a005b0d2591d4d57a19be48c4954b9f1434a9 (diff)
downloadspark-2931e89f0c54248d87f1f84c81137a5a91e142e9.tar.gz
spark-2931e89f0c54248d87f1f84c81137a5a91e142e9.tar.bz2
spark-2931e89f0c54248d87f1f84c81137a5a91e142e9.zip
[SPARK-10736] [ML] Use 1 for all ratings if $(ratingCol) = ""
For some implicit dataset, ratings may not exist in the training data. In this case, we can assume all observed pairs to be positive and treat their ratings as 1. This should happen when users set ```ratingCol``` to an empty string. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8937 from yanboliang/spark-10736.
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala4
1 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 9a56a75b69..f6f5281f71 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -315,9 +315,9 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
override def fit(dataset: DataFrame): ALSModel = {
import dataset.sqlContext.implicits._
+ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f)
val ratings = dataset
- .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType),
- col($(ratingCol)).cast(FloatType))
+ .select(col($(userCol)).cast(IntegerType), col($(itemCol)).cast(IntegerType), r)
.map { row =>
Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
}