aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-13 21:18:05 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-13 21:18:05 -0700
commit971b95b0c9002bd541bcbe0da54a9967ba22588f (patch)
treeb2a79cf00c1d2290e7e4024df27c0ee9b203c09a /examples
parent0ba3fdd5992cf09bd38303ebff34d2ed19e5e09b (diff)
downloadspark-971b95b0c9002bd541bcbe0da54a9967ba22588f.tar.gz
spark-971b95b0c9002bd541bcbe0da54a9967ba22588f.tar.bz2
spark-971b95b0c9002bd541bcbe0da54a9967ba22588f.zip
[SPARK-5957][ML] better handling of parameters
The design doc was posted on the JIRA page. Python changes will be in a follow-up PR. jkbradley 1. Use codegen for shared params. 1. Move shared params to package `ml.param.shared`. 1. Set default values in `Params` instead of in `Param`. 1. Add a few methods to `Params` and `ParamMap`. 1. Move schema handling to `SchemaUtils` from `Params`. - [x] check visibility of the methods added Author: Xiangrui Meng <meng@databricks.com> Closes #5431 from mengxr/SPARK-5957 and squashes the following commits: d19236d [Xiangrui Meng] fix test 26ae2d7 [Xiangrui Meng] re-gen code and mark clear protected 38b78c7 [Xiangrui Meng] update Param.toString and remove Params.explain() 409e2d5 [Xiangrui Meng] address comments 2d637bd [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 eec2264 [Xiangrui Meng] make get* public in Params 4090d95 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-5957 4fee9e7 [Xiangrui Meng] re-gen shared params 2737c2d [Xiangrui Meng] rename SharedParamCodeGen to SharedParamsCodeGen e938f81 [Xiangrui Meng] update code to set default parameter values 28ed322 [Xiangrui Meng] merge master 55be1f3 [Xiangrui Meng] merge master d63b5cc [Xiangrui Meng] fix examples 29b004c [Xiangrui Meng] update ParamsSuite 94fd98e [Xiangrui Meng] fix explain params 48d0e84 [Xiangrui Meng] add remove and update explainParams 4ac6348 [Xiangrui Meng] move schema utils to SchemaUtils add a few methods to Params 0d9594e [Xiangrui Meng] add getOrElse to ParamMap eeeffe8 [Xiangrui Meng] map ++ paramMap => extractValues 0d3fc5b [Xiangrui Meng] setDefault after param a9dbf59 [Xiangrui Meng] minor updates d9302b8 [Xiangrui Meng] generate default values 1c72579 [Xiangrui Meng] pass test compile abb7a3b [Xiangrui Meng] update default values handling dcab97a [Xiangrui Meng] add codegen for shared params
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala6
2 files changed, 5 insertions, 5 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 19d0eb2168..eaf00d09f5 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -116,7 +116,7 @@ class MyJavaLogisticRegression
*/
IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
- int getMaxIter() { return (Integer) get(maxIter); }
+ int getMaxIter() { return (Integer) getOrDefault(maxIter); }
public MyJavaLogisticRegression() {
setMaxIter(100);
@@ -211,7 +211,7 @@ class MyJavaLogisticRegressionModel
public MyJavaLogisticRegressionModel copy() {
MyJavaLogisticRegressionModel m =
new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
- Params$.MODULE$.inheritValues(this.paramMap(), this, m);
+ Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
return m;
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index df26798e41..2245fa429f 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -99,7 +99,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
- def getMaxIter: Int = get(maxIter)
+ def getMaxIter: Int = getOrDefault(maxIter)
}
/**
@@ -174,11 +174,11 @@ private class MyLogisticRegressionModel(
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
*
- * This is used for the defaul implementation of [[transform()]].
+ * This is used for the default implementation of [[transform()]].
*/
override protected def copy(): MyLogisticRegressionModel = {
val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
- Params.inheritValues(this.paramMap, this, m)
+ Params.inheritValues(extractParamMap(), this, m)
m
}
}