aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java9
1 files changed, 5 insertions, 4 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index ec6b4bf3c0..d499d363f1 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification;
import java.io.Serializable;
import java.util.Arrays;
+import java.util.List;
import org.junit.After;
import org.junit.Assert;
@@ -28,7 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
@@ -52,7 +53,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
@Test
public void testMLPC() {
- DataFrame dataFrame = sqlContext.createDataFrame(
+ Dataset<Row> dataFrame = sqlContext.createDataFrame(
jsc.parallelize(Arrays.asList(
new LabeledPoint(0.0, Vectors.dense(0.0, 0.0)),
new LabeledPoint(1.0, Vectors.dense(0.0, 1.0)),
@@ -65,8 +66,8 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
.setSeed(11L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
- DataFrame result = model.transform(dataFrame);
- Row[] predictionAndLabels = result.select("prediction", "label").collect();
+ Dataset<Row> result = model.transform(dataFrame);
+ List<Row> predictionAndLabels = result.select("prediction", "label").collectAsList();
for (Row r: predictionAndLabels) {
Assert.assertEquals((int) r.getDouble(0), (int) r.getDouble(1));
}