diff options
author | Reynold Xin <rxin@databricks.com> | 2015-01-16 21:09:06 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-01-16 21:09:06 -0800 |
commit | 61b427d4b1c4934bd70ed4da844b64f0e9a377aa (patch) | |
tree | 5068b31119fa7e2256422d4fdf18703ae64d7ab2 /mllib | |
parent | ee1c1f3a04dfe80843432e349f01178e47f02443 (diff) | |
download | spark-61b427d4b1c4934bd70ed4da844b64f0e9a377aa.tar.gz spark-61b427d4b1c4934bd70ed4da844b64f0e9a377aa.tar.bz2 spark-61b427d4b1c4934bd70ed4da844b64f0e9a377aa.zip |
[SPARK-5193][SQL] Remove Spark SQL Java-specific API.
After the following patches, the main (Scala) API is now usable for Java users directly.
https://github.com/apache/spark/pull/4056
https://github.com/apache/spark/pull/4054
https://github.com/apache/spark/pull/4049
https://github.com/apache/spark/pull/4030
https://github.com/apache/spark/pull/3965
https://github.com/apache/spark/pull/3958
Author: Reynold Xin <rxin@databricks.com>
Closes #4065 from rxin/sql-java-api and squashes the following commits:
b1fd860 [Reynold Xin] Fix Mima
6d86578 [Reynold Xin] Ok one more attempt in fixing Python...
e8f1455 [Reynold Xin] Fix Python again...
3e53f91 [Reynold Xin] Fixed Python.
83735da [Reynold Xin] Fix BigDecimal test.
e9f1de3 [Reynold Xin] Use scala BigDecimal.
500d2c4 [Reynold Xin] Fix Decimal.
ba3bfa2 [Reynold Xin] Updated javadoc for RowFactory.
c4ae1c5 [Reynold Xin] [SPARK-5193][SQL] Remove Spark SQL Java-specific API.
Diffstat (limited to 'mllib')
5 files changed, 24 insertions, 89 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala index fdbee743e8..77d230eb4a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala @@ -18,12 +18,10 @@ package org.apache.spark.ml import scala.annotation.varargs -import scala.collection.JavaConverters._ import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param.{ParamMap, ParamPair, Params} import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.api.java.JavaSchemaRDD /** * :: AlphaComponent :: @@ -66,40 +64,4 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params { def fit(dataset: SchemaRDD, paramMaps: Array[ParamMap]): Seq[M] = { paramMaps.map(fit(dataset, _)) } - - // Java-friendly versions of fit. - - /** - * Fits a single model to the input data with optional parameters. - * - * @param dataset input dataset - * @param paramPairs optional list of param pairs (overwrite embedded params) - * @return fitted model - */ - @varargs - def fit(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): M = { - fit(dataset.schemaRDD, paramPairs: _*) - } - - /** - * Fits a single model to the input data with provided parameter map. - * - * @param dataset input dataset - * @param paramMap parameter map - * @return fitted model - */ - def fit(dataset: JavaSchemaRDD, paramMap: ParamMap): M = { - fit(dataset.schemaRDD, paramMap) - } - - /** - * Fits multiple models to the input data with multiple sets of parameters. - * - * @param dataset input dataset - * @param paramMaps an array of parameter maps - * @return fitted models, matching the input parameter maps - */ - def fit(dataset: JavaSchemaRDD, paramMaps: Array[ParamMap]): java.util.List[M] = { - fit(dataset.schemaRDD, paramMaps).asJava - } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 1331b91240..af56f9c435 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -23,7 +23,6 @@ import org.apache.spark.Logging import org.apache.spark.annotation.AlphaComponent import org.apache.spark.ml.param._ import org.apache.spark.sql.SchemaRDD -import org.apache.spark.sql.api.java.JavaSchemaRDD import org.apache.spark.sql.catalyst.analysis.Star import org.apache.spark.sql.catalyst.expressions.ScalaUdf import org.apache.spark.sql.types._ @@ -55,29 +54,6 @@ abstract class Transformer extends PipelineStage with Params { * @return transformed dataset */ def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD - - // Java-friendly versions of transform. - - /** - * Transforms the dataset with optional parameters. - * @param dataset input datset - * @param paramPairs optional list of param pairs, overwrite embedded params - * @return transformed dataset - */ - @varargs - def transform(dataset: JavaSchemaRDD, paramPairs: ParamPair[_]*): JavaSchemaRDD = { - transform(dataset.schemaRDD, paramPairs: _*).toJavaSchemaRDD - } - - /** - * Transforms the dataset with provided parameter map as additional parameters. - * @param dataset input dataset - * @param paramMap additional parameters, overwrite embedded params - * @return transformed dataset - */ - def transform(dataset: JavaSchemaRDD, paramMap: ParamMap): JavaSchemaRDD = { - transform(dataset.schemaRDD, paramMap).toJavaSchemaRDD - } } /** diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java index 42846677ed..47f1f46c6c 100644 --- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java @@ -26,10 +26,9 @@ import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.feature.StandardScaler; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; /** * Test Pipeline construction and fitting in Java. @@ -37,13 +36,13 @@ import static org.apache.spark.mllib.classification.LogisticRegressionSuite public class JavaPipelineSuite { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaPipelineSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); JavaRDD<LabeledPoint> points = jsc.parallelize(generateLogisticInputAsList(1.0, 1.0, 100, 42), 2); dataset = jsql.applySchema(points, LabeledPoint.class); @@ -66,7 +65,7 @@ public class JavaPipelineSuite { .setStages(new PipelineStage[] {scaler, lr}); PipelineModel model = pipeline.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collectAsList(); } } diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java index 76eb7f0032..2eba83335b 100644 --- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java @@ -26,21 +26,20 @@ import org.junit.Test; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaLogisticRegressionSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); } @@ -56,8 +55,8 @@ public class JavaLogisticRegressionSuite implements Serializable { LogisticRegression lr = new LogisticRegression(); LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset).registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collectAsList(); } @Test @@ -68,8 +67,8 @@ public class JavaLogisticRegressionSuite implements Serializable { LogisticRegressionModel model = lr.fit(dataset); model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold .registerTempTable("prediction"); - JavaSchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); - predictions.collect(); + SchemaRDD predictions = jsql.sql("SELECT label, score, prediction FROM prediction"); + predictions.collectAsList(); } @Test diff --git a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java index a266ebd207..a9f1c4a2c3 100644 --- a/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java +++ b/mllib/src/test/java/org/apache/spark/ml/tuning/JavaCrossValidatorSuite.java @@ -30,21 +30,20 @@ import org.apache.spark.ml.classification.LogisticRegression; import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator; import org.apache.spark.ml.param.ParamMap; import org.apache.spark.mllib.regression.LabeledPoint; -import org.apache.spark.sql.api.java.JavaSQLContext; -import org.apache.spark.sql.api.java.JavaSchemaRDD; -import static org.apache.spark.mllib.classification.LogisticRegressionSuite - .generateLogisticInputAsList; +import org.apache.spark.sql.SchemaRDD; +import org.apache.spark.sql.SQLContext; +import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList; public class JavaCrossValidatorSuite implements Serializable { private transient JavaSparkContext jsc; - private transient JavaSQLContext jsql; - private transient JavaSchemaRDD dataset; + private transient SQLContext jsql; + private transient SchemaRDD dataset; @Before public void setUp() { jsc = new JavaSparkContext("local", "JavaCrossValidatorSuite"); - jsql = new JavaSQLContext(jsc); + jsql = new SQLContext(jsc); List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42); dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class); } |