aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-03-10 17:00:17 -0800
committerYin Huai <yhuai@databricks.com>2016-03-10 17:00:17 -0800
commit1d542785b9949e7f92025e6754973a779cc37c52 (patch)
treeceda7492e40c9d9a9231a5011c91e30bf0b1f390 /mllib
parent27fe6bacc532184ef6e8a2a24cd07f2c9188004e (diff)
downloadspark-1d542785b9949e7f92025e6754973a779cc37c52.tar.gz
spark-1d542785b9949e7f92025e6754973a779cc37c52.tar.bz2
spark-1d542785b9949e7f92025e6754973a779cc37c52.zip
[SPARK-13244][SQL] Migrates DataFrame to Dataset
## What changes were proposed in this pull request? This PR unifies DataFrame and Dataset by migrating existing DataFrame operations to Dataset and make `DataFrame` a type alias of `Dataset[Row]`. Most Scala code changes are source compatible, but Java API is broken as Java knows nothing about Scala type alias (mostly replacing `DataFrame` with `Dataset<Row>`). There are several noticeable API changes related to those returning arrays: 1. `collect`/`take` - Old APIs in class `DataFrame`: ```scala def collect(): Array[Row] def take(n: Int): Array[Row] ``` - New APIs in class `Dataset[T]`: ```scala def collect(): Array[T] def take(n: Int): Array[T] def collectRows(): Array[Row] def takeRows(n: Int): Array[Row] ``` Two specialized methods `collectRows` and `takeRows` are added because Java doesn't support returning generic arrays. Thus, for example, `DataFrame.collect(): Array[T]` actually returns `Object` instead of `Array<T>` from Java side. Normally, Java users may fall back to `collectAsList` and `takeAsList`. The two new specialized versions are added to avoid performance regression in ML related code (but maybe I'm wrong and they are not necessary here). 1. `randomSplit` - Old APIs in class `DataFrame`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[DataFrame] def randomSplit(weights: Array[Double]): Array[DataFrame] ``` - New APIs in class `Dataset[T]`: ```scala def randomSplit(weights: Array[Double], seed: Long): Array[Dataset[T]] def randomSplit(weights: Array[Double]): Array[Dataset[T]] ``` Similar problem as above, but hasn't been addressed for Java API yet. We can probably add `randomSplitAsList` to fix this one. 1. `groupBy` Some original `DataFrame.groupBy` methods have conflicting signature with original `Dataset.groupBy` methods. To distinguish these two, typed `Dataset.groupBy` methods are renamed to `groupByKey`. Other noticeable changes: 1. Dataset always do eager analysis now We used to support disabling DataFrame eager analysis to help reporting partially analyzed malformed logical plan on analysis failure. However, Dataset encoders requires eager analysi during Dataset construction. To preserve the error reporting feature, `AnalysisException` now takes an extra `Option[LogicalPlan]` argument to hold the partially analyzed plan, so that we can check the plan tree when reporting test failures. This plan is passed by `QueryExecution.assertAnalyzed`. ## How was this patch tested? Existing tests do the work. ## TODO - [ ] Fix all tests - [ ] Re-enable MiMA check - [ ] Update ScalaDoc (`since`, `group`, and example code) Author: Cheng Lian <lian@databricks.com> Author: Yin Huai <yhuai@databricks.com> Author: Wenchen Fan <wenchen@databricks.com> Author: Cheng Lian <liancheng@users.noreply.github.com> Closes #11443 from liancheng/ds-to-df.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java10
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java5
29 files changed, 118 insertions, 99 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 0a8c9e5954..60a4a1d2ea 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.ml;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -26,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -37,7 +38,7 @@ public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
@Before
public void setUp() {
@@ -65,7 +66,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 40b9c35adc..0d923dfeff 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -21,6 +21,8 @@ import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -30,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeClassifier dt = new DecisionTreeClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 59b6fba7a9..f470f4ada6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaGBTClassifierSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaGBTClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
GBTClassifier rf = new GBTClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index fd22eb6dca..536f0dc58f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
private double eps = 1e-5;
@@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps);
@@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
- DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
+ Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
- DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
for (Row r: predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
@@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
- DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
- for (Row row: trans1.collect()) {
+ Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row: trans1.collectAsList()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
Assert.assertEquals(raw.size(), 2);
@@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
- DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
- for (Row row: trans2.collect()) {
+ Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
+ for (Row row: trans2.collectAsList()) {
double pred = row.getDouble(0);
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index ec6b4bf3c0..d499d363f1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -28,7 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
@Test
public void testMLPC() {
- DataFrame dataFrame = sqlContext.createDataFrame(
+ Dataset<Row> dataFrame = sqlContext.createDataFrame(
jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
@@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
.setSeed(11L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
- DataFrame result = model.transform(dataFrame);
- Row[] predictionAndLabels = result.select("prediction", "label").collect();
+ Dataset<Row> result = model.transform(dataFrame);
+ List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
for (Row r: predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 07936eb79b..45101f286c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable {
jsc = null;
}
- public void validatePrediction(DataFrame predictionAndLabels) {
- for (Row r : predictionAndLabels.collect()) {
+ public void validatePrediction(Dataset<Row> predictionAndLabels) {
+ for (Row r : predictionAndLabels.collectAsList()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
assertEquals(label, prediction, 1E-5);
@@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
- DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label");
validatePrediction(predictionAndLabels);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index cbabafe1b5..d493a7fcec 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
+import org.apache.spark.sql.Row;
import scala.collection.JavaConverters;
import org.junit.After;
@@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
public class JavaOneVsRestSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
@@ -75,7 +76,7 @@ public class JavaOneVsRestSuite implements Serializable {
Assert.assertEquals(ova.getLabelCol() , "label");
Assert.assertEquals(ova.getPredictionCol() , "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
- DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+ Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 5485fcbf01..9a63cef2a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -31,7 +31,8 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaRandomForestClassifierSuite implements Serializable {
@@ -58,7 +59,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
RandomForestClassifier rf = new RandomForestClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index cc5a4ef4c2..a3fcdb54ee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
private transient JavaSparkContext sc;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient SQLContext sql;
@Before
@@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector[] centers = model.clusterCenters();
assertEquals(k, centers.length);
- DataFrame transformed = model.transform(dataset);
+ Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
for (String column: expectedColumns) {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index d707bdee99..77e3a489a9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -25,7 +26,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -57,7 +58,7 @@ public class JavaBucketizerSuite {
StructType schema = new StructType(new StructField[] {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = jsql.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
@@ -70,7 +71,7 @@ public class JavaBucketizerSuite {
.setOutputCol("result")
.setSplits(splits);
- Row[] result = bucketizer.transform(dataset).select("result").collect();
+ List<Row> result = bucketizer.transform(dataset).select("result").collectAsList();
for (Row r : result) {
double index = r.getDouble(0);
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 63e5c93798..ed1ad4c3a3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
import org.junit.After;
@@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -56,7 +57,7 @@ public class JavaDCTSuite {
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
- DataFrame dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = jsql.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
@@ -69,8 +70,8 @@ public class JavaDCTSuite {
.setInputCol("vec")
.setOutputCol("resultVec");
- Row[] result = dct.transform(dataset).select("resultVec").collect();
- Vector resultVec = result[0].getAs("resultVec");
+ List<Row> result = dct.transform(dataset).select("resultVec").collectAsList();
+ Vector resultVec = result.get(0).getAs("resultVec");
Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 5932017f8f..6e2cc7e887 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -65,21 +65,21 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceData = jsql.createDataFrame(data, schema);
+ Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
- DataFrame wordsData = tokenizer.transform(sentenceData);
+ Dataset<Row> wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
- DataFrame featurizedData = hashingTF.transform(wordsData);
+ Dataset<Row> featurizedData = hashingTF.transform(wordsData);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
- DataFrame rescaledData = idfModel.transform(featurizedData);
- for (Row r : rescaledData.select("features", "label").take(3)) {
+ Dataset<Row> rescaledData = idfModel.transform(featurizedData);
+ for (Row r : rescaledData.select("features", "label").takeAsList(3)) {
Vector features = r.getAs(0);
Assert.assertEquals(features.size(), numFeatures);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index e17d549c50..5bbd9634b2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -26,7 +26,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaNormalizerSuite {
@@ -53,17 +54,17 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
));
- DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
+ Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normFeatures");
// Normalize each Vector using $L^2$ norm.
- DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
+ Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
l2NormData.count();
// Normalize each Vector using $L^\infty$ norm.
- DataFrame lInfNormData =
+ Dataset<Row> lInfNormData =
normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
lInfNormData.count();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index e8f329f9cf..1389d17e7e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -100,7 +100,7 @@ public class JavaPCASuite implements Serializable {
}
);
- DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+ Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index e22d117032..6a8bb64801 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -77,11 +77,11 @@ public class JavaPolynomialExpansionSuite {
new StructField("expected", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
- Row[] pairs = polyExpansion.transform(dataset)
+ List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected")
- .collect();
+ .collectAsList();
for (Row r : pairs) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index ed74363f59..3f6fc333e4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -26,7 +26,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaStandardScalerSuite {
@@ -53,7 +54,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
);
- DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
@@ -65,7 +66,7 @@ public class JavaStandardScalerSuite {
StandardScalerModel scalerModel = scaler.fit(dataFrame);
// Normalize each feature to have unit standard deviation.
- DataFrame scaledData = scalerModel.transform(dataFrame);
+ Dataset<Row> scaledData = scalerModel.transform(dataFrame);
scaledData.count();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 139d1d005a..5812037dee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -25,7 +25,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -65,7 +65,7 @@ public class JavaStopWordsRemoverSuite {
StructType schema = new StructType(new StructField[] {
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 153a08a4cd..431779cd2e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -26,7 +26,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -58,16 +58,16 @@ public class JavaStringIndexerSuite {
});
List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
- DataFrame dataset = sqlContext.createDataFrame(data, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex");
- DataFrame output = indexer.fit(dataset).transform(dataset);
+ Dataset<Row> output = indexer.fit(dataset).transform(dataset);
- Assert.assertArrayEquals(
- new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) },
- output.orderBy("id").select("id", "labelIndex").collect());
+ Assert.assertEquals(
+ Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)),
+ output.orderBy("id").select("id", "labelIndex").collectAsList());
}
/** An alias for RowFactory.create. */
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index c407d98f1b..83d16cbd0e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -26,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -61,11 +62,11 @@ public class JavaTokenizerSuite {
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
));
- DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+ Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
- Row[] pairs = myRegExTokenizer.transform(dataset)
+ List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
- .collect();
+ .collectAsList();
for (Row r : pairs) {
Assert.assertEquals(r.get(0), r.get(1));
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index f8ba84ef77..e45e198043 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -64,11 +64,11 @@ public class JavaVectorAssemblerSuite {
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
- DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setOutputCol("features");
- DataFrame output = assembler.transform(dataset);
+ Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals(
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0));
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index bfcca62fa1..fec6cac8be 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -30,7 +30,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -57,7 +58,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 4.0))
);
SQLContext sqlContext = new SQLContext(sc);
- DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexed")
@@ -66,6 +67,6 @@ public class JavaVectorIndexerSuite implements Serializable {
Assert.assertEquals(model.numFeatures(), 2);
Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps();
Assert.assertEquals(categoryMaps.size(), 1);
- DataFrame indexedData = model.transform(data);
+ Dataset<Row> indexedData = model.transform(data);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index 786c11c412..b87605ebfd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite {
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
);
- DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
+ Dataset<Row> dataset =
+ jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
- DataFrame output = vectorSlicer.transform(dataset);
+ Dataset<Row> output = vectorSlicer.transform(dataset);
- for (Row r : output.select("userFeatures", "features").take(2)) {
+ for (Row r : output.select("userFeatures", "features").takeRows(2)) {
Vector features = r.getAs(1);
Assert.assertEquals(features.size(), 2);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index b292b1b06d..7517b70cc9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -26,7 +26,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- DataFrame documentDF = sqlContext.createDataFrame(
+ Dataset<Row> documentDF = sqlContext.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@@ -66,9 +66,9 @@ public class JavaWord2VecSuite {
.setVectorSize(3)
.setMinCount(0);
Word2VecModel model = word2Vec.fit(documentDF);
- DataFrame result = model.transform(documentDF);
+ Dataset<Row> result = model.transform(documentDF);
- for (Row r: result.select("result").collect()) {
+ for (Row r: result.select("result").collectAsList()) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index d5c9d120c5..a1575300a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 38d15dc2b7..9477e8d2bf 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaGBTRegressorSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaGBTRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
GBTRegressor rf = new GBTRegressor()
.setMaxDepth(2)
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 4fb0b0d109..9f817515eb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -28,7 +28,8 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
.generateLogisticInputAsList;
@@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
@@ -64,7 +65,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assertEquals("features", model.getFeaturesCol());
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 31be8880c2..a90535d11a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -31,7 +31,8 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaRandomForestRegressorSuite implements Serializable {
@@ -58,7 +59,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
RandomForestRegressor rf = new RandomForestRegressor()
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 2976b38e45..b8ddf907d0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -31,7 +31,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
@@ -68,7 +68,7 @@ public class JavaLibSVMRelationSuite {
@Test
public void verifyLibSVMDF() {
- DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
.load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 08eeca53f0..24b0097454 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
@Before
public void setUp() {