aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java21
1 files changed, 10 insertions, 11 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 76eb7f0032..2eba83335b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -26,21 +26,20 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient SchemaRDD dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
}
@@ -56,8 +55,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
@Test
@@ -68,8 +67,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
.registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
@Test