aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-04 11:28:59 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-04 11:28:59 -0700
commite0833c5958bbd73ff27cfe6865648d7b6e5a99bc (patch)
tree373883fa46f206ffcd34c4d0b67ce246b61bbc93 /examples
parent5a1a1075a607be683f008ef92fa227803370c45f (diff)
downloadspark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.tar.gz
spark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.tar.bz2
spark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.zip
[SPARK-5956] [MLLIB] Pipeline components should be copyable.
This PR added `copy(extra: ParamMap): Params` to `Params`, which makes a copy of the current instance with a randomly generated uid and some extra param values. With this change, we only need to implement `fit` and `transform` without extra param values given the default implementation of `fit(dataset, extra)`: ~~~scala def fit(dataset: DataFrame, extra: ParamMap): Model = { copy(extra).fit(dataset) } ~~~ Inside `fit` and `transform`, since only the embedded values are used, I added `$` as an alias for `getOrDefault` to make the code easier to read. For example, in `LinearRegression.fit` we have: ~~~scala val effectiveRegParam = $(regParam) / yStd val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam ~~~ Meta-algorithm like `Pipeline` implements its own `copy(extra)`. So the fitted pipeline model stored all copied stages (no matter whether it is a transformer or a model). Other changes: * `Params$.inheritValues` is moved to `Params!.copyValues` and returns the target instance. * `fittingParamMap` was removed because the `parent` carries this information. * `validate` was renamed to `validateParams` to be more precise. TODOs: * [x] add tests for newly added methods * [ ] update documentation jkbradley dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #5820 from mengxr/SPARK-5956 and squashes the following commits: 7bef88d [Xiangrui Meng] address comments 05229c3 [Xiangrui Meng] assert -> assertEquals b2927b1 [Xiangrui Meng] organize imports f14456b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 93e7924 [Xiangrui Meng] add tests for hasParam & copy 463ecae [Xiangrui Meng] merge master 2b954c3 [Xiangrui Meng] update Binarizer 465dd12 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 282a1a8 [Xiangrui Meng] fix test 819dd2d [Xiangrui Meng] merge master b642872 [Xiangrui Meng] example code runs 5a67779 [Xiangrui Meng] examples compile c76b4d1 [Xiangrui Meng] fix all unit tests 0f4fd64 [Xiangrui Meng] fix some tests 9286a22 [Xiangrui Meng] copyValues to trained models 53e0973 [Xiangrui Meng] move inheritValues to Params and rename it to copyValues 9ee004e [Xiangrui Meng] merge copy and copyWith; rename validate to validateParams d882afc [Xiangrui Meng] test compile f082a31 [Xiangrui Meng] make Params copyable and simply handling of extra params in all spark.ml components
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java24
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala22
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala4
7 files changed, 25 insertions, 45 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 46377a99c4..eac4f898a4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.param.Params$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -129,16 +128,16 @@ class MyJavaLogisticRegression
// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
- public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Extract columns from data using helper method.
- JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+ JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
// Do learning to estimate the weight vector.
int numFeatures = oldDataset.take(1).get(0).features().size();
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ return new MyJavaLogisticRegressionModel(this, weights);
}
}
@@ -155,18 +154,11 @@ class MyJavaLogisticRegressionModel
private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }
- private ParamMap fittingParamMap_;
- public ParamMap fittingParamMap() { return fittingParamMap_; }
-
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(
- MyJavaLogisticRegression parent_,
- ParamMap fittingParamMap_,
- Vector weights_) {
+ public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
this.parent_ = parent_;
- this.fittingParamMap_ = fittingParamMap_;
this.weights_ = weights_;
}
@@ -210,10 +202,8 @@ class MyJavaLogisticRegressionModel
* In Java, we have to make this method public since Java does not understand Scala's protected
* modifier.
*/
- public MyJavaLogisticRegressionModel copy() {
- MyJavaLogisticRegressionModel m =
- new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
- Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
- return m;
+ @Override
+ public MyJavaLogisticRegressionModel copy(ParamMap extra) {
+ return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 4e02acce69..29158d5c85 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -71,7 +71,7 @@ public class JavaSimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+ System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -87,7 +87,7 @@ public class JavaSimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
- System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+ System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List<LabeledPoint> localTest = Lists.newArrayList(
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 9002e99d82..8340d91101 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -276,16 +276,14 @@ object DecisionTreeExample {
// Get the trained Decision Tree from the fitted PipelineModel
algo match {
case "classification" =>
- val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
- dt.asInstanceOf[DecisionTreeClassifier])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
- val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
- dt.asInstanceOf[DecisionTreeRegressor])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2245fa429f..2a2d067727 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -18,13 +18,12 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
-import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
+import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
* Transformer, and other abstractions.
@@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
- def getMaxIter: Int = getOrDefault(maxIter)
+ def getMaxIter: Int = $(maxIter)
}
/**
@@ -117,18 +116,16 @@ private class MyLogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
// This method is used by fit()
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): MyLogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, paramMap, weights)
+ new MyLogisticRegressionModel(this, weights)
}
}
@@ -139,7 +136,6 @@ private class MyLogisticRegression
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -176,9 +172,7 @@ private class MyLogisticRegressionModel(
*
* This is used for the default implementation of [[transform()]].
*/
- override protected def copy(): MyLogisticRegressionModel = {
- val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
- Params.inheritValues(extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): MyLogisticRegressionModel = {
+ copyValues(new MyLogisticRegressionModel(parent, weights), extra)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
index 5fccb142d4..c5899b6683 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -201,14 +201,14 @@ object GBTExample {
// Get the trained GBT from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
index 9b909324ec..7f88d2681b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -209,16 +209,14 @@ object RandomForestExample {
// Get the trained Random Forest from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
- dt.asInstanceOf[RandomForestClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
- dt.asInstanceOf[RandomForestRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index bf805149d0..e8a991f50e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -63,7 +63,7 @@ object SimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap())
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -78,7 +78,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF(), paramMapCombined)
- println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+ println("Model 2 was fit using parameters: " + model2.parent.extractParamMap())
// Prepare test data.
val test = sc.parallelize(Seq(