diff options
7 files changed, 85 insertions, 12 deletions
diff --git a/docs/ml-features.md b/docs/ml-features.md index 746593fb9e..a39b31c8f7 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -1102,7 +1102,11 @@ for more details on the API. ## QuantileDiscretizer `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned -categorical features. The number of bins is set by the `numBuckets` parameter. +categorical features. The number of bins is set by the `numBuckets` parameter. It is possible +that the number of buckets used will be less than this value, for example, if there are too few +distinct values of the input to create enough distinct quantiles. Note also that NaN values are +handled specially and placed into their own bucket. For example, if 4 buckets are used, then +non-NaN data will be put into buckets[0-3], but NaNs will be counted in a special bucket[4]. The bin ranges are chosen using an approximate algorithm (see the documentation for [approxQuantile](api/scala/index.html#org.apache.spark.sql.DataFrameStatFunctions) for a detailed description). The precision of the approximation can be controlled with the diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index 100d9e7f6c..ec0ea05f9e 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -106,7 +106,10 @@ final class Bucketizer @Since("1.4.0") (@Since("1.4.0") override val uid: String @Since("1.6.0") object Bucketizer extends DefaultParamsReadable[Bucketizer] { - /** We require splits to be of length >= 3 and to be in strictly increasing order. */ + /** + * We require splits to be of length >= 3 and to be in strictly increasing order. + * No NaN split should be accepted. + */ private[feature] def checkSplits(splits: Array[Double]): Boolean = { if (splits.length < 3) { false @@ -114,10 +117,10 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { var i = 0 val n = splits.length - 1 while (i < n) { - if (splits(i) >= splits(i + 1)) return false + if (splits(i) >= splits(i + 1) || splits(i).isNaN) return false i += 1 } - true + !splits(n).isNaN } } @@ -126,7 +129,9 @@ object Bucketizer extends DefaultParamsReadable[Bucketizer] { * @throws SparkException if a feature is < splits.head or > splits.last */ private[feature] def binarySearchForBuckets(splits: Array[Double], feature: Double): Double = { - if (feature == splits.last) { + if (feature.isNaN) { + splits.length - 1 + } else if (feature == splits.last) { splits.length - 2 } else { val idx = ju.Arrays.binarySearch(splits, feature) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index e09800877c..1e59d71a70 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -39,7 +39,7 @@ private[feature] trait QuantileDiscretizerBase extends Params * default: 2 * @group param */ - val numBuckets = new IntParam(this, "numBuckets", "Maximum number of buckets (quantiles, or " + + val numBuckets = new IntParam(this, "numBuckets", "Number of buckets (quantiles, or " + "categories) into which data points are grouped. Must be >= 2.", ParamValidators.gtEq(2)) setDefault(numBuckets -> 2) @@ -65,7 +65,12 @@ private[feature] trait QuantileDiscretizerBase extends Params /** * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned - * categorical features. The number of bins can be set using the `numBuckets` parameter. + * categorical features. The number of bins can be set using the `numBuckets` parameter. It is + * possible that the number of buckets used will be less than this value, for example, if there + * are too few distinct values of the input to create enough distinct quantiles. Note also that + * NaN values are handled specially and placed into their own bucket. For example, if 4 buckets + * are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in a special + * bucket(4). * The bin ranges are chosen using an approximate algorithm (see the documentation for * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] * for a detailed description). The precision of the approximation can be controlled with the diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala index cd10c78311..c7f5093e74 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketizerSuite.scala @@ -88,6 +88,37 @@ class BucketizerSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } } + test("Bucket continuous features, with NaN data but non-NaN splits") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity) + val validData = Array(-0.9, -0.5, -0.3, 0.0, 0.2, 0.5, 0.9, Double.NaN, Double.NaN, Double.NaN) + val expectedBuckets = Array(0.0, 1.0, 1.0, 2.0, 2.0, 3.0, 3.0, 4.0, 4.0, 4.0) + val dataFrame: DataFrame = + spark.createDataFrame(validData.zip(expectedBuckets)).toDF("feature", "expected") + + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + + bucketizer.transform(dataFrame).select("result", "expected").collect().foreach { + case Row(x: Double, y: Double) => + assert(x === y, + s"The feature value is not correct after bucketing. Expected $y but found $x") + } + } + + test("Bucket continuous features, with NaN splits") { + val splits = Array(Double.NegativeInfinity, -0.5, 0.0, 0.5, Double.PositiveInfinity, Double.NaN) + withClue("Invalid NaN split was not caught as an invalid split!") { + intercept[IllegalArgumentException] { + val bucketizer: Bucketizer = new Bucketizer() + .setInputCol("feature") + .setOutputCol("result") + .setSplits(splits) + } + } + } + test("Binary search correctness on hand-picked examples") { import BucketizerSuite.checkBinarySearch // length 3, with -inf diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala index 18f1e89ee8..6822594044 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/QuantileDiscretizerSuite.scala @@ -52,12 +52,12 @@ class QuantileDiscretizerSuite "Bucket sizes are not within expected relative error tolerance.") } - test("Test Bucketizer on duplicated splits") { + test("Test on data with high proportion of duplicated values") { val spark = this.spark import spark.implicits._ - val datasetSize = 12 val numBuckets = 5 + val expectedNumBuckets = 3 val df = sc.parallelize(Array(1.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 2.0, 2.0, 2.0, 1.0, 3.0)) .map(Tuple1.apply).toDF("input") val discretizer = new QuantileDiscretizer() @@ -65,10 +65,31 @@ class QuantileDiscretizerSuite .setOutputCol("result") .setNumBuckets(numBuckets) val result = discretizer.fit(df).transform(df) + val observedNumBuckets = result.select("result").distinct.count + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") + } + + test("Test transform on data with NaN value") { + val spark = this.spark + import spark.implicits._ + + val numBuckets = 3 + val df = sc.parallelize(Array(1.0, 1.0, 1.0, Double.NaN)) + .map(Tuple1.apply).toDF("input") + val discretizer = new QuantileDiscretizer() + .setInputCol("input") + .setOutputCol("result") + .setNumBuckets(numBuckets) + // Reserve extra one bucket for NaN + val expectedNumBuckets = discretizer.fit(df).getSplits.length - 1 + val result = discretizer.fit(df).transform(df) val observedNumBuckets = result.select("result").distinct.count - assert(2 <= observedNumBuckets && observedNumBuckets <= numBuckets, - "Observed number of buckets are not within expected range.") + assert(observedNumBuckets == expectedNumBuckets, + s"Observed number of buckets are not correct." + + s" Expected $expectedNumBuckets but found $observedNumBuckets") } test("Test transform method on unseen data") { diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 2881380152..c45434f1a5 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -1155,6 +1155,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadab `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned categorical features. The number of bins can be set using the :py:attr:`numBuckets` parameter. + It is possible that the number of buckets used will be less than this value, for example, if + there are too few distinct values of the input to create enough distinct quantiles. Note also + that NaN values are handled specially and placed into their own bucket. For example, if 4 + buckets are used, then non-NaN data will be put into buckets(0-3), but NaNs will be counted in + a special bucket(4). The bin ranges are chosen using an approximate algorithm (see the documentation for :py:meth:`~.DataFrameStatFunctions.approxQuantile` for a detailed description). The precision of the approximation can be controlled with the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala index 1855eab96e..d69be36917 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala @@ -52,6 +52,7 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient * Online Computation of Quantile Summaries]] by Greenwald and Khanna. * + * Note that NaN values will be removed from the numerical column before calculation * @param col the name of the numerical column * @param probabilities a list of quantile probabilities * Each number must belong to [0, 1]. @@ -67,7 +68,8 @@ final class DataFrameStatFunctions private[sql](df: DataFrame) { col: String, probabilities: Array[Double], relativeError: Double): Array[Double] = { - StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray + StatFunctions.multipleApproxQuantiles(df.select(col).na.drop(), + Seq(col), probabilities, relativeError).head.toArray } /** |