diff options
author | lewuathe <lewuathe@me.com> | 2015-08-13 09:17:19 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-13 09:17:19 -0700 |
commit | 2932e25da4532de9e86b01d08bce0cb680874e70 (patch) | |
tree | 07acdab578dc3c7a543babfcef1a6e3a8227807d /examples | |
parent | 69930310115501f0de094fe6f5c6c60dade342bd (diff) | |
download | spark-2932e25da4532de9e86b01d08bce0cb680874e70.tar.gz spark-2932e25da4532de9e86b01d08bce0cb680874e70.tar.bz2 spark-2932e25da4532de9e86b01d08bce0cb680874e70.zip |
[SPARK-9073] [ML] spark.ml Models copy() should call setParent when there is a parent
Copied ML models must have the same parent of original ones
Author: lewuathe <lewuathe@me.com>
Author: Lewuathe <lewuathe@me.com>
Closes #7447 from Lewuathe/SPARK-9073.
Diffstat (limited to 'examples')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java | 3 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala | 2 |
2 files changed, 3 insertions, 2 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index 9df26ffca5..3f1fe900b0 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -230,6 +230,7 @@ class MyJavaLogisticRegressionModel */ @Override public MyJavaLogisticRegressionModel copy(ParamMap extra) { - return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra); + return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra) + .setParent(parent()); } } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 78f31b4ffe..340c3559b1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -179,7 +179,7 @@ private class MyLogisticRegressionModel( * This is used for the default implementation of [[transform()]]. */ override def copy(extra: ParamMap): MyLogisticRegressionModel = { - copyValues(new MyLogisticRegressionModel(uid, weights), extra) + copyValues(new MyLogisticRegressionModel(uid, weights), extra).setParent(parent) } } // scalastyle:on println |