aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-08-12 10:48:52 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-12 10:50:11 -0700
commitb515f890defe96149e56580e8ed2c00febf7dc8e (patch)
treecdae6ef81a1416b7bd5c24c403935839e330795d /mllib/src
parente9641f192dc6a949cfb8fa1614d446026c7bf4b3 (diff)
downloadspark-b515f890defe96149e56580e8ed2c00febf7dc8e.tar.gz
spark-b515f890defe96149e56580e8ed2c00febf7dc8e.tar.bz2
spark-b515f890defe96149e56580e8ed2c00febf7dc8e.zip
[SPARK-9847] [ML] Modified copyValues to distinguish between default, explicit param values
From JIRA: Currently, Params.copyValues copies default parameter values to the paramMap of the target instance, rather than the defaultParamMap. It should copy to the defaultParamMap because explicitly setting a parameter can change the semantics. This issue arose in SPARK-9789, where 2 params "threshold" and "thresholds" for LogisticRegression can have mutually exclusive values. If thresholds is set, then fit() will copy the default value of threshold as well, easily resulting in inconsistent settings for the 2 params. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #8115 from jkbradley/copyvalues-fix. (cherry picked from commit 70fe558867ccb4bcff6ec673438b03608bb02252) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala8
2 files changed, 24 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index d68f5ff005..91c0a56313 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -559,13 +559,26 @@ trait Params extends Identifiable with Serializable {
/**
* Copies param values from this instance to another instance for params shared by them.
- * @param to the target instance
- * @param extra extra params to be copied
+ *
+ * This handles default Params and explicitly set Params separately.
+ * Default Params are copied from and to [[defaultParamMap]], and explicitly set Params are
+ * copied from and to [[paramMap]].
+ * Warning: This implicitly assumes that this [[Params]] instance and the target instance
+ * share the same set of default Params.
+ *
+ * @param to the target instance, which should work with the same set of default Params as this
+ * source instance
+ * @param extra extra params to be copied to the target's [[paramMap]]
* @return the target instance with param values copied
*/
protected def copyValues[T <: Params](to: T, extra: ParamMap = ParamMap.empty): T = {
- val map = extractParamMap(extra)
+ val map = paramMap ++ extra
params.foreach { param =>
+ // copy default Params
+ if (defaultParamMap.contains(param) && to.hasParam(param.name)) {
+ to.defaultParamMap.put(to.getParam(param.name), defaultParamMap(param))
+ }
+ // copy explicitly set Params
if (map.contains(param) && to.hasParam(param.name)) {
to.set(param.name, map(param))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
index 050d4170ea..be95638d81 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/param/ParamsSuite.scala
@@ -200,6 +200,14 @@ class ParamsSuite extends SparkFunSuite {
val inArray = ParamValidators.inArray[Int](Array(1, 2))
assert(inArray(1) && inArray(2) && !inArray(0))
}
+
+ test("Params.copyValues") {
+ val t = new TestParams()
+ val t2 = t.copy(ParamMap.empty)
+ assert(!t2.isSet(t2.maxIter))
+ val t3 = t.copy(ParamMap(t.maxIter -> 20))
+ assert(t3.isSet(t3.maxIter))
+ }
}
object ParamsSuite extends SparkFunSuite {