diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-06-03 15:16:24 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-06-03 15:16:36 -0700 |
commit | b2a22a651f9d86ba85c78058c42402e7fdb3c4f1 (patch) | |
tree | 0da8a2c2b554b2a9d5173488cbb0b904b9c27357 | |
parent | ca21fff7dad14da9549dfdfcb35de627dad99ff8 (diff) | |
download | spark-b2a22a651f9d86ba85c78058c42402e7fdb3c4f1.tar.gz spark-b2a22a651f9d86ba85c78058c42402e7fdb3c4f1.tar.bz2 spark-b2a22a651f9d86ba85c78058c42402e7fdb3c4f1.zip |
[SPARK-8051] [MLLIB] make StringIndexerModel silent if input column does not exist
This is just a workaround to a bigger problem. Some pipeline stages may not be effective during prediction, and they should not complain about missing required columns, e.g. `StringIndexerModel`. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #6595 from mengxr/SPARK-8051 and squashes the following commits:
b6a36b9 [Xiangrui Meng] add doc
f143fd4 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-8051
8ee7c7e [Xiangrui Meng] use SparkFunSuite
e112394 [Xiangrui Meng] make StringIndexerModel silent if input column does not exist
(cherry picked from commit 26c9d7a0f975009e22ec91e5c0b5cfcada79b35e)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala | 16 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala | 8 |
2 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a2dc8a8b96..f4e2507575 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod /** * :: Experimental :: * Model fitted by [[StringIndexer]]. + * NOTE: During transformation, if the input column does not exist, + * [[StringIndexerModel.transform]] would return the input dataset unmodified. + * This is a temporary fix for the case when target labels do not exist during prediction. */ @Experimental class StringIndexerModel private[ml] ( @@ -112,6 +115,12 @@ class StringIndexerModel private[ml] ( def setOutputCol(value: String): this.type = set(outputCol, value) override def transform(dataset: DataFrame): DataFrame = { + if (!dataset.schema.fieldNames.contains($(inputCol))) { + logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + + "Skip StringIndexerModel.") + return dataset + } + val indexer = udf { label: String => if (labelToIndex.contains(label)) { labelToIndex(label) @@ -128,6 +137,11 @@ class StringIndexerModel private[ml] ( } override def transformSchema(schema: StructType): StructType = { - validateAndTransformSchema(schema) + if (schema.fieldNames.contains($(inputCol))) { + validateAndTransformSchema(schema) + } else { + // If the input column does not exist during transformation, we skip StringIndexerModel. + schema + } } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 89c2fe4557..5184863058 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -61,4 +61,12 @@ class StringIndexerSuite extends FunSuite with MLlibTestSparkContext { val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)) assert(output === expected) } + + test("StringIndexerModel should keep silent if the input column does not exist.") { + val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) + .setInputCol("label") + .setOutputCol("labelIndex") + val df = sqlContext.range(0L, 10L) + assert(indexerModel.transform(df).eq(df)) + } } |