aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-01-16 21:09:06 -0800
committerReynold Xin <rxin@databricks.com>2015-01-16 21:09:06 -0800
commit61b427d4b1c4934bd70ed4da844b64f0e9a377aa (patch)
tree5068b31119fa7e2256422d4fdf18703ae64d7ab2 /mllib
parentee1c1f3a04dfe80843432e349f01178e47f02443 (diff)
downloadspark-61b427d4b1c4934bd70ed4da844b64f0e9a377aa.tar.gz
spark-61b427d4b1c4934bd70ed4da844b64f0e9a377aa.tar.bz2
spark-61b427d4b1c4934bd70ed4da844b64f0e9a377aa.zip
[SPARK-5193][SQL] Remove Spark SQL Java-specific API.
After the following patches, the main (Scala) API is now usable for Java users directly. https://github.com/apache/spark/pull/4056 https://github.com/apache/spark/pull/4054 https://github.com/apache/spark/pull/4049 https://github.com/apache/spark/pull/4030 https://github.com/apache/spark/pull/3965 https://github.com/apache/spark/pull/3958 Author: Reynold Xin <rxin@databricks.com> Closes #4065 from rxin/sql-java-api and squashes the following commits: b1fd860 [Reynold Xin] Fix Mima 6d86578 [Reynold Xin] Ok one more attempt in fixing Python... e8f1455 [Reynold Xin] Fix Python again... 3e53f91 [Reynold Xin] Fixed Python. 83735da [Reynold Xin] Fix BigDecimal test. e9f1de3 [Reynold Xin] Use scala BigDecimal. 500d2c4 [Reynold Xin] Fix Decimal. ba3bfa2 [Reynold Xin] Updated javadoc for RowFactory. c4ae1c5 [Reynold Xin] [SPARK-5193][SQL] Remove Spark SQL Java-specific API.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala38
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala24
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java17
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java21
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java13
5 files changed, 24 insertions, 89 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index fdbee743e8..77d230eb4a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -18,12 +18,10 @@
package org.apache.spark.ml
import scala.annotation.varargs
-import scala.collection.JavaConverters._
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{ParamMap, ParamPair, Params}
import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.api.java.JavaSchemaRDD
/**
* :: AlphaComponent ::
@@ -66,40 +64,4 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = {
paramMaps.map(fit(dataset, _))
}
-
- // Java-friendly versions of fit.
-
- /**
- * Fits a single model to the input data with optional parameters.
- *
- * @param dataset input dataset
- * @param paramPairs optional list of param pairs (overwrite embedded params)
- * @return fitted model
- */
- @varargs
- def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = {
- fit(dataset.schemaRDD, paramPairs: _*)
- }
-
- /**
- * Fits a single model to the input data with provided parameter map.
- *
- * @param dataset input dataset
- * @param paramMap parameter map
- * @return fitted model
- */
- def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = {
- fit(dataset.schemaRDD, paramMap)
- }
-
- /**
- * Fits multiple models to the input data with multiple sets of parameters.
- *
- * @param dataset input dataset
- * @param paramMaps an array of parameter maps
- * @return fitted models, matching the input parameter maps
- */
- def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = {
- fit(dataset.schemaRDD, paramMaps).asJava
- }
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 1331b91240..af56f9c435 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -23,7 +23,6 @@ import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.SchemaRDD
-import org.apache.spark.sql.api.java.JavaSchemaRDD
import org.apache.spark.sql.catalyst.analysis.Star
import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.types._
@@ -55,29 +54,6 @@ abstract class Transformer extends PipelineStage with Params {
* @return transformed dataset
*/
def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD
-
- // Java-friendly versions of transform.
-
- /**
- * Transforms the dataset with optional parameters.
- * @param dataset input datset
- * @param paramPairs optional list of param pairs, overwrite embedded params
- * @return transformed dataset
- */
- @varargs
- def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = {
- transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD
- }
-
- /**
- * Transforms the dataset with provided parameter map as additional parameters.
- * @param dataset input dataset
- * @param paramMap additional parameters, overwrite embedded params
- * @return transformed dataset
- */
- def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = {
- transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD
- }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 42846677ed..47f1f46c6c 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -26,10 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.feature.StandardScaler;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
/**
* Test Pipeline construction and fitting in Java.
@@ -37,13 +36,13 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite
public class JavaPipelineSuite {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient SchemaRDD dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaPipelineSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
JavaRDD<LabeledPoint> points =
jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2);
dataset = jsql.applySchema(points, LabeledPoint.class);
@@ -66,7 +65,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 76eb7f0032..2eba83335b 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -26,21 +26,20 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaLogisticRegressionSuite implements Serializable {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient SchemaRDD dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
}
@@ -56,8 +55,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegression lr = new LogisticRegression();
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
@Test
@@ -68,8 +67,8 @@ public class JavaLogisticRegressionSuite implements Serializable {
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
.registerTempTable("prediction");
- JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collect();
+ SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ predictions.collectAsList();
}
@Test
diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
index a266ebd207..a9f1c4a2c3 100644
--- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java
@@ -30,21 +30,20 @@ import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;
import org.apache.spark.ml.param.ParamMap;
import org.apache.spark.mllib.regression.LabeledPoint;
-import org.apache.spark.sql.api.java.JavaSQLContext;
-import org.apache.spark.sql.api.java.JavaSchemaRDD;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite
- .generateLogisticInputAsList;
+import org.apache.spark.sql.SchemaRDD;
+import org.apache.spark.sql.SQLContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
public class JavaCrossValidatorSuite implements Serializable {
private transient JavaSparkContext jsc;
- private transient JavaSQLContext jsql;
- private transient JavaSchemaRDD dataset;
+ private transient SQLContext jsql;
+ private transient SchemaRDD dataset;
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite");
- jsql = new JavaSQLContext(jsc);
+ jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
}