aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-14 01:22:15 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-14 01:22:15 -0700
commit1b8625f4258d6d1a049d0ba60e39e9757f5a568b (patch)
treecb6c44497bc20939bad4fa30e8b59ab17f64a9bf /examples/src/main/java/org
parent13e652b61a81b2d2e94088006fbd5fd4ed383e3d (diff)
downloadspark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.tar.gz
spark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.tar.bz2
spark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.zip
[SPARK-7407] [MLLIB] use uid + name to identify parameters
A param instance is strongly attached to an parent in the current implementation. So if we make a copy of an estimator or a transformer in pipelines and other meta-algorithms, it becomes error-prone to copy the params to the copied instances. In this PR, a param is identified by its parent's UID and the param name. So it becomes loosely attached to its parent and all its derivatives. The UID is preserved during copying or fitting. All components now have a default constructor and a constructor that takes a UID as input. I keep the constructors for Param in this PR to reduce the amount of diff and moved `parent` as a mutable field. This PR still needs some clean-ups, and there are several spark.ml PRs pending. I'll try to get them merged first and then update this PR. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6019 from mengxr/SPARK-7407 and squashes the following commits: c4c8120 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 520f0a2 [Xiangrui Meng] address comments 2569168 [Xiangrui Meng] fix tests 873caca [Xiangrui Meng] fix tests in OneVsRest; fix a racing condition in shouldOwn 409ea08 [Xiangrui Meng] minor updates 83a163c [Xiangrui Meng] update JavaDeveloperApiExample 5db5325 [Xiangrui Meng] update OneVsRest 7bde7ae [Xiangrui Meng] merge master 697fdf9 [Xiangrui Meng] update Bucketizer 7b4f6c2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 629d402 [Xiangrui Meng] fix LRSuite 154516f [Xiangrui Meng] merge master aa4a611 [Xiangrui Meng] fix examples/compile a4794dd [Xiangrui Meng] change Param to use to reduce the size of diff fdbc415 [Xiangrui Meng] all tests passed c255f17 [Xiangrui Meng] fix tests in ParamsSuite 818e1db [Xiangrui Meng] merge master e1160cf [Xiangrui Meng] fix tests fbc39f0 [Xiangrui Meng] pass test:compile 108937e [Xiangrui Meng] pass compile 8726d39 [Xiangrui Meng] use parent uid in Param eaeed35 [Xiangrui Meng] update Identifiable
Diffstat (limited to 'examples/src/main/java/org')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java43
1 files changed, 32 insertions, 11 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index eac4f898a4..ec533d174e 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,6 +28,7 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -103,7 +104,23 @@ public class JavaDeveloperApiExample {
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegression
- extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
+ extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
+
+ public MyJavaLogisticRegression() {
+ init();
+ }
+
+ public MyJavaLogisticRegression(String uid) {
+ this.uid_ = uid;
+ init();
+ }
+
+ private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
+
+ @Override
+ public String uid() {
+ return uid_;
+ }
/**
* Param for max number of iterations
@@ -117,7 +134,7 @@ class MyJavaLogisticRegression
int getMaxIter() { return (Integer) getOrDefault(maxIter); }
- public MyJavaLogisticRegression() {
+ private void init() {
setMaxIter(100);
}
@@ -137,7 +154,7 @@ class MyJavaLogisticRegression
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, weights);
+ return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}
}
@@ -149,17 +166,21 @@ class MyJavaLogisticRegression
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegressionModel
- extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
-
- private MyJavaLogisticRegression parent_;
- public MyJavaLogisticRegression parent() { return parent_; }
+ extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
- this.parent_ = parent_;
- this.weights_ = weights_;
+ public MyJavaLogisticRegressionModel(String uid, Vector weights) {
+ this.uid_ = uid;
+ this.weights_ = weights;
+ }
+
+ private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
+
+ @Override
+ public String uid() {
+ return uid_;
}
// This uses the default implementation of transform(), which reads column "features" and outputs
@@ -204,6 +225,6 @@ class MyJavaLogisticRegressionModel
*/
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
- return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
+ return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra);
}
}