diff options
author | Xiangrui Meng <meng@databricks.com> | 2016-04-11 09:28:28 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-11 09:28:28 -0700 |
commit | 1c751fcf488189e5176546fe0d00f560ffcf1cec (patch) | |
tree | 40863cdb5ac52b6fdc74a22d64853ea07826e6be /examples/src | |
parent | e82d95bf63f57cefa02dc545ceb451ecdeedce28 (diff) | |
download | spark-1c751fcf488189e5176546fe0d00f560ffcf1cec.tar.gz spark-1c751fcf488189e5176546fe0d00f560ffcf1cec.tar.bz2 spark-1c751fcf488189e5176546fe0d00f560ffcf1cec.zip |
[SPARK-14500] [ML] Accept Dataset[_] instead of DataFrame in MLlib APIs
## What changes were proposed in this pull request?
This PR updates MLlib APIs to accept `Dataset[_]` as input where `DataFrame` was the input type. This PR doesn't change the output type. In Java, `Dataset[_]` maps to `Dataset<?>`, which includes `Dataset<Row>`. Some implementations were changed in order to return `DataFrame`. Tests and examples were updated. Note that this is a breaking change for subclasses of Transformer/Estimator.
Lol, we don't have to rename the input argument, which has been `dataset` since Spark 1.2.
TODOs:
- [x] update MiMaExcludes (seems all covered by explicit filters from SPARK-13920)
- [x] Python
- [x] add a new test to accept Dataset[LabeledPoint]
- [x] remove unused imports of Dataset
## How was this patch tested?
Exiting unit tests with some modifications.
cc: rxin jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #12274 from mengxr/SPARK-14500.
Diffstat (limited to 'examples/src')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java | 2 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala | 4 |
2 files changed, 3 insertions, 3 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index fbd8817669..0ba94786d4 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -146,7 +146,7 @@ class MyJavaLogisticRegression // This method is used by fit(). // In Java, we have to make it public since Java does not understand Scala's protected modifier. - public MyJavaLogisticRegressionModel train(Dataset<Row> dataset) { + public MyJavaLogisticRegressionModel train(Dataset<?> dataset) { // Extract columns from data using helper method. JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD(); diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index c1f63c6a1d..8d127f9b35 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -24,7 +24,7 @@ import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util.Identifiable import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} /** * A simple example demonstrating how to write your own learning algorithm using Estimator, @@ -120,7 +120,7 @@ private class MyLogisticRegression(override val uid: String) def setMaxIter(value: Int): this.type = set(maxIter, value) // This method is used by fit() - override protected def train(dataset: DataFrame): MyLogisticRegressionModel = { + override protected def train(dataset: Dataset[_]): MyLogisticRegressionModel = { // Extract columns from data using helper method. val oldDataset = extractLabeledPoints(dataset) |