diff options
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java index 07936eb79b..45101f286c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaNaiveBayesSuite.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertEquals; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; -import org.apache.spark.sql.DataFrame; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; import org.apache.spark.sql.SQLContext; @@ -55,8 +55,8 @@ public class JavaNaiveBayesSuite implements Serializable { jsc = null; } - public void validatePrediction(DataFrame predictionAndLabels) { - for (Row r : predictionAndLabels.collect()) { + public void validatePrediction(Dataset<Row> predictionAndLabels) { + for (Row r : predictionAndLabels.collectAsList()) { double prediction = r.getAs(0); double label = r.getAs(1); assertEquals(label, prediction, 1E-5); @@ -88,11 +88,11 @@ public class JavaNaiveBayesSuite implements Serializable { new StructField("features", new VectorUDT(), false, Metadata.empty()) }); - DataFrame dataset = jsql.createDataFrame(data, schema); + Dataset<Row> dataset = jsql.createDataFrame(data, schema); NaiveBayes nb = new NaiveBayes().setSmoothing(0.5).setModelType("multinomial"); NaiveBayesModel model = nb.fit(dataset); - DataFrame predictionAndLabels = model.transform(dataset).select("prediction", "label"); + Dataset<Row> predictionAndLabels = model.transform(dataset).select("prediction", "label"); validatePrediction(predictionAndLabels); } } |