aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala124
1 files changed, 29 insertions, 95 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index e486e92c12..5c7993af64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -23,10 +23,10 @@ import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.attribute.NominalAttribute
-import org.apache.spark.ml.param.{IntParam, _}
+import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol, HasSeed}
import org.apache.spark.ml.util._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{Dataset, Row}
import org.apache.spark.sql.types.{DoubleType, StructType}
import org.apache.spark.util.random.XORShiftRandom
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
- * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+ * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
* default: 2
* @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
+
+ /**
+ * Relative error (see documentation for
+ * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+ * Must be a number in [0, 1].
+ * default: 0.001
+ * @group param
+ */
+ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+ "for approxQuantile",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(relativeError -> 0.001)
+
+ /** @group getParam */
+ def getRelativeError: Double = getOrDefault(relativeError)
}
/**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
@@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
+ def setRelativeError(value: Double): this.type = set(relativeError, value)
+
+ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
/** @group setParam */
@@ -87,12 +104,13 @@ final class QuantileDiscretizer(override val uid: String)
StructType(outputFields)
}
- override def fit(dataset: DataFrame): Bucketizer = {
- val samples = QuantileDiscretizer
- .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
- .map { case Row(feature: Double) => feature }
- val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
- val splits = QuantileDiscretizer.getSplits(candidates)
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): Bucketizer = {
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ splits(0) = Double.NegativeInfinity
+ splits(splits.length - 1) = Double.PositiveInfinity
+
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this))
}
@@ -103,90 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
- /**
- * Minimum number of samples required for finding splits, regardless of number of bins. If
- * the dataset has fewer rows than this value, the entire dataset will be used.
- */
- private[spark] val minSamplesRequired: Int = 10000
-
- /**
- * Sampling from the given dataset to collect quantile statistics.
- */
- private[feature] def getSampledInput(dataset: DataFrame, numBins: Int, seed: Long): Array[Row] = {
- val totalSamples = dataset.count()
- require(totalSamples > 0,
- "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
- val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
- dataset.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
- }
-
- /**
- * Compute split points with respect to the sample distribution.
- */
- private[feature]
- def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
- val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
- }
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
- val possibleSplits = valueCounts.length - 1
- if (possibleSplits <= numSplits) {
- valueCounts.dropRight(1).map(_._1)
- } else {
- val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
- val splitsBuilder = mutable.ArrayBuilder.make[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`. If `currentCount` is closest value to
- // `targetCount`, then current value is a split threshold. After finding a split threshold,
- // `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount makes the gap between currentCount and
- // targetCount smaller, previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
- splitsBuilder.result()
- }
- }
-
- /**
- * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
- * needed, and adding a default split value of 0 if no good candidates are found.
- */
- private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
- val effectiveValues = if (candidates.nonEmpty) {
- if (candidates.head == Double.NegativeInfinity
- && candidates.last == Double.PositiveInfinity) {
- candidates.drop(1).dropRight(1)
- } else if (candidates.head == Double.NegativeInfinity) {
- candidates.drop(1)
- } else if (candidates.last == Double.PositiveInfinity) {
- candidates.dropRight(1)
- } else {
- candidates
- }
- } else {
- candidates
- }
-
- if (effectiveValues.isEmpty) {
- Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
- } else {
- Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
- }
- }
-
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}