aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala10
1 files changed, 5 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index a455341a1f..8eddf79cdf 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -84,7 +84,7 @@ class Pipeline extends Estimator[PipelineModel] {
/** param for pipeline stages */
val stages: Param[Array[PipelineStage]] = new Param(this, "stages", "stages of the pipeline")
def setStages(value: Array[PipelineStage]): this.type = { set(stages, value); this }
- def getStages: Array[PipelineStage] = get(stages)
+ def getStages: Array[PipelineStage] = getOrDefault(stages)
/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
@@ -101,7 +101,7 @@ class Pipeline extends Estimator[PipelineModel] {
*/
override def fit(dataset: DataFrame, paramMap: ParamMap): PipelineModel = {
transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val theStages = map(stages)
// Search for the last estimator.
var indexOfLastEstimator = -1
@@ -138,7 +138,7 @@ class Pipeline extends Estimator[PipelineModel] {
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- val map = this.paramMap ++ paramMap
+ val map = extractParamMap(paramMap)
val theStages = map(stages)
require(theStages.toSet.size == theStages.size,
"Cannot have duplicate components in a pipeline.")
@@ -177,14 +177,14 @@ class PipelineModel private[ml] (
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ val map = fittingParamMap ++ extractParamMap(paramMap)
transformSchema(dataset.schema, map, logging = true)
stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
// Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
- val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ val map = fittingParamMap ++ extractParamMap(paramMap)
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}