aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala8
2 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a2dc8a8b96..f4e2507575 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -88,6 +88,9 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
/**
* :: Experimental ::
* Model fitted by [[StringIndexer]].
+ * NOTE: During transformation, if the input column does not exist,
+ * [[StringIndexerModel.transform]] would return the input dataset unmodified.
+ * This is a temporary fix for the case when target labels do not exist during prediction.
*/
@Experimental
class StringIndexerModel private[ml] (
@@ -112,6 +115,12 @@ class StringIndexerModel private[ml] (
def setOutputCol(value: String): this.type = set(outputCol, value)
override def transform(dataset: DataFrame): DataFrame = {
+ if (!dataset.schema.fieldNames.contains($(inputCol))) {
+ logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " +
+ "Skip StringIndexerModel.")
+ return dataset
+ }
+
val indexer = udf { label: String =>
if (labelToIndex.contains(label)) {
labelToIndex(label)
@@ -128,6 +137,11 @@ class StringIndexerModel private[ml] (
}
override def transformSchema(schema: StructType): StructType = {
- validateAndTransformSchema(schema)
+ if (schema.fieldNames.contains($(inputCol))) {
+ validateAndTransformSchema(schema)
+ } else {
+ // If the input column does not exist during transformation, we skip StringIndexerModel.
+ schema
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index cbf1e8ddcb..5f557e16e5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -60,4 +60,12 @@ class StringIndexerSuite extends SparkFunSuite with MLlibTestSparkContext {
val expected = Set((0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0))
assert(output === expected)
}
+
+ test("StringIndexerModel should keep silent if the input column does not exist.") {
+ val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+ val df = sqlContext.range(0L, 10L)
+ assert(indexerModel.transform(df).eq(df))
+ }
}