aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorTimothy Hunter <timhunter@databricks.com>2016-02-23 15:31:17 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-23 15:31:17 -0800
commit15e30155631d52e35ab8522584027ab350e5acb3 (patch)
treef7678f8ba2c87de7900806fdb8531b474b3f0a2d /sql
parent9cdd867da978629ea2f61f94e3c346fa0bfecf0e (diff)
downloadspark-15e30155631d52e35ab8522584027ab350e5acb3.tar.gz
spark-15e30155631d52e35ab8522584027ab350e5acb3.tar.bz2
spark-15e30155631d52e35ab8522584027ab350e5acb3.zip
[SPARK-6761][SQL][ML] Fixes to API and documentation of approximate quantiles
## What changes were proposed in this pull request? This continues thunterdb 's work on `approxQuantile` API. It changes the signature of `approxQuantile` from `(col: String, quantile: Double, epsilon: Double): Double` to `(col: String, probabilities: Array[Double], relativeError: Double): Array[Double]` and update API doc. It also improves the error message in tests and simplifies the merge algorithm for summaries. ## How was the this patch tested? Use the same unit tests as before. Closes #11325 Author: Timothy Hunter <timhunter@databricks.com> Author: Xiangrui Meng <meng@databricks.com> Closes #11332 from mengxr/SPARK-6761.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala180
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala42
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala12
4 files changed, 150 insertions, 120 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
index 7f110c4e7f..39a31ab028 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameStatFunctions.scala
@@ -37,13 +37,37 @@ import org.apache.spark.util.sketch.{BloomFilter, CountMinSketch}
final class DataFrameStatFunctions private[sql](df: DataFrame) {
/**
- * Calculate the approximate quantile of numerical column of a DataFrame.
- * @param col the name of the column
- * @param quantile the quantile number
- * @return the approximate quantile
+ * Calculates the approximate quantiles of a numerical column of a DataFrame.
+ *
+ * The result of this algorithm has the following deterministic bound:
+ * If the DataFrame has N elements and if we request the quantile at probability `p` up to error
+ * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank
+ * of `x` is close to (p * N).
+ * More precisely,
+ *
+ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
+ *
+ * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+ * optimizations).
+ * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient
+ * Online Computation of Quantile Summaries]] by Greenwald and Khanna.
+ *
+ * @param col the name of the numerical column
+ * @param probabilities a list of quantile probabilities
+ * Each number must belong to [0, 1].
+ * For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError The relative target precision to achieve (>= 0).
+ * If set to zero, the exact quantiles are computed, which could be very expensive.
+ * Note that values greater than 1 are accepted but give the same result as 1.
+ * @return the approximate quantiles at the given probabilities
+ *
+ * @since 2.0.0
*/
- def approxQuantile(col: String, quantile: Double, epsilon: Double): Double = {
- StatFunctions.approxQuantile(df, col, quantile, epsilon)
+ def approxQuantile(
+ col: String,
+ probabilities: Array[Double],
+ relativeError: Double): Array[Double] = {
+ StatFunctions.multipleApproxQuantiles(df, Seq(col), probabilities, relativeError).head.toArray
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
index eb056d555b..26e4eda542 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql.execution.stat
-import scala.annotation.tailrec
import scala.collection.mutable.ArrayBuffer
import org.apache.spark.Logging
@@ -33,59 +32,37 @@ private[sql] object StatFunctions extends Logging {
import QuantileSummaries.Stats
/**
- * Calculates the approximate quantile for the given column.
- *
- * If you need to compute multiple quantiles at once, you should use [[multipleApproxQuantiles]]
- *
- * Note on the target error.
+ * Calculates the approximate quantiles of multiple numerical columns of a DataFrame in one pass.
*
* The result of this algorithm has the following deterministic bound:
- * if the DataFrame has N elements and if we request the quantile `phi` up to error `epsi`,
- * then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank
- * of `x` close to (phi * N). More precisely:
- *
- * floor((phi - epsi) * N) <= rank(x) <= ceil((phi + epsi) * N)
- *
- * Note on the algorithm used.
+ * If the DataFrame has N elements and if we request the quantile at probability `p` up to error
+ * `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank
+ * of `x` is close to (p * N).
+ * More precisely,
*
- * This method implements a variation of the Greenwald-Khanna algorithm
- * (with some speed optimizations). The algorithm was first present in the following article:
- * "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael
- * and Khanna, Sanjeev. (http://dl.acm.org/citation.cfm?id=375670)
+ * floor((p - err) * N) <= rank(x) <= ceil((p + err) * N).
*
- * The performance optimizations are detailed in the comments of the implementation.
- *
- * @param df the dataframe to estimate quantiles on
- * @param col the name of the column
- * @param quantile the target quantile of interest
- * @param epsilon the target error. Should be >= 0.
- * */
- def approxQuantile(
- df: DataFrame,
- col: String,
- quantile: Double,
- epsilon: Double = QuantileSummaries.defaultEpsilon): Double = {
- require(quantile >= 0.0 && quantile <= 1.0, "Quantile must be in the range of (0.0, 1.0).")
- val Seq(Seq(res)) = multipleApproxQuantiles(df, Seq(col), Seq(quantile), epsilon)
- res
- }
-
- /**
- * Runs multiple quantile computations in a single pass, with the same target error.
- *
- * See [[approxQuantile)]] for more details on the approximation guarantees.
+ * This method implements a variation of the Greenwald-Khanna algorithm (with some speed
+ * optimizations).
+ * The algorithm was first present in [[http://dx.doi.org/10.1145/375663.375670 Space-efficient
+ * Online Computation of Quantile Summaries]] by Greenwald and Khanna.
*
* @param df the dataframe
- * @param cols columns of the dataframe
- * @param quantiles target quantiles to compute
- * @param epsilon the precision to achieve
+ * @param cols numerical columns of the dataframe
+ * @param probabilities a list of quantile probabilities
+ * Each number must belong to [0, 1].
+ * For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
+ * @param relativeError The relative target precision to achieve (>= 0).
+ * If set to zero, the exact quantiles are computed, which could be very expensive.
+ * Note that values greater than 1 are accepted but give the same result as 1.
+ *
* @return for each column, returns the requested approximations
*/
def multipleApproxQuantiles(
df: DataFrame,
cols: Seq[String],
- quantiles: Seq[Double],
- epsilon: Double): Seq[Seq[Double]] = {
+ probabilities: Seq[Double],
+ relativeError: Double): Seq[Seq[Double]] = {
val columns: Seq[Column] = cols.map { colName =>
val field = df.schema(colName)
require(field.dataType.isInstanceOf[NumericType],
@@ -94,7 +71,7 @@ private[sql] object StatFunctions extends Logging {
Column(Cast(Column(colName).expr, DoubleType))
}
val emptySummaries = Array.fill(cols.size)(
- new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, epsilon))
+ new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError))
// Note that it works more or less by accident as `rdd.aggregate` is not a pure function:
// this function returns the same array as given in the input (because `aggregate` reuses
@@ -115,40 +92,49 @@ private[sql] object StatFunctions extends Logging {
}
val summaries = df.select(columns: _*).rdd.aggregate(emptySummaries)(apply, merge)
- summaries.map { summary => quantiles.map(summary.query) }
+ summaries.map { summary => probabilities.map(summary.query) }
}
/**
* Helper class to compute approximate quantile summary.
* This implementation is based on the algorithm proposed in the paper:
* "Space-efficient Online Computation of Quantile Summaries" by Greenwald, Michael
- * and Khanna, Sanjeev. (http://dl.acm.org/citation.cfm?id=375670)
+ * and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)
*
* In order to optimize for speed, it maintains an internal buffer of the last seen samples,
* and only inserts them after crossing a certain size threshold. This guarantees a near-constant
* runtime complexity compared to the original algorithm.
*
- * @param compressThreshold the compression threshold: after the internal buffer of statistics
- * crosses this size, it attempts to compress the statistics together
- * @param epsilon the target precision
- * @param sampled a buffer of quantile statistics. See the G-K article for more details
+ * @param compressThreshold the compression threshold.
+ * After the internal buffer of statistics crosses this size, it attempts to compress the
+ * statistics together.
+ * @param relativeError the target relative error.
+ * It is uniform across the complete range of values.
+ * @param sampled a buffer of quantile statistics.
+ * See the G-K article for more details.
* @param count the count of all the elements *inserted in the sampled buffer*
* (excluding the head buffer)
* @param headSampled a buffer of latest samples seen so far
*/
class QuantileSummaries(
val compressThreshold: Int,
- val epsilon: Double,
+ val relativeError: Double,
val sampled: ArrayBuffer[Stats] = ArrayBuffer.empty,
private[stat] var count: Long = 0L,
val headSampled: ArrayBuffer[Double] = ArrayBuffer.empty) extends Serializable {
import QuantileSummaries._
+ /**
+ * Returns a summary with the given observation inserted into the summary.
+ * This method may either modify in place the current summary (and return the same summary,
+ * modified in place), or it may create a new summary from scratch it necessary.
+ * @param x the new observation to insert into the summary
+ */
def insert(x: Double): QuantileSummaries = {
headSampled.append(x)
if (headSampled.size >= defaultHeadSize) {
- this.withHeadInserted
+ this.withHeadBufferInserted
} else {
this
}
@@ -162,7 +148,7 @@ private[sql] object StatFunctions extends Logging {
*
* @return a new quantile summary object.
*/
- private def withHeadInserted: QuantileSummaries = {
+ private def withHeadBufferInserted: QuantileSummaries = {
if (headSampled.isEmpty) {
return this
}
@@ -187,7 +173,7 @@ private[sql] object StatFunctions extends Logging {
if (newSamples.isEmpty || (sampleIdx == sampled.size && opsIdx == sorted.length - 1)) {
0
} else {
- math.floor(2 * epsilon * currentCount).toInt
+ math.floor(2 * relativeError * currentCount).toInt
}
val tuple = Stats(currentSample, 1, delta)
@@ -200,67 +186,80 @@ private[sql] object StatFunctions extends Logging {
newSamples.append(sampled(sampleIdx))
sampleIdx += 1
}
- new QuantileSummaries(compressThreshold, epsilon, newSamples, currentCount)
+ new QuantileSummaries(compressThreshold, relativeError, newSamples, currentCount)
}
+ /**
+ * Returns a new summary that compresses the summary statistics and the head buffer.
+ *
+ * This implements the COMPRESS function of the GK algorithm. It does not modify the object.
+ *
+ * @return a new summary object with compressed statistics
+ */
def compress(): QuantileSummaries = {
// Inserts all the elements first
- val inserted = this.withHeadInserted
+ val inserted = this.withHeadBufferInserted
assert(inserted.headSampled.isEmpty)
assert(inserted.count == count + headSampled.size)
val compressed =
- compressImmut(inserted.sampled, mergeThreshold = 2 * epsilon * inserted.count)
- new QuantileSummaries(compressThreshold, epsilon, compressed, inserted.count)
+ compressImmut(inserted.sampled, mergeThreshold = 2 * relativeError * inserted.count)
+ new QuantileSummaries(compressThreshold, relativeError, compressed, inserted.count)
}
+ private def shallowCopy: QuantileSummaries = {
+ new QuantileSummaries(compressThreshold, relativeError, sampled, count, headSampled)
+ }
+
+ /**
+ * Merges two (compressed) summaries together.
+ *
+ * Returns a new summary.
+ */
def merge(other: QuantileSummaries): QuantileSummaries = {
+ require(headSampled.isEmpty, "Current buffer needs to be compressed before merge")
+ require(other.headSampled.isEmpty, "Other buffer needs to be compressed before merge")
if (other.count == 0) {
- this
+ this.shallowCopy
} else if (count == 0) {
- other
+ other.shallowCopy
} else {
- // We rely on the fact that they are ordered to efficiently interleave them.
- val thisSampled = sampled.toList
- val otherSampled = other.sampled.toList
- val res: ArrayBuffer[Stats] = ArrayBuffer.empty
-
- @tailrec
- def mergeCurrent(
- thisList: List[Stats],
- otherList: List[Stats]): Unit = (thisList, otherList) match {
- case (Nil, l) =>
- res.appendAll(l)
- case (l, Nil) =>
- res.appendAll(l)
- case (h1 :: t1, h2 :: t2) if h1.value > h2.value =>
- mergeCurrent(otherList, thisList)
- case (h1 :: t1, l) =>
- // We know that h1.value <= all values in l
- // TODO(thunterdb) do we need to adjust g and delta?
- res.append(h1)
- mergeCurrent(t1, l)
- }
-
- mergeCurrent(thisSampled, otherSampled)
- val comp = compressImmut(res, mergeThreshold = 2 * epsilon * count)
- new QuantileSummaries(other.compressThreshold, other.epsilon, comp, other.count + count)
+ // Merge the two buffers.
+ // The GK algorithm is a bit unclear about it, but it seems there is no need to adjust the
+ // statistics during the merging: the invariants are still respected after the merge.
+ // TODO: could replace full sort by ordered merge, the two lists are known to be sorted
+ // already.
+ val res = (sampled ++ other.sampled).sortBy(_.value)
+ val comp = compressImmut(res, mergeThreshold = 2 * relativeError * count)
+ new QuantileSummaries(
+ other.compressThreshold, other.relativeError, comp, other.count + count)
}
}
+ /**
+ * Runs a query for a given quantile.
+ * The result follows the approximation guarantees detailed above.
+ * The query can only be run on a compressed summary: you need to call compress() before using
+ * it.
+ *
+ * @param quantile the target quantile
+ * @return
+ */
def query(quantile: Double): Double = {
require(quantile >= 0 && quantile <= 1.0, "quantile should be in the range [0.0, 1.0]")
+ require(headSampled.isEmpty,
+ "Cannot operate on an uncompressed summary, call compress() first")
- if (quantile <= epsilon) {
+ if (quantile <= relativeError) {
return sampled.head.value
}
- if (quantile >= 1 - epsilon) {
+ if (quantile >= 1 - relativeError) {
return sampled.last.value
}
// Target rank
val rank = math.ceil(quantile * count).toInt
- val targetError = math.ceil(epsilon * count)
+ val targetError = math.ceil(relativeError * count)
// Minimum rank at current sample
var minRank = 0
var i = 1
@@ -291,9 +290,10 @@ private[sql] object StatFunctions extends Logging {
val defaultHeadSize: Int = 50000
/**
- * The default value for epsilon.
+ * The default value for the relative error (1%).
+ * With this value, the best extreme percentiles that can be approximated are 1% and 99%.
*/
- val defaultEpsilon: Double = 0.01
+ val defaultRelativeError: Double = 0.01
/**
* Statisttics from the Greenwald-Khanna paper.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 7f9229244b..e865dbe6b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -126,22 +126,29 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
}
test("approximate quantile") {
- val df = Seq.tabulate(1000)(i => (i, 2.0 * i)).toDF("singles", "doubles")
-
- val expected_1 = 500.0
- val expected_2 = 1600.0
+ val n = 1000
+ val df = Seq.tabulate(n)(i => (i, 2.0 * i)).toDF("singles", "doubles")
+ val q1 = 0.5
+ val q2 = 0.8
val epsilons = List(0.1, 0.05, 0.001)
for (epsilon <- epsilons) {
- val result1 = df.stat.approxQuantile("singles", 0.5, epsilon)
- val result2 = df.stat.approxQuantile("doubles", 0.8, epsilon)
-
- val error_1 = 2 * 1000 * epsilon
- val error_2 = 2 * 2000 * epsilon
-
- assert(math.abs(result1 - expected_1) < error_1)
- assert(math.abs(result2 - expected_2) < error_2)
+ val Array(single1) = df.stat.approxQuantile("singles", Array(q1), epsilon)
+ val Array(double2) = df.stat.approxQuantile("doubles", Array(q2), epsilon)
+ // Also make sure there is no regression by computing multiple quantiles at once.
+ val Array(d1, d2) = df.stat.approxQuantile("doubles", Array(q1, q2), epsilon)
+ val Array(s1, s2) = df.stat.approxQuantile("singles", Array(q1, q2), epsilon)
+
+ val error_single = 2 * 1000 * epsilon
+ val error_double = 2 * 2000 * epsilon
+
+ assert(math.abs(single1 - q1 * n) < error_single)
+ assert(math.abs(double2 - 2 * q2 * n) < error_double)
+ assert(math.abs(s1 - q1 * n) < error_single)
+ assert(math.abs(s2 - q2 * n) < error_single)
+ assert(math.abs(d1 - 2 * q1 * n) < error_double)
+ assert(math.abs(d2 - 2 * q2 * n) < error_double)
}
}
@@ -296,9 +303,9 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext {
class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Logging {
// Turn on this test if you want to test the performance of approximate quantiles.
- ignore("describe() should not be slowed down too much by quantiles") {
+ ignore("computing quantiles should not take much longer than describe()") {
val df = sqlContext.range(5000000L).toDF("col1").cache()
- def millis(f: => Any): Double = {
+ def seconds(f: => Any): Double = {
// Do some warmup
logDebug("warmup...")
for (i <- 1 to 10) {
@@ -314,15 +321,14 @@ class DataFrameStatPerfSuite extends QueryTest with SharedSQLContext with Loggin
(end - start) / 1e9
}
logDebug("execute done")
- times.sum.toDouble / times.length.toDouble
-
+ times.sum / times.length.toDouble
}
logDebug("*** Normal describe ***")
- val t1 = millis { df.describe() }
+ val t1 = seconds { df.describe() }
logDebug(s"T1 = $t1")
logDebug("*** Just quantiles ***")
- val t2 = millis {
+ val t2 = seconds {
StatFunctions.multipleApproxQuantiles(df, Seq("col1"), Seq(0.1, 0.25, 0.5, 0.75, 0.9), 0.01)
}
logDebug(s"T1 = $t1, T2 = $t2")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala
index 6992b4c723..0a989d026c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/stat/ApproxQuantileSuite.scala
@@ -46,12 +46,12 @@ class ApproxQuantileSuite extends SparkFunSuite {
val approx = summary.query(quant)
// The rank of the approximation.
val rank = data.count(_ < approx) // has to be <, not <= to be exact
- val lower = math.floor((quant - summary.epsilon) * data.size)
- assert(rank >= lower,
- s"approx_rank: $rank ! >= $lower, requested quantile = $quant")
- val upper = math.ceil((quant + summary.epsilon) * data.size)
- assert(rank <= upper,
- s"approx_rank: $rank ! <= $upper, requested quantile = $quant")
+ val lower = math.floor((quant - summary.relativeError) * data.size)
+ val upper = math.ceil((quant + summary.relativeError) * data.size)
+ val msg =
+ s"$rank not in [$lower $upper], requested quantile: $quant, approx returned: $approx"
+ assert(rank >= lower, msg)
+ assert(rank <= upper, msg)
}
for {