diff options
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala | 22 |
2 files changed, 21 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 195a93e086..f406f8c426 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -169,7 +169,7 @@ class Pipeline @Since("1.4.0") ( override def copy(extra: ParamMap): Pipeline = { val map = extractParamMap(extra) val newStages = map(stages).map(_.copy(extra)) - new Pipeline().setStages(newStages) + new Pipeline(uid).setStages(newStages) } @Since("1.2.0") diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 6413ca1f8b..dafc6c200f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -101,13 +101,31 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } + test("Pipeline.copy") { + val hashingTF = new HashingTF() + .setNumFeatures(100) + val pipeline = new Pipeline("pipeline").setStages(Array[Transformer](hashingTF)) + val copied = pipeline.copy(ParamMap(hashingTF.numFeatures -> 10)) + + assert(copied.uid === pipeline.uid, + "copy should create an instance with the same UID") + assert(copied.getStages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + "copy should handle extra stage params") + } + test("PipelineModel.copy") { val hashingTF = new HashingTF() .setNumFeatures(100) - val model = new PipelineModel("pipeline", Array[Transformer](hashingTF)) + val model = new PipelineModel("pipelineModel", Array[Transformer](hashingTF)) + .setParent(new Pipeline()) val copied = model.copy(ParamMap(hashingTF.numFeatures -> 10)) - require(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, + + assert(copied.uid === model.uid, + "copy should create an instance with the same UID") + assert(copied.stages(0).asInstanceOf[HashingTF].getNumFeatures === 10, "copy should handle extra stage params") + assert(copied.parent === model.parent, + "copy should create an instance with the same parent") } test("pipeline model constructors") { |