aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala26
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala15
3 files changed, 41 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
index ebfa972532..e4485eb038 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala
@@ -33,7 +33,8 @@ import org.apache.spark.util.collection.OpenHashMap
/**
* Base trait for [[StringIndexer]] and [[StringIndexerModel]].
*/
-private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol {
+private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol
+ with HasHandleInvalid {
/** Validates and transforms the input schema. */
protected def validateAndTransformSchema(schema: StructType): StructType = {
@@ -66,12 +67,15 @@ class StringIndexer(override val uid: String) extends Estimator[StringIndexerMod
def this() = this(Identifiable.randomUID("strIdx"))
/** @group setParam */
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+ setDefault(handleInvalid, "error")
+
+ /** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
def setOutputCol(value: String): this.type = set(outputCol, value)
- // TODO: handle unseen labels
override def fit(dataset: DataFrame): StringIndexerModel = {
val counts = dataset.select(col($(inputCol)).cast(StringType))
@@ -112,6 +116,10 @@ class StringIndexerModel private[ml] (
}
/** @group setParam */
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+ setDefault(handleInvalid, "error")
+
+ /** @group setParam */
def setInputCol(value: String): this.type = set(inputCol, value)
/** @group setParam */
@@ -128,14 +136,24 @@ class StringIndexerModel private[ml] (
if (labelToIndex.contains(label)) {
labelToIndex(label)
} else {
- // TODO: handle unseen labels
throw new SparkException(s"Unseen label: $label.")
}
}
+
val outputColName = $(outputCol)
val metadata = NominalAttribute.defaultAttr
.withName(outputColName).withValues(labels).toMetadata()
- dataset.select(col("*"),
+ // If we are skipping invalid records, filter them out.
+ val filteredDataset = (getHandleInvalid) match {
+ case "skip" => {
+ val filterer = udf { label: String =>
+ labelToIndex.contains(label)
+ }
+ dataset.where(filterer(dataset($(inputCol))))
+ }
+ case _ => dataset
+ }
+ filteredDataset.select(col("*"),
indexer(dataset($(inputCol)).cast(StringType)).as(outputColName, metadata))
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index a97c8059b8..da4c076830 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -59,6 +59,10 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
+ ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
+ "will filter out rows with bad values), or error (which will throw an errror). More " +
+ "options may be added later.",
+ isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"),
ParamDesc[Boolean]("standardization", "whether to standardize the training features" +
" before fitting the model.", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("this.getClass.getName.hashCode.toLong")),
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index f332630c32..23e2b6cc43 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -248,6 +248,21 @@ private[ml] trait HasFitIntercept extends Params {
}
/**
+ * Trait for shared param handleInvalid.
+ */
+private[ml] trait HasHandleInvalid extends Params {
+
+ /**
+ * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later..
+ * @group param
+ */
+ final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", ParamValidators.inArray(Array("skip", "error")))
+
+ /** @group getParam */
+ final def getHandleInvalid: String = $(handleInvalid)
+}
+
+/**
* Trait for shared param standardization (default: true).
*/
private[ml] trait HasStandardization extends Params {