diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala | 71 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala | 47 |
2 files changed, 104 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index ec0ea05f9e..1143f0f565 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -27,6 +27,7 @@ import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql._ +import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DoubleType, StructField, StructType} @@ -46,6 +47,9 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String * also includes y. Splits should be of length >= 3 and strictly increasing. * Values at -inf, inf must be explicitly provided to cover all Double values; * otherwise, values outside the splits specified will be treated as errors. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * @group param */ @Since("1.4.0") @@ -73,15 +77,47 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema) - val bucketizer = udf { feature: Double => - Bucketizer.binarySearchForBuckets($(splits), feature) + val (filteredDataset, keepInvalid) = { + if (getHandleInvalid == Bucketizer.SKIP_INVALID) { + // "skip" NaN option is set, will filter out NaN values in the dataset + (dataset.na.drop().toDF(), false) + } else { + (dataset.toDF(), getHandleInvalid == Bucketizer.KEEP_INVALID) + } + } + + val bucketizer: UserDefinedFunction = udf { (feature: Double) => + Bucketizer.binarySearchForBuckets($(splits), feature, keepInvalid) } - val newCol = bucketizer(dataset($(inputCol))) - val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol, newField.metadata) + + val newCol = bucketizer(filteredDataset($(inputCol))) + val newField = prepOutputField(filteredDataset.schema) + filteredDataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { @@ -106,6 +142,12 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" + private[feature] val supportedHandleInvalid: Array[String] = + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) + /** * We require splits to be of length >= 3 and to be in strictly increasing order. * No NaN split should be accepted. @@ -126,11 +168,26 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { /** * Binary searching in several buckets to place each data point. + * @param splits array of split points + * @param feature data point + * @param keepInvalid NaN flag. + * Set "true" to make an extra bucket for NaN values; + * Set "false" to report an error for NaN values + * @return bucket for each data point * @throws SparkException if a feature is < splits.head or > splits.last */ - private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { + + private[feature] def binarySearchForBuckets( + splits: Array[Double], + feature: Double, + keepInvalid: Boolean): Double = { if (feature.isNaN) { - splits.length - 1 + if (keepInvalid) { + splits.length - 1 + } else { + throw new SparkException("Bucketizer encountered NaN value. To handle or skip NaNs," + + " try setting Bucketizer.handleInvalid.") + } } else if (feature == splits.last) { splits.length - 2 } else { diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index 05e034d90f..b9e01dde70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -36,6 +36,9 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. + * + * See also [[handleInvalid]], which can optionally create an additional bucket for NaN values. + * * default: 2 * @group param */ @@ -61,17 +64,41 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getRelativeError: Double = getOrDefault(relativeError) + + /** + * Param for how to handle invalid entries. Options are skip (filter out rows with + * invalid values), error (throw an error), or keep (keep invalid values in a special additional + * bucket). + * Default: "error" + * @group param + */ + @Since("2.1.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle" + + "invalid entries. Options are skip (filter out rows with invalid values), " + + "error (throw an error), or keep (keep invalid values in a special additional bucket).", + ParamValidators.inArray(Bucketizer.supportedHandleInvalid)) + setDefault(handleInvalid, Bucketizer.ERROR_INVALID) + + /** @group getParam */ + @Since("2.1.0") + def getHandleInvalid: String = $(handleInvalid) + } /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The number of bins can be set using the `numBuckets` parameter. It is - * possible that the number of buckets used will be less than this value, for example, if there - * are too few distinct values of the input to create enough distinct quantiles. Note also that - * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets - * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special - * bucket(4). - * The bin ranges are chosen using an approximate algorithm (see the documentation for + * possible that the number of buckets used will be smaller than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. + * + * NaN handling: Note also that + * QuantileDiscretizer will raise an error when it finds NaN values in the dataset, but the user can + * also choose to either keep or remove NaN values within the dataset by setting `handleInvalid`. + * If the user chooses to keep NaN values, they will be handled specially and placed into their own + * bucket, for example, if 4 buckets are used, then non-NaN data will be put into buckets[0-3], + * but NaNs will be counted in a special bucket[4]. + * + * Algorithm: The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the * `relativeError` parameter. The lower and upper bin bounds will be `-Infinity` and `+Infinity`, @@ -100,6 +127,10 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.1.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + @Since("1.6.0") override def transformSchema(schema: StructType): StructType = { SchemaUtils.checkNumericType(schema, $(inputCol)) @@ -124,7 +155,9 @@ final class QuantileDiscretizer @Since("1.6.0") (@Since("1.6.0") override val ui log.warn(s"Some quantiles were identical. Bucketing to ${distinctSplits.length - 1}" + s" buckets as a result.") } - val bucketizer = new Bucketizer(uid).setSplits(distinctSplits.sorted) + val bucketizer = new Bucketizer(uid) + .setSplits(distinctSplits.sorted) + .setHandleInvalid($(handleInvalid)) copyValues(bucketizer.setParent(this)) } |