aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-03-16 17:31:55 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-16 17:31:55 -0700
commit92b70576eabf8ff94ac476e2b3c66f8b3d28e79e (patch)
treeb15286aade54722a14fa9325e965803316a531f7 /mllib
parentd4d84936fb82bee91f4b04608de9f75c293ccc9e (diff)
downloadspark-92b70576eabf8ff94ac476e2b3c66f8b3d28e79e.tar.gz
spark-92b70576eabf8ff94ac476e2b3c66f8b3d28e79e.tar.bz2
spark-92b70576eabf8ff94ac476e2b3c66f8b3d28e79e.zip
[SPARK-13761][ML] Deprecate validateParams
## What changes were proposed in this pull request? Deprecate validateParams() method here: https://github.com/apache/spark/blob/035d3acdf3c1be5b309a861d5c5beb803b946b5e/mllib/src/main/scala/org/apache/spark/ml/param/params.scala#L553 Move all functionality in overridden methods to transformSchema(). Check docs to make sure they indicate complex Param interaction checks should be done in transformSchema. ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11620 from hhbyyh/depreValid.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala14
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Predictor.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala13
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala12
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala8
33 files changed, 36 insertions, 89 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index cbac7bbf49..f4c6214a56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -110,12 +110,6 @@ class Pipeline @Since("1.4.0") (
@Since("1.2.0")
def getStages: Array[PipelineStage] = $(stages).clone()
- @Since("1.4.0")
- override def validateParams(): Unit = {
- super.validateParams()
- $(stages).foreach(_.validateParams())
- }
-
/**
* Fits the pipeline to the input dataset with additional parameters. If a stage is an
* [[Estimator]], its [[Estimator#fit]] method will be called on the input dataset to fit a model.
@@ -175,7 +169,6 @@ class Pipeline @Since("1.4.0") (
@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val theStages = $(stages)
require(theStages.toSet.size == theStages.length,
"Cannot have duplicate components in a pipeline.")
@@ -297,12 +290,6 @@ class PipelineModel private[ml] (
this(uid, stages.asScala.toArray)
}
- @Since("1.4.0")
- override def validateParams(): Unit = {
- super.validateParams()
- stages.foreach(_.validateParams())
- }
-
@Since("1.2.0")
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
@@ -311,7 +298,6 @@ class PipelineModel private[ml] (
@Since("1.2.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
index 4b27ee6c5a..ebe48700f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Predictor.scala
@@ -46,7 +46,6 @@ private[ml] trait PredictorParams extends Params
schema: StructType,
fitting: Boolean,
featuresDataType: DataType): StructType = {
- validateParams()
// TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
SchemaUtils.checkColumnType(schema, $(featuresCol), featuresDataType)
if (fitting) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index fdce273193..1f3325ad09 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -103,7 +103,6 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
protected def validateInputType(inputType: DataType): Unit = {}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
validateInputType(inputType)
if (schema.fieldNames.contains($(outputCol))) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 79332b0d02..ab00127899 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -81,7 +81,6 @@ private[clustering] trait KMeansParams extends Params with HasMaxIter with HasFe
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 6304b20d54..0de82b49ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -263,13 +263,6 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
- SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
- SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
- }
-
- @Since("1.6.0")
- override def validateParams(): Unit = {
if (isSet(docConcentration)) {
if (getDocConcentration.length != 1) {
require(getDocConcentration.length == getK, s"LDA docConcentration was of length" +
@@ -297,6 +290,8 @@ private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasM
s" must be >= 1. Found value: $getTopicConcentration")
}
}
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(topicDistributionCol), new VectorUDT)
}
private[clustering] def getOldOptimizer: OldLDAOptimizer = getOptimizer match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index 0c75317d82..33abc7c99d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -86,7 +86,6 @@ final class Bucketizer(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
SchemaUtils.appendColumn(schema, prepOutputField(schema))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
index 4abc459f53..b9e9d56853 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala
@@ -88,7 +88,6 @@ final class ChiSqSelector(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
@@ -136,7 +135,6 @@ final class ChiSqSelectorModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
val newField = prepOutputField(schema)
val outputFields = schema.fields :+ newField
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index cf151458f0..f7d08b39a9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -70,7 +70,6 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
val typeCandidates = List(new ArrayType(StringType, true), new ArrayType(StringType, false))
SchemaUtils.checkColumnTypes(schema, $(inputCol), typeCandidates)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 8af00581f7..61a78d73c4 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -69,7 +69,6 @@ class HashingTF(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[ArrayType],
s"The input column must be ArrayType, but got $inputType.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
index cebbe5c162..f36cf503a0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/IDF.scala
@@ -52,7 +52,6 @@ private[feature] trait IDFBase extends Params with HasInputCol with HasOutputCol
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
index 7d2a1da990..d3fe6e528f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Interaction.scala
@@ -61,13 +61,15 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
// optimistic schema; does not contain any ML attributes
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
+ require(get(inputCols).isDefined, "Input cols must be defined first.")
+ require(get(outputCol).isDefined, "Output col must be defined first.")
+ require($(inputCols).length > 0, "Input cols must have non-zero length.")
+ require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
StructType(schema.fields :+ StructField($(outputCol), new VectorUDT, false))
}
@Since("1.6.0")
override def transform(dataset: DataFrame): DataFrame = {
- validateParams()
val inputFeatures = $(inputCols).map(c => dataset.schema(c))
val featureEncoders = getFeatureEncoders(inputFeatures)
val featureAttrs = getFeatureAttrs(inputFeatures)
@@ -217,13 +219,6 @@ class Interaction @Since("1.6.0") (override val uid: String) extends Transformer
@Since("1.6.0")
override def copy(extra: ParamMap): Interaction = defaultCopy(extra)
- @Since("1.6.0")
- override def validateParams(): Unit = {
- require(get(inputCols).isDefined, "Input cols must be defined first.")
- require(get(outputCol).isDefined, "Output col must be defined first.")
- require($(inputCols).length > 0, "Input cols must have non-zero length.")
- require($(inputCols).distinct.length == $(inputCols).length, "Input cols must be distinct.")
- }
}
@Since("1.6.0")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 09fad23642..7de5a4d5d3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -37,7 +37,6 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index 3b4209bbc4..b13684a1cb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -59,7 +59,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
+ require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
@@ -69,9 +69,6 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
StructType(outputFields)
}
- override def validateParams(): Unit = {
- require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
- }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
index fa5013d3c9..4f67042629 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoder.scala
@@ -66,7 +66,6 @@ class OneHotEncoder(override val uid: String) extends Transformer
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputColName = $(inputCol)
val outputColName = $(outputCol)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index 80b124f747..305c3d187f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -77,7 +77,6 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
@@ -133,7 +132,6 @@ class PCAModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 18896fcc4d..e830d2a9ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -78,7 +78,6 @@ final class QuantileDiscretizer(override val uid: String)
def setSeed(value: Long): this.type = set(seed, value)
override def transformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), DoubleType)
val inputFields = schema.fields
require(inputFields.forall(_.name != $(outputCol)),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
index c21da218b3..ab5f4a1a9a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala
@@ -167,7 +167,6 @@ class RFormula(override val uid: String) extends Estimator[RFormulaModel] with R
// optimistic schema; does not contain any ML attributes
override def transformSchema(schema: StructType): StructType = {
- validateParams()
if (hasLabelCol(schema)) {
StructType(schema.fields :+ StructField($(featuresCol), new VectorUDT, true))
} else {
@@ -200,7 +199,6 @@ class RFormulaModel private[feature](
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
checkCanTransform(schema)
val withFeatures = pipelineModel.transformSchema(schema)
if (hasLabelCol(withFeatures)) {
@@ -263,7 +261,6 @@ private class ColumnPruner(columnsToPrune: Set[String]) extends Transformer {
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
StructType(schema.fields.filter(col => !columnsToPrune.contains(col.name)))
}
@@ -312,7 +309,6 @@ private class VectorAttributeRewriter(
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
StructType(
schema.fields.filter(_.name != vectorCol) ++
schema.fields.filter(_.name == vectorCol))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
index af6494b234..e0ca45b9a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/SQLTransformer.scala
@@ -74,7 +74,6 @@ class SQLTransformer @Since("1.6.0") (override val uid: String) extends Transfor
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val sc = SparkContext.getOrCreate()
val sqlContext = SQLContext.getOrCreate(sc)
val dummyRDD = sc.parallelize(Seq(Row.empty))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 9952d3bc9f..26ee8e1bf1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -94,7 +94,6 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
@@ -144,7 +143,6 @@ class StandardScalerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.isInstanceOf[VectorUDT],
s"Input column ${$(inputCol)} must be a vector column")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
index 0d4c968633..0a0e0b0960 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StopWordsRemover.scala
@@ -145,7 +145,6 @@ class StopWordsRemover(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputType = schema($(inputCol)).dataType
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 7dd794b9d7..c579a0d68e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,7 +39,6 @@ private[feature] trait StringIndexerBase extends Params with HasInputCol with Ha
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType == StringType || inputDataType.isInstanceOf[NumericType],
@@ -275,7 +274,6 @@ class IndexToString private[ml] (override val uid: String)
final def getLabels: Array[String] = $(labels)
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputColName = $(inputCol)
val inputDataType = schema(inputColName).dataType
require(inputDataType.isInstanceOf[NumericType],
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 7ff5ad143f..957e8e7a59 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -106,7 +106,6 @@ class VectorAssembler(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val inputColNames = $(inputCols)
val outputColName = $(outputCol)
val inputDataTypes = inputColNames.map(name => schema(name).dataType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 5c11760fab..bf4aef2a74 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -126,7 +126,6 @@ class VectorIndexer(override val uid: String) extends Estimator[VectorIndexerMod
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
// We do not transfer feature metadata since we do not know what types of features we will
// produce in transform().
val dataType = new VectorUDT
@@ -355,7 +354,6 @@ class VectorIndexerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
val dataType = new VectorUDT
require(isDefined(inputCol),
s"VectorIndexerModel requires input column parameter: $inputCol")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 300d63bd3a..b60e82de00 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -89,11 +89,6 @@ final class VectorSlicer(override val uid: String)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- override def validateParams(): Unit = {
- require($(indices).length > 0 || $(names).length > 0,
- s"VectorSlicer requires that at least one feature be selected.")
- }
-
override def transform(dataset: DataFrame): DataFrame = {
// Validity checks
transformSchema(dataset.schema)
@@ -139,7 +134,8 @@ final class VectorSlicer(override val uid: String)
}
override def transformSchema(schema: StructType): StructType = {
- validateParams()
+ require($(indices).length > 0 || $(names).length > 0,
+ s"VectorSlicer requires that at least one feature be selected.")
SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
if (schema.fieldNames.contains($(outputCol))) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 3d3c7bdc2f..95bae1c8a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -92,7 +92,6 @@ private[feature] trait Word2VecBase extends Params
* Validate and transform the input schema.
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 42411d2d8a..d7837b6730 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -58,9 +58,8 @@ class Param[T](val parent: String, val name: String, val doc: String, val isVali
/**
* Assert that the given value is valid for this parameter.
*
- * Note: Parameter checks involving interactions between multiple parameters should be
- * implemented in [[Params.validateParams()]]. Checks for input/output columns should be
- * implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
+ * Note: Parameter checks involving interactions between multiple parameters and input/output
+ * columns should be implemented in [[org.apache.spark.ml.PipelineStage.transformSchema()]].
*
* DEVELOPERS: This method is only called by [[ParamPair]], which means that all parameters
* should be specified via [[ParamPair]].
@@ -555,7 +554,9 @@ trait Params extends Identifiable with Serializable {
* Parameter value checks which do not depend on other parameters are handled by
* [[Param.validate()]]. This method does not handle input/output column parameters;
* those are checked during schema validation.
+ * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema
*/
+ @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0")
def validateParams(): Unit = {
// Do nothing by default. Override to handle Param interactions.
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index dacdac9a1d..f3bc9f095a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -162,7 +162,6 @@ private[recommendation] trait ALSParams extends ALSModelParams with HasMaxIter w
* @return output schema
*/
protected def validateAndTransformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
val ratingType = schema($(ratingCol)).dataType
@@ -220,7 +219,6 @@ class ALSModel private[ml] (
@Since("1.3.0")
override def transformSchema(schema: StructType): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(userCol), IntegerType)
SchemaUtils.checkColumnType(schema, $(itemCol), IntegerType)
SchemaUtils.appendColumn(schema, $(predictionCol), FloatType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
index e4339d67b9..0901642d39 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/AFTSurvivalRegression.scala
@@ -99,7 +99,6 @@ private[regression] trait AFTSurvivalRegressionParams extends Params
protected def validateAndTransformSchema(
schema: StructType,
fitting: Boolean): StructType = {
- validateParams()
SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
if (fitting) {
SchemaUtils.checkColumnType(schema, $(censorCol), DoubleType)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index b4e47c8073..46ba5589ff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -32,6 +32,7 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* Params for Generalized Linear Regression.
@@ -77,7 +78,10 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
import GeneralizedLinearRegression._
@Since("2.0.0")
- override def validateParams(): Unit = {
+ override def validateAndTransformSchema(
+ schema: StructType,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
if ($(solver) == "irls") {
setDefault(maxIter -> 25)
}
@@ -86,6 +90,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam
Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " +
s"with ${$(family)} family does not support ${$(link)} link function.")
}
+ super.validateAndTransformSchema(schema, fitting, featuresDataType)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
index 36b006c10e..20a0998201 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/IsotonicRegression.scala
@@ -105,7 +105,6 @@ private[regression] trait IsotonicRegressionBase extends Params with HasFeatures
protected[ml] def validateAndTransformSchema(
schema: StructType,
fitting: Boolean): StructType = {
- validateParams()
if (fitting) {
SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
if (hasWeightCol) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
index a3a8f65eac..dd3f4c6e53 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -138,16 +138,18 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead
new LDA().setTopicConcentration(-1.1)
}
- // validateParams()
- lda.validateParams()
+ val dummyDF = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(1.0, 2.0)))).toDF("id", "features")
+ // validate parameters
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(1.1)
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
withClue("LDA docConcentration validity check failed for bad array length") {
intercept[IllegalArgumentException] {
- lda.validateParams()
+ lda.transformSchema(dummyDF.schema)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
index 035bfc07b6..87206c777e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinMaxScalerSuite.scala
@@ -57,13 +57,15 @@ class MinMaxScalerSuite extends SparkFunSuite with MLlibTestSparkContext with De
test("MinMaxScaler arguments max must be larger than min") {
withClue("arguments max must be larger than min") {
+ val dummyDF = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(1.0, 2.0)))).toDF("id", "feature")
intercept[IllegalArgumentException] {
- val scaler = new MinMaxScaler().setMin(10).setMax(0)
- scaler.validateParams()
+ val scaler = new MinMaxScaler().setMin(10).setMax(0).setInputCol("feature")
+ scaler.transformSchema(dummyDF.schema)
}
intercept[IllegalArgumentException] {
- val scaler = new MinMaxScaler().setMin(0).setMax(0)
- scaler.validateParams()
+ val scaler = new MinMaxScaler().setMin(0).setMax(0).setInputCol("feature")
+ scaler.transformSchema(dummyDF.schema)
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
index 94191e5df3..6bb4678dc5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorSlicerSuite.scala
@@ -21,21 +21,21 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.attribute.{Attribute, AttributeGroup, NumericAttribute}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector, Vectors, VectorUDT}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{StructField, StructType}
class VectorSlicerSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
test("params") {
- val slicer = new VectorSlicer
+ val slicer = new VectorSlicer().setInputCol("feature")
ParamsSuite.checkParams(slicer)
assert(slicer.getIndices.length === 0)
assert(slicer.getNames.length === 0)
withClue("VectorSlicer should not have any features selected by default") {
intercept[IllegalArgumentException] {
- slicer.validateParams()
+ slicer.transformSchema(StructType(Seq(StructField("feature", new VectorUDT, true))))
}
}
}