diff options
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java | 9 |
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java index ec6b4bf3c0..d499d363f1 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification; import java.io.Serializable; import java.util.Arrays; +import java.util.List; import org.junit.After; import org.junit.Assert; @@ -28,7 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.SQLContext; @@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { @Test public void testMLPC() { - DataFrame dataFrame = sqlContext.createDataFrame( + Dataset<Row> dataFrame = sqlContext.createDataFrame( jsc.parallelize(Arrays.asList( new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)), new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)), @@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable { .setSeed(11L) .setMaxIter(100); MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame); - DataFrame result = model.transform(dataFrame); - Row[] predictionAndLabels = result.select("prediction", "label").collect(); + Dataset<Row> result = model.transform(dataFrame); + List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList(); for (Row r: predictionAndLabels) { Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1)); } |