diff options
Diffstat (limited to 'mllib/src/test/java/org')
29 files changed, 118 insertions, 99 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 0a8c9e5954..60a4a1d2ea 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -17,6 +17,8 @@ package org.apache.spark.ml; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -26,7 +28,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -37,7 +38,7 @@ public class JavaPipelineSuite { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; @Before public void setUp() { @@ -65,7 +66,7 @@ public class JavaPipelineSuite { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java index 40b9c35adc..0d923dfeff 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaDecisionTreeClassifierSuite.java @@ -21,6 +21,8 @@ import java.io.Serializable; import java.util.HashMap; import java.util.Map; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -30,7 +32,6 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; public class JavaDecisionTreeClassifierSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaDecisionTreeClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. DecisionTreeClassifier dt = new DecisionTreeClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java index 59b6fba7a9..f470f4ada6 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaGBTClassifierSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTClassifierSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaGBTClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. GBTClassifier rf = new GBTClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fd22eb6dca..536f0dc58f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; private double eps = 1e-5; @@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; @@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collect()) { + Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collectAsList()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); Assert.assertEquals(raw.size(), 2); @@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collect()) { + Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collectAsList()) { double pred = row.getDouble(0); Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index ec6b4bf3c0..d499d363f1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -28,7 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { @Test public void testMLPC() { - DataFrame dataFrame = sqlContext.createDataFrame( + Dataset<Row> dataFrame = sqlContext.createDataFrame( jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), @@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { .setSeed(11L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); - DataFrame result = model.transform(dataFrame); - Row[] predictionAndLabels = result.select("prediction", "label").collect(); + Dataset<Row> result = model.transform(dataFrame); + List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList(); for (Row r: predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 07936eb79b..45101f286c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable { jsc = null; } - public void validatePrediction(DataFrame predictionAndLabels) { - for (Row r : predictionAndLabels.collect()) { + public void validatePrediction(Dataset<Row> predictionAndLabels) { + for (Row r : predictionAndLabels.collectAsList()) { double prediction = r.getAs(0); double label = r.getAs(1); assertEquals(label, prediction, 1E-5); @@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); - DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label"); validatePrediction(predictionAndLabels); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java index cbabafe1b5..d493a7fcec 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaOneVsRestSuite.java @@ -20,6 +20,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.List; +import org.apache.spark.sql.Row; import scala.collection.JavaConverters; import org.junit.After; @@ -31,14 +32,14 @@ import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateMultinomialLogisticInput; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.SQLContext; public class JavaOneVsRestSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before @@ -75,7 +76,7 @@ public class JavaOneVsRestSuite implements Serializable { Assert.assertEquals(ova.getLabelCol() , "label"); Assert.assertEquals(ova.getPredictionCol() , "prediction"); OneVsRestModel ovaModel = ova.fit(dataset); - DataFrame predictions = ovaModel.transform(dataset).select("label", "prediction"); + Dataset<Row> predictions = ovaModel.transform(dataset).select("label", "prediction"); predictions.collectAsList(); Assert.assertEquals(ovaModel.getLabelCol(), "label"); Assert.assertEquals(ovaModel.getPredictionCol() , "prediction"); diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java index 5485fcbf01..9a63cef2a8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaRandomForestClassifierSuite.java @@ -31,7 +31,8 @@ import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestClassifierSuite implements Serializable { @@ -58,7 +59,7 @@ public class JavaRandomForestClassifierSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 2); // This tests setters. Training with various options is tested in Scala. RandomForestClassifier rf = new RandomForestClassifier() diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java index cc5a4ef4c2..a3fcdb54ee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java @@ -29,14 +29,15 @@ import static org.junit.Assert.assertTrue; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaKMeansSuite implements Serializable { private transient int k = 5; private transient JavaSparkContext sc; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient SQLContext sql; @Before @@ -61,7 +62,7 @@ public class JavaKMeansSuite implements Serializable { Vector[] centers = model.clusterCenters(); assertEquals(k, centers.length); - DataFrame transformed = model.transform(dataset); + Dataset<Row> transformed = model.transform(dataset); List<String> columns = Arrays.asList(transformed.columns()); List<String> expectedColumns = Arrays.asList("features", "prediction"); for (String column: expectedColumns) { diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java index d707bdee99..77e3a489a9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaBucketizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -25,7 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public class JavaBucketizerSuite { StructType schema = new StructType(new StructField[] { new StructField("feature", DataTypes.DoubleType, false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame( + Dataset<Row> dataset = jsql.createDataFrame( Arrays.asList( RowFactory.create(-0.5), RowFactory.create(-0.3), @@ -70,7 +71,7 @@ public class JavaBucketizerSuite { .setOutputCol("result") .setSplits(splits); - Row[] result = bucketizer.transform(dataset).select("result").collect(); + List<Row> result = bucketizer.transform(dataset).select("result").collectAsList(); for (Row r : result) { double index = r.getDouble(0); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java index 63e5c93798..ed1ad4c3a3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaDCTSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import edu.emory.mathcs.jtransforms.dct.DoubleDCT_1D; import org.junit.After; @@ -29,7 +30,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -56,7 +57,7 @@ public class JavaDCTSuite { @Test public void javaCompatibilityTest() { double[] input = new double[] {1D, 2D, 3D, 4D}; - DataFrame dataset = jsql.createDataFrame( + Dataset<Row> dataset = jsql.createDataFrame( Arrays.asList(RowFactory.create(Vectors.dense(input))), new StructType(new StructField[]{ new StructField("vec", (new VectorUDT()), false, Metadata.empty()) @@ -69,8 +70,8 @@ public class JavaDCTSuite { .setInputCol("vec") .setOutputCol("resultVec"); - Row[] result = dct.transform(dataset).select("resultVec").collect(); - Vector resultVec = result[0].getAs("resultVec"); + List<Row> result = dct.transform(dataset).select("resultVec").collectAsList(); + Vector resultVec = result.get(0).getAs("resultVec"); Assert.assertArrayEquals(expectedResult, resultVec.toArray(), 1e-6); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java index 5932017f8f..6e2cc7e887 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaHashingTFSuite.java @@ -27,7 +27,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,21 +65,21 @@ public class JavaHashingTFSuite { new StructField("sentence", DataTypes.StringType, false, Metadata.empty()) }); - DataFrame sentenceData = jsql.createDataFrame(data, schema); + Dataset<Row> sentenceData = jsql.createDataFrame(data, schema); Tokenizer tokenizer = new Tokenizer() .setInputCol("sentence") .setOutputCol("words"); - DataFrame wordsData = tokenizer.transform(sentenceData); + Dataset<Row> wordsData = tokenizer.transform(sentenceData); int numFeatures = 20; HashingTF hashingTF = new HashingTF() .setInputCol("words") .setOutputCol("rawFeatures") .setNumFeatures(numFeatures); - DataFrame featurizedData = hashingTF.transform(wordsData); + Dataset<Row> featurizedData = hashingTF.transform(wordsData); IDF idf = new IDF().setInputCol("rawFeatures").setOutputCol("features"); IDFModel idfModel = idf.fit(featurizedData); - DataFrame rescaledData = idfModel.transform(featurizedData); - for (Row r : rescaledData.select("features", "label").take(3)) { + Dataset<Row> rescaledData = idfModel.transform(featurizedData); + for (Row r : rescaledData.select("features", "label").takeAsList(3)) { Vector features = r.getAs(0); Assert.assertEquals(features.size(), numFeatures); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java index e17d549c50..5bbd9634b2 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaNormalizerSuite.java @@ -26,7 +26,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaNormalizerSuite { @@ -53,17 +54,17 @@ public class JavaNormalizerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) )); - DataFrame dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); + Dataset<Row> dataFrame = jsql.createDataFrame(points, VectorIndexerSuite.FeatureData.class); Normalizer normalizer = new Normalizer() .setInputCol("features") .setOutputCol("normFeatures"); // Normalize each Vector using $L^2$ norm. - DataFrame l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); + Dataset<Row> l2NormData = normalizer.transform(dataFrame, normalizer.p().w(2)); l2NormData.count(); // Normalize each Vector using $L^\infty$ norm. - DataFrame lInfNormData = + Dataset<Row> lInfNormData = normalizer.transform(dataFrame, normalizer.p().w(Double.POSITIVE_INFINITY)); lInfNormData.count(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java index e8f329f9cf..1389d17e7e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPCASuite.java @@ -35,7 +35,7 @@ import org.apache.spark.mllib.linalg.distributed.RowMatrix; import org.apache.spark.mllib.linalg.Matrix; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -100,7 +100,7 @@ public class JavaPCASuite implements Serializable { } ); - DataFrame df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); + Dataset<Row> df = sqlContext.createDataFrame(featuresExpected, VectorPair.class); PCAModel pca = new PCA() .setInputCol("features") .setOutputCol("pca_features") diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java index e22d117032..6a8bb64801 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaPolynomialExpansionSuite.java @@ -29,7 +29,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -77,11 +77,11 @@ public class JavaPolynomialExpansionSuite { new StructField("expected", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); - Row[] pairs = polyExpansion.transform(dataset) + List<Row> pairs = polyExpansion.transform(dataset) .select("polyFeatures", "expected") - .collect(); + .collectAsList(); for (Row r : pairs) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java index ed74363f59..3f6fc333e4 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStandardScalerSuite.java @@ -26,7 +26,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; public class JavaStandardScalerSuite { @@ -53,7 +54,7 @@ public class JavaStandardScalerSuite { new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 3.0)), new VectorIndexerSuite.FeatureData(Vectors.dense(1.0, 4.0)) ); - DataFrame dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), + Dataset<Row> dataFrame = jsql.createDataFrame(jsc.parallelize(points, 2), VectorIndexerSuite.FeatureData.class); StandardScaler scaler = new StandardScaler() .setInputCol("features") @@ -65,7 +66,7 @@ public class JavaStandardScalerSuite { StandardScalerModel scalerModel = scaler.fit(dataFrame); // Normalize each feature to have unit standard deviation. - DataFrame scaledData = scalerModel.transform(dataFrame); + Dataset<Row> scaledData = scalerModel.transform(dataFrame); scaledData.count(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java index 139d1d005a..5812037dee 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStopWordsRemoverSuite.java @@ -25,7 +25,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -65,7 +65,7 @@ public class JavaStopWordsRemoverSuite { StructType schema = new StructType(new StructField[] { new StructField("raw", DataTypes.createArrayType(DataTypes.StringType), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); remover.transform(dataset).collect(); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java index 153a08a4cd..431779cd2e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaStringIndexerSuite.java @@ -26,7 +26,7 @@ import org.junit.Before; import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -58,16 +58,16 @@ public class JavaStringIndexerSuite { }); List<Row> data = Arrays.asList( cr(0, "a"), cr(1, "b"), cr(2, "c"), cr(3, "a"), cr(4, "a"), cr(5, "c")); - DataFrame dataset = sqlContext.createDataFrame(data, schema); + Dataset<Row> dataset = sqlContext.createDataFrame(data, schema); StringIndexer indexer = new StringIndexer() .setInputCol("label") .setOutputCol("labelIndex"); - DataFrame output = indexer.fit(dataset).transform(dataset); + Dataset<Row> output = indexer.fit(dataset).transform(dataset); - Assert.assertArrayEquals( - new Row[] { cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0) }, - output.orderBy("id").select("id", "labelIndex").collect()); + Assert.assertEquals( + Arrays.asList(cr(0, 0.0), cr(1, 2.0), cr(2, 1.0), cr(3, 0.0), cr(4, 0.0), cr(5, 1.0)), + output.orderBy("id").select("id", "labelIndex").collectAsList()); } /** An alias for RowFactory.create. */ diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java index c407d98f1b..83d16cbd0e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaTokenizerSuite.java @@ -18,6 +18,7 @@ package org.apache.spark.ml.feature; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -26,7 +27,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -61,11 +62,11 @@ public class JavaTokenizerSuite { new TokenizerTestData("Test of tok.", new String[] {"Test", "tok."}), new TokenizerTestData("Te,st. punct", new String[] {"Te,st.", "punct"}) )); - DataFrame dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); + Dataset<Row> dataset = jsql.createDataFrame(rdd, TokenizerTestData.class); - Row[] pairs = myRegExTokenizer.transform(dataset) + List<Row> pairs = myRegExTokenizer.transform(dataset) .select("tokens", "wantedTokens") - .collect(); + .collectAsList(); for (Row r : pairs) { Assert.assertEquals(r.get(0), r.get(1)); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java index f8ba84ef77..e45e198043 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorAssemblerSuite.java @@ -28,7 +28,7 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -64,11 +64,11 @@ public class JavaVectorAssemblerSuite { Row row = RowFactory.create( 0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, new int[] {1}, new double[] {3.0}), 10L); - DataFrame dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); + Dataset<Row> dataset = sqlContext.createDataFrame(Arrays.asList(row), schema); VectorAssembler assembler = new VectorAssembler() .setInputCols(new String[] {"x", "y", "z", "n"}) .setOutputCol("features"); - DataFrame output = assembler.transform(dataset); + Dataset<Row> output = assembler.transform(dataset); Assert.assertEquals( Vectors.sparse(6, new int[] {1, 2, 4, 5}, new double[] {1.0, 2.0, 3.0, 10.0}), output.select("features").first().<Vector>getAs(0)); diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java index bfcca62fa1..fec6cac8be 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorIndexerSuite.java @@ -30,7 +30,8 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.feature.VectorIndexerSuite.FeatureData; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -57,7 +58,7 @@ public class JavaVectorIndexerSuite implements Serializable { new FeatureData(Vectors.dense(1.0, 4.0)) ); SQLContext sqlContext = new SQLContext(sc); - DataFrame data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); + Dataset<Row> data = sqlContext.createDataFrame(sc.parallelize(points, 2), FeatureData.class); VectorIndexer indexer = new VectorIndexer() .setInputCol("features") .setOutputCol("indexed") @@ -66,6 +67,6 @@ public class JavaVectorIndexerSuite implements Serializable { Assert.assertEquals(model.numFeatures(), 2); Map<Integer, Map<Double, Integer>> categoryMaps = model.javaCategoryMaps(); Assert.assertEquals(categoryMaps.size(), 1); - DataFrame indexedData = model.transform(data); + Dataset<Row> indexedData = model.transform(data); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java index 786c11c412..b87605ebfd 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaVectorSlicerSuite.java @@ -31,7 +31,7 @@ import org.apache.spark.ml.attribute.AttributeGroup; import org.apache.spark.ml.attribute.NumericAttribute; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -68,16 +68,17 @@ public class JavaVectorSlicerSuite { RowFactory.create(Vectors.dense(-2.0, 2.3, 0.0)) ); - DataFrame dataset = jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); + Dataset<Row> dataset = + jsql.createDataFrame(data, (new StructType()).add(group.toStructField())); VectorSlicer vectorSlicer = new VectorSlicer() .setInputCol("userFeatures").setOutputCol("features"); vectorSlicer.setIndices(new int[]{1}).setNames(new String[]{"f3"}); - DataFrame output = vectorSlicer.transform(dataset); + Dataset<Row> output = vectorSlicer.transform(dataset); - for (Row r : output.select("userFeatures", "features").take(2)) { + for (Row r : output.select("userFeatures", "features").takeRows(2)) { Vector features = r.getAs(1); Assert.assertEquals(features.size(), 2); } diff --git a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java index b292b1b06d..7517b70cc9 100644 --- a/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/feature/JavaWord2VecSuite.java @@ -26,7 +26,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -53,7 +53,7 @@ public class JavaWord2VecSuite { StructType schema = new StructType(new StructField[]{ new StructField("text", new ArrayType(DataTypes.StringType, true), false, Metadata.empty()) }); - DataFrame documentDF = sqlContext.createDataFrame( + Dataset<Row> documentDF = sqlContext.createDataFrame( Arrays.asList( RowFactory.create(Arrays.asList("Hi I heard about Spark".split(" "))), RowFactory.create(Arrays.asList("I wish Java could use case classes".split(" "))), @@ -66,9 +66,9 @@ public class JavaWord2VecSuite { .setVectorSize(3) .setMinCount(0); Word2VecModel model = word2Vec.fit(documentDF); - DataFrame result = model.transform(documentDF); + Dataset<Row> result = model.transform(documentDF); - for (Row r: result.select("result").collect()) { + for (Row r: result.select("result").collectAsList()) { double[] polyFeatures = ((Vector)r.get(0)).toArray(); Assert.assertEquals(polyFeatures.length, 3); } diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java index d5c9d120c5..a1575300a8 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaDecisionTreeRegressorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaDecisionTreeRegressorSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaDecisionTreeRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. DecisionTreeRegressor dt = new DecisionTreeRegressor() diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java index 38d15dc2b7..9477e8d2bf 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaGBTRegressorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaGBTRegressorSuite implements Serializable { @@ -57,7 +58,7 @@ public class JavaGBTRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); GBTRegressor rf = new GBTRegressor() .setMaxDepth(2) diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java index 4fb0b0d109..9f817515eb 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java @@ -28,7 +28,8 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite .generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaLinearRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; @Before @@ -64,7 +65,7 @@ public class JavaLinearRegressionSuite implements Serializable { assertEquals("auto", lr.getSolver()); LinearRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, prediction FROM prediction"); predictions.collect(); // Check defaults assertEquals("features", model.getFeaturesCol()); diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java index 31be8880c2..a90535d11a 100644 --- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaRandomForestRegressorSuite.java @@ -31,7 +31,8 @@ import org.apache.spark.mllib.classification.LogisticRegressionSuite; import org.apache.spark.ml.impl.TreeTests; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; public class JavaRandomForestRegressorSuite implements Serializable { @@ -58,7 +59,7 @@ public class JavaRandomForestRegressorSuite implements Serializable { JavaRDD<LabeledPoint> data = sc.parallelize( LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache(); Map<Integer, Integer> categoricalFeatures = new HashMap<>(); - DataFrame dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); + Dataset<Row> dataFrame = TreeTests.setMetadata(data, categoricalFeatures, 0); // This tests setters. Training with various options is tested in Scala. RandomForestRegressor rf = new RandomForestRegressor() diff --git a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java index 2976b38e45..b8ddf907d0 100644 --- a/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/source/libsvm/JavaLibSVMRelationSuite.java @@ -31,7 +31,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.DenseVector; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import org.apache.spark.util.Utils; @@ -68,7 +68,7 @@ public class JavaLibSVMRelationSuite { @Test public void verifyLibSVMDF() { - DataFrame dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") + Dataset<Row> dataset = sqlContext.read().format("libsvm").option("vectorType", "dense") .load(path); Assert.assertEquals("label", dataset.columns()[0]); Assert.assertEquals("features", dataset.columns()[1]); diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index 08eeca53f0..24b0097454 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,7 +30,8 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; +import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; @@ -38,7 +39,7 @@ public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; @Before public void setUp() { |