diff options
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java')
-rw-r--r-- | mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index fd22eb6dca..536f0dc58f 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -31,16 +31,16 @@ import org.apache.spark.api.java.JavaSparkContext; import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; import org.apache.spark.mllib.linalg.Vector; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.DataFrame; -import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; +import org.apache.spark.sql.SQLContext; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; private transient SQLContext jsql; - private transient DataFrame dataset; + private transient Dataset<Row> dataset; private transient JavaRDD<LabeledPoint> datasetRDD; private double eps = 1e-5; @@ -67,7 +67,7 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(lr.getLabelCol(), "label"); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); + Dataset<Row> predictions = jsql.sql("SELECT label, probability, prediction FROM prediction"); predictions.collectAsList(); // Check defaults Assert.assertEquals(0.5, model.getThreshold(), eps); @@ -96,14 +96,14 @@ public class JavaLogisticRegressionSuite implements Serializable { // Modify model params, and check that the params worked. model.setThreshold(1.0); model.transform(dataset).registerTempTable("predAllZero"); - DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); + Dataset<Row> predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero"); for (Row r: predAllZero.collectAsList()) { Assert.assertEquals(0.0, r.getDouble(0), eps); } // Call transform with params, and check that the params worked. model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb")) .registerTempTable("predNotAllZero"); - DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); + Dataset<Row> predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero"); boolean foundNonZero = false; for (Row r: predNotAllZero.collectAsList()) { if (r.getDouble(0) != 0.0) foundNonZero = true; @@ -129,8 +129,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(2, model.numClasses()); model.transform(dataset).registerTempTable("transformed"); - DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); - for (Row row: trans1.collect()) { + Dataset<Row> trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed"); + for (Row row: trans1.collectAsList()) { Vector raw = (Vector)row.get(0); Vector prob = (Vector)row.get(1); Assert.assertEquals(raw.size(), 2); @@ -140,8 +140,8 @@ public class JavaLogisticRegressionSuite implements Serializable { Assert.assertEquals(0, Math.abs(prob.apply(0) - (1.0 - probFromRaw1)), eps); } - DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); - for (Row row: trans2.collect()) { + Dataset<Row> trans2 = jsql.sql("SELECT prediction, probability FROM transformed"); + for (Row row: trans2.collectAsList()) { double pred = row.getDouble(0); Vector prob = (Vector)row.get(1); double probOfPred = prob.apply((int)pred); |