aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java20
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java10
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java9
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java8
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java7
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java5
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java5
29 files changed, 118 insertions, 99 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 0a8c9e5954..60a4a1d2ea 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -17,6 +17,8 @@
package org.apache.spark.ml;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -26,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -37,7 +38,7 @@ public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
@Before
public void setUp() {
@@ -65,7 +66,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
index 40b9c35adc..0d923dfeff 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java
@@ -21,6 +21,8 @@ import java.io.Serializable;
import java.util.HashMap;
import java.util.Map;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
@@ -30,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
public class JavaDecisionTreeClassifierSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeClassifier dt = new DecisionTreeClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
index 59b6fba7a9..f470f4ada6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaGBTClassifierSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaGBTClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
GBTClassifier rf = new GBTClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index fd22eb6dca..536f0dc58f 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
-import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
private double eps = 1e-5;
@@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(lr.getLabelCol(), "label");
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
// Check defaults
Assert.assertEquals(0.5, model.getThreshold(), eps);
@@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Modify model params, and check that the params worked.
model.setThreshold(1.0);
model.transform(dataset).registerTempTable("predAllZero");
- DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
+ Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
for (Row r: predAllZero.collectAsList()) {
Assert.assertEquals(0.0, r.getDouble(0), eps);
}
// Call transform with params, and check that the params worked.
model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
.registerTempTable("predNotAllZero");
- DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
boolean foundNonZero = false;
for (Row r: predNotAllZero.collectAsList()) {
if (r.getDouble(0) != 0.0) foundNonZero = true;
@@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(2, model.numClasses());
model.transform(dataset).registerTempTable("transformed");
- DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
- for (Row row: trans1.collect()) {
+ Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row: trans1.collectAsList()) {
Vector raw = (Vector)row.get(0);
Vector prob = (Vector)row.get(1);
Assert.assertEquals(raw.size(), 2);
@@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps);
}
- DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
- for (Row row: trans2.collect()) {
+ Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
+ for (Row row: trans2.collectAsList()) {
double pred = row.getDouble(0);
Vector prob = (Vector)row.get(1);
double probOfPred = prob.apply((int)pred);
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index ec6b4bf3c0..d499d363f1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -28,7 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
@Test
public void testMLPC() {
- DataFrame dataFrame = sqlContext.createDataFrame(
+ Dataset<Row> dataFrame = sqlContext.createDataFrame(
jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
@@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
.setSeed(11L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
- DataFrame result = model.transform(dataFrame);
- Row[] predictionAndLabels = result.select("prediction", "label").collect();
+ Dataset<Row> result = model.transform(dataFrame);
+ List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
for (Row r: predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 07936eb79b..45101f286c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable {
jsc = null;
}
- public void validatePrediction(DataFrame predictionAndLabels) {
- for (Row r : predictionAndLabels.collect()) {
+ public void validatePrediction(Dataset<Row> predictionAndLabels) {
+ for (Row r : predictionAndLabels.collectAsList()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
assertEquals(label, prediction, 1E-5);
@@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
- DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label");
validatePrediction(predictionAndLabels);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
index cbabafe1b5..d493a7fcec 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.List;
+import org.apache.spark.sql.Row;
import scala.collection.JavaConverters;
import org.junit.After;
@@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SQLContext;
public class JavaOneVsRestSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
@@ -75,7 +76,7 @@ public class JavaOneVsRestSuite implements Serializable {
Assert.assertEquals(ova.getLabelCol() , "label");
Assert.assertEquals(ova.getPredictionCol() , "prediction");
OneVsRestModel ovaModel = ova.fit(dataset);
- DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction");
+ Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction");
predictions.collectAsList();
Assert.assertEquals(ovaModel.getLabelCol(), "label");
Assert.assertEquals(ovaModel.getPredictionCol() , "prediction");
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
index 5485fcbf01..9a63cef2a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java
@@ -31,7 +31,8 @@ import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaRandomForestClassifierSuite implements Serializable {
@@ -58,7 +59,7 @@ public class JavaRandomForestClassifierSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2);
// This tests setters. Training with various options is tested in Scala.
RandomForestClassifier rf = new RandomForestClassifier()
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
index cc5a4ef4c2..a3fcdb54ee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaKMeansSuite implements Serializable {
private transient int k = 5;
private transient JavaSparkContext sc;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient SQLContext sql;
@Before
@@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable {
Vector[] centers = model.clusterCenters();
assertEquals(k, centers.length);
- DataFrame transformed = model.transform(dataset);
+ Dataset<Row> transformed = model.transform(dataset);
List<String> columns = Arrays.asList(transformed.columns());
List<String> expectedColumns = Arrays.asList("features", "prediction");
for (String column: expectedColumns) {
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
index d707bdee99..77e3a489a9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -25,7 +26,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -57,7 +58,7 @@ public class JavaBucketizerSuite {
StructType schema = new StructType(new StructField[] {
new StructField("feature", DataTypes.DoubleType, false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = jsql.createDataFrame(
Arrays.asList(
RowFactory.create(-0.5),
RowFactory.create(-0.3),
@@ -70,7 +71,7 @@ public class JavaBucketizerSuite {
.setOutputCol("result")
.setSplits(splits);
- Row[] result = bucketizer.transform(dataset).select("result").collect();
+ List<Row> result = bucketizer.transform(dataset).select("result").collectAsList();
for (Row r : result) {
double index = r.getDouble(0);
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
index 63e5c93798..ed1ad4c3a3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D;
import org.junit.After;
@@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -56,7 +57,7 @@ public class JavaDCTSuite {
@Test
public void javaCompatibilityTest() {
double[] input = new double[] {1D, 2D, 3D, 4D};
- DataFrame dataset = jsql.createDataFrame(
+ Dataset<Row> dataset = jsql.createDataFrame(
Arrays.asList(RowFactory.create(Vectors.dense(input))),
new StructType(new StructField[]{
new StructField("vec", (new VectorUDT()), false, Metadata.empty())
@@ -69,8 +70,8 @@ public class JavaDCTSuite {
.setInputCol("vec")
.setOutputCol("resultVec");
- Row[] result = dct.transform(dataset).select("resultVec").collect();
- Vector resultVec = result[0].getAs("resultVec");
+ List<Row> result = dct.transform(dataset).select("resultVec").collectAsList();
+ Vector resultVec = result.get(0).getAs("resultVec");
Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
index 5932017f8f..6e2cc7e887 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java
@@ -27,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -65,21 +65,21 @@ public class JavaHashingTFSuite {
new StructField("sentence", DataTypes.StringType, false, Metadata.empty())
});
- DataFrame sentenceData = jsql.createDataFrame(data, schema);
+ Dataset<Row> sentenceData = jsql.createDataFrame(data, schema);
Tokenizer tokenizer = new Tokenizer()
.setInputCol("sentence")
.setOutputCol("words");
- DataFrame wordsData = tokenizer.transform(sentenceData);
+ Dataset<Row> wordsData = tokenizer.transform(sentenceData);
int numFeatures = 20;
HashingTF hashingTF = new HashingTF()
.setInputCol("words")
.setOutputCol("rawFeatures")
.setNumFeatures(numFeatures);
- DataFrame featurizedData = hashingTF.transform(wordsData);
+ Dataset<Row> featurizedData = hashingTF.transform(wordsData);
IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features");
IDFModel idfModel = idf.fit(featurizedData);
- DataFrame rescaledData = idfModel.transform(featurizedData);
- for (Row r : rescaledData.select("features", "label").take(3)) {
+ Dataset<Row> rescaledData = idfModel.transform(featurizedData);
+ for (Row r : rescaledData.select("features", "label").takeAsList(3)) {
Vector features = r.getAs(0);
Assert.assertEquals(features.size(), numFeatures);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
index e17d549c50..5bbd9634b2 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java
@@ -26,7 +26,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaNormalizerSuite {
@@ -53,17 +54,17 @@ public class JavaNormalizerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
));
- DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
+ Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class);
Normalizer normalizer = new Normalizer()
.setInputCol("features")
.setOutputCol("normFeatures");
// Normalize each Vector using $L^2$ norm.
- DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
+ Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2));
l2NormData.count();
// Normalize each Vector using $L^\infty$ norm.
- DataFrame lInfNormData =
+ Dataset<Row> lInfNormData =
normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY));
lInfNormData.count();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
index e8f329f9cf..1389d17e7e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java
@@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -100,7 +100,7 @@ public class JavaPCASuite implements Serializable {
}
);
- DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
+ Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class);
PCAModel pca = new PCA()
.setInputCol("features")
.setOutputCol("pca_features")
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
index e22d117032..6a8bb64801 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java
@@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -77,11 +77,11 @@ public class JavaPolynomialExpansionSuite {
new StructField("expected", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
- Row[] pairs = polyExpansion.transform(dataset)
+ List<Row> pairs = polyExpansion.transform(dataset)
.select("polyFeatures", "expected")
- .collect();
+ .collectAsList();
for (Row r : pairs) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
index ed74363f59..3f6fc333e4 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java
@@ -26,7 +26,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
public class JavaStandardScalerSuite {
@@ -53,7 +54,7 @@ public class JavaStandardScalerSuite {
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)),
new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0))
);
- DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
+ Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2),
VectorIndexerSuite.FeatureData.class);
StandardScaler scaler = new StandardScaler()
.setInputCol("features")
@@ -65,7 +66,7 @@ public class JavaStandardScalerSuite {
StandardScalerModel scalerModel = scaler.fit(dataFrame);
// Normalize each feature to have unit standard deviation.
- DataFrame scaledData = scalerModel.transform(dataFrame);
+ Dataset<Row> scaledData = scalerModel.transform(dataFrame);
scaledData.count();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
index 139d1d005a..5812037dee 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java
@@ -25,7 +25,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -65,7 +65,7 @@ public class JavaStopWordsRemoverSuite {
StructType schema = new StructType(new StructField[] {
new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
remover.transform(dataset).collect();
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
index 153a08a4cd..431779cd2e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java
@@ -26,7 +26,7 @@ import org.junit.Before;
import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -58,16 +58,16 @@ public class JavaStringIndexerSuite {
});
List<Row> data = Arrays.asList(
cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c"));
- DataFrame dataset = sqlContext.createDataFrame(data, schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(data, schema);
StringIndexer indexer = new StringIndexer()
.setInputCol("label")
.setOutputCol("labelIndex");
- DataFrame output = indexer.fit(dataset).transform(dataset);
+ Dataset<Row> output = indexer.fit(dataset).transform(dataset);
- Assert.assertArrayEquals(
- new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) },
- output.orderBy("id").select("id", "labelIndex").collect());
+ Assert.assertEquals(
+ Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)),
+ output.orderBy("id").select("id", "labelIndex").collectAsList());
}
/** An alias for RowFactory.create. */
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
index c407d98f1b..83d16cbd0e 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java
@@ -18,6 +18,7 @@
package org.apache.spark.ml.feature;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -26,7 +27,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -61,11 +62,11 @@ public class JavaTokenizerSuite {
new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}),
new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"})
));
- DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
+ Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class);
- Row[] pairs = myRegExTokenizer.transform(dataset)
+ List<Row> pairs = myRegExTokenizer.transform(dataset)
.select("tokens", "wantedTokens")
- .collect();
+ .collectAsList();
for (Row r : pairs) {
Assert.assertEquals(r.get(0), r.get(1));
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
index f8ba84ef77..e45e198043 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java
@@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -64,11 +64,11 @@ public class JavaVectorAssemblerSuite {
Row row = RowFactory.create(
0, 0.0, Vectors.dense(1.0, 2.0), "a",
Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L);
- DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
+ Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema);
VectorAssembler assembler = new VectorAssembler()
.setInputCols(new String[] {"x", "y", "z", "n"})
.setOutputCol("features");
- DataFrame output = assembler.transform(dataset);
+ Dataset<Row> output = assembler.transform(dataset);
Assert.assertEquals(
Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}),
output.select("features").first().<Vector>getAs(0));
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
index bfcca62fa1..fec6cac8be 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java
@@ -30,7 +30,8 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -57,7 +58,7 @@ public class JavaVectorIndexerSuite implements Serializable {
new FeatureData(Vectors.dense(1.0, 4.0))
);
SQLContext sqlContext = new SQLContext(sc);
- DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
+ Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class);
VectorIndexer indexer = new VectorIndexer()
.setInputCol("features")
.setOutputCol("indexed")
@@ -66,6 +67,6 @@ public class JavaVectorIndexerSuite implements Serializable {
Assert.assertEquals(model.numFeatures(), 2);
Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps();
Assert.assertEquals(categoryMaps.size(), 1);
- DataFrame indexedData = model.transform(data);
+ Dataset<Row> indexedData = model.transform(data);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
index 786c11c412..b87605ebfd 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java
@@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup;
import org.apache.spark.ml.attribute.NumericAttribute;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite {
RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0))
);
- DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
+ Dataset<Row> dataset =
+ jsql.createDataFrame(data, (new StructType()).add(group.toStructField()));
VectorSlicer vectorSlicer = new VectorSlicer()
.setInputCol("userFeatures").setOutputCol("features");
vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"});
- DataFrame output = vectorSlicer.transform(dataset);
+ Dataset<Row> output = vectorSlicer.transform(dataset);
- for (Row r : output.select("userFeatures", "features").take(2)) {
+ for (Row r : output.select("userFeatures", "features").takeRows(2)) {
Vector features = r.getAs(1);
Assert.assertEquals(features.size(), 2);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
index b292b1b06d..7517b70cc9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java
@@ -26,7 +26,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -53,7 +53,7 @@ public class JavaWord2VecSuite {
StructType schema = new StructType(new StructField[]{
new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty())
});
- DataFrame documentDF = sqlContext.createDataFrame(
+ Dataset<Row> documentDF = sqlContext.createDataFrame(
Arrays.asList(
RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))),
RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))),
@@ -66,9 +66,9 @@ public class JavaWord2VecSuite {
.setVectorSize(3)
.setMinCount(0);
Word2VecModel model = word2Vec.fit(documentDF);
- DataFrame result = model.transform(documentDF);
+ Dataset<Row> result = model.transform(documentDF);
- for (Row r: result.select("result").collect()) {
+ for (Row r: result.select("result").collectAsList()) {
double[] polyFeatures = ((Vector)r.get(0)).toArray();
Assert.assertEquals(polyFeatures.length, 3);
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
index d5c9d120c5..a1575300a8 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaDecisionTreeRegressorSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
DecisionTreeRegressor dt = new DecisionTreeRegressor()
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
index 38d15dc2b7..9477e8d2bf 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaGBTRegressorSuite implements Serializable {
@@ -57,7 +58,7 @@ public class JavaGBTRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
GBTRegressor rf = new GBTRegressor()
.setMaxDepth(2)
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index 4fb0b0d109..9f817515eb 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -28,7 +28,8 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite
.generateLogisticInputAsList;
@@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
private transient JavaRDD<LabeledPoint> datasetRDD;
@Before
@@ -64,7 +65,7 @@ public class JavaLinearRegressionSuite implements Serializable {
assertEquals("auto", lr.getSolver());
LinearRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction");
predictions.collect();
// Check defaults
assertEquals("features", model.getFeaturesCol());
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
index 31be8880c2..a90535d11a 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java
@@ -31,7 +31,8 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite;
import org.apache.spark.ml.impl.TreeTests;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
public class JavaRandomForestRegressorSuite implements Serializable {
@@ -58,7 +59,7 @@ public class JavaRandomForestRegressorSuite implements Serializable {
JavaRDD<LabeledPoint> data = sc.parallelize(
LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
Map<Integer, Integer> categoricalFeatures = new HashMap<>();
- DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
+ Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0);
// This tests setters. Training with various options is tested in Scala.
RandomForestRegressor rf = new RandomForestRegressor()
diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
index 2976b38e45..b8ddf907d0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java
@@ -31,7 +31,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.util.Utils;
@@ -68,7 +68,7 @@ public class JavaLibSVMRelationSuite {
@Test
public void verifyLibSVMDF() {
- DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
+ Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense")
.load(path);
Assert.assertEquals("label", dataset.columns()[0]);
Assert.assertEquals("features", dataset.columns()[1]);
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index 08eeca53f0..24b0097454 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,7 +30,8 @@ import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
+import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
@@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
private transient SQLContext jsql;
- private transient DataFrame dataset;
+ private transient Dataset<Row> dataset;
@Before
public void setUp() {