aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/scala
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-04 11:28:59 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-04 11:28:59 -0700
commite0833c5958bbd73ff27cfe6865648d7b6e5a99bc (patch)
tree373883fa46f206ffcd34c4d0b67ce246b61bbc93 /examples/src/main/scala
parent5a1a1075a607be683f008ef92fa227803370c45f (diff)
downloadspark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.tar.gz
spark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.tar.bz2
spark-e0833c5958bbd73ff27cfe6865648d7b6e5a99bc.zip
[SPARK-5956] [MLLIB] Pipeline components should be copyable.
This PR added `copy(extra: ParamMap): Params` to `Params`, which makes a copy of the current instance with a randomly generated uid and some extra param values. With this change, we only need to implement `fit` and `transform` without extra param values given the default implementation of `fit(dataset, extra)`: ~~~scala def fit(dataset: DataFrame, extra: ParamMap): Model = { copy(extra).fit(dataset) } ~~~ Inside `fit` and `transform`, since only the embedded values are used, I added `$` as an alias for `getOrDefault` to make the code easier to read. For example, in `LinearRegression.fit` we have: ~~~scala val effectiveRegParam = $(regParam) / yStd val effectiveL1RegParam = $(elasticNetParam) * effectiveRegParam val effectiveL2RegParam = (1.0 - $(elasticNetParam)) * effectiveRegParam ~~~ Meta-algorithm like `Pipeline` implements its own `copy(extra)`. So the fitted pipeline model stored all copied stages (no matter whether it is a transformer or a model). Other changes: * `Params$.inheritValues` is moved to `Params!.copyValues` and returns the target instance. * `fittingParamMap` was removed because the `parent` carries this information. * `validate` was renamed to `validateParams` to be more precise. TODOs: * [x] add tests for newly added methods * [ ] update documentation jkbradley dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #5820 from mengxr/SPARK-5956 and squashes the following commits: 7bef88d [Xiangrui Meng] address comments 05229c3 [Xiangrui Meng] assert -> assertEquals b2927b1 [Xiangrui Meng] organize imports f14456b [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 93e7924 [Xiangrui Meng] add tests for hasParam & copy 463ecae [Xiangrui Meng] merge master 2b954c3 [Xiangrui Meng] update Binarizer 465dd12 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5956 282a1a8 [Xiangrui Meng] fix test 819dd2d [Xiangrui Meng] merge master b642872 [Xiangrui Meng] example code runs 5a67779 [Xiangrui Meng] examples compile c76b4d1 [Xiangrui Meng] fix all unit tests 0f4fd64 [Xiangrui Meng] fix some tests 9286a22 [Xiangrui Meng] copyValues to trained models 53e0973 [Xiangrui Meng] move inheritValues to Params and rename it to copyValues 9ee004e [Xiangrui Meng] merge copy and copyWith; rename validate to validateParams d882afc [Xiangrui Meng] test compile f082a31 [Xiangrui Meng] make Params copyable and simply handling of extra params in all spark.ml components
Diffstat (limited to 'examples/src/main/scala')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala22
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala4
5 files changed, 16 insertions, 26 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 9002e99d82..8340d91101 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -276,16 +276,14 @@ object DecisionTreeExample {
// Get the trained Decision Tree from the fitted PipelineModel
algo match {
case "classification" =>
- val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
- dt.asInstanceOf[DecisionTreeClassifier])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
- val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
- dt.asInstanceOf[DecisionTreeRegressor])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2245fa429f..2a2d067727 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -18,13 +18,12 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
-import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
+import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
* Transformer, and other abstractions.
@@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
- def getMaxIter: Int = getOrDefault(maxIter)
+ def getMaxIter: Int = $(maxIter)
}
/**
@@ -117,18 +116,16 @@ private class MyLogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
// This method is used by fit()
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): MyLogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, paramMap, weights)
+ new MyLogisticRegressionModel(this, weights)
}
}
@@ -139,7 +136,6 @@ private class MyLogisticRegression
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -176,9 +172,7 @@ private class MyLogisticRegressionModel(
*
* This is used for the default implementation of [[transform()]].
*/
- override protected def copy(): MyLogisticRegressionModel = {
- val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
- Params.inheritValues(extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): MyLogisticRegressionModel = {
+ copyValues(new MyLogisticRegressionModel(parent, weights), extra)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
index 5fccb142d4..c5899b6683 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -201,14 +201,14 @@ object GBTExample {
// Get the trained GBT from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
index 9b909324ec..7f88d2681b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -209,16 +209,14 @@ object RandomForestExample {
// Get the trained Random Forest from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
- dt.asInstanceOf[RandomForestClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
- dt.asInstanceOf[RandomForestRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index bf805149d0..e8a991f50e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -63,7 +63,7 @@ object SimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap())
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -78,7 +78,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF(), paramMapCombined)
- println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+ println("Model 2 was fit using parameters: " + model2.parent.extractParamMap())
// Prepare test data.
val test = sc.parallelize(Seq(