aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
index 07936eb79b..45101f286c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java
@@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
-import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
@@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable {
jsc = null;
}
- public void validatePrediction(DataFrame predictionAndLabels) {
- for (Row r : predictionAndLabels.collect()) {
+ public void validatePrediction(Dataset<Row> predictionAndLabels) {
+ for (Row r : predictionAndLabels.collectAsList()) {
double prediction = r.getAs(0);
double label = r.getAs(1);
assertEquals(label, prediction, 1E-5);
@@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable {
new StructField("features", new VectorUDT(), false, Metadata.empty())
});
- DataFrame dataset = jsql.createDataFrame(data, schema);
+ Dataset<Row> dataset = jsql.createDataFrame(data, schema);
NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial");
NaiveBayesModel model = nb.fit(dataset);
- DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label");
+ Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label");
validatePrediction(predictionAndLabels);
}
}