aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala71
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala47
2 files changed, 104 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index ec0ea05f9e..1143f0f565 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -27,6 +27,7 @@ import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.sql._
+import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
@@ -46,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
* also includes y. Splits should be of length >= 3 and strictly increasing.
* Values at -inf, inf must be explicitly provided to cover all Double values;
* otherwise, values outside the splits specified will be treated as errors.
+ *
+ * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
+ *
* @group param
*/
@Since("1.4.0")
@@ -73,15 +77,47 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.4.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /**
+ * Param for how to handle invalid entries. Options are skip (filter out rows with
+ * invalid values), error (throw an error), or keep (keep invalid values in a special additional
+ * bucket).
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.1.0")
+ val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
+ "invalid entries. Options are skip (filter out rows with invalid values), " +
+ "error (throw an error), or keep (keep invalid values in a special additional bucket).",
+ ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
+
+ /** @group getParam */
+ @Since("2.1.0")
+ def getHandleInvalid: String = $(handleInvalid)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+ setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema)
- val bucketizer = udf { feature: Double =>
- Bucketizer.binarySearchForBuckets($(splits), feature)
+ val (filteredDataset, keepInvalid) = {
+ if (getHandleInvalid == Bucketizer.SKIP_INVALID) {
+ // "skip" NaN option is set, will filter out NaN values in the dataset
+ (dataset.na.drop().toDF(), false)
+ } else {
+ (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID)
+ }
+ }
+
+ val bucketizer: UserDefinedFunction = udf { (feature: Double) =>
+ Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid)
}
- val newCol = bucketizer(dataset($(inputCol)))
- val newField = prepOutputField(dataset.schema)
- dataset.withColumn($(outputCol), newCol, newField.metadata)
+
+ val newCol = bucketizer(filteredDataset($(inputCol)))
+ val newField = prepOutputField(filteredDataset.schema)
+ filteredDataset.withColumn($(outputCol), newCol, newField.metadata)
}
private def prepOutputField(schema: StructType): StructField = {
@@ -106,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String
@Since("1.6.0")
object Bucketizer extends DefaultParamsReadable[Bucketizer] {
+ private[feature] val SKIP_INVALID: String = "skip"
+ private[feature] val ERROR_INVALID: String = "error"
+ private[feature] val KEEP_INVALID: String = "keep"
+ private[feature] val supportedHandleInvalid: Array[String] =
+ Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID)
+
/**
* We require splits to be of length >= 3 and to be in strictly increasing order.
* No NaN split should be accepted.
@@ -126,11 +168,26 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] {
/**
* Binary searching in several buckets to place each data point.
+ * @param splits array of split points
+ * @param feature data point
+ * @param keepInvalid NaN flag.
+ * Set "true" to make an extra bucket for NaN values;
+ * Set "false" to report an error for NaN values
+ * @return bucket for each data point
* @throws SparkException if a feature is < splits.head or > splits.last
*/
- private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = {
+
+ private[feature] def binarySearchForBuckets(
+ splits: Array[Double],
+ feature: Double,
+ keepInvalid: Boolean): Double = {
if (feature.isNaN) {
- splits.length - 1
+ if (keepInvalid) {
+ splits.length - 1
+ } else {
+ throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," +
+ " try setting Bucketizer.handleInvalid.")
+ }
} else if (feature == splits.last) {
splits.length - 2
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index 05e034d90f..b9e01dde70 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params
/**
* Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
+ *
+ * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values.
+ *
* default: 2
* @group param
*/
@@ -61,17 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getRelativeError: Double = getOrDefault(relativeError)
+
+ /**
+ * Param for how to handle invalid entries. Options are skip (filter out rows with
+ * invalid values), error (throw an error), or keep (keep invalid values in a special additional
+ * bucket).
+ * Default: "error"
+ * @group param
+ */
+ @Since("2.1.0")
+ val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" +
+ "invalid entries. Options are skip (filter out rows with invalid values), " +
+ "error (throw an error), or keep (keep invalid values in a special additional bucket).",
+ ParamValidators.inArray(Bucketizer.supportedHandleInvalid))
+ setDefault(handleInvalid, Bucketizer.ERROR_INVALID)
+
+ /** @group getParam */
+ @Since("2.1.0")
+ def getHandleInvalid: String = $(handleInvalid)
+
}
/**
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The number of bins can be set using the `numBuckets` parameter. It is
- * possible that the number of buckets used will be less than this value, for example, if there
- * are too few distinct values of the input to create enough distinct quantiles. Note also that
- * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets
- * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special
- * bucket(4).
- * The bin ranges are chosen using an approximate algorithm (see the documentation for
+ * possible that the number of buckets used will be smaller than this value, for example, if there
+ * are too few distinct values of the input to create enough distinct quantiles.
+ *
+ * NaN handling: Note also that
+ * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can
+ * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`.
+ * If the user chooses to keep NaN values, they will be handled specially and placed into their own
+ * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3],
+ * but NaNs will be counted in a special bucket[4].
+ *
+ * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for
* [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]]
* for a detailed description). The precision of the approximation can be controlled with the
* `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`,
@@ -100,6 +127,10 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
@Since("1.6.0")
def setOutputCol(value: String): this.type = set(outputCol, value)
+ /** @group setParam */
+ @Since("2.1.0")
+ def setHandleInvalid(value: String): this.type = set(handleInvalid, value)
+
@Since("1.6.0")
override def transformSchema(schema: StructType): StructType = {
SchemaUtils.checkNumericType(schema, $(inputCol))
@@ -124,7 +155,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui
log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" +
s" buckets as a result.")
}
- val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted)
+ val bucketizer = new Bucketizer(uid)
+ .setSplits(distinctSplits.sorted)
+ .setHandleInvalid($(handleInvalid))
copyValues(bucketizer.setParent(this))
}