aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-09-01 16:31:13 +0800
committerWenchen Fan <wenchen@databricks.com>2016-09-01 16:31:13 +0800
commita18c169fd050e71fdb07b153ae0fa5c410d8de27 (patch)
treebe693af0f087329fd457ad0c238ffb2a3a9f8bab /sql/catalyst/src
parent536fa911c181958d84f14156f7d57ef5fd68df48 (diff)
downloadspark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.tar.gz
spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.tar.bz2
spark-a18c169fd050e71fdb07b153ae0fa5c410d8de27.zip
[SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation.
## What changes were proposed in this pull request? This PR implements aggregation function `percentile_approx`. Function `percentile_approx` returns the approximate percentile(s) of a column at the given percentage(s). A percentile is a watermark value below which a given percentage of the column values fall. For example, the percentile of column `col` at percentage 50% is the median value of column `col`. ### Syntax: ``` # Returns percentile at a given percentage value. The approximation error can be reduced by increasing parameter accuracy, at the cost of memory. percentile_approx(col, percentage [, accuracy]) # Returns percentile value array at given percentage value array percentile_approx(col, array(percentage1 [, percentage2]...) [, accuracy]) ``` ### Features: 1. This function supports partial aggregation. 2. The memory consumption is bounded. The larger `accuracy` parameter we choose, we smaller error we get. The default accuracy value is 10000, to match with Hive default setting. Choose a smaller value for smaller memory footprint. 3. This function supports window function aggregation. ### Example usages: ``` ## Returns the 25th percentile value, with default accuracy SELECT percentile_approx(col, 0.25) FROM table ## Returns an array of percentile value (25th, 50th, 75th), with default accuracy SELECT percentile_approx(col, array(0.25, 0.5, 0.75)) FROM table ## Returns 25th percentile value, with custom accuracy value 100, larger accuracy parameter yields smaller approximation error SELECT percentile_approx(col, 0.25, 100) FROM table ## Returns the 25th, and 50th percentile values, with custom accuracy value 100 SELECT percentile_approx(col, array(0.25, 0.5), 100) FROM table ``` ### NOTE: 1. The `percentile_approx` implementation is different from Hive, so the result returned on same query maybe slightly different with Hive. This implementation uses `QuantileSummaries` as the underlying probabilistic data structure, and mainly follows paper `Space-efficient Online Computation of Quantile Summaries` by Greenwald, Michael and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)` 2. The current implementation of `QuantileSummaries` doesn't support automatic compression. This PR has a rule to do compression automatically at the caller side, but it may not be optimal. ## How was this patch tested? Unit test, and Sql query test. ## Acknowledgement 1. This PR's work in based on lw-lin's PR https://github.com/apache/spark/pull/14298, with improvements like supporting partial aggregation, fixing out of memory issue. Author: Sean Zhong <seanzhong@databricks.com> Closes #14868 from clockfly/appro_percentile_try_2.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala321
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala339
3 files changed, 661 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 35fd800df4..b05f4f61f6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -250,6 +250,7 @@ object FunctionRegistry {
expression[Average]("mean"),
expression[Min]("min"),
expression[Skewness]("skewness"),
+ expression[ApproximatePercentile]("percentile_approx"),
expression[StddevSamp]("std"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
new file mode 100644
index 0000000000..f91ff87fc1
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala
@@ -0,0 +1,321 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.nio.ByteBuffer
+
+import com.google.common.primitives.{Doubles, Ints, Longs}
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.{InternalRow}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest}
+import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData}
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats}
+import org.apache.spark.sql.types._
+
+/**
+ * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given
+ * percentage(s). A percentile is a watermark value below which a given percentage of the column
+ * values fall. For example, the percentile of column `col` at percentage 50% is the median of
+ * column `col`.
+ *
+ * This function supports partial aggregation.
+ *
+ * @param child child expression that can produce column value with `child.eval(inputRow)`
+ * @param percentageExpression Expression that represents a single percentage value or
+ * an array of percentage values. Each percentage value must be between
+ * 0.0 and 1.0.
+ * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value
+ * yields better accuracy, the default value is
+ * DEFAULT_PERCENTILE_ACCURACY.
+ */
+@ExpressionDescription(
+ usage =
+ """
+ _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric
+ column `col` at the given percentage. The value of percentage must be between 0.0
+ and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which
+ controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields
+ better accuracy, `1.0/accuracy` is the relative error of the approximation.
+
+ _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate
+ percentile array of column `col` at the given percentage array. Each value of the
+ percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is
+ a positive integer literal which controls approximation accuracy at the cost of memory.
+ Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of
+ the approximation.
+ """)
+case class ApproximatePercentile(
+ child: Expression,
+ percentageExpression: Expression,
+ accuracyExpression: Expression,
+ override val mutableAggBufferOffset: Int,
+ override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] {
+
+ def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = {
+ this(child, percentageExpression, accuracyExpression, 0, 0)
+ }
+
+ def this(child: Expression, percentageExpression: Expression) = {
+ this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY))
+ }
+
+ // Mark as lazy so that accuracyExpression is not evaluated during tree transformation.
+ private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int]
+
+ override def inputTypes: Seq[AbstractDataType] = {
+ Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType)
+ }
+
+ // Mark as lazy so that percentageExpression is not evaluated during tree transformation.
+ private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = {
+ (percentageExpression.dataType, percentageExpression.eval()) match {
+ // Rule ImplicitTypeCasts can cast other numeric types to double
+ case (_, num: Double) => (false, Array(num))
+ case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
+ val numericArray = arrayData.toObjectArray(baseType)
+ (true, numericArray.map { x =>
+ baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])
+ })
+ case other =>
+ throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage")
+ }
+ }
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!percentageExpression.foldable || !accuracyExpression.foldable) {
+ TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal")
+ } else if (accuracy <= 0) {
+ TypeCheckFailure(
+ s"The accuracy provided must be a positive integer literal (current value = $accuracy)")
+ } else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) {
+ TypeCheckFailure(
+ s"All percentage values must be between 0.0 and 1.0 " +
+ s"(current = ${percentages.mkString(", ")})")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def createAggregationBuffer(): PercentileDigest = {
+ val relativeError = 1.0D / accuracy
+ new PercentileDigest(relativeError)
+ }
+
+ override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = {
+ val value = child.eval(inputRow)
+ // Ignore empty rows, for example: percentile_approx(null)
+ if (value != null) {
+ buffer.add(value.asInstanceOf[Double])
+ }
+ }
+
+ override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = {
+ buffer.merge(other)
+ }
+
+ override def eval(buffer: PercentileDigest): Any = {
+ val result = buffer.getPercentiles(percentages)
+ if (result.length == 0) {
+ null
+ } else if (returnPercentileArray) {
+ new GenericArrayData(result)
+ } else {
+ result(0)
+ }
+ }
+
+ override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile =
+ copy(mutableAggBufferOffset = newOffset)
+
+ override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile =
+ copy(inputAggBufferOffset = newOffset)
+
+ override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression)
+
+ // Returns null for empty inputs
+ override def nullable: Boolean = true
+
+ override def dataType: DataType = {
+ if (returnPercentileArray) ArrayType(DoubleType) else DoubleType
+ }
+
+ override def prettyName: String = "percentile_approx"
+
+ override def serialize(obj: PercentileDigest): Array[Byte] = {
+ ApproximatePercentile.serializer.serialize(obj)
+ }
+
+ override def deserialize(bytes: Array[Byte]): PercentileDigest = {
+ ApproximatePercentile.serializer.deserialize(bytes)
+ }
+}
+
+object ApproximatePercentile {
+
+ // Default accuracy of Percentile approximation. Larger value means better accuracy.
+ // The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY
+ val DEFAULT_PERCENTILE_ACCURACY: Int = 10000
+
+ /**
+ * PercentileDigest is a probabilistic data structure used for approximating percentiles
+ * with limited memory. PercentileDigest is backed by [[QuantileSummaries]].
+ *
+ * @param summaries underlying probabilistic data structure [[QuantileSummaries]].
+ * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the
+ * underlying quantileSummaries is compressed.
+ */
+ class PercentileDigest(
+ private var summaries: QuantileSummaries,
+ private var isCompressed: Boolean) {
+
+ // Trigger compression if the QuantileSummaries's buffer length exceeds
+ // compressThresHoldBufferLength. The buffer length can be get by
+ // quantileSummaries.sampled.length
+ private[this] final val compressThresHoldBufferLength: Int = {
+ // Max buffer length after compression.
+ val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2
+ // A safe upper bound for buffer length before compression
+ maxBufferLengthAfterCompression * 2
+ }
+
+ def this(relativeError: Double) = {
+ this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true)
+ }
+
+ /** Returns compressed object of [[QuantileSummaries]] */
+ def quantileSummaries: QuantileSummaries = {
+ if (!isCompressed) compress()
+ summaries
+ }
+
+ /** Insert an observation value into the PercentileDigest data structure. */
+ def add(value: Double): Unit = {
+ summaries = summaries.insert(value)
+ // The result of QuantileSummaries.insert is un-compressed
+ isCompressed = false
+
+ // Currently, QuantileSummaries ignores the construction parameter compressThresHold,
+ // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here
+ // to make sure QuantileSummaries doesn't occupy infinite memory.
+ // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold
+ if (summaries.sampled.length >= compressThresHoldBufferLength) compress()
+ }
+
+ /** In-place merges in another PercentileDigest. */
+ def merge(other: PercentileDigest): Unit = {
+ if (!isCompressed) compress()
+ summaries = summaries.merge(other.quantileSummaries)
+ }
+
+ /**
+ * Returns the approximate percentiles of all observation values at the given percentages.
+ * A percentile is a watermark value below which a given percentage of observation values fall.
+ * For example, the following code returns the 25th, median, and 75th percentiles of
+ * all observation values:
+ *
+ * {{{
+ * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75))
+ * }}}
+ */
+ def getPercentiles(percentages: Array[Double]): Array[Double] = {
+ if (!isCompressed) compress()
+ if (summaries.count == 0 || percentages.length == 0) {
+ Array.empty[Double]
+ } else {
+ val result = new Array[Double](percentages.length)
+ var i = 0
+ while (i < percentages.length) {
+ result(i) = summaries.query(percentages(i))
+ i += 1
+ }
+ result
+ }
+ }
+
+ private final def compress(): Unit = {
+ summaries = summaries.compress()
+ isCompressed = true
+ }
+ }
+
+ /**
+ * Serializer for class [[PercentileDigest]]
+ *
+ * This class is thread safe.
+ */
+ class PercentileDigestSerializer {
+
+ private final def length(summaries: QuantileSummaries): Int = {
+ // summaries.compressThreshold, summary.relativeError, summary.count
+ Ints.BYTES + Doubles.BYTES + Longs.BYTES +
+ // length of summary.sampled
+ Ints.BYTES +
+ // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)]
+ summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES)
+ }
+
+ final def serialize(obj: PercentileDigest): Array[Byte] = {
+ val summary = obj.quantileSummaries
+ val buffer = ByteBuffer.wrap(new Array(length(summary)))
+ buffer.putInt(summary.compressThreshold)
+ buffer.putDouble(summary.relativeError)
+ buffer.putLong(summary.count)
+ buffer.putInt(summary.sampled.length)
+
+ var i = 0
+ while (i < summary.sampled.length) {
+ val stat = summary.sampled(i)
+ buffer.putDouble(stat.value)
+ buffer.putInt(stat.g)
+ buffer.putInt(stat.delta)
+ i += 1
+ }
+ buffer.array()
+ }
+
+ final def deserialize(bytes: Array[Byte]): PercentileDigest = {
+ val buffer = ByteBuffer.wrap(bytes)
+ val compressThreshold = buffer.getInt()
+ val relativeError = buffer.getDouble()
+ val count = buffer.getLong()
+ val sampledLength = buffer.getInt()
+ val sampled = new Array[Stats](sampledLength)
+
+ var i = 0
+ while (i < sampledLength) {
+ val value = buffer.getDouble()
+ val g = buffer.getInt()
+ val delta = buffer.getInt()
+ sampled(i) = Stats(value, g, delta)
+ i += 1
+ }
+ val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count)
+ new PercentileDigest(summary, isCompressed = true)
+ }
+ }
+
+ val serializer: PercentileDigestSerializer = new PercentileDigestSerializer
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
new file mode 100644
index 0000000000..61298a1b72
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala
@@ -0,0 +1,339 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute}
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericMutableRow, Literal}
+import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer}
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.catalyst.util.QuantileSummaries
+import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats
+import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType}
+import org.apache.spark.util.SizeEstimator
+
+class ApproximatePercentileSuite extends SparkFunSuite {
+
+ private val random = new java.util.Random()
+
+ private val data = (0 until 10000).map { _ =>
+ random.nextInt(10000)
+ }
+
+ test("serialize and de-serialize") {
+ val serializer = new PercentileDigestSerializer
+
+ // Check empty serialize and de-serialize
+ val emptyBuffer = new PercentileDigest(relativeError = 0.01)
+ assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer))))
+
+ val buffer = new PercentileDigest(relativeError = 0.01)
+ data.foreach { value =>
+ buffer.add(value)
+ }
+ assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer))))
+
+ val agg = new ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5))
+ assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+ }
+
+ test("class PercentileDigest, basic operations") {
+ val valueCount = 10000
+ val percentages = Array(0.25, 0.5, 0.75)
+ Seq(0.0001, 0.001, 0.01, 0.1).foreach { relativeError =>
+ val buffer = new PercentileDigest(relativeError)
+ (1 to valueCount).grouped(10).foreach { group =>
+ val partialBuffer = new PercentileDigest(relativeError)
+ group.foreach(x => partialBuffer.add(x))
+ buffer.merge(partialBuffer)
+ }
+ val expectedPercentiles = percentages.map(_ * valueCount)
+ val approxPercentiles = buffer.getPercentiles(Array(0.25, 0.5, 0.75))
+ expectedPercentiles.zip(approxPercentiles).foreach { pair =>
+ val (expected, estimate) = pair
+ assert((estimate - expected) / valueCount <= relativeError)
+ }
+ }
+ }
+
+ test("class PercentileDigest, makes sure the memory foot print is bounded") {
+ val relativeError = 0.01
+ val memoryFootPrintUpperBound = {
+ val headBufferSize =
+ SizeEstimator.estimate(new Array[Double](QuantileSummaries.defaultHeadSize))
+ val bufferSize = SizeEstimator.estimate(new Stats(0, 0, 0)) * (1 / relativeError) * 2
+ // A safe upper bound
+ (headBufferSize + bufferSize) * 2
+ }
+
+ val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count =>
+ val buffer = new PercentileDigest(relativeError)
+ // Worst case, data is linear sorted
+ (0 until count).foreach(buffer.add(_))
+ assert(SizeEstimator.estimate(buffer) < memoryFootPrintUpperBound)
+ }
+ }
+
+ test("class ApproximatePercentile, high level interface, update, merge, eval...") {
+ val count = 10000
+ val data = (1 until 10000).toSeq
+ val percentages = Array(0.25D, 0.5D, 0.75D)
+ val accuracy = 10000
+ val expectedPercentiles = percentages.map(count * _)
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
+ val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
+ val accuracyExpression = Literal(10000)
+ val agg = new ApproximatePercentile(childExpression, percentageExpression, accuracyExpression)
+
+ assert(agg.nullable)
+ val group1 = (0 until data.length / 2)
+ val group1Buffer = agg.createAggregationBuffer()
+ group1.foreach { index =>
+ val input = InternalRow(data(index))
+ agg.update(group1Buffer, input)
+ }
+
+ val group2 = (data.length / 2 until data.length)
+ val group2Buffer = agg.createAggregationBuffer()
+ group2.foreach { index =>
+ val input = InternalRow(data(index))
+ agg.update(group2Buffer, input)
+ }
+
+ val mergeBuffer = agg.createAggregationBuffer()
+ agg.merge(mergeBuffer, group1Buffer)
+ agg.merge(mergeBuffer, group2Buffer)
+
+ agg.eval(mergeBuffer) match {
+ case arrayData: ArrayData =>
+ val error = count / accuracy
+ val percentiles = arrayData.toDoubleArray()
+ assert(percentiles.zip(expectedPercentiles)
+ .forall(pair => Math.abs(pair._1 - pair._2) < error))
+ }
+ }
+
+ test("class ApproximatePercentile, low level interface, update, merge, eval...") {
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+ val inputAggregationBufferOffset = 1
+ val mutableAggregationBufferOffset = 2
+ val percentage = 0.5D
+
+ // Phase one, partial mode aggregation
+ val agg = new ApproximatePercentile(childExpression, Literal(percentage))
+ .withNewInputAggBufferOffset(inputAggregationBufferOffset)
+ .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
+
+ val mutableAggBuffer = new GenericMutableRow(new Array[Any](mutableAggregationBufferOffset + 1))
+ agg.initialize(mutableAggBuffer)
+ val dataCount = 10
+ (1 to dataCount).foreach { data =>
+ agg.update(mutableAggBuffer, InternalRow(data))
+ }
+ agg.serializeAggregateBufferInPlace(mutableAggBuffer)
+
+ // Serialize the aggregation buffer
+ val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
+ val inputAggBuffer = new GenericMutableRow(Array[Any](null, serialized))
+
+ // Phase 2: final mode aggregation
+ // Re-initialize the aggregation buffer
+ agg.initialize(mutableAggBuffer)
+ agg.merge(mutableAggBuffer, inputAggBuffer)
+ val expectedPercentile = dataCount * percentage
+ assert(Math.abs(agg.eval(mutableAggBuffer).asInstanceOf[Double] - expectedPercentile) < 0.1)
+ }
+
+ test("class ApproximatePercentile, sql string") {
+ val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY
+ // sql, single percentile
+ assertEqual(
+ s"percentile_approx(`a`, 0.5D, $defaultAccuracy)",
+ new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String)
+
+ // sql, array of percentile
+ assertEqual(
+ s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)",
+ new ApproximatePercentile(
+ "a".attr,
+ percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_)))
+ ).sql: String)
+
+ // sql(isDistinct = false), single percentile
+ assertEqual(
+ s"percentile_approx(`a`, 0.5D, $defaultAccuracy)",
+ new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D))
+ .sql(isDistinct = false))
+
+ // sql(isDistinct = false), array of percentile
+ assertEqual(
+ s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)",
+ new ApproximatePercentile(
+ "a".attr,
+ percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_)))
+ ).sql(isDistinct = false))
+
+ // sql(isDistinct = true), single percentile
+ assertEqual(
+ s"percentile_approx(DISTINCT `a`, 0.5D, $defaultAccuracy)",
+ new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D))
+ .sql(isDistinct = true))
+
+ // sql(isDistinct = true), array of percentile
+ assertEqual(
+ s"percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)",
+ new ApproximatePercentile(
+ "a".attr,
+ percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_)))
+ ).sql(isDistinct = true))
+ }
+
+ test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") {
+ val attribute = AttributeReference("a", DoubleType)()
+ val wrongAccuracy = new ApproximatePercentile(
+ attribute,
+ percentageExpression = Literal(0.5D),
+ accuracyExpression = AttributeReference("b", IntegerType)())
+
+ assertEqual(
+ wrongAccuracy.checkInputDataTypes(),
+ TypeCheckFailure("The accuracy or percentage provided must be a constant literal")
+ )
+
+ val wrongPercentage = new ApproximatePercentile(
+ attribute,
+ percentageExpression = attribute,
+ accuracyExpression = Literal(10000))
+
+ assertEqual(
+ wrongPercentage.checkInputDataTypes(),
+ TypeCheckFailure("The accuracy or percentage provided must be a constant literal")
+ )
+ }
+
+ test("class ApproximatePercentile, fails analysis if parameters are invalid") {
+ val wrongAccuracy = new ApproximatePercentile(
+ AttributeReference("a", DoubleType)(),
+ percentageExpression = Literal(0.5D),
+ accuracyExpression = Literal(-1))
+ assertEqual(
+ wrongAccuracy.checkInputDataTypes(),
+ TypeCheckFailure(
+ "The accuracy provided must be a positive integer literal (current value = -1)"))
+
+ val correctPercentageExpresions = Seq(
+ Literal(0D),
+ Literal(1D),
+ Literal(0.5D),
+ CreateArray(Seq(0D, 1D, 0.5D).map(Literal(_)))
+ )
+ correctPercentageExpresions.foreach { percentageExpression =>
+ val correctPercentage = new ApproximatePercentile(
+ AttributeReference("a", DoubleType)(),
+ percentageExpression = percentageExpression,
+ accuracyExpression = Literal(100))
+
+ // no exception should be thrown
+ correctPercentage.checkInputDataTypes()
+ }
+
+ val wrongPercentageExpressions = Seq(
+ Literal(1.1D),
+ Literal(-0.5D),
+ CreateArray(Seq(0D, 0.5D, 1.1D).map(Literal(_)))
+ )
+
+ wrongPercentageExpressions.foreach { percentageExpression =>
+ val wrongPercentage = new ApproximatePercentile(
+ AttributeReference("a", DoubleType)(),
+ percentageExpression = percentageExpression,
+ accuracyExpression = Literal(100))
+
+ val result = wrongPercentage.checkInputDataTypes()
+ assert(
+ wrongPercentage.checkInputDataTypes() match {
+ case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true
+ case _ => false
+ })
+ }
+ }
+
+ test("class ApproximatePercentile, automatically add type casting for parameters") {
+ val testRelation = LocalRelation('a.int)
+ val analyzer = SimpleAnalyzer
+
+ // Compatible accuracy types: Long type and decimal type
+ val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D))
+ // Compatible percentage types: float, decimal
+ val percentageExpressions = Seq(Literal(0.3f), DecimalLiteral(0.5),
+ CreateArray(Seq(Literal(0.3f), Literal(0.5D), DecimalLiteral(0.7))))
+
+ accuracyExpressions.foreach { accuracyExpression =>
+ percentageExpressions.foreach { percentageExpression =>
+ val agg = new ApproximatePercentile(
+ UnresolvedAttribute("a"),
+ percentageExpression,
+ accuracyExpression)
+ val analyzed = testRelation.select(agg).analyze.expressions.head
+ analyzed match {
+ case Alias(agg: ApproximatePercentile, _) =>
+ assert(agg.resolved)
+ assert(agg.child.dataType == DoubleType)
+ assert(agg.percentageExpression.dataType == DoubleType ||
+ agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false))
+ assert(agg.accuracyExpression.dataType == IntegerType)
+ case _ => fail()
+ }
+ }
+ }
+ }
+
+ test("class ApproximatePercentile, null handling") {
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+ val agg = new ApproximatePercentile(childExpression, Literal(0.5D))
+ val buffer = new GenericMutableRow(new Array[Any](1))
+ agg.initialize(buffer)
+ // Empty aggregation buffer
+ assert(agg.eval(buffer) == null)
+ // Empty input row
+ agg.update(buffer, InternalRow(null))
+ assert(agg.eval(buffer) == null)
+
+ // Add some non-empty row
+ agg.update(buffer, InternalRow(0))
+ assert(agg.eval(buffer) != null)
+ }
+
+ private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = {
+ val leftSummary = left.quantileSummaries
+ val rightSummary = right.quantileSummaries
+ leftSummary.compressThreshold == rightSummary.compressThreshold &&
+ leftSummary.relativeError == rightSummary.relativeError &&
+ leftSummary.count == rightSummary.count &&
+ leftSummary.sampled.sameElements(rightSummary.sampled)
+ }
+
+ private def assertEqual[T](left: T, right: T): Unit = {
+ assert(left == right)
+ }
+}