aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <yuhao.yang@intel.com>2016-06-02 16:37:01 -0700
committerNick Pentreath <nickp@za.ibm.com>2016-06-02 16:37:01 -0700
commit5855e0057defeab8006ca4f7b0196003bbc9e899 (patch)
treea5a52a9e18335198c768db29cbd70b83dc364593 /mllib
parentccd298eb6794cbcb220ac9889db60d745231e0fe (diff)
downloadspark-5855e0057defeab8006ca4f7b0196003bbc9e899.tar.gz
spark-5855e0057defeab8006ca4f7b0196003bbc9e899.tar.bz2
spark-5855e0057defeab8006ca4f7b0196003bbc9e899.zip
[SPARK-15668][ML] ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type
## What changes were proposed in this pull request? ml.feature: update check schema to avoid confusion when user use MLlib.vector as input type ## How was this patch tested? existing ut Author: Yuhao Yang <yuhao.yang@intel.com> Closes #13411 from hhbyyh/schemaCheck.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala25
4 files changed, 25 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
index 1b5159902e..7298a18ff8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MaxAbsScaler.scala
@@ -39,9 +39,7 @@ private[feature] trait MaxAbsScalerParams extends Params with HasInputCol with H
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
- require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${$(inputCol)} must be a vector column")
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
require(!schema.fieldNames.contains($(outputCol)),
s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
index d15f1b8563..a27bed5333 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinMaxScaler.scala
@@ -63,9 +63,7 @@ private[feature] trait MinMaxScalerParams extends Params with HasInputCol with H
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
require($(min) < $(max), s"The specified min(${$(min)}) is larger or equal to max(${$(max)})")
- val inputType = schema($(inputCol)).dataType
- require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${$(inputCol)} must be a vector column")
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
require(!schema.fieldNames.contains($(outputCol)),
s"Output column ${$(outputCol)} already exists.")
val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
index dbbaa5aa46..2f667af9d1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/PCA.scala
@@ -49,8 +49,16 @@ private[feature] trait PCAParams extends Params with HasInputCol with HasOutputC
/** @group getParam */
def getK: Int = $(k)
-}
+ /** Validates and transforms the input schema. */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ require(!schema.fieldNames.contains($(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+}
/**
* :: Experimental ::
* PCA trains a model to project vectors to a lower dimensional space of the top [[PCA!.k]]
@@ -86,13 +94,7 @@ class PCA (override val uid: String) extends Estimator[PCAModel] with PCAParams
}
override def transformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
- require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${$(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains($(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
- StructType(outputFields)
+ validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): PCA = defaultCopy(extra)
@@ -148,13 +150,7 @@ class PCAModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
- require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${$(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains($(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
- StructType(outputFields)
+ validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): PCAModel = {
@@ -201,7 +197,7 @@ object PCAModel extends MLReadable[PCAModel] {
val versionRegex = "([0-9]+)\\.([0-9]+).*".r
val hasExplainedVariance = metadata.sparkVersion match {
case versionRegex(major, minor) =>
- (major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6))
+ major.toInt >= 2 || (major.toInt == 1 && minor.toInt > 6)
case _ => false
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index 9d084b520c..7cec369c23 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -62,6 +62,15 @@ private[feature] trait StandardScalerParams extends Params with HasInputCol with
/** @group getParam */
def getWithStd: Boolean = $(withStd)
+ /** Validates and transforms the input schema. */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+ require(!schema.fieldNames.contains($(outputCol)),
+ s"Output column ${$(outputCol)} already exists.")
+ val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
+ StructType(outputFields)
+ }
+
setDefault(withMean -> false, withStd -> true)
}
@@ -105,13 +114,7 @@ class StandardScaler(override val uid: String) extends Estimator[StandardScalerM
}
override def transformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
- require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${$(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains($(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
- StructType(outputFields)
+ validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): StandardScaler = defaultCopy(extra)
@@ -159,13 +162,7 @@ class StandardScalerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- val inputType = schema($(inputCol)).dataType
- require(inputType.isInstanceOf[VectorUDT],
- s"Input column ${$(inputCol)} must be a vector column")
- require(!schema.fieldNames.contains($(outputCol)),
- s"Output column ${$(outputCol)} already exists.")
- val outputFields = schema.fields :+ StructField($(outputCol), new VectorUDT, false)
- StructType(outputFields)
+ validateAndTransformSchema(schema)
}
override def copy(extra: ParamMap): StandardScalerModel = {