aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorOliver Pierson <ocp@gatech.edu>2016-04-11 12:02:48 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-11 12:02:48 -0700
commit89a41c5b7a3f727b44a7f615a1352ca006d12f73 (patch)
tree1c59e13c4fe03bbb0c5717f6c08311a2d2648da2 /mllib/src/main
parent2dacc81ec31233e558855a26340ad4662d470387 (diff)
downloadspark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.gz
spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.tar.bz2
spark-89a41c5b7a3f727b44a7f615a1352ca006d12f73.zip
[SPARK-13600][MLLIB] Use approxQuantile from DataFrame stats in QuantileDiscretizer
## What changes were proposed in this pull request? QuantileDiscretizer can return an unexpected number of buckets in certain cases. This PR proposes to fix this issue and also refactor QuantileDiscretizer to use approxQuantiles from DataFrame stats functions. ## How was this patch tested? QuantileDiscretizerSuite unit tests (some existing tests will change or even be removed in this PR) Author: Oliver Pierson <ocp@gatech.edu> Closes #11553 from oliverpierson/SPARK-13600.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala119
1 files changed, 25 insertions, 94 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
index efe8b93d82..5c7993af64 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/QuantileDiscretizer.scala
@@ -37,7 +37,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
with HasInputCol with HasOutputCol with HasSeed {
/**
- * Maximum number of buckets (quantiles, or categories) into which data points are grouped. Must
+ * Number of buckets (quantiles, or categories) into which data points are grouped. Must
* be >= 2.
* default: 2
* @group param
@@ -49,6 +49,21 @@ private[feature] trait QuantileDiscretizerBase extends Params
/** @group getParam */
def getNumBuckets: Int = getOrDefault(numBuckets)
+
+ /**
+ * Relative error (see documentation for
+ * [[org.apache.spark.sql.DataFrameStatFunctions.approxQuantile approxQuantile]] for description)
+ * Must be a number in [0, 1].
+ * default: 0.001
+ * @group param
+ */
+ val relativeError = new DoubleParam(this, "relativeError", "The relative target precision " +
+ "for approxQuantile",
+ ParamValidators.inRange(0.0, 1.0))
+ setDefault(relativeError -> 0.001)
+
+ /** @group getParam */
+ def getRelativeError: Double = getOrDefault(relativeError)
}
/**
@@ -56,8 +71,7 @@ private[feature] trait QuantileDiscretizerBase extends Params
* `QuantileDiscretizer` takes a column with continuous features and outputs a column with binned
* categorical features. The bin ranges are chosen by taking a sample of the data and dividing it
* into roughly equal parts. The lower and upper bin bounds will be -Infinity and +Infinity,
- * covering all real values. This attempts to find numBuckets partitions based on a sample of data,
- * but it may find fewer depending on the data sample values.
+ * covering all real values.
*/
@Experimental
final class QuantileDiscretizer(override val uid: String)
@@ -66,6 +80,9 @@ final class QuantileDiscretizer(override val uid: String)
def this() = this(Identifiable.randomUID("quantileDiscretizer"))
/** @group setParam */
+ def setRelativeError(value: Double): this.type = set(relativeError, value)
+
+ /** @group setParam */
def setNumBuckets(value: Int): this.type = set(numBuckets, value)
/** @group setParam */
@@ -89,11 +106,11 @@ final class QuantileDiscretizer(override val uid: String)
@Since("2.0.0")
override def fit(dataset: Dataset[_]): Bucketizer = {
- val samples = QuantileDiscretizer
- .getSampledInput(dataset.select($(inputCol)), $(numBuckets), $(seed))
- .map { case Row(feature: Double) => feature }
- val candidates = QuantileDiscretizer.findSplitCandidates(samples, $(numBuckets) - 1)
- val splits = QuantileDiscretizer.getSplits(candidates)
+ val splits = dataset.stat.approxQuantile($(inputCol),
+ (0.0 to 1.0 by 1.0/$(numBuckets)).toArray, $(relativeError))
+ splits(0) = Double.NegativeInfinity
+ splits(splits.length - 1) = Double.PositiveInfinity
+
val bucketizer = new Bucketizer(uid).setSplits(splits)
copyValues(bucketizer.setParent(this))
}
@@ -104,92 +121,6 @@ final class QuantileDiscretizer(override val uid: String)
@Since("1.6.0")
object QuantileDiscretizer extends DefaultParamsReadable[QuantileDiscretizer] with Logging {
- /**
- * Minimum number of samples required for finding splits, regardless of number of bins. If
- * the dataset has fewer rows than this value, the entire dataset will be used.
- */
- private[spark] val minSamplesRequired: Int = 10000
-
- /**
- * Sampling from the given dataset to collect quantile statistics.
- */
- private[feature]
- def getSampledInput(dataset: Dataset[_], numBins: Int, seed: Long): Array[Row] = {
- val totalSamples = dataset.count()
- require(totalSamples > 0,
- "QuantileDiscretizer requires non-empty input dataset but was given an empty input.")
- val requiredSamples = math.max(numBins * numBins, minSamplesRequired)
- val fraction = math.min(requiredSamples.toDouble / totalSamples, 1.0)
- dataset.toDF.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
- .collect()
- }
-
- /**
- * Compute split points with respect to the sample distribution.
- */
- private[feature]
- def findSplitCandidates(samples: Array[Double], numSplits: Int): Array[Double] = {
- val valueCountMap = samples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
- }
- val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray ++ Array((Double.MaxValue, 1))
- val possibleSplits = valueCounts.length - 1
- if (possibleSplits <= numSplits) {
- valueCounts.dropRight(1).map(_._1)
- } else {
- val stride: Double = math.ceil(samples.length.toDouble / (numSplits + 1))
- val splitsBuilder = mutable.ArrayBuilder.make[Double]
- var index = 1
- // currentCount: sum of counts of values that have been visited
- var currentCount = valueCounts(0)._2
- // targetCount: target value for `currentCount`. If `currentCount` is closest value to
- // `targetCount`, then current value is a split threshold. After finding a split threshold,
- // `targetCount` is added by stride.
- var targetCount = stride
- while (index < valueCounts.length) {
- val previousCount = currentCount
- currentCount += valueCounts(index)._2
- val previousGap = math.abs(previousCount - targetCount)
- val currentGap = math.abs(currentCount - targetCount)
- // If adding count of current value to currentCount makes the gap between currentCount and
- // targetCount smaller, previous value is a split threshold.
- if (previousGap < currentGap) {
- splitsBuilder += valueCounts(index - 1)._1
- targetCount += stride
- }
- index += 1
- }
- splitsBuilder.result()
- }
- }
-
- /**
- * Adjust split candidates to proper splits by: adding positive/negative infinity to both sides as
- * needed, and adding a default split value of 0 if no good candidates are found.
- */
- private[feature] def getSplits(candidates: Array[Double]): Array[Double] = {
- val effectiveValues = if (candidates.nonEmpty) {
- if (candidates.head == Double.NegativeInfinity
- && candidates.last == Double.PositiveInfinity) {
- candidates.drop(1).dropRight(1)
- } else if (candidates.head == Double.NegativeInfinity) {
- candidates.drop(1)
- } else if (candidates.last == Double.PositiveInfinity) {
- candidates.dropRight(1)
- } else {
- candidates
- }
- } else {
- candidates
- }
-
- if (effectiveValues.isEmpty) {
- Array(Double.NegativeInfinity, 0, Double.PositiveInfinity)
- } else {
- Array(Double.NegativeInfinity) ++ effectiveValues ++ Array(Double.PositiveInfinity)
- }
- }
-
@Since("1.6.0")
override def load(path: String): QuantileDiscretizer = super.load(path)
}