aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-14 01:22:15 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-14 01:22:15 -0700
commit1b8625f4258d6d1a049d0ba60e39e9757f5a568b (patch)
treecb6c44497bc20939bad4fa30e8b59ab17f64a9bf /mllib/src/test/java/org/apache
parent13e652b61a81b2d2e94088006fbd5fd4ed383e3d (diff)
downloadspark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.tar.gz
spark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.tar.bz2
spark-1b8625f4258d6d1a049d0ba60e39e9757f5a568b.zip
[SPARK-7407] [MLLIB] use uid + name to identify parameters
A param instance is strongly attached to an parent in the current implementation. So if we make a copy of an estimator or a transformer in pipelines and other meta-algorithms, it becomes error-prone to copy the params to the copied instances. In this PR, a param is identified by its parent's UID and the param name. So it becomes loosely attached to its parent and all its derivatives. The UID is preserved during copying or fitting. All components now have a default constructor and a constructor that takes a UID as input. I keep the constructors for Param in this PR to reduce the amount of diff and moved `parent` as a mutable field. This PR still needs some clean-ups, and there are several spark.ml PRs pending. I'll try to get them merged first and then update this PR. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #6019 from mengxr/SPARK-7407 and squashes the following commits: c4c8120 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 520f0a2 [Xiangrui Meng] address comments 2569168 [Xiangrui Meng] fix tests 873caca [Xiangrui Meng] fix tests in OneVsRest; fix a racing condition in shouldOwn 409ea08 [Xiangrui Meng] minor updates 83a163c [Xiangrui Meng] update JavaDeveloperApiExample 5db5325 [Xiangrui Meng] update OneVsRest 7bde7ae [Xiangrui Meng] merge master 697fdf9 [Xiangrui Meng] update Bucketizer 7b4f6c2 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7407 629d402 [Xiangrui Meng] fix LRSuite 154516f [Xiangrui Meng] merge master aa4a611 [Xiangrui Meng] fix examples/compile a4794dd [Xiangrui Meng] change Param to use to reduce the size of diff fdbc415 [Xiangrui Meng] all tests passed c255f17 [Xiangrui Meng] fix tests in ParamsSuite 818e1db [Xiangrui Meng] merge master e1160cf [Xiangrui Meng] fix tests fbc39f0 [Xiangrui Meng] pass test:compile 108937e [Xiangrui Meng] pass compile 8726d39 [Xiangrui Meng] use parent uid in Param eaeed35 [Xiangrui Meng] update Identifiable
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java52
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala40
4 files changed, 81 insertions, 19 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index 7e7189a2b1..f75e024a71 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -84,7 +84,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
.setThreshold(0.6)
.setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- LogisticRegression parent = model.parent();
+ LogisticRegression parent = (LogisticRegression) model.parent();
assert(parent.getMaxIter() == 10);
assert(parent.getRegParam() == 1.0);
assert(parent.getThreshold() == 0.6);
@@ -110,7 +110,7 @@ public class JavaLogisticRegressionSuite implements Serializable {
// Call fit() with new params, and check as many params as we can.
LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
- LogisticRegression parent2 = model2.parent();
+ LogisticRegression parent2 = (LogisticRegression) model2.parent();
assert(parent2.getMaxIter() == 5);
assert(parent2.getRegParam() == 0.1);
assert(parent2.getThreshold() == 0.4);
diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
index 8abe575610..3a41890b92 100644
--- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
+++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java
@@ -21,43 +21,65 @@ import java.util.List;
import com.google.common.collect.Lists;
+import org.apache.spark.ml.util.Identifiable$;
+
/**
* A subclass of Params for testing.
*/
public class JavaTestParams extends JavaParams {
- public IntParam myIntParam;
+ public JavaTestParams() {
+ this.uid_ = Identifiable$.MODULE$.randomUID("javaTestParams");
+ init();
+ }
+
+ public JavaTestParams(String uid) {
+ this.uid_ = uid;
+ init();
+ }
+
+ private String uid_;
+
+ @Override
+ public String uid() {
+ return uid_;
+ }
- public int getMyIntParam() { return (Integer)getOrDefault(myIntParam); }
+ private IntParam myIntParam_;
+ public IntParam myIntParam() { return myIntParam_; }
+
+ public int getMyIntParam() { return (Integer)getOrDefault(myIntParam_); }
public JavaTestParams setMyIntParam(int value) {
- set(myIntParam, value); return this;
+ set(myIntParam_, value); return this;
}
- public DoubleParam myDoubleParam;
+ private DoubleParam myDoubleParam_;
+ public DoubleParam myDoubleParam() { return myDoubleParam_; }
- public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam); }
+ public double getMyDoubleParam() { return (Double)getOrDefault(myDoubleParam_); }
public JavaTestParams setMyDoubleParam(double value) {
- set(myDoubleParam, value); return this;
+ set(myDoubleParam_, value); return this;
}
- public Param<String> myStringParam;
+ private Param<String> myStringParam_;
+ public Param<String> myStringParam() { return myStringParam_; }
- public String getMyStringParam() { return (String)getOrDefault(myStringParam); }
+ public String getMyStringParam() { return getOrDefault(myStringParam_); }
public JavaTestParams setMyStringParam(String value) {
- set(myStringParam, value); return this;
+ set(myStringParam_, value); return this;
}
- public JavaTestParams() {
- myIntParam = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
- myDoubleParam = new DoubleParam(this, "myDoubleParam", "this is a double param",
+ private void init() {
+ myIntParam_ = new IntParam(this, "myIntParam", "this is an int param", ParamValidators.gt(0));
+ myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param",
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Lists.newArrayList("a", "b");
- myStringParam = new Param<String>(this, "myStringParam", "this is a string param",
+ myStringParam_ = new Param<String>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
- setDefault(myIntParam, 1);
- setDefault(myDoubleParam, 0.5);
+ setDefault(myIntParam_, 1);
+ setDefault(myDoubleParam_, 0.5);
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
index a82b86d560..d591a45686 100644
--- a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -77,14 +77,14 @@ public class JavaLinearRegressionSuite implements Serializable {
.setMaxIter(10)
.setRegParam(1.0);
LinearRegressionModel model = lr.fit(dataset);
- LinearRegression parent = model.parent();
+ LinearRegression parent = (LinearRegression) model.parent();
assertEquals(10, parent.getMaxIter());
assertEquals(1.0, parent.getRegParam(), 0.0);
// Call fit() with new params, and check as many params as we can.
LinearRegressionModel model2 =
lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
- LinearRegression parent2 = model2.parent();
+ LinearRegression parent2 = (LinearRegression) model2.parent();
assertEquals(5, parent2.getMaxIter());
assertEquals(0.1, parent2.getRegParam(), 0.0);
assertEquals("thePred", model2.getPredictionCol());
diff --git a/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
new file mode 100644
index 0000000000..67c262d0f9
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/util/IdentifiableSuite.scala
@@ -0,0 +1,40 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.scalatest.FunSuite
+
+class IdentifiableSuite extends FunSuite {
+
+ import IdentifiableSuite.Test
+
+ test("Identifiable") {
+ val test0 = new Test("test_0")
+ assert(test0.uid === "test_0")
+
+ val test1 = new Test
+ assert(test1.uid.startsWith("test_"))
+ }
+}
+
+object IdentifiableSuite {
+
+ class Test(override val uid: String) extends Identifiable {
+ def this() = this(Identifiable.randomUID("test"))
+ }
+}