aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala54
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala45
2 files changed, 77 insertions, 22 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index 810b02febb..99321bcc7c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap
private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
/**
- * Param for how to handle unseen labels. Options are 'skip' (filter out rows with
- * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
- * bucket, at index numLabels.
+ * Param for how to handle invalid data (unseen labels or NULL values).
+ * Options are 'skip' (filter out rows with invalid data),
+ * 'error' (throw an error), or 'keep' (put invalid data in a special additional
+ * bucket, at index numLabels).
* Default: "error"
* @group param
*/
@Since("1.6.0")
val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
- "unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
- "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
- "at index numLabels).",
+ "invalid data (unseen labels or NULL values). " +
+ "Options are 'skip' (filter out rows with invalid data), error (throw an error), " +
+ "or 'keep' (put invalid data in a special additional bucket, at index numLabels).",
ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
- setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
+ setDefault(handleInvalid, StringIndexer.ERROR_INVALID)
/** @group getParam */
@Since("1.6.0")
@@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") (
@Since("2.0.0")
override def fit(dataset: Dataset[_]): StringIndexerModel = {
transformSchema(dataset.schema, logging = true)
- val counts = dataset.select(col($(inputCol)).cast(StringType))
+ val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType))
.rdd
.map(_.getString(0))
.countByValue()
@@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") (
@Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
- private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
- private[feature] val ERROR_UNSEEN_LABEL: String = "error"
- private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
private[feature] val supportedHandleInvalids: Array[String] =
- Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
@@ -188,7 +189,7 @@ class StringIndexerModel (
transformSchema(dataset.schema, logging = true)
val filteredLabels = getHandleInvalid match {
- case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
+ case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
case _ => labels
}
@@ -196,22 +197,31 @@ class StringIndexerModel (
.withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
val (filteredDataset, keepInvalid) = getHandleInvalid match {
- case StringIndexer.SKIP_UNSEEN_LABEL =>
+ case StringIndexer.SKIP_INVALID =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
- (dataset.where(filterer(dataset($(inputCol)))), false)
- case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
+ (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false)
+ case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
}
val indexer = udf { label: String =>
- if (labelToIndex.contains(label)) {
- labelToIndex(label)
- } else if (keepInvalid) {
- labels.length
+ if (label == null) {
+ if (keepInvalid) {
+ labels.length
+ } else {
+ throw new SparkException("StringIndexer encountered NULL value. To handle or skip " +
+ "NULLS, try setting StringIndexer.handleInvalid.")
+ }
} else {
- throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
- s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
+ if (labelToIndex.contains(label)) {
+ labelToIndex(label)
+ } else if (keepInvalid) {
+ labels.length
+ } else {
+ throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
+ s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
index 188dffb3dd..8d9042b31e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala
@@ -122,6 +122,51 @@ class StringIndexerSuite
assert(output === expected)
}
+ test("StringIndexer with NULLs") {
+ val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null))
+ val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null))
+ val df = data.toDF("id", "label")
+ val df2 = data2.toDF("id", "label")
+
+ val indexer = new StringIndexer()
+ .setInputCol("label")
+ .setOutputCol("labelIndex")
+
+ withClue("StringIndexer should throw error when setHandleInvalid=error " +
+ "when given NULL values") {
+ intercept[SparkException] {
+ indexer.setHandleInvalid("error")
+ indexer.fit(df).transform(df2).collect()
+ }
+ }
+
+ indexer.setHandleInvalid("skip")
+ val transformedSkip = indexer.fit(df).transform(df2)
+ val attrSkip = Attribute
+ .fromStructField(transformedSkip.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrSkip.values.get === Array("b", "a"))
+ val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0
+ val expectedSkip = Set((0, 1.0), (1, 0.0))
+ assert(outputSkip === expectedSkip)
+
+ indexer.setHandleInvalid("keep")
+ val transformedKeep = indexer.fit(df).transform(df2)
+ val attrKeep = Attribute
+ .fromStructField(transformedKeep.schema("labelIndex"))
+ .asInstanceOf[NominalAttribute]
+ assert(attrKeep.values.get === Array("b", "a", "__unknown"))
+ val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r =>
+ (r.getInt(0), r.getDouble(1))
+ }.collect().toSet
+ // a -> 1, b -> 0, null -> 2
+ val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0))
+ assert(outputKeep === expectedKeep)
+ }
+
test("StringIndexerModel should keep silent if the input column does not exist.") {
val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c"))
.setInputCol("label")