diff options
author | Cheng Lian <lian@databricks.com> | 2016-03-10 17:00:17 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-03-10 17:00:17 -0800 |
commit | 1d542785b9949e7f92025e6754973a779cc37c52 (patch) | |
tree | ceda7492e40c9d9a9231a5011c91e30bf0b1f390 /mllib | |
parent | 27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff) | |
download | spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2 spark-1d542785b9949e7f92025e6754973a779cc37c52.zip |
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request?
This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`.
Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`).
There are several noticeable API changes related to those returning arrays:
1. `collect`/`take`
- Old APIs in class `DataFrame`:
```scala
def collect(): Array[Row]
def take(n: Int): Array[Row]
```
- New APIs in class `Dataset[T]`:
```scala
def collect(): Array[T]
def take(n: Int): Array[T]
def collectRows(): Array[Row]
def takeRows(n: Int): Array[Row]
```
Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side.
Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here).
1. `randomSplit`
- Old APIs in class `DataFrame`:
```scala
def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame]
def randomSplit(weights: Array[Double]): Array[DataFrame]
```
- New APIs in class `Dataset[T]`:
```scala
def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]]
def randomSplit(weights: Array[Double]): Array[Dataset[T]]
```
Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one.
1. `groupBy`
Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`.
Other noticeable changes:
1. Dataset always do eager analysis now
We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`.
## How was this patch tested?
Existing tests do the work.
## TODO
- [ ] Fix all tests
- [ ] Re-enable MiMA check
- [ ] Update ScalaDoc (`since`, `group`, and example code)
Author: Cheng Lian <lian@databricks.com>
Author: Yin Huai <yhuai@databricks.com>
Author: Wenchen Fan <wenchen@databricks.com>
Author: Cheng Lian <liancheng@users.noreply.github.com>
Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'mllib')
29 files changed, 118 insertions, 99 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 0a8c9e5954..60a4a1d2ea 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.ml; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -26,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +38,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; @Before public void setUp() { @@ -65,7 +66,7 @@ public class JavaPipelineSuite { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 40b9c35adc..0d923dfeff 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -21,6 +21,8 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -30,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeClassifier dt = new DecisionTreeClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 59b6fba7a9..f470f4ada6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTClassifierSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaGBTClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. GBTClassifier rf = new GBTClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fd22eb6dca..536f0dc58f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; private double eps = 1e-5; @@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; @@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collect()) { + Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collectAsList()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); Assert.assertEquals(raw.size(), 2); @@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collect()) { + Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collectAsList()) { double pred = row.getDouble(0); Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index ec6b4bf3c0..d499d363f1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -28,7 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { @Test public void testMLPC() { - DataFrame dataFrame = sqlContext.createDataFrame( + Dataset<Row> dataFrame = sqlContext.createDataFrame( jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), @@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { .setSeed(11L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); - DataFrame result = model.transform(dataFrame); - Row[] predictionAndLabels = result.select("prediction", "label").collect(); + Dataset<Row> result = model.transform(dataFrame); + List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList(); for (Row r: predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 07936eb79b..45101f286c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable { jsc = null; } - public void validatePrediction(DataFrame predictionAndLabels) { - for (Row r : predictionAndLabels.collect()) { + public void validatePrediction(Dataset<Row> predictionAndLabels) { + for (Row r : predictionAndLabels.collectAsList()) { double prediction = r.getAs(0); double label = r.getAs(1); assertEquals(label, prediction, 1E-5); @@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); - DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label"); validatePrediction(predictionAndLabels); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index cbabafe1b5..d493a7fcec 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.List; +import org.apache.spark.sql.Row; import scala.collection.JavaConverters; import org.junit.After; @@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.SQLContext; public class JavaOneVsRestSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before @@ -75,7 +76,7 @@ public class JavaOneVsRestSuite implements Serializable { Assert.assertEquals(ova.getLabelCol() , "label"); Assert.assertEquals(ova.getPredictionCol() , "prediction"); OneVsRestModel ovaModel = ova.fit(dataset); - DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction"); predictions.collectAsList(); Assert.assertEquals(ovaModel.getLabelCol(), "label"); Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 5485fcbf01..9a63cef2a8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -31,7 +31,8 @@ import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestClassifierSuite implements Serializable { @@ -58,7 +59,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. RandomForestClassifier rf = new RandomForestClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index cc5a4ef4c2..a3fcdb54ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaKMeansSuite implements Serializable { private transient int k = 5; private transient JavaSparkContext sc; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient SQLContext sql; @Before @@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable { Vector[] centers = model.clusterCenters(); assertEquals(k, centers.length); - DataFrame transformed = model.transform(dataset); + Dataset<Row> transformed = model.transform(dataset); List<String> columns = Arrays.asList(transformed.columns()); List<String> expectedColumns = Arrays.asList("features", "prediction"); for (String column: expectedColumns) { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index d707bdee99..77e3a489a9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -25,7 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public class JavaBucketizerSuite { StructType schema = new StructType(new StructField[] { new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame( + Dataset<Row> dataset = jsql.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), @@ -70,7 +71,7 @@ public class JavaBucketizerSuite { .setOutputCol("result") .setSplits(splits); - Row[] result = bucketizer.transform(dataset).select("result").collect(); + List<Row> result = bucketizer.transform(dataset).select("result").collectAsList(); for (Row r : result) { double index = r.getDouble(0); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 63e5c93798..ed1ad4c3a3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -56,7 +57,7 @@ public class JavaDCTSuite { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - DataFrame dataset = jsql.createDataFrame( + Dataset<Row> dataset = jsql.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) @@ -69,8 +70,8 @@ public class JavaDCTSuite { .setInputCol("vec") .setOutputCol("resultVec"); - Row[] result = dct.transform(dataset).select("resultVec").collect(); - Vector resultVec = result[0].getAs("resultVec"); + List<Row> result = dct.transform(dataset).select("resultVec").collectAsList(); + Vector resultVec = result.get(0).getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 5932017f8f..6e2cc7e887 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -27,7 +27,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,21 +65,21 @@ public class JavaHashingTFSuite { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = jsql.createDataFrame(data, schema); + Dataset<Row> sentenceData = jsql.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); - DataFrame wordsData = tokenizer.transform(sentenceData); + Dataset<Row> wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurizedData = hashingTF.transform(wordsData); + Dataset<Row> featurizedData = hashingTF.transform(wordsData); IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); - DataFrame rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").take(3)) { + Dataset<Row> rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index e17d549c50..5bbd9634b2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -26,7 +26,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaNormalizerSuite { @@ -53,17 +54,17 @@ public class JavaNormalizerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); // Normalize each Vector using $L^2$ norm. - DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); l2NormData.count(); // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = + Dataset<Row> lInfNormData = normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); lInfNormData.count(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index e8f329f9cf..1389d17e7e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -100,7 +100,7 @@ public class JavaPCASuite implements Serializable { } ); - DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index e22d117032..6a8bb64801 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -77,11 +77,11 @@ public class JavaPolynomialExpansionSuite { new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); - Row[] pairs = polyExpansion.transform(dataset) + List<Row> pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") - .collect(); + .collectAsList(); for (Row r : pairs) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index ed74363f59..3f6fc333e4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -26,7 +26,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaStandardScalerSuite { @@ -53,7 +54,7 @@ public class JavaStandardScalerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -65,7 +66,7 @@ public class JavaStandardScalerSuite { StandardScalerModel scalerModel = scaler.fit(dataFrame); // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); + Dataset<Row> scaledData = scalerModel.transform(dataFrame); scaledData.count(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index 139d1d005a..5812037dee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -25,7 +25,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,7 +65,7 @@ public class JavaStopWordsRemoverSuite { StructType schema = new StructType(new StructField[] { new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); remover.transform(dataset).collect(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 153a08a4cd..431779cd2e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -26,7 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -58,16 +58,16 @@ public class JavaStringIndexerSuite { }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); - DataFrame dataset = sqlContext.createDataFrame(data, schema); + Dataset<Row> dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); - DataFrame output = indexer.fit(dataset).transform(dataset); + Dataset<Row> output = indexer.fit(dataset).transform(dataset); - Assert.assertArrayEquals( - new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, - output.orderBy("id").select("id", "labelIndex").collect()); + Assert.assertEquals( + Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), + output.orderBy("id").select("id", "labelIndex").collectAsList()); } /** An alias for RowFactory.create. */ diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index c407d98f1b..83d16cbd0e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -26,7 +27,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -61,11 +62,11 @@ public class JavaTokenizerSuite { new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); - DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); - Row[] pairs = myRegExTokenizer.transform(dataset) + List<Row> pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") - .collect(); + .collectAsList(); for (Row r : pairs) { Assert.assertEquals(r.get(0), r.get(1)); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index f8ba84ef77..e45e198043 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -64,11 +64,11 @@ public class JavaVectorAssemblerSuite { Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[] {"x", "y", "z", "n"}) .setOutputCol("features"); - DataFrame output = assembler.transform(dataset); + Dataset<Row> output = assembler.transform(dataset); Assert.assertEquals( Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), output.select("features").first().<Vector>getAs(0)); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index bfcca62fa1..fec6cac8be 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -30,7 +30,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public class JavaVectorIndexerSuite implements Serializable { new FeatureData(Vectors.dense(1.0, 4.0)) ); SQLContext sqlContext = new SQLContext(sc); - DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -66,6 +67,6 @@ public class JavaVectorIndexerSuite implements Serializable { Assert.assertEquals(model.numFeatures(), 2); Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps(); Assert.assertEquals(categoryMaps.size(), 1); - DataFrame indexedData = model.transform(data); + Dataset<Row> indexedData = model.transform(data); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 786c11c412..b87605ebfd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite { RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); - DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + Dataset<Row> dataset = + jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - DataFrame output = vectorSlicer.transform(dataset); + Dataset<Row> output = vectorSlicer.transform(dataset); - for (Row r : output.select("userFeatures", "features").take(2)) { + for (Row r : output.select("userFeatures", "features").takeRows(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index b292b1b06d..7517b70cc9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -26,7 +26,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -53,7 +53,7 @@ public class JavaWord2VecSuite { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame( + Dataset<Row> documentDF = sqlContext.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -66,9 +66,9 @@ public class JavaWord2VecSuite { .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); - DataFrame result = model.transform(documentDF); + Dataset<Row> result = model.transform(documentDF); - for (Row r: result.select("result").collect()) { + for (Row r: result.select("result").collectAsList()) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index d5c9d120c5..a1575300a8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 38d15dc2b7..9477e8d2bf 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTRegressorSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaGBTRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); GBTRegressor rf = new GBTRegressor() .setMaxDepth(2) diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 4fb0b0d109..9f817515eb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -28,7 +28,8 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite .generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before @@ -64,7 +65,7 @@ public class JavaLinearRegressionSuite implements Serializable { assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 31be8880c2..a90535d11a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -31,7 +31,8 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestRegressorSuite implements Serializable { @@ -58,7 +59,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. RandomForestRegressor rf = new RandomForestRegressor() diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 2976b38e45..b8ddf907d0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -31,7 +31,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; @@ -68,7 +68,7 @@ public class JavaLibSVMRelationSuite { @Test public void verifyLibSVMDF() { - DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 08eeca53f0..24b0097454 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; @Before public void setUp() { |