aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala22
2 files changed, 21 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 195a93e086..f406f8c426 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") (
override def copy(extra: ParamMap): Pipeline = {
val map = extractParamMap(extra)
val newStages = map(stages).map(_.copy(extra))
- new Pipeline().setStages(newStages)
+ new Pipeline(uid).setStages(newStages)
}
@Since("1.2.0")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 6413ca1f8b..dafc6c200f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
+ test("Pipeline.copy") {
+ val hashingTF = new HashingTF()
+ .setNumFeatures(100)
+ val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF))
+ val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10))
+
+ assert(copied.uid === pipeline.uid,
+ "copy should create an instance with the same UID")
+ assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+ "copy should handle extra stage params")
+ }
+
test("PipelineModel.copy") {
val hashingTF = new HashingTF()
.setNumFeatures(100)
- val model = new PipelineModel("pipeline", Array[Transformer](hashingTF))
+ val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF))
+ .setParent(new Pipeline())
val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10))
- require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
+
+ assert(copied.uid === model.uid,
+ "copy should create an instance with the same UID")
+ assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10,
"copy should handle extra stage params")
+ assert(copied.parent === model.parent,
+ "copy should create an instance with the same parent")
}
test("pipeline model constructors") {