diff options
author | Sean Zhong <seanzhong@databricks.com> | 2016-09-01 16:31:13 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-09-01 16:31:13 +0800 |
commit | a18c169fd050e71fdb07b153ae0fa5c410d8de27 (patch) | |
tree | be693af0f087329fd457ad0c238ffb2a3a9f8bab /sql/catalyst/src/main | |
parent | 536fa911c181958d84f14156f7d57ef5fd68df48 (diff) | |
download | spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.tar.gz spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.tar.bz2 spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.zip |
[SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation.
## What changes were proposed in this pull request?
This PR implements aggregation function `percentile_approx`. Function `percentile_approx` returns the approximate percentile(s) of a column at the given percentage(s). A percentile is a watermark value below which a given percentage of the column values fall. For example, the percentile of column `col` at percentage 50% is the median value of column `col`.
### Syntax:
```
# Returns percentile at a given percentage value. The approximation error can be reduced by increasing parameter accuracy, at the cost of memory.
percentile_approx(col, percentage [, accuracy])
# Returns percentile value array at given percentage value array
percentile_approx(col, array(percentage1 [, percentage2]...) [, accuracy])
```
### Features:
1. This function supports partial aggregation.
2. The memory consumption is bounded. The larger `accuracy` parameter we choose, we smaller error we get. The default accuracy value is 10000, to match with Hive default setting. Choose a smaller value for smaller memory footprint.
3. This function supports window function aggregation.
### Example usages:
```
## Returns the 25th percentile value, with default accuracy
SELECT percentile_approx(col, 0.25) FROM table
## Returns an array of percentile value (25th, 50th, 75th), with default accuracy
SELECT percentile_approx(col, array(0.25, 0.5, 0.75)) FROM table
## Returns 25th percentile value, with custom accuracy value 100, larger accuracy parameter yields smaller approximation error
SELECT percentile_approx(col, 0.25, 100) FROM table
## Returns the 25th, and 50th percentile values, with custom accuracy value 100
SELECT percentile_approx(col, array(0.25, 0.5), 100) FROM table
```
### NOTE:
1. The `percentile_approx` implementation is different from Hive, so the result returned on same query maybe slightly different with Hive. This implementation uses `QuantileSummaries` as the underlying probabilistic data structure, and mainly follows paper `Space-efficient Online Computation of Quantile Summaries` by Greenwald, Michael and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)`
2. The current implementation of `QuantileSummaries` doesn't support automatic compression. This PR has a rule to do compression automatically at the caller side, but it may not be optimal.
## How was this patch tested?
Unit test, and Sql query test.
## Acknowledgement
1. This PR's work in based on lw-lin's PR https://github.com/apache/spark/pull/14298, with improvements like supporting partial aggregation, fixing out of memory issue.
Author: Sean Zhong <seanzhong@databricks.com>
Closes #14868 from clockfly/appro_percentile_try_2.
Diffstat (limited to 'sql/catalyst/src/main')
2 files changed, 322 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 35fd800df4..b05f4f61f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -250,6 +250,7 @@ object FunctionRegistry { expression[Average]("mean"), expression[Min]("min"), expression[Skewness]("skewness"), + expression[ApproximatePercentile]("percentile_approx"), expression[StddevSamp]("std"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala new file mode 100644 index 0000000000..f91ff87fc1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.nio.ByteBuffer + +import com.google.common.primitives.{Doubles, Ints, Longs} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{InternalRow} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} +import org.apache.spark.sql.types._ + +/** + * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given + * percentage(s). A percentile is a watermark value below which a given percentage of the column + * values fall. For example, the percentile of column `col` at percentage 50% is the median of + * column `col`. + * + * This function supports partial aggregation. + * + * @param child child expression that can produce column value with `child.eval(inputRow)` + * @param percentageExpression Expression that represents a single percentage value or + * an array of percentage values. Each percentage value must be between + * 0.0 and 1.0. + * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value + * yields better accuracy, the default value is + * DEFAULT_PERCENTILE_ACCURACY. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric + column `col` at the given percentage. The value of percentage must be between 0.0 + and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which + controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields + better accuracy, `1.0/accuracy` is the relative error of the approximation. + + _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate + percentile array of column `col` at the given percentage array. Each value of the + percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is + a positive integer literal which controls approximation accuracy at the cost of memory. + Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of + the approximation. + """) +case class ApproximatePercentile( + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] { + + def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { + this(child, percentageExpression, accuracyExpression, 0, 0) + } + + def this(child: Expression, percentageExpression: Expression) = { + this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + } + + // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. + private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] + + override def inputTypes: Seq[AbstractDataType] = { + Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType) + } + + // Mark as lazy so that percentageExpression is not evaluated during tree transformation. + private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = { + (percentageExpression.dataType, percentageExpression.eval()) match { + // Rule ImplicitTypeCasts can cast other numeric types to double + case (_, num: Double) => (false, Array(num)) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + val numericArray = arrayData.toObjectArray(baseType) + (true, numericArray.map { x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) + }) + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!percentageExpression.foldable || !accuracyExpression.foldable) { + TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal") + } else if (accuracy <= 0) { + TypeCheckFailure( + s"The accuracy provided must be a positive integer literal (current value = $accuracy)") + } else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) { + TypeCheckFailure( + s"All percentage values must be between 0.0 and 1.0 " + + s"(current = ${percentages.mkString(", ")})") + } else { + TypeCheckSuccess + } + } + + override def createAggregationBuffer(): PercentileDigest = { + val relativeError = 1.0D / accuracy + new PercentileDigest(relativeError) + } + + override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = { + val value = child.eval(inputRow) + // Ignore empty rows, for example: percentile_approx(null) + if (value != null) { + buffer.add(value.asInstanceOf[Double]) + } + } + + override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = { + buffer.merge(other) + } + + override def eval(buffer: PercentileDigest): Any = { + val result = buffer.getPercentiles(percentages) + if (result.length == 0) { + null + } else if (returnPercentileArray) { + new GenericArrayData(result) + } else { + result(0) + } + } + + override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile = + copy(inputAggBufferOffset = newOffset) + + override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression) + + // Returns null for empty inputs + override def nullable: Boolean = true + + override def dataType: DataType = { + if (returnPercentileArray) ArrayType(DoubleType) else DoubleType + } + + override def prettyName: String = "percentile_approx" + + override def serialize(obj: PercentileDigest): Array[Byte] = { + ApproximatePercentile.serializer.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): PercentileDigest = { + ApproximatePercentile.serializer.deserialize(bytes) + } +} + +object ApproximatePercentile { + + // Default accuracy of Percentile approximation. Larger value means better accuracy. + // The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY + val DEFAULT_PERCENTILE_ACCURACY: Int = 10000 + + /** + * PercentileDigest is a probabilistic data structure used for approximating percentiles + * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. + * + * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. + * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the + * underlying quantileSummaries is compressed. + */ + class PercentileDigest( + private var summaries: QuantileSummaries, + private var isCompressed: Boolean) { + + // Trigger compression if the QuantileSummaries's buffer length exceeds + // compressThresHoldBufferLength. The buffer length can be get by + // quantileSummaries.sampled.length + private[this] final val compressThresHoldBufferLength: Int = { + // Max buffer length after compression. + val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 + // A safe upper bound for buffer length before compression + maxBufferLengthAfterCompression * 2 + } + + def this(relativeError: Double) = { + this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + } + + /** Returns compressed object of [[QuantileSummaries]] */ + def quantileSummaries: QuantileSummaries = { + if (!isCompressed) compress() + summaries + } + + /** Insert an observation value into the PercentileDigest data structure. */ + def add(value: Double): Unit = { + summaries = summaries.insert(value) + // The result of QuantileSummaries.insert is un-compressed + isCompressed = false + + // Currently, QuantileSummaries ignores the construction parameter compressThresHold, + // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here + // to make sure QuantileSummaries doesn't occupy infinite memory. + // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold + if (summaries.sampled.length >= compressThresHoldBufferLength) compress() + } + + /** In-place merges in another PercentileDigest. */ + def merge(other: PercentileDigest): Unit = { + if (!isCompressed) compress() + summaries = summaries.merge(other.quantileSummaries) + } + + /** + * Returns the approximate percentiles of all observation values at the given percentages. + * A percentile is a watermark value below which a given percentage of observation values fall. + * For example, the following code returns the 25th, median, and 75th percentiles of + * all observation values: + * + * {{{ + * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75)) + * }}} + */ + def getPercentiles(percentages: Array[Double]): Array[Double] = { + if (!isCompressed) compress() + if (summaries.count == 0 || percentages.length == 0) { + Array.empty[Double] + } else { + val result = new Array[Double](percentages.length) + var i = 0 + while (i < percentages.length) { + result(i) = summaries.query(percentages(i)) + i += 1 + } + result + } + } + + private final def compress(): Unit = { + summaries = summaries.compress() + isCompressed = true + } + } + + /** + * Serializer for class [[PercentileDigest]] + * + * This class is thread safe. + */ + class PercentileDigestSerializer { + + private final def length(summaries: QuantileSummaries): Int = { + // summaries.compressThreshold, summary.relativeError, summary.count + Ints.BYTES + Doubles.BYTES + Longs.BYTES + + // length of summary.sampled + Ints.BYTES + + // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] + summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + } + + final def serialize(obj: PercentileDigest): Array[Byte] = { + val summary = obj.quantileSummaries + val buffer = ByteBuffer.wrap(new Array(length(summary))) + buffer.putInt(summary.compressThreshold) + buffer.putDouble(summary.relativeError) + buffer.putLong(summary.count) + buffer.putInt(summary.sampled.length) + + var i = 0 + while (i < summary.sampled.length) { + val stat = summary.sampled(i) + buffer.putDouble(stat.value) + buffer.putInt(stat.g) + buffer.putInt(stat.delta) + i += 1 + } + buffer.array() + } + + final def deserialize(bytes: Array[Byte]): PercentileDigest = { + val buffer = ByteBuffer.wrap(bytes) + val compressThreshold = buffer.getInt() + val relativeError = buffer.getDouble() + val count = buffer.getLong() + val sampledLength = buffer.getInt() + val sampled = new Array[Stats](sampledLength) + + var i = 0 + while (i < sampledLength) { + val value = buffer.getDouble() + val g = buffer.getInt() + val delta = buffer.getInt() + sampled(i) = Stats(value, g, delta) + i += 1 + } + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new PercentileDigest(summary, isCompressed = true) + } + } + + val serializer: PercentileDigestSerializer = new PercentileDigestSerializer +} |