diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala | 124 |
1 files changed, 29 insertions, 95 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala index e486e92c12..5c7993af64 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala @@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.attribute.NominalAttribute -import org.apache.spark.ml.param.{IntParam, _} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed} import org.apache.spark.ml.util._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{Dataset, Row} import org.apache.spark.sql.types.{DoubleType, StructType} import org.apache.spark.util.random.XORShiftRandom @@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params with HasInputCol with HasOutputCol with HasSeed { /** - * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must + * Number of buckets (quantiles, or categories) into which data points are grouped. Must * be >= 2. * default: 2 * @group param @@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params /** @group getParam */ def getNumBuckets: Int = getOrDefault(numBuckets) + + /** + * Relative error (see documentation for + * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description) + * Must be a number in [0, 1]. + * default: 0.001 + * @group param + */ + val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " + + "for approxQuantile", + ParamValidators.inRange(0.0, 1.0)) + setDefault(relativeError -> 0.001) + + /** @group getParam */ + def getRelativeError: Double = getOrDefault(relativeError) } /** @@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params * `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned * categorical features. The bin ranges are chosen by taking a sample of the data and dividing it * into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity, - * covering all real values. This attempts to find numBuckets partitions based on a sample of data, - * but it may find fewer depending on the data sample values. + * covering all real values. */ @Experimental final class QuantileDiscretizer(override val uid: String) @@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String) def this() = this(Identifiable.randomUID("quantileDiscretizer")) /** @group setParam */ + def setRelativeError(value: Double): this.type = set(relativeError, value) + + /** @group setParam */ def setNumBuckets(value: Int): this.type = set(numBuckets, value) /** @group setParam */ @@ -87,12 +104,13 @@ final class QuantileDiscretizer(override val uid: String) StructType(outputFields) } - override def fit(dataset: DataFrame): Bucketizer = { - val samples = QuantileDiscretizer - .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed)) - .map { case Row(feature: Double) => feature } - val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1) - val splits = QuantileDiscretizer.getSplits(candidates) + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Bucketizer = { + val splits = dataset.stat.approxQuantile($(inputCol), + (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError)) + splits(0) = Double.NegativeInfinity + splits(splits.length - 1) = Double.PositiveInfinity + val bucketizer = new Bucketizer(uid).setSplits(splits) copyValues(bucketizer.setParent(this)) } @@ -103,90 +121,6 @@ final class QuantileDiscretizer(override val uid: String) @Since("1.6.0") object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging { - /** - * Minimum number of samples required for finding splits, regardless of number of bins. If - * the dataset has fewer rows than this value, the entire dataset will be used. - */ - private[spark] val minSamplesRequired: Int = 10000 - - /** - * Sampling from the given dataset to collect quantile statistics. - */ - private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = { - val totalSamples = dataset.count() - require(totalSamples > 0, - "QuantileDiscretizer requires non-empty input dataset but was given an empty input.") - val requiredSamples = math.max(numBins * numBins, minSamplesRequired) - val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0) - dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() - } - - /** - * Compute split points with respect to the sample distribution. - */ - private[feature] - def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = { - val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) => - m + ((x, m.getOrElse(x, 0) + 1)) - } - val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1)) - val possibleSplits = valueCounts.length - 1 - if (possibleSplits <= numSplits) { - valueCounts.dropRight(1).map(_._1) - } else { - val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1)) - val splitsBuilder = mutable.ArrayBuilder.make[Double] - var index = 1 - // currentCount: sum of counts of values that have been visited - var currentCount = valueCounts(0)._2 - // targetCount: target value for `currentCount`. If `currentCount` is closest value to - // `targetCount`, then current value is a split threshold. After finding a split threshold, - // `targetCount` is added by stride. - var targetCount = stride - while (index < valueCounts.length) { - val previousCount = currentCount - currentCount += valueCounts(index)._2 - val previousGap = math.abs(previousCount - targetCount) - val currentGap = math.abs(currentCount - targetCount) - // If adding count of current value to currentCount makes the gap between currentCount and - // targetCount smaller, previous value is a split threshold. - if (previousGap < currentGap) { - splitsBuilder += valueCounts(index - 1)._1 - targetCount += stride - } - index += 1 - } - splitsBuilder.result() - } - } - - /** - * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as - * needed, and adding a default split value of 0 if no good candidates are found. - */ - private[feature] def getSplits(candidates: Array[Double]): Array[Double] = { - val effectiveValues = if (candidates.nonEmpty) { - if (candidates.head == Double.NegativeInfinity - && candidates.last == Double.PositiveInfinity) { - candidates.drop(1).dropRight(1) - } else if (candidates.head == Double.NegativeInfinity) { - candidates.drop(1) - } else if (candidates.last == Double.PositiveInfinity) { - candidates.dropRight(1) - } else { - candidates - } - } else { - candidates - } - - if (effectiveValues.isEmpty) { - Array(Double.NegativeInfinity, 0, Double.PositiveInfinity) - } else { - Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity) - } - } - @Since("1.6.0") override def load(path: String): QuantileDiscretizer = super.load(path) } |