aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-08-05 22:07:59 +0100
committerSean Owen <sowen@cloudera.com>2016-08-05 22:07:59 +0100
commit6cbde337a539e5bb170d0eb81f715a95ee9c9af3 (patch)
treebb169670d130594c9130a2c1b600b2975301e664 /mllib/src/main
parent1f96c97f2374a95140a0c72b1f4eae50ac21d84a (diff)
downloadspark-6cbde337a539e5bb170d0eb81f715a95ee9c9af3.tar.gz
spark-6cbde337a539e5bb170d0eb81f715a95ee9c9af3.tar.bz2
spark-6cbde337a539e5bb170d0eb81f715a95ee9c9af3.zip
[SPARK-16750][FOLLOW-UP][ML] Add transformSchema for StringIndexer/VectorAssembler and fix failed tests.
## What changes were proposed in this pull request? This is follow-up for #14378. When we add ```transformSchema``` for all estimators and transformers, I found there are tests failed for ```StringIndexer``` and ```VectorAssembler```. So I moved these parts of work separately in this PR, to make it more clear to review. The corresponding tests should throw ```IllegalArgumentException``` at schema validation period after we add ```transformSchema```. It's efficient that to throw exception at the start of ```fit``` or ```transform``` rather than during the process. ## How was this patch tested? Modified unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14455 from yanboliang/transformSchema.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala1
2 files changed, 4 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index fe79e2ec80..80fe46796f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -85,6 +85,7 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
+ transformSchema(dataset.schema, logging = true)
val counts = dataset.select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
@@ -160,7 +161,7 @@ class StringIndexerModel (
"Skip StringIndexerModel.")
return dataset.toDF
}
- validateAndTransformSchema(dataset.schema)
+ transformSchema(dataset.schema, logging = true)
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
@@ -305,6 +306,7 @@ class IndexToString private[ml] (@Since("1.5.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
val inputColSchema = dataset.schema($(inputCol))
// If the labels array is empty use column metadata
val values = if (!isDefined(labels) || $(labels).isEmpty) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
index 142a2ae44c..ca900536bc 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala
@@ -51,6 +51,7 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String)
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
+ transformSchema(dataset.schema, logging = true)
// Schema transformation.
val schema = dataset.schema
lazy val first = dataset.toDF.first()