aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-09-01 16:31:13 +0800
committerWenchen Fan <wenchen@databricks.com>2016-09-01 16:31:13 +0800
commita18c169fd050e71fdb07b153ae0fa5c410d8de27 (patch)
treebe693af0f087329fd457ad0c238ffb2a3a9f8bab /sql/catalyst/src/main/scala/org
parent536fa911c181958d84f14156f7d57ef5fd68df48 (diff)
downloadspark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.tar.gz
spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.tar.bz2
spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.zip
[SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation.
## What changes were proposed in this pull request? This PR implements aggregation function `percentile_approx`. Function `percentile_approx` returns the approximate percentile(s) of a column at the given percentage(s). A percentile is a watermark value below which a given percentage of the column values fall. For example, the percentile of column `col` at percentage 50% is the median value of column `col`. ### Syntax: ``` # Returns percentile at a given percentage value. The approximation error can be reduced by increasing parameter accuracy, at the cost of memory. percentile_approx(col, percentage [, accuracy]) # Returns percentile value array at given percentage value array percentile_approx(col, array(percentage1 [, percentage2]...) [, accuracy]) ``` ### Features: 1. This function supports partial aggregation. 2. The memory consumption is bounded. The larger `accuracy` parameter we choose, we smaller error we get. The default accuracy value is 10000, to match with Hive default setting. Choose a smaller value for smaller memory footprint. 3. This function supports window function aggregation. ### Example usages: ``` ## Returns the 25th percentile value, with default accuracy SELECT percentile_approx(col, 0.25) FROM table ## Returns an array of percentile value (25th, 50th, 75th), with default accuracy SELECT percentile_approx(col, array(0.25, 0.5, 0.75)) FROM table ## Returns 25th percentile value, with custom accuracy value 100, larger accuracy parameter yields smaller approximation error SELECT percentile_approx(col, 0.25, 100) FROM table ## Returns the 25th, and 50th percentile values, with custom accuracy value 100 SELECT percentile_approx(col, array(0.25, 0.5), 100) FROM table ``` ### NOTE: 1. The `percentile_approx` implementation is different from Hive, so the result returned on same query maybe slightly different with Hive. This implementation uses `QuantileSummaries` as the underlying probabilistic data structure, and mainly follows paper `Space-efficient Online Computation of Quantile Summaries` by Greenwald, Michael and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)` 2. The current implementation of `QuantileSummaries` doesn't support automatic compression. This PR has a rule to do compression automatically at the caller side, but it may not be optimal. ## How was this patch tested? Unit test, and Sql query test. ## Acknowledgement 1. This PR's work in based on lw-lin's PR https://github.com/apache/spark/pull/14298, with improvements like supporting partial aggregation, fixing out of memory issue. Author: Sean Zhong <seanzhong@databricks.com> Closes #14868 from clockfly/appro_percentile_try_2.
Diffstat (limited to 'sql/catalyst/src/main/scala/org')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala321
2 files changed, 322 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 35fd800df4..b05f4f61f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -250,6 +250,7 @@ object FunctionRegistry {
expression[Average]("mean"),
expression[Min]("min"),
expression[Skewness]("skewness"),
+ expression[ApproximatePercentile]("percentile_approx"),
expression[StddevSamp]("std"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
new file mode 100644
index 0000000000..f91ff87fc1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.nio.ByteBuffer
+
+import com.google.common.primitives.{Doubles, Ints, Longs}
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{InternalRow}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
+import org.apache.spark.sql.types._
+
+/**
+ * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
+ * percentage(s). A percentile is a watermark value below which a given percentage of the column
+ * values fall. For example, the percentile of column `col` at percentage 50% is the median of
+ * column `col`.
+ *
+ * This function supports partial aggregation.
+ *
+ * @param child child expression that can produce column value with `child.eval(inputRow)`
+ * @param percentageExpression Expression that represents a single percentage value or
+ * an array of percentage values. Each percentage value must be between
+ * 0.0 and 1.0.
+ * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value
+ * yields better accuracy, the default value is
+ * DEFAULT_PERCENTILE_ACCURACY.
+ */
+@ExpressionDescription(
+ usage =
+ """
+ _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric
+ column `col` at the given percentage. The value of percentage must be between 0.0
+ and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which
+ controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields
+ better accuracy, `1.0/accuracy` is the relative error of the approximation.
+
+ _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate
+ percentile array of column `col` at the given percentage array. Each value of the
+ percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is
+ a positive integer literal which controls approximation accuracy at the cost of memory.
+ Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of
+ the approximation.
+ """)
+case class ApproximatePercentile(
+ child: Expression,
+ percentageExpression: Expression,
+ accuracyExpression: Expression,
+ override val mutableAggBufferOffset: Int,
+ override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
+
+ def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
+ this(child, percentageExpression, accuracyExpression, 0, 0)
+ }
+
+ def this(child: Expression, percentageExpression: Expression) = {
+ this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
+ }
+
+ // Mark as lazy so that accuracyExpression is not evaluated during tree transformation.
+ private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
+ }
+
+ // Mark as lazy so that percentageExpression is not evaluated during tree transformation.
+ private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = {
+ (percentageExpression.dataType, percentageExpression.eval()) match {
+ // Rule ImplicitTypeCasts can cast other numeric types to double
+ case (_, num: Double) => (false, Array(num))
+ case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
+ val numericArray = arrayData.toObjectArray(baseType)
+ (true, numericArray.map { x =>
+ baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
+ })
+ case other =>
+ throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!percentageExpression.foldable || !accuracyExpression.foldable) {
+ TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal")
+ } else if (accuracy <= 0) {
+ TypeCheckFailure(
+ s"The accuracy provided must be a positive integer literal (current value = $accuracy)")
+ } else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) {
+ TypeCheckFailure(
+ s"All percentage values must be between 0.0 and 1.0 " +
+ s"(current = ${percentages.mkString(", ")})")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def createAggregationBuffer(): PercentileDigest = {
+ val relativeError = 1.0D / accuracy
+ new PercentileDigest(relativeError)
+ }
+
+ override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = {
+ val value = child.eval(inputRow)
+ // Ignore empty rows, for example: percentile_approx(null)
+ if (value != null) {
+ buffer.add(value.asInstanceOf[Double])
+ }
+ }
+
+ override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = {
+ buffer.merge(other)
+ }
+
+ override def eval(buffer: PercentileDigest): Any = {
+ val result = buffer.getPercentiles(percentages)
+ if (result.length == 0) {
+ null
+ } else if (returnPercentileArray) {
+ new GenericArrayData(result)
+ } else {
+ result(0)
+ }
+ }
+
+ override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile =
+ copy(mutableAggBufferOffset = newOffset)
+
+ override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile =
+ copy(inputAggBufferOffset = newOffset)
+
+ override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression)
+
+ // Returns null for empty inputs
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = {
+ if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
+ }
+
+ override def prettyName: String = "percentile_approx"
+
+ override def serialize(obj: PercentileDigest): Array[Byte] = {
+ ApproximatePercentile.serializer.serialize(obj)
+ }
+
+ override def deserialize(bytes: Array[Byte]): PercentileDigest = {
+ ApproximatePercentile.serializer.deserialize(bytes)
+ }
+}
+
+object ApproximatePercentile {
+
+ // Default accuracy of Percentile approximation. Larger value means better accuracy.
+ // The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
+ val DEFAULT_PERCENTILE_ACCURACY: Int = 10000
+
+ /**
+ * PercentileDigest is a probabilistic data structure used for approximating percentiles
+ * with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
+ *
+ * @param summaries underlying probabilistic data structure [[QuantileSummaries]].
+ * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the
+ * underlying quantileSummaries is compressed.
+ */
+ class PercentileDigest(
+ private var summaries: QuantileSummaries,
+ private var isCompressed: Boolean) {
+
+ // Trigger compression if the QuantileSummaries's buffer length exceeds
+ // compressThresHoldBufferLength. The buffer length can be get by
+ // quantileSummaries.sampled.length
+ private[this] final val compressThresHoldBufferLength: Int = {
+ // Max buffer length after compression.
+ val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2
+ // A safe upper bound for buffer length before compression
+ maxBufferLengthAfterCompression * 2
+ }
+
+ def this(relativeError: Double) = {
+ this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true)
+ }
+
+ /** Returns compressed object of [[QuantileSummaries]] */
+ def quantileSummaries: QuantileSummaries = {
+ if (!isCompressed) compress()
+ summaries
+ }
+
+ /** Insert an observation value into the PercentileDigest data structure. */
+ def add(value: Double): Unit = {
+ summaries = summaries.insert(value)
+ // The result of QuantileSummaries.insert is un-compressed
+ isCompressed = false
+
+ // Currently, QuantileSummaries ignores the construction parameter compressThresHold,
+ // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here
+ // to make sure QuantileSummaries doesn't occupy infinite memory.
+ // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold
+ if (summaries.sampled.length >= compressThresHoldBufferLength) compress()
+ }
+
+ /** In-place merges in another PercentileDigest. */
+ def merge(other: PercentileDigest): Unit = {
+ if (!isCompressed) compress()
+ summaries = summaries.merge(other.quantileSummaries)
+ }
+
+ /**
+ * Returns the approximate percentiles of all observation values at the given percentages.
+ * A percentile is a watermark value below which a given percentage of observation values fall.
+ * For example, the following code returns the 25th, median, and 75th percentiles of
+ * all observation values:
+ *
+ * {{{
+ * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
+ * }}}
+ */
+ def getPercentiles(percentages: Array[Double]): Array[Double] = {
+ if (!isCompressed) compress()
+ if (summaries.count == 0 || percentages.length == 0) {
+ Array.empty[Double]
+ } else {
+ val result = new Array[Double](percentages.length)
+ var i = 0
+ while (i < percentages.length) {
+ result(i) = summaries.query(percentages(i))
+ i += 1
+ }
+ result
+ }
+ }
+
+ private final def compress(): Unit = {
+ summaries = summaries.compress()
+ isCompressed = true
+ }
+ }
+
+ /**
+ * Serializer for class [[PercentileDigest]]
+ *
+ * This class is thread safe.
+ */
+ class PercentileDigestSerializer {
+
+ private final def length(summaries: QuantileSummaries): Int = {
+ // summaries.compressThreshold, summary.relativeError, summary.count
+ Ints.BYTES + Doubles.BYTES + Longs.BYTES +
+ // length of summary.sampled
+ Ints.BYTES +
+ // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)]
+ summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES)
+ }
+
+ final def serialize(obj: PercentileDigest): Array[Byte] = {
+ val summary = obj.quantileSummaries
+ val buffer = ByteBuffer.wrap(new Array(length(summary)))
+ buffer.putInt(summary.compressThreshold)
+ buffer.putDouble(summary.relativeError)
+ buffer.putLong(summary.count)
+ buffer.putInt(summary.sampled.length)
+
+ var i = 0
+ while (i < summary.sampled.length) {
+ val stat = summary.sampled(i)
+ buffer.putDouble(stat.value)
+ buffer.putInt(stat.g)
+ buffer.putInt(stat.delta)
+ i += 1
+ }
+ buffer.array()
+ }
+
+ final def deserialize(bytes: Array[Byte]): PercentileDigest = {
+ val buffer = ByteBuffer.wrap(bytes)
+ val compressThreshold = buffer.getInt()
+ val relativeError = buffer.getDouble()
+ val count = buffer.getLong()
+ val sampledLength = buffer.getInt()
+ val sampled = new Array[Stats](sampledLength)
+
+ var i = 0
+ while (i < sampledLength) {
+ val value = buffer.getDouble()
+ val g = buffer.getInt()
+ val delta = buffer.getInt()
+ sampled(i) = Stats(value, g, delta)
+ i += 1
+ }
+ val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count)
+ new PercentileDigest(summary, isCompressed = true)
+ }
+ }
+
+ val serializer: PercentileDigestSerializer = new PercentileDigestSerializer
+}