diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 4a3ad662a0..36dce01590 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -40,7 +40,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.CholeskyDecomposition import org.apache.spark.mllib.optimization.NNLS import org.apache.spark.rdd.RDD -import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructType} import org.apache.spark.storage.StorageLevel @@ -200,8 +200,8 @@ class ALSModel private[ml] ( @Since("1.3.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - @Since("1.3.0") - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { // Register a UDF for DataFrame, and then // create a new column named map(predictionCol) by running the predict UDF. val predict = udf { (userFeatures: Seq[Float], itemFeatures: Seq[Float]) => @@ -385,8 +385,8 @@ class ALS(@Since("1.4.0") override val uid: String) extends Estimator[ALSModel] this } - @Since("1.3.0") - override def fit(dataset: DataFrame): ALSModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): ALSModel = { import dataset.sqlContext.implicits._ val r = if ($(ratingCol) != "") col($(ratingCol)).cast(FloatType) else lit(1.0f) val ratings = dataset |