aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
committerReynold Xin <rxin@databricks.com>2015-01-27 16:08:24 -0800
commit119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8 (patch)
tree714df6362313e93bee0e9dba2f84b3ba1697e555 /mllib
parentb1b35ca2e440df40b253bf967bb93705d355c1c0 (diff)
downloadspark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.gz
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.tar.bz2
spark-119f45d61d7b48d376cca05e1b4f0c7fcf65bfa8.zip
[SPARK-5097][SQL] DataFrame
This pull request redesigns the existing Spark SQL dsl, which already provides data frame like functionalities. TODOs: With the exception of Python support, other tasks can be done in separate, follow-up PRs. - [ ] Audit of the API - [ ] Documentation - [ ] More test cases to cover the new API - [x] Python support - [ ] Type alias SchemaRDD Author: Reynold Xin <rxin@databricks.com> Author: Davies Liu <davies@databricks.com> Closes #4173 from rxin/df1 and squashes the following commits: 0a1a73b [Reynold Xin] Merge branch 'df1' of github.com:rxin/spark into df1 23b4427 [Reynold Xin] Mima. 828f70d [Reynold Xin] Merge pull request #7 from davies/df 257b9e6 [Davies Liu] add repartition 6bf2b73 [Davies Liu] fix collect with UDT and tests e971078 [Reynold Xin] Missing quotes. b9306b4 [Reynold Xin] Remove removeColumn/updateColumn for now. a728bf2 [Reynold Xin] Example rename. e8aa3d3 [Reynold Xin] groupby -> groupBy. 9662c9e [Davies Liu] improve DataFrame Python API 4ae51ea [Davies Liu] python API for dataframe 1e5e454 [Reynold Xin] Fixed a bug with symbol conversion. 2ca74db [Reynold Xin] Couple minor fixes. ea98ea1 [Reynold Xin] Documentation & literal expressions. 2b22684 [Reynold Xin] Got rid of IntelliJ problems. 02bbfbc [Reynold Xin] Tightening imports. ffbce66 [Reynold Xin] Fixed compilation error. 59b6d8b [Reynold Xin] Style violation. b85edfb [Reynold Xin] ALS. 8c37f0a [Reynold Xin] Made MLlib and examples compile 6d53134 [Reynold Xin] Hive module. d35efd5 [Reynold Xin] Fixed compilation error. ce4a5d2 [Reynold Xin] Fixed test cases in SQL except ParquetIOSuite. 66d5ef1 [Reynold Xin] SQLContext minor patch. c9bcdc0 [Reynold Xin] Checkpoint: SQL module compiles!
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala37
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala4
16 files changed, 77 insertions, 95 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index 77d230eb4a..bc3defe968 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -21,7 +21,7 @@ import scala.annotation.varargs
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
/**
* :: AlphaComponent ::
@@ -38,7 +38,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @return fitted model
*/
@varargs
- def fit(dataset: SchemaRDD, paramPairs: ParamPair[_]*): M = {
+ def fit(dataset: DataFrame, paramPairs: ParamPair[_]*): M = {
val map = new ParamMap().put(paramPairs: _*)
fit(dataset, map)
}
@@ -50,7 +50,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @param paramMap parameter map
* @return fitted model
*/
- def fit(dataset: SchemaRDD, paramMap: ParamMap): M
+ def fit(dataset: DataFrame, paramMap: ParamMap): M
/**
* Fits multiple models to the input data with multiple sets of parameters.
@@ -61,7 +61,7 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* @param paramMaps an array of parameter maps
* @return fitted models, matching the input parameter maps
*/
- def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
+ def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
index db563dd550..d2ca2e6871 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Evaluator.scala
@@ -19,7 +19,7 @@ package org.apache.spark.ml
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
/**
* :: AlphaComponent ::
@@ -35,5 +35,5 @@ abstract class Evaluator extends Identifiable {
* @param paramMap parameter map that specifies the input columns and output metrics
* @return metric
*/
- def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double
+ def evaluate(dataset: DataFrame, paramMap: ParamMap): Double
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index ad6fed178f..fe39cd1bc0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -22,7 +22,7 @@ import scala.collection.mutable.ListBuffer
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{Param, ParamMap}
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
@@ -88,7 +88,7 @@ class Pipeline extends Estimator[PipelineModel] {
* @param paramMap parameter map
* @return fitted pipeline
*/
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): PipelineModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
val map = this.paramMap ++ paramMap
val theStages = map(stages)
@@ -162,7 +162,7 @@ class PipelineModel private[ml] (
}
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
val map = (fittingParamMap ++ this.paramMap) ++ paramMap
transformSchema(dataset.schema, map, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index af56f9c435..b233bff083 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -22,9 +22,9 @@ import scala.annotation.varargs
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.expressions.ScalaUdf
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql._
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types._
/**
@@ -41,7 +41,7 @@ abstract class Transformer extends PipelineStage with Params {
* @return transformed dataset
*/
@varargs
- def transform(dataset: SchemaRDD, paramPairs: ParamPair[_]*): SchemaRDD = {
+ def transform(dataset: DataFrame, paramPairs: ParamPair[_]*): DataFrame = {
val map = new ParamMap()
paramPairs.foreach(map.put(_))
transform(dataset, map)
@@ -53,7 +53,7 @@ abstract class Transformer extends PipelineStage with Params {
* @param paramMap additional parameters, overwrite embedded params
* @return transformed dataset
*/
- def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
+ def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame
}
/**
@@ -95,11 +95,10 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, O
StructType(outputFields)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
- dataset.select(Star(None), udf as map(outputCol))
+ dataset.select($"*", callUDF(
+ this.createTransformFunc(map), outputDataType, Column(map(inputCol))).as(map(outputCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 8c570812f8..eeb6301c3f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,7 +24,7 @@ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.storage.StorageLevel
@@ -87,11 +87,10 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): LogisticRegressionModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val instances = dataset.select(map(labelCol).attr, map(featuresCol).attr)
+ val instances = dataset.select(map(labelCol), map(featuresCol))
.map { case Row(label: Double, features: Vector) =>
LabeledPoint(label, features)
}.persist(StorageLevel.MEMORY_AND_DISK)
@@ -131,9 +130,8 @@ class LogisticRegressionModel private[ml] (
validateAndTransformSchema(schema, paramMap, fitting = false)
}
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val score: Vector => Double = (v) => {
val margin = BLAS.dot(v, weights)
@@ -143,7 +141,7 @@ class LogisticRegressionModel private[ml] (
val predict: Double => Double = (score) => {
if (score > t) 1.0 else 0.0
}
- dataset.select(Star(None), score.call(map(featuresCol).attr) as map(scoreCol))
- .select(Star(None), predict.call(map(scoreCol).attr) as map(predictionCol))
+ dataset.select($"*", callUDF(score, Column(map(featuresCol))).as(map(scoreCol)))
+ .select($"*", callUDF(predict, Column(map(scoreCol))).as(map(predictionCol)))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 12473cb2b5..1979ab9eb6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -21,7 +21,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
-import org.apache.spark.sql.{Row, SchemaRDD}
+import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
/**
@@ -41,7 +41,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
def setScoreCol(value: String): this.type = set(scoreCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
- override def evaluate(dataset: SchemaRDD, paramMap: ParamMap): Double = {
+ override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
@@ -52,8 +52,7 @@ class BinaryClassificationEvaluator extends Evaluator with Params
require(labelType == DoubleType,
s"Label column ${map(labelCol)} must be double type but found $labelType")
- import dataset.sqlContext._
- val scoreAndLabels = dataset.select(map(scoreCol).attr, map(labelCol).attr)
+ val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol))
.map { case Row(score: Double, label: Double) =>
(score, label)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 72825f6e02..e7bdb070c8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -23,7 +23,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.mllib.feature
import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql._
-import org.apache.spark.sql.catalyst.analysis.Star
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.catalyst.dsl._
import org.apache.spark.sql.types.{StructField, StructType}
@@ -43,14 +43,10 @@ class StandardScaler extends Estimator[StandardScalerModel] with StandardScalerP
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): StandardScalerModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): StandardScalerModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val input = dataset.select(map(inputCol).attr)
- .map { case Row(v: Vector) =>
- v
- }
+ val input = dataset.select(map(inputCol)).map { case Row(v: Vector) => v }
val scaler = new feature.StandardScaler().fit(input)
val model = new StandardScalerModel(this, map, scaler)
Params.inheritValues(map, this, model)
@@ -83,14 +79,13 @@ class StandardScalerModel private[ml] (
def setInputCol(value: String): this.type = set(inputCol, value)
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
transformSchema(dataset.schema, paramMap, logging = true)
- import dataset.sqlContext._
val map = this.paramMap ++ paramMap
val scale: (Vector) => Vector = (v) => {
scaler.transform(v)
}
- dataset.select(Star(None), scale.call(map(inputCol).attr) as map(outputCol))
+ dataset.select($"*", callUDF(scale, Column(map(inputCol))).as(map(outputCol)))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 2d89e76a4c..f6437c7fbc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -29,10 +29,8 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.catalyst.dsl._
-import org.apache.spark.sql.catalyst.expressions.Cast
-import org.apache.spark.sql.catalyst.plans.LeftOuter
+import org.apache.spark.sql.{Column, DataFrame}
+import org.apache.spark.sql.dsl._
import org.apache.spark.sql.types.{DoubleType, FloatType, IntegerType, StructField, StructType}
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.{OpenHashMap, OpenHashSet, SortDataFormat, Sorter}
@@ -112,7 +110,7 @@ class ALSModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
import dataset.sqlContext._
import org.apache.spark.ml.recommendation.ALSModel.Factor
val map = this.paramMap ++ paramMap
@@ -120,13 +118,13 @@ class ALSModel private[ml] (
val instanceTable = s"instance_$uid"
val userTable = s"user_$uid"
val itemTable = s"item_$uid"
- val instances = dataset.as(Symbol(instanceTable))
+ val instances = dataset.as(instanceTable)
val users = userFactors.map { case (id, features) =>
Factor(id, features)
- }.as(Symbol(userTable))
+ }.as(userTable)
val items = itemFactors.map { case (id, features) =>
Factor(id, features)
- }.as(Symbol(itemTable))
+ }.as(itemTable)
val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, itemFeatures) => {
if (userFeatures != null && itemFeatures != null) {
blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
@@ -135,12 +133,12 @@ class ALSModel private[ml] (
}
}
val inputColumns = dataset.schema.fieldNames
- val prediction =
- predict.call(s"$userTable.features".attr, s"$itemTable.features".attr) as map(predictionCol)
- val outputColumns = inputColumns.map(f => s"$instanceTable.$f".attr as f) :+ prediction
+ val prediction = callUDF(predict, $"$userTable.features", $"$itemTable.features")
+ .as(map(predictionCol))
+ val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ prediction
instances
- .join(users, LeftOuter, Some(map(userCol).attr === s"$userTable.id".attr))
- .join(items, LeftOuter, Some(map(itemCol).attr === s"$itemTable.id".attr))
+ .join(users, Column(map(userCol)) === $"$userTable.id", "left")
+ .join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
.select(outputColumns: _*)
}
@@ -209,14 +207,13 @@ class ALS extends Estimator[ALSModel] with ALSParams {
setMaxIter(20)
setRegParam(1.0)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): ALSModel = {
- import dataset.sqlContext._
+ override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
val map = this.paramMap ++ paramMap
- val ratings =
- dataset.select(map(userCol).attr, map(itemCol).attr, Cast(map(ratingCol).attr, FloatType))
- .map { row =>
- new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
- }
+ val ratings = dataset
+ .select(Column(map(userCol)), Column(map(itemCol)), Column(map(ratingCol)).cast(FloatType))
+ .map { row =>
+ new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
+ }
val (userFactors, itemFactors) = ALS.train(ratings, rank = map(rank),
numUserBlocks = map(numUserBlocks), numItemBlocks = map(numItemBlocks),
maxIter = map(maxIter), regParam = map(regParam), implicitPrefs = map(implicitPrefs),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 08fe991764..5d51c51346 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -24,7 +24,7 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml._
import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
/**
@@ -64,7 +64,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
def setEvaluator(value: Evaluator): this.type = set(evaluator, value)
def setNumFolds(value: Int): this.type = set(numFolds, value)
- override def fit(dataset: SchemaRDD, paramMap: ParamMap): CrossValidatorModel = {
+ override def fit(dataset: DataFrame, paramMap: ParamMap): CrossValidatorModel = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
transformSchema(dataset.schema, paramMap, logging = true)
@@ -74,7 +74,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
val epm = map(estimatorParamMaps)
val numModels = epm.size
val metrics = new Array[Double](epm.size)
- val splits = MLUtils.kFold(dataset, map(numFolds), 0)
+ val splits = MLUtils.kFold(dataset.rdd, map(numFolds), 0)
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.applySchema(training, schema).cache()
val validationDataset = sqlCtx.applySchema(validation, schema).cache()
@@ -117,7 +117,7 @@ class CrossValidatorModel private[ml] (
val bestModel: Model[_])
extends Model[CrossValidatorModel] with CrossValidatorParams {
- override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
bestModel.transform(dataset, paramMap)
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 47f1f46c6c..56a9dbdd58 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -26,7 +26,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -37,7 +37,7 @@ public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient SchemaRDD dataset;
+ private transient DataFrame dataset;
@Before
public void setUp() {
@@ -65,7 +65,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 2eba83335b..f4ba23c445 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -26,7 +26,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -34,7 +34,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient SchemaRDD dataset;
+ private transient DataFrame dataset;
@Before
public void setUp() {
@@ -55,7 +55,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
}
@@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
.registerTempTable("prediction");
- SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
predictions.collectAsList();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index a9f1c4a2c3..074b58c07d 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,7 +30,7 @@ import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -38,7 +38,7 @@ public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient SchemaRDD dataset;
+ private transient DataFrame dataset;
@Before
public void setUp() {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 4515084bc7..2f175fb117 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -23,7 +23,7 @@ import org.scalatest.FunSuite
import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.SchemaRDD
+import org.apache.spark.sql.DataFrame
class PipelineSuite extends FunSuite {
@@ -36,11 +36,11 @@ class PipelineSuite extends FunSuite {
val estimator2 = mock[Estimator[MyModel]]
val model2 = mock[MyModel]
val transformer3 = mock[Transformer]
- val dataset0 = mock[SchemaRDD]
- val dataset1 = mock[SchemaRDD]
- val dataset2 = mock[SchemaRDD]
- val dataset3 = mock[SchemaRDD]
- val dataset4 = mock[SchemaRDD]
+ val dataset0 = mock[DataFrame]
+ val dataset1 = mock[DataFrame]
+ val dataset2 = mock[DataFrame]
+ val dataset3 = mock[DataFrame]
+ val dataset4 = mock[DataFrame]
when(estimator0.fit(meq(dataset0), any[ParamMap]())).thenReturn(model0)
when(model0.transform(meq(dataset0), any[ParamMap]())).thenReturn(dataset1)
@@ -74,7 +74,7 @@ class PipelineSuite extends FunSuite {
val estimator = mock[Estimator[MyModel]]
val pipeline = new Pipeline()
.setStages(Array(estimator, estimator))
- val dataset = mock[SchemaRDD]
+ val dataset = mock[DataFrame]
intercept[IllegalArgumentException] {
pipeline.fit(dataset)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index e8030fef55..1912afce93 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -21,12 +21,12 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, DataFrame}
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
- @transient var dataset: SchemaRDD = _
+ @transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -36,34 +36,28 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
}
test("logistic regression") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset)
model.transform(dataset)
- .select('label, 'prediction)
+ .select("label", "prediction")
.collect()
}
test("logistic regression with setters") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
val model = lr.fit(dataset)
model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select('label, 'score, 'prediction)
+ .select("label", "score", "prediction")
.collect()
}
test("logistic regression fit and transform with varargs") {
- val sqlContext = this.sqlContext
- import sqlContext._
val lr = new LogisticRegression
val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select('label, 'probability, 'prediction)
+ .select("label", "probability", "prediction")
.collect()
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index cdd4db1b5b..58289acdbc 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -350,7 +350,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
numItemBlocks: Int = 3,
targetRMSE: Double = 0.05): Unit = {
val sqlContext = this.sqlContext
- import sqlContext.{createSchemaRDD, symbolToUnresolvedAttribute}
+ import sqlContext.createSchemaRDD
val als = new ALS()
.setRank(rank)
.setRegParam(regParam)
@@ -360,7 +360,7 @@ class ALSSuite extends FunSuite with MLlibTestSparkContext with Logging {
val alpha = als.getAlpha
val model = als.fit(training)
val predictions = model.transform(test)
- .select('rating, 'prediction)
+ .select("rating", "prediction")
.map { case Row(rating: Float, prediction: Float) =>
(rating.toDouble, prediction.toDouble)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 41cc13da4d..74104fa7a6 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -23,11 +23,11 @@ import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, SchemaRDD}
+import org.apache.spark.sql.{SQLContext, DataFrame}
class CrossValidatorSuite extends FunSuite with MLlibTestSparkContext {
- @transient var dataset: SchemaRDD = _
+ @transient var dataset: DataFrame = _
override def beforeAll(): Unit = {
super.beforeAll()