aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-14 01:22:15 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-14 01:22:15 -0700
commit1b8625f4258d6d1a049d0ba60e39e9757f5a568b (patch)
treecb6c44497bc20939bad4fa30e8b59ab17f64a9bf /examples
parent13e652b61a81b2d2e94088006fbd5fd4ed383e3d (diff)
downloadspark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.tar.gz
spark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.tar.bz2
spark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.zip
[SPARK-7407] [MLLIB] use uid + name to identify parameters
A param instance is strongly attached to an parent in the current implementation. So if we make a copy of an estimator or a transformer in pipelines and other meta-algorithms, it becomes error-prone to copy the params to the copied instances. In this PR, a param is identified by its parent's UID and the param name. So it becomes loosely attached to its parent and all its derivatives. The UID is preserved during copying or fitting. All components now have a default constructor and a constructor that takes a UID as input. I keep the constructors for Param in this PR to reduce the amount of diff and moved `parent` as a mutable field. This PR still needs some clean-ups, and there are several spark.ml PRs pending. I'll try to get them merged first and then update this PR. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6019 from mengxr/SPARK-7407 and squashes the following commits: c4c8120 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 520f0a2 [Xiangrui Meng] address comments 2569168 [Xiangrui Meng] fix tests 873caca [Xiangrui Meng] fix tests in OneVsRest; fix a racing condition in shouldOwn 409ea08 [Xiangrui Meng] minor updates 83a163c [Xiangrui Meng] update JavaDeveloperApiExample 5db5325 [Xiangrui Meng] update OneVsRest 7bde7ae [Xiangrui Meng] merge master 697fdf9 [Xiangrui Meng] update Bucketizer 7b4f6c2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 629d402 [Xiangrui Meng] fix LRSuite 154516f [Xiangrui Meng] merge master aa4a611 [Xiangrui Meng] fix examples/compile a4794dd [Xiangrui Meng] change Param to use to reduce the size of diff fdbc415 [Xiangrui Meng] all tests passed c255f17 [Xiangrui Meng] fix tests in ParamsSuite 818e1db [Xiangrui Meng] merge master e1160cf [Xiangrui Meng] fix tests fbc39f0 [Xiangrui Meng] pass test:compile 108937e [Xiangrui Meng] pass compile 8726d39 [Xiangrui Meng] use parent uid in Param eaeed35 [Xiangrui Meng] update Identifiable
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java43
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala11
2 files changed, 39 insertions, 15 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index eac4f898a4..ec533d174e 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,6 +28,7 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -103,7 +104,23 @@ public class JavaDeveloperApiExample {
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegression
- extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
+ extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
+
+ public MyJavaLogisticRegression() {
+ init();
+ }
+
+ public MyJavaLogisticRegression(String uid) {
+ this.uid_ = uid;
+ init();
+ }
+
+ private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
+
+ @Override
+ public String uid() {
+ return uid_;
+ }
/**
* Param for max number of iterations
@@ -117,7 +134,7 @@ class MyJavaLogisticRegression
int getMaxIter() { return (Integer) getOrDefault(maxIter); }
- public MyJavaLogisticRegression() {
+ private void init() {
setMaxIter(100);
}
@@ -137,7 +154,7 @@ class MyJavaLogisticRegression
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, weights);
+ return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}
}
@@ -149,17 +166,21 @@ class MyJavaLogisticRegression
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegressionModel
- extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
-
- private MyJavaLogisticRegression parent_;
- public MyJavaLogisticRegression parent() { return parent_; }
+ extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
- this.parent_ = parent_;
- this.weights_ = weights_;
+ public MyJavaLogisticRegressionModel(String uid, Vector weights) {
+ this.uid_ = uid;
+ this.weights_ = weights;
+ }
+
+ private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
+
+ @Override
+ public String uid() {
+ return uid_;
}
// This uses the default implementation of transform(), which reads column "features" and outputs
@@ -204,6 +225,6 @@ class MyJavaLogisticRegressionModel
*/
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
- return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
+ return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra);
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2a2d067727..3ee456edbe 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
@@ -106,10 +107,12 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
*
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
-private class MyLogisticRegression
+private class MyLogisticRegression(override val uid: String)
extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
+ def this() = this(Identifiable.randomUID("myLogReg"))
+
setMaxIter(100) // Initialize
// The parameter setter is in this class since it should return type MyLogisticRegression.
@@ -125,7 +128,7 @@ private class MyLogisticRegression
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, weights)
+ new MyLogisticRegressionModel(uid, weights).setParent(this)
}
}
@@ -135,7 +138,7 @@ private class MyLogisticRegression
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
private class MyLogisticRegressionModel(
- override val parent: MyLogisticRegression,
+ override val uid: String,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -173,6 +176,6 @@ private class MyLogisticRegressionModel(
* This is used for the default implementation of [[transform()]].
*/
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
- copyValues(new MyLogisticRegressionModel(parent, weights), extra)
+ copyValues(new MyLogisticRegressionModel(uid, weights), extra)
}
}