aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorVinceShieh <vincent.xie@intel.com>2017-03-07 11:24:20 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-03-07 11:24:20 -0800
commit4a9034b17374cf19c77cb74e36c86cd085d59602 (patch)
tree1bb589f4efd445295897173aff2146ecfec425f8 /mllib/src/main
parentc05baabf10dd4c808929b4ae7a6d118aba6dd665 (diff)
downloadspark-4a9034b17374cf19c77cb74e36c86cd085d59602.tar.gz
spark-4a9034b17374cf19c77cb74e36c86cd085d59602.tar.bz2
spark-4a9034b17374cf19c77cb74e36c86cd085d59602.zip
[SPARK-17498][ML] StringIndexer enhancement for handling unseen labels
## What changes were proposed in this pull request? This PR is an enhancement to ML StringIndexer. Before this PR, String Indexer only supports "skip"/"error" options to deal with unseen records. But those unseen records might still be useful and user would like to keep the unseen labels in certain use cases, This PR enables StringIndexer to support keeping unseen labels as indices [numLabels]. '''Before StringIndexer().setHandleInvalid("skip") StringIndexer().setHandleInvalid("error") '''After support the third option "keep" StringIndexer().setHandleInvalid("keep") ## How was this patch tested? Test added in StringIndexerSuite Signed-off-by: VinceShieh <vincent.xieintel.com> (Please fill in changes proposed in this fix) Author: VinceShieh <vincent.xie@intel.com> Closes #16883 from VinceShieh/spark-17498.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala65
1 files changed, 49 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index a503411b63..810b02febb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -17,6 +17,8 @@
package org.apache.spark.ml.feature
+import scala.language.existentials
+
import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
@@ -24,7 +26,7 @@ import org.apache.spark.annotation.Since
import org.apache.spark.ml.{Estimator, Model, Transformer}
import org.apache.spark.ml.attribute.{Attribute, NominalAttribute}
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
@@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
-private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
- with HasHandleInvalid {
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * Param for how to handle unseen labels. Options are 'skip' (filter out rows with
+ * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional
+ * bucket, at index numLabels.
+ * Default: "error"
+ * @group param
+ */
+ @Since("1.6.0")
+ val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " +
+ "unseen labels. Options are 'skip' (filter out rows with unseen labels), " +
+ "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " +
+ "at index numLabels).",
+ ParamValidators.inArray(StringIndexer.supportedHandleInvalids))
+
+ setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL)
+
+ /** @group getParam */
+ @Since("1.6.0")
+ def getHandleInvalid: String = $(handleInvalid)
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") (
/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
- setDefault(handleInvalid, "error")
/** @group setParam */
@Since("1.4.0")
@@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") (
@Since("1.6.0")
object StringIndexer extends DefaultParamsReadable[StringIndexer] {
+ private[feature] val SKIP_UNSEEN_LABEL: String = "skip"
+ private[feature] val ERROR_UNSEEN_LABEL: String = "error"
+ private[feature] val KEEP_UNSEEN_LABEL: String = "keep"
+ private[feature] val supportedHandleInvalids: Array[String] =
+ Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL)
@Since("1.6.0")
override def load(path: String): StringIndexer = super.load(path)
@@ -144,7 +169,6 @@ class StringIndexerModel (
/** @group setParam */
@Since("1.6.0")
def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
- setDefault(handleInvalid, "error")
/** @group setParam */
@Since("1.4.0")
@@ -163,25 +187,34 @@ class StringIndexerModel (
}
transformSchema(dataset.schema, logging = true)
- val indexer = udf { label: String =>
- if (labelToIndex.contains(label)) {
- labelToIndex(label)
- } else {
- throw new SparkException(s"Unseen label: $label.")
- }
+ val filteredLabels = getHandleInvalid match {
+ case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown"
+ case _ => labels
}
val metadata = NominalAttribute.defaultAttr
- .withName($(outputCol)).withValues(labels).toMetadata()
+ .withName($(outputCol)).withValues(filteredLabels).toMetadata()
// If we are skipping invalid records, filter them out.
- val filteredDataset = getHandleInvalid match {
- case "skip" =>
+ val (filteredDataset, keepInvalid) = getHandleInvalid match {
+ case StringIndexer.SKIP_UNSEEN_LABEL =>
val filterer = udf { label: String =>
labelToIndex.contains(label)
}
- dataset.where(filterer(dataset($(inputCol))))
- case _ => dataset
+ (dataset.where(filterer(dataset($(inputCol)))), false)
+ case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL)
}
+
+ val indexer = udf { label: String =>
+ if (labelToIndex.contains(label)) {
+ labelToIndex(label)
+ } else if (keepInvalid) {
+ labels.length
+ } else {
+ throw new SparkException(s"Unseen label: $label. To handle unseen labels, " +
+ s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.")
+ }
+ }
+
filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata))
}