diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-04 13:30:17 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-04 13:30:17 -0800 |
commit | ba5f81859d6ba37a228a1c43d26c47e64c0382cd (patch) | |
tree | 375e9042ebf42c3a915186e1e21de6c650436b83 /mllib/src | |
parent | 0171b71e9511cef512e96a759e407207037f3c49 (diff) | |
download | spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.gz spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.bz2 spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.zip |
[SPARK-11259][ML] Params.validateParams() should be called automatically
See JIRA: https://issues.apache.org/jira/browse/SPARK-11259
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9224 from yanboliang/spark-11259.
Diffstat (limited to 'mllib/src')
30 files changed, 63 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 3acc60d6c6..32570a16e6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -165,6 +165,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M } override def transformSchema(schema: StructType): StructType = { + validateParams() val theStages = $(stages) require(theStages.toSet.size == theStages.length, "Cannot have duplicate components in a pipeline.") @@ -296,6 +297,7 @@ class PipelineModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala index 6aacffd4f2..d1388b5e2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala @@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { + validateParams() // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType) if (fitting) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index 1f3325ad09..fdce273193 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] protected def validateInputType(inputType: DataType): Unit = {} override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType validateInputType(inputType) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala index 6e5abb29ff..dc6d5d9280 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala @@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala index af0b3e1835..99383e77f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala @@ -263,6 +263,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala index 5b17d3483b..544cf05a30 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala @@ -72,6 +72,7 @@ final class Binarizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 33abc7c99d..0c75317d82 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) SchemaUtils.appendColumn(schema, prepOutputField(schema)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index dfec03828f..7b565ef3ed 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) @@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) val newField = prepOutputField(schema) val outputFields = schema.fields :+ newField diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index 1268c87908..10dcda2382 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 61a78d73c4..8af00581f7 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -69,6 +69,7 @@ class HashingTF(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[ArrayType], s"The input column must be ArrayType, but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala index f7b0f29a27..9e7eee4f29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala @@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala index 559a025265..ad0458d0d0 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala @@ -59,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala index c01e29af47..342540418f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala @@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends Transformer def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val outputColName = $(outputCol) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala index f653798b46..7020397f3b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala @@ -77,6 +77,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -130,6 +131,7 @@ class PCAModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 39de8461dc..8fd0ce2f2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String) def setOutputCol(value: String): this.type = set(outputCol, value) override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType) val inputFields = schema.fields require(inputFields.forall(_.name != $(outputCol)), diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 2b578c2a95..f9952434d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R // optimistic schema; does not contain any ML attributes override def transformSchema(schema: StructType): StructType = { + validateParams() if (hasLabelCol(schema)) { StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true)) } else { @@ -178,6 +179,7 @@ class RFormulaModel private[feature]( } override def transformSchema(schema: StructType): StructType = { + validateParams() checkCanTransform(schema) val withFeatures = pipelineModel.transformSchema(schema) if (hasLabelCol(withFeatures)) { @@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer { } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name))) } @@ -288,6 +291,7 @@ private class VectorAttributeRewriter( } override def transformSchema(schema: StructType): StructType = { + validateParams() StructType( schema.fields.filter(_.name != vectorCol) ++ schema.fields.filter(_.name == vectorCol)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala index e0ca45b9a6..af6494b234 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala @@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { + validateParams() val sc = SparkContext.getOrCreate() val sqlContext = SQLContext.getOrCreate(sc) val dummyRDD = sc.parallelize(Seq(Row.empty)) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala index d76a9c6275..6a0b6c240e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala @@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") @@ -143,6 +144,7 @@ class StandardScalerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.isInstanceOf[VectorUDT], s"Input column ${$(inputCol)} must be a vector column") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala index 5d6936dce2..b93c9ed382 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala @@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputType = schema($(inputCol)).dataType require(inputType.sameType(ArrayType(StringType)), s"Input type must be ArrayType(StringType) but got $inputType.") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 5c40c35eea..912bd95a2e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType], @@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String) final def getLabels: Array[String] = $(labels) override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColName = $(inputCol) val inputDataType = schema(inputColName).dataType require(inputDataType.isInstanceOf[NumericType], diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala index e9d1b57b91..0b215659b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala @@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() val inputColNames = $(inputCols) val outputColName = $(outputCol) val inputDataTypes = inputColNames.map(name => schema(name).dataType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index a637a6f288..2a5268406d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod } override def transformSchema(schema: StructType): StructType = { + validateParams() // We do not transfer feature metadata since we do not know what types of features we will // produce in transform(). val dataType = new VectorUDT @@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() val dataType = new VectorUDT require(isDefined(inputCol), s"VectorIndexerModel requires input column parameter: $inputCol") diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 4813d8a5b5..300d63bd3a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String) } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT) if (schema.fieldNames.contains($(outputCol))) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 59c34cd170..2b6b3c3a0f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params * Validate and transform the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true)) SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 14a28b8d5b..472c1854d3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w * @return output schema */ protected def validateAndTransformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) val ratingType = schema($(ratingCol)).dataType @@ -213,6 +214,7 @@ class ALSModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(userCol), IntegerType) SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType) SchemaUtils.appendColumn(schema, $(predictionCol), FloatType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala index 3787ca45d5..e8a1ff2278 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala @@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params protected def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT) if (fitting) { SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala index e8d361b1a2..1573bb4c1b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala @@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures protected[ml] def validateAndTransformSchema( schema: StructType, fitting: Boolean): StructType = { + validateParams() if (fitting) { SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType) if (hasWeightCol) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 477675cad1..3eac616aea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] ( @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index f346ea655a..4f67e8c219 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() $(estimator).transformSchema(schema) } @@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] ( @Since("1.5.0") override def transformSchema(schema: StructType): StructType = { + validateParams() bestModel.transformSchema(schema) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala index 8c86767456..f3321fb5a1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala @@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock import org.apache.spark.SparkFunSuite import org.apache.spark.ml.Pipeline.SharedReadWrite -import org.apache.spark.ml.feature.HashingTF +import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler} import org.apache.spark.ml.param.{IntParam, ParamMap} import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.DataFrame import org.apache.spark.sql.types.StructType @@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul } } } + + test("pipeline validateParams") { + val df = sqlContext.createDataFrame( + Seq( + (1, Vectors.dense(0.0, 1.0, 4.0), 1.0), + (2, Vectors.dense(1.0, 0.0, 4.0), 2.0), + (3, Vectors.dense(1.0, 0.0, 5.0), 3.0), + (4, Vectors.dense(0.0, 0.0, 5.0), 4.0)) + ).toDF("id", "features", "label") + + intercept[IllegalArgumentException] { + val scaler = new MinMaxScaler() + .setInputCol("features") + .setOutputCol("features_scaled") + .setMin(10) + .setMax(0) + val pipeline = new Pipeline().setStages(Array(scaler)) + pipeline.fit(df) + } + } } |