aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
committerReynold Xin <rxin@databricks.com>2015-02-13 23:03:22 -0800
commite98dfe627c5d0201464cdd0f363f391ea84c389a (patch)
tree794beea739eb04bf2e0926f9b0e19ffacb94ba08 /mllib
parent0ce4e430a81532dc317136f968f28742e087d840 (diff)
downloadspark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.gz
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.tar.bz2
spark-e98dfe627c5d0201464cdd0f363f391ea84c389a.zip
[SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
- The old implicit would convert RDDs directly to DataFrames, and that added too many methods. - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed Python changes: - toDataFrame -> toDF - Dsl -> functions package - addColumn -> withColumn - renameColumn -> withColumnRenamed - add toDF functions to RDD on SQLContext init - add flatMap to DataFrame Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4556 from rxin/SPARK-5752 and squashes the following commits: 5ef9910 [Reynold Xin] More fix 61d3fca [Reynold Xin] Merge branch 'df5' of github.com:davies/spark into SPARK-5752 ff5832c [Reynold Xin] Fix python 749c675 [Reynold Xin] count(*) fixes. 5806df0 [Reynold Xin] Fix build break again. d941f3d [Reynold Xin] Fixed explode compilation break. fe1267a [Davies Liu] flatMap c4afb8e [Reynold Xin] style d9de47f [Davies Liu] add comment b783994 [Davies Liu] add comment for toDF e2154e5 [Davies Liu] schema() -> schema 3a1004f [Davies Liu] Dsl -> functions, toDF() fb256af [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 0dd74eb [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames 97dd47c [Davies Liu] fix mistake 6168f74 [Davies Liu] fix test 1fc0199 [Davies Liu] fix test a075cd5 [Davies Liu] clean up, toPandas 663d314 [Davies Liu] add test for agg('*') 9e214d5 [Reynold Xin] count(*) fixes. 1ed7136 [Reynold Xin] Fix build break again. 921b2e3 [Reynold Xin] Fixed explode compilation break. 14698d4 [Davies Liu] flatMap ba3e12d [Reynold Xin] style d08c92d [Davies Liu] add comment 5c8b524 [Davies Liu] add comment for toDF a4e5e66 [Davies Liu] schema() -> schema d377fc9 [Davies Liu] Dsl -> functions, toDF() 6b3086c [Reynold Xin] - toDataFrame -> toDF - Dsl -> functions - implicits moved into SQLContext.implicits - addColumn -> withColumn - renameColumn -> withColumnRenamed 807e8b1 [Reynold Xin] [SPARK-5752][SQL] Don't implicitly convert RDDs directly to DataFrames
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala33
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala4
14 files changed, 43 insertions, 50 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 2ec2ccdb8c..9a5848684b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -23,7 +23,7 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
/**
@@ -100,7 +100,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
- dataset.select($"*", callUDF(
- this.createTransformFunc(map), outputDataType, dataset(map(inputCol))).as(map(outputCol)))
+ dataset.withColumn(map(outputCol),
+ callUDF(this.createTransformFunc(map), outputDataType, dataset(map(inputCol))))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 124ab30f27..c5fc89f935 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -182,24 +182,22 @@ private[ml] object ClassificationModel {
if (map(model.rawPredictionCol) != "") {
// output raw prediction
val features2raw: FeaturesType => Vector = model.predictRaw
- tmpData = tmpData.select($"*",
- callUDF(features2raw, new VectorUDT,
- col(map(model.featuresCol))).as(map(model.rawPredictionCol)))
+ tmpData = tmpData.withColumn(map(model.rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col(map(model.featuresCol))))
numColsOutput += 1
if (map(model.predictionCol) != "") {
val raw2pred: Vector => Double = (rawPred) => {
rawPred.toArray.zipWithIndex.maxBy(_._1)._2
}
- tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
- col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
+ tmpData = tmpData.withColumn(map(model.predictionCol),
+ callUDF(raw2pred, DoubleType, col(map(model.rawPredictionCol))))
numColsOutput += 1
}
} else if (map(model.predictionCol) != "") {
// output prediction
val features2pred: FeaturesType => Double = model.predict
- tmpData = tmpData.select($"*",
- callUDF(features2pred, DoubleType,
- col(map(model.featuresCol))).as(map(model.predictionCol)))
+ tmpData = tmpData.withColumn(map(model.predictionCol),
+ callUDF(features2pred, DoubleType, col(map(model.featuresCol))))
numColsOutput += 1
}
(numColsOutput, tmpData)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index a9a5af5f0f..21f61d80dd 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -22,7 +22,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
@@ -130,44 +130,39 @@ class LogisticRegressionModel private[ml] (
var numColsOutput = 0
if (map(rawPredictionCol) != "") {
val features2raw: Vector => Vector = (features) => predictRaw(features)
- tmpData = tmpData.select($"*",
- callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+ tmpData = tmpData.withColumn(map(rawPredictionCol),
+ callUDF(features2raw, new VectorUDT, col(map(featuresCol))))
numColsOutput += 1
}
if (map(probabilityCol) != "") {
if (map(rawPredictionCol) != "") {
- val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+ val raw2prob = udf { (rawPreds: Vector) =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
- Vectors.dense(1.0 - prob1, prob1)
+ Vectors.dense(1.0 - prob1, prob1): Vector
}
- tmpData = tmpData.select($"*",
- callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+ tmpData = tmpData.withColumn(map(probabilityCol), raw2prob(col(map(rawPredictionCol))))
} else {
- val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
- tmpData = tmpData.select($"*",
- callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ val features2prob = udf { (features: Vector) => predictProbabilities(features) : Vector }
+ tmpData = tmpData.withColumn(map(probabilityCol), features2prob(col(map(featuresCol))))
}
numColsOutput += 1
}
if (map(predictionCol) != "") {
val t = map(threshold)
if (map(probabilityCol) != "") {
- val predict: Vector => Double = { probs: Vector =>
+ val predict = udf { probs: Vector =>
if (probs(1) > t) 1.0 else 0.0
}
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(probabilityCol))))
} else if (map(rawPredictionCol) != "") {
- val predict: Vector => Double = { rawPreds: Vector =>
+ val predict = udf { rawPreds: Vector =>
val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
if (prob1 > t) 1.0 else 0.0
}
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(rawPredictionCol))))
} else {
- val predict: Vector => Double = (features: Vector) => this.predict(features)
- tmpData = tmpData.select($"*",
- callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ val predict = udf { features: Vector => this.predict(features) }
+ tmpData = tmpData.withColumn(map(predictionCol), predict(col(map(featuresCol))))
}
numColsOutput += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 38518785dc..bd8caac855 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, StructType}
@@ -122,8 +122,8 @@ private[spark] abstract class ProbabilisticClassificationModel[
val features2probs: FeaturesType => Vector = (features) => {
tmpModel.predictProbabilities(features)
}
- outputData.select($"*",
- callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ outputData.withColumn(map(probabilityCol),
+ callUDF(features2probs, new VectorUDT, col(map(featuresCol))))
} else {
if (numColsOutput == 0) {
this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 7623ec59ae..ddbd648d64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{StructField, StructType}
/**
@@ -88,7 +88,7 @@ class StandardScalerModel private[ml] (
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val scale = udf((v: Vector) => { scaler.transform(v) } : Vector)
- dataset.select($"*", scale(col(map(inputCol))).as(map(outputCol)))
+ dataset.withColumn(map(outputCol), scale(col(map(inputCol))))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
index e416c1eb58..7daeff980f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
@@ -216,7 +216,7 @@ private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel
val pred: FeaturesType => Double = (features) => {
tmpModel.predict(features)
}
- dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ dataset.withColumn(map(predictionCol), callUDF(pred, DoubleType, col(map(featuresCol))))
} else {
this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
" since no output columns were set.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index aac487745f..8d70e4347c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -36,7 +36,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -170,8 +170,8 @@ class ALSModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext.implicits._
val map = this.paramMap ++ paramMap
- val users = userFactors.toDataFrame("id", "features")
- val items = itemFactors.toDataFrame("id", "features")
+ val users = userFactors.toDF("id", "features")
+ val items = itemFactors.toDF("id", "features")
// Register a UDF for DataFrame, and then
// create a new column named map(predictionCol) by running the predict UDF.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index f9142bc226..dd7a9469d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -102,7 +102,7 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
// Create Parquet data.
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF
dataRDD.saveAsParquetFile(dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
index 1d118963b4..0a358f2e4f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/impl/GLMClassificationModel.scala
@@ -62,7 +62,7 @@ private[classification] object GLMClassificationModel {
// Create Parquet data.
val data = Data(weights, intercept, threshold)
- sc.parallelize(Seq(data), 1).saveAsParquetFile(Loader.dataPath(path))
+ sc.parallelize(Seq(data), 1).toDF.saveAsParquetFile(Loader.dataPath(path))
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index a3a3b5d418..c399496568 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -187,8 +187,8 @@ object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
val metadata = compact(render(
("class" -> thisClassName) ~ ("version" -> thisFormatVersion) ~ ("rank" -> model.rank)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(metadataPath(path))
- model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
- model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
+ model.userFeatures.toDF("id", "features").saveAsParquetFile(userPath(path))
+ model.productFeatures.toDF("id", "features").saveAsParquetFile(productPath(path))
}
def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
index f75de6f637..7b27aaa322 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/impl/GLMRegressionModel.scala
@@ -58,7 +58,7 @@ private[regression] object GLMRegressionModel {
// Create Parquet data.
val data = Data(weights, intercept)
- val dataRDD: DataFrame = sc.parallelize(Seq(data), 1)
+ val dataRDD: DataFrame = sc.parallelize(Seq(data), 1).toDF
// TODO: repartition with 1 partition after SPARK-5532 gets fixed
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
index 373192a20c..5dac62b0c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/DecisionTreeModel.scala
@@ -197,7 +197,7 @@ object DecisionTreeModel extends Loader[DecisionTreeModel] {
val nodes = model.topNode.subtreeIterator.toSeq
val dataRDD: DataFrame = sc.parallelize(nodes)
.map(NodeData.apply(0, _))
- .toDataFrame
+ .toDF
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
index dbd69dca60..e507f247cc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/treeEnsembleModels.scala
@@ -289,7 +289,7 @@ private[tree] object TreeEnsembleModel {
// Create Parquet data.
val dataRDD = sc.parallelize(model.trees.zipWithIndex).flatMap { case (tree, treeId) =>
tree.topNode.subtreeIterator.toSeq.map(node => NodeData(treeId, node))
- }.toDataFrame
+ }.toDF
dataRDD.saveAsParquetFile(Loader.dataPath(path))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index cb7d57de35..b118a8dcf1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -358,8 +358,8 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
.setNumUserBlocks(numUserBlocks)
.setNumItemBlocks(numItemBlocks)
val alpha = als.getAlpha
- val model = als.fit(training)
- val predictions = model.transform(test)
+ val model = als.fit(training.toDF)
+ val predictions = model.transform(test.toDF)
.select("rating", "prediction")
.map { case Row(rating: Float, prediction: Float) =>
(rating.toDouble, prediction.toDouble)