diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2015-08-12 10:48:52 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-08-12 10:48:52 -0700 |
commit | 70fe558867ccb4bcff6ec673438b03608bb02252 (patch) | |
tree | f4f02935c3e5964ca7b00f068aab2ed6c2276bf8 /mllib/src/test | |
parent | 57ec27dd7784ce15a2ece8a6c8ac7bd5fd25aea2 (diff) | |
download | spark-70fe558867ccb4bcff6ec673438b03608bb02252.tar.gz spark-70fe558867ccb4bcff6ec673438b03608bb02252.tar.bz2 spark-70fe558867ccb4bcff6ec673438b03608bb02252.zip |
[SPARK-9847] [ML] Modified copyValues to distinguish between default, explicit param values
From JIRA: Currently, Params.copyValues copies default parameter values to the paramMap of the target instance, rather than the defaultParamMap. It should copy to the defaultParamMap because explicitly setting a parameter can change the semantics.
This issue arose in SPARK-9789, where 2 params "threshold" and "thresholds" for LogisticRegression can have mutually exclusive values. If thresholds is set, then fit() will copy the default value of threshold as well, easily resulting in inconsistent settings for the 2 params.
CC: mengxr
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #8115 from jkbradley/copyvalues-fix.
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala | 8 |
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala index 050d4170ea..be95638d81 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala @@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite { val inArray = ParamValidators.inArray[Int](Array(1, 2)) assert(inArray(1) && inArray(2) && !inArray(0)) } + + test("Params.copyValues") { + val t = new TestParams() + val t2 = t.copy(ParamMap.empty) + assert(!t2.isSet(t2.maxIter)) + val t3 = t.copy(ParamMap(t.maxIter -> 20)) + assert(t3.isSet(t3.maxIter)) + } } object ParamsSuite extends SparkFunSuite { |