aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-04 13:30:17 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-04 13:30:17 -0800
commitba5f81859d6ba37a228a1c43d26c47e64c0382cd (patch)
tree375e9042ebf42c3a915186e1e21de6c650436b83 /mllib
parent0171b71e9511cef512e96a759e407207037f3c49 (diff)
downloadspark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.gz
spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.tar.bz2
spark-ba5f81859d6ba37a228a1c43d26c47e64c0382cd.zip
[SPARK-11259][ML] Params.validateParams() should be called automatically
See JIRA: https://issues.apache.org/jira/browse/SPARK-11259 Author: Yanbo Liang <ybliang8@gmail.com> Closes #9224 from yanboliang/spark-11259.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala23
30 files changed, 63 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index 3acc60d6c6..32570a16e6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -165,6 +165,7 @@ class Pipeline(override val uid: String) extends Estimator[PipelineModel] with M
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
@@ -296,6 +297,7 @@ class PipelineModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 6aacffd4f2..d1388b5e2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -46,6 +46,7 @@ private[ml] trait PredictorParams extends Params
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
+ validateParams()
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 1f3325ad09..fdce273193 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -103,6 +103,7 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
protected def validateInputType(inputType: DataType): Unit = {}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 6e5abb29ff..dc6d5d9280 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -80,6 +80,7 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index af0b3e1835..99383e77f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -263,6 +263,7 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
index 5b17d3483b..544cf05a30 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Binarizer.scala
@@ -72,6 +72,7 @@ final class Binarizer(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 33abc7c99d..0c75317d82 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -86,6 +86,7 @@ final class Bucketizer(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index dfec03828f..7b565ef3ed 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -88,6 +88,7 @@ final class ChiSqSelector(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
@@ -135,6 +136,7 @@ final class ChiSqSelectorModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val newField = prepOutputField(schema)
val outputFields = schema.fields :+ newField
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 1268c87908..10dcda2382 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -70,6 +70,7 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 61a78d73c4..8af00581f7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -69,6 +69,7 @@ class HashingTF(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index f7b0f29a27..9e7eee4f29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -52,6 +52,7 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 559a025265..ad0458d0d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -59,6 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index c01e29af47..342540418f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,6 +66,7 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputColName = $(inputCol)
val outputColName = $(outputCol)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index f653798b46..7020397f3b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -77,6 +77,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
@@ -130,6 +131,7 @@ class PCAModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 39de8461dc..8fd0ce2f2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -74,6 +74,7 @@ final class QuantileDiscretizer(override val uid: String)
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
require(inputFields.forall(_.name != $(outputCol)),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index 2b578c2a95..f9952434d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -146,6 +146,7 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
@@ -178,6 +179,7 @@ class RFormulaModel private[feature](
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(withFeatures)) {
@@ -240,6 +242,7 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
}
@@ -288,6 +291,7 @@ private class VectorAttributeRewriter(
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
StructType(
schema.fields.filter(_.name != vectorCol) ++
schema.fields.filter(_.name == vectorCol))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index e0ca45b9a6..af6494b234 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -74,6 +74,7 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index d76a9c6275..6a0b6c240e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -94,6 +94,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
@@ -143,6 +144,7 @@ class StandardScalerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 5d6936dce2..b93c9ed382 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -145,6 +145,7 @@ class StopWordsRemover(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 5c40c35eea..912bd95a2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,6 +39,7 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
@@ -272,6 +273,7 @@ class IndexToString private[ml] (override val uid: String)
final def getLabels: Array[String] = $(labels)
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index e9d1b57b91..0b215659b3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -106,6 +106,7 @@ class VectorAssembler(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index a637a6f288..2a5268406d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -126,6 +126,7 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
val dataType = new VectorUDT
@@ -354,6 +355,7 @@ class VectorIndexerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
val dataType = new VectorUDT
require(isDefined(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 4813d8a5b5..300d63bd3a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -139,6 +139,7 @@ final class VectorSlicer(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
if (schema.fieldNames.contains($(outputCol))) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 59c34cd170..2b6b3c3a0f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -92,6 +92,7 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 14a28b8d5b..472c1854d3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -162,6 +162,7 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
@@ -213,6 +214,7 @@ class ALSModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index 3787ca45d5..e8a1ff2278 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -99,6 +99,7 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean): StructType = {
+ validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index e8d361b1a2..1573bb4c1b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -105,6 +105,7 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
protected[ml] def validateAndTransformSchema(
schema: StructType,
fitting: Boolean): StructType = {
+ validateParams()
if (fitting) {
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
if (hasWeightCol) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 477675cad1..3eac616aea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -131,6 +131,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
$(estimator).transformSchema(schema)
}
@@ -345,6 +346,7 @@ class CrossValidatorModel private[ml] (
@Since("1.4.0")
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
bestModel.transformSchema(schema)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index f346ea655a..4f67e8c219 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -118,6 +118,7 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
$(estimator).transformSchema(schema)
}
@@ -172,6 +173,7 @@ class TrainValidationSplitModel private[ml] (
@Since("1.5.0")
override def transformSchema(schema: StructType): StructType = {
+ validateParams()
bestModel.transformSchema(schema)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
index 8c86767456..f3321fb5a1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/PipelineSuite.scala
@@ -26,9 +26,10 @@ import org.scalatest.mock.MockitoSugar.mock
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.Pipeline.SharedReadWrite
-import org.apache.spark.ml.feature.HashingTF
+import org.apache.spark.ml.feature.{HashingTF, MinMaxScaler}
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.ml.util._
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -174,6 +175,26 @@ class PipelineSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
}
}
+
+ test("pipeline validateParams") {
+ val df = sqlContext.createDataFrame(
+ Seq(
+ (1, Vectors.dense(0.0, 1.0, 4.0), 1.0),
+ (2, Vectors.dense(1.0, 0.0, 4.0), 2.0),
+ (3, Vectors.dense(1.0, 0.0, 5.0), 3.0),
+ (4, Vectors.dense(0.0, 0.0, 5.0), 4.0))
+ ).toDF("id", "features", "label")
+
+ intercept[IllegalArgumentException] {
+ val scaler = new MinMaxScaler()
+ .setInputCol("features")
+ .setOutputCol("features_scaled")
+ .setMin(10)
+ .setMax(0)
+ val pipeline = new Pipeline().setStages(Array(scaler))
+ pipeline.fit(df)
+ }
+ }
}