From a18c169fd050e71fdb07b153ae0fa5c410d8de27 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Thu, 1 Sep 2016 16:31:13 +0800 Subject: [SPARK-16283][SQL] Implements percentile_approx aggregation function which supports partial aggregation. ## What changes were proposed in this pull request? This PR implements aggregation function `percentile_approx`. Function `percentile_approx` returns the approximate percentile(s) of a column at the given percentage(s). A percentile is a watermark value below which a given percentage of the column values fall. For example, the percentile of column `col` at percentage 50% is the median value of column `col`. ### Syntax: ``` # Returns percentile at a given percentage value. The approximation error can be reduced by increasing parameter accuracy, at the cost of memory. percentile_approx(col, percentage [, accuracy]) # Returns percentile value array at given percentage value array percentile_approx(col, array(percentage1 [, percentage2]...) [, accuracy]) ``` ### Features: 1. This function supports partial aggregation. 2. The memory consumption is bounded. The larger `accuracy` parameter we choose, we smaller error we get. The default accuracy value is 10000, to match with Hive default setting. Choose a smaller value for smaller memory footprint. 3. This function supports window function aggregation. ### Example usages: ``` ## Returns the 25th percentile value, with default accuracy SELECT percentile_approx(col, 0.25) FROM table ## Returns an array of percentile value (25th, 50th, 75th), with default accuracy SELECT percentile_approx(col, array(0.25, 0.5, 0.75)) FROM table ## Returns 25th percentile value, with custom accuracy value 100, larger accuracy parameter yields smaller approximation error SELECT percentile_approx(col, 0.25, 100) FROM table ## Returns the 25th, and 50th percentile values, with custom accuracy value 100 SELECT percentile_approx(col, array(0.25, 0.5), 100) FROM table ``` ### NOTE: 1. The `percentile_approx` implementation is different from Hive, so the result returned on same query maybe slightly different with Hive. This implementation uses `QuantileSummaries` as the underlying probabilistic data structure, and mainly follows paper `Space-efficient Online Computation of Quantile Summaries` by Greenwald, Michael and Khanna, Sanjeev. (http://dx.doi.org/10.1145/375663.375670)` 2. The current implementation of `QuantileSummaries` doesn't support automatic compression. This PR has a rule to do compression automatically at the caller side, but it may not be optimal. ## How was this patch tested? Unit test, and Sql query test. ## Acknowledgement 1. This PR's work in based on lw-lin's PR https://github.com/apache/spark/pull/14298, with improvements like supporting partial aggregation, fixing out of memory issue. Author: Sean Zhong Closes #14868 from clockfly/appro_percentile_try_2. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../aggregate/ApproximatePercentile.scala | 321 +++++++++++++++++++ .../aggregate/ApproximatePercentileSuite.scala | 339 +++++++++++++++++++++ 3 files changed, 661 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 35fd800df4..b05f4f61f6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -250,6 +250,7 @@ object FunctionRegistry { expression[Average]("mean"), expression[Min]("min"), expression[Skewness]("skewness"), + expression[ApproximatePercentile]("percentile_approx"), expression[StddevSamp]("std"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala new file mode 100644 index 0000000000..f91ff87fc1 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -0,0 +1,321 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.nio.ByteBuffer + +import com.google.common.primitives.{Doubles, Ints, Longs} + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.{InternalRow} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest} +import org.apache.spark.sql.catalyst.util.{ArrayData, GenericArrayData} +import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.QuantileSummaries.{defaultCompressThreshold, Stats} +import org.apache.spark.sql.types._ + +/** + * The ApproximatePercentile function returns the approximate percentile(s) of a column at the given + * percentage(s). A percentile is a watermark value below which a given percentage of the column + * values fall. For example, the percentile of column `col` at percentage 50% is the median of + * column `col`. + * + * This function supports partial aggregation. + * + * @param child child expression that can produce column value with `child.eval(inputRow)` + * @param percentageExpression Expression that represents a single percentage value or + * an array of percentage values. Each percentage value must be between + * 0.0 and 1.0. + * @param accuracyExpression Integer literal expression of approximation accuracy. Higher value + * yields better accuracy, the default value is + * DEFAULT_PERCENTILE_ACCURACY. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(col, percentage [, accuracy]) - Returns the approximate percentile value of numeric + column `col` at the given percentage. The value of percentage must be between 0.0 + and 1.0. The `accuracy` parameter (default: 10000) is a positive integer literal which + controls approximation accuracy at the cost of memory. Higher value of `accuracy` yields + better accuracy, `1.0/accuracy` is the relative error of the approximation. + + _FUNC_(col, array(percentage1 [, percentage2]...) [, accuracy]) - Returns the approximate + percentile array of column `col` at the given percentage array. Each value of the + percentage array must be between 0.0 and 1.0. The `accuracy` parameter (default: 10000) is + a positive integer literal which controls approximation accuracy at the cost of memory. + Higher value of `accuracy` yields better accuracy, `1.0/accuracy` is the relative error of + the approximation. + """) +case class ApproximatePercentile( + child: Expression, + percentageExpression: Expression, + accuracyExpression: Expression, + override val mutableAggBufferOffset: Int, + override val inputAggBufferOffset: Int) extends TypedImperativeAggregate[PercentileDigest] { + + def this(child: Expression, percentageExpression: Expression, accuracyExpression: Expression) = { + this(child, percentageExpression, accuracyExpression, 0, 0) + } + + def this(child: Expression, percentageExpression: Expression) = { + this(child, percentageExpression, Literal(ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY)) + } + + // Mark as lazy so that accuracyExpression is not evaluated during tree transformation. + private lazy val accuracy: Int = accuracyExpression.eval().asInstanceOf[Int] + + override def inputTypes: Seq[AbstractDataType] = { + Seq(DoubleType, TypeCollection(DoubleType, ArrayType), IntegerType) + } + + // Mark as lazy so that percentageExpression is not evaluated during tree transformation. + private lazy val (returnPercentileArray: Boolean, percentages: Array[Double]) = { + (percentageExpression.dataType, percentageExpression.eval()) match { + // Rule ImplicitTypeCasts can cast other numeric types to double + case (_, num: Double) => (false, Array(num)) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + val numericArray = arrayData.toObjectArray(baseType) + (true, numericArray.map { x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType]) + }) + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentage") + } + } + + override def checkInputDataTypes(): TypeCheckResult = { + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!percentageExpression.foldable || !accuracyExpression.foldable) { + TypeCheckFailure(s"The accuracy or percentage provided must be a constant literal") + } else if (accuracy <= 0) { + TypeCheckFailure( + s"The accuracy provided must be a positive integer literal (current value = $accuracy)") + } else if (percentages.exists(percentage => percentage < 0.0D || percentage > 1.0D)) { + TypeCheckFailure( + s"All percentage values must be between 0.0 and 1.0 " + + s"(current = ${percentages.mkString(", ")})") + } else { + TypeCheckSuccess + } + } + + override def createAggregationBuffer(): PercentileDigest = { + val relativeError = 1.0D / accuracy + new PercentileDigest(relativeError) + } + + override def update(buffer: PercentileDigest, inputRow: InternalRow): Unit = { + val value = child.eval(inputRow) + // Ignore empty rows, for example: percentile_approx(null) + if (value != null) { + buffer.add(value.asInstanceOf[Double]) + } + } + + override def merge(buffer: PercentileDigest, other: PercentileDigest): Unit = { + buffer.merge(other) + } + + override def eval(buffer: PercentileDigest): Any = { + val result = buffer.getPercentiles(percentages) + if (result.length == 0) { + null + } else if (returnPercentileArray) { + new GenericArrayData(result) + } else { + result(0) + } + } + + override def withNewMutableAggBufferOffset(newOffset: Int): ApproximatePercentile = + copy(mutableAggBufferOffset = newOffset) + + override def withNewInputAggBufferOffset(newOffset: Int): ApproximatePercentile = + copy(inputAggBufferOffset = newOffset) + + override def children: Seq[Expression] = Seq(child, percentageExpression, accuracyExpression) + + // Returns null for empty inputs + override def nullable: Boolean = true + + override def dataType: DataType = { + if (returnPercentileArray) ArrayType(DoubleType) else DoubleType + } + + override def prettyName: String = "percentile_approx" + + override def serialize(obj: PercentileDigest): Array[Byte] = { + ApproximatePercentile.serializer.serialize(obj) + } + + override def deserialize(bytes: Array[Byte]): PercentileDigest = { + ApproximatePercentile.serializer.deserialize(bytes) + } +} + +object ApproximatePercentile { + + // Default accuracy of Percentile approximation. Larger value means better accuracy. + // The default relative error can be deduced by defaultError = 1.0 / DEFAULT_PERCENTILE_ACCURACY + val DEFAULT_PERCENTILE_ACCURACY: Int = 10000 + + /** + * PercentileDigest is a probabilistic data structure used for approximating percentiles + * with limited memory. PercentileDigest is backed by [[QuantileSummaries]]. + * + * @param summaries underlying probabilistic data structure [[QuantileSummaries]]. + * @param isCompressed An internal flag from class [[QuantileSummaries]] to indicate whether the + * underlying quantileSummaries is compressed. + */ + class PercentileDigest( + private var summaries: QuantileSummaries, + private var isCompressed: Boolean) { + + // Trigger compression if the QuantileSummaries's buffer length exceeds + // compressThresHoldBufferLength. The buffer length can be get by + // quantileSummaries.sampled.length + private[this] final val compressThresHoldBufferLength: Int = { + // Max buffer length after compression. + val maxBufferLengthAfterCompression: Int = (1 / summaries.relativeError).toInt * 2 + // A safe upper bound for buffer length before compression + maxBufferLengthAfterCompression * 2 + } + + def this(relativeError: Double) = { + this(new QuantileSummaries(defaultCompressThreshold, relativeError), isCompressed = true) + } + + /** Returns compressed object of [[QuantileSummaries]] */ + def quantileSummaries: QuantileSummaries = { + if (!isCompressed) compress() + summaries + } + + /** Insert an observation value into the PercentileDigest data structure. */ + def add(value: Double): Unit = { + summaries = summaries.insert(value) + // The result of QuantileSummaries.insert is un-compressed + isCompressed = false + + // Currently, QuantileSummaries ignores the construction parameter compressThresHold, + // which may cause QuantileSummaries to occupy unbounded memory. We have to hack around here + // to make sure QuantileSummaries doesn't occupy infinite memory. + // TODO: Figure out why QuantileSummaries ignores construction parameter compressThresHold + if (summaries.sampled.length >= compressThresHoldBufferLength) compress() + } + + /** In-place merges in another PercentileDigest. */ + def merge(other: PercentileDigest): Unit = { + if (!isCompressed) compress() + summaries = summaries.merge(other.quantileSummaries) + } + + /** + * Returns the approximate percentiles of all observation values at the given percentages. + * A percentile is a watermark value below which a given percentage of observation values fall. + * For example, the following code returns the 25th, median, and 75th percentiles of + * all observation values: + * + * {{{ + * val Array(p25, median, p75) = percentileDigest.getPercentiles(Array(0.25, 0.5, 0.75)) + * }}} + */ + def getPercentiles(percentages: Array[Double]): Array[Double] = { + if (!isCompressed) compress() + if (summaries.count == 0 || percentages.length == 0) { + Array.empty[Double] + } else { + val result = new Array[Double](percentages.length) + var i = 0 + while (i < percentages.length) { + result(i) = summaries.query(percentages(i)) + i += 1 + } + result + } + } + + private final def compress(): Unit = { + summaries = summaries.compress() + isCompressed = true + } + } + + /** + * Serializer for class [[PercentileDigest]] + * + * This class is thread safe. + */ + class PercentileDigestSerializer { + + private final def length(summaries: QuantileSummaries): Int = { + // summaries.compressThreshold, summary.relativeError, summary.count + Ints.BYTES + Doubles.BYTES + Longs.BYTES + + // length of summary.sampled + Ints.BYTES + + // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] + summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + } + + final def serialize(obj: PercentileDigest): Array[Byte] = { + val summary = obj.quantileSummaries + val buffer = ByteBuffer.wrap(new Array(length(summary))) + buffer.putInt(summary.compressThreshold) + buffer.putDouble(summary.relativeError) + buffer.putLong(summary.count) + buffer.putInt(summary.sampled.length) + + var i = 0 + while (i < summary.sampled.length) { + val stat = summary.sampled(i) + buffer.putDouble(stat.value) + buffer.putInt(stat.g) + buffer.putInt(stat.delta) + i += 1 + } + buffer.array() + } + + final def deserialize(bytes: Array[Byte]): PercentileDigest = { + val buffer = ByteBuffer.wrap(bytes) + val compressThreshold = buffer.getInt() + val relativeError = buffer.getDouble() + val count = buffer.getLong() + val sampledLength = buffer.getInt() + val sampled = new Array[Stats](sampledLength) + + var i = 0 + while (i < sampledLength) { + val value = buffer.getDouble() + val g = buffer.getInt() + val delta = buffer.getInt() + sampled(i) = Stats(value, g, delta) + i += 1 + } + val summary = new QuantileSummaries(compressThreshold, relativeError, sampled, count) + new PercentileDigest(summary, isCompressed = true) + } + } + + val serializer: PercentileDigestSerializer = new PercentileDigestSerializer +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala new file mode 100644 index 0000000000..61298a1b72 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentileSuite.scala @@ -0,0 +1,339 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, BoundReference, Cast, CreateArray, DecimalLiteral, GenericMutableRow, Literal} +import org.apache.spark.sql.catalyst.expressions.aggregate.ApproximatePercentile.{PercentileDigest, PercentileDigestSerializer} +import org.apache.spark.sql.catalyst.plans.logical.LocalRelation +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.catalyst.util.QuantileSummaries +import org.apache.spark.sql.catalyst.util.QuantileSummaries.Stats +import org.apache.spark.sql.types.{ArrayType, DoubleType, IntegerType} +import org.apache.spark.util.SizeEstimator + +class ApproximatePercentileSuite extends SparkFunSuite { + + private val random = new java.util.Random() + + private val data = (0 until 10000).map { _ => + random.nextInt(10000) + } + + test("serialize and de-serialize") { + val serializer = new PercentileDigestSerializer + + // Check empty serialize and de-serialize + val emptyBuffer = new PercentileDigest(relativeError = 0.01) + assert(compareEquals(emptyBuffer, serializer.deserialize(serializer.serialize(emptyBuffer)))) + + val buffer = new PercentileDigest(relativeError = 0.01) + data.foreach { value => + buffer.add(value) + } + assert(compareEquals(buffer, serializer.deserialize(serializer.serialize(buffer)))) + + val agg = new ApproximatePercentile(BoundReference(0, DoubleType, true), Literal(0.5)) + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) + } + + test("class PercentileDigest, basic operations") { + val valueCount = 10000 + val percentages = Array(0.25, 0.5, 0.75) + Seq(0.0001, 0.001, 0.01, 0.1).foreach { relativeError => + val buffer = new PercentileDigest(relativeError) + (1 to valueCount).grouped(10).foreach { group => + val partialBuffer = new PercentileDigest(relativeError) + group.foreach(x => partialBuffer.add(x)) + buffer.merge(partialBuffer) + } + val expectedPercentiles = percentages.map(_ * valueCount) + val approxPercentiles = buffer.getPercentiles(Array(0.25, 0.5, 0.75)) + expectedPercentiles.zip(approxPercentiles).foreach { pair => + val (expected, estimate) = pair + assert((estimate - expected) / valueCount <= relativeError) + } + } + } + + test("class PercentileDigest, makes sure the memory foot print is bounded") { + val relativeError = 0.01 + val memoryFootPrintUpperBound = { + val headBufferSize = + SizeEstimator.estimate(new Array[Double](QuantileSummaries.defaultHeadSize)) + val bufferSize = SizeEstimator.estimate(new Stats(0, 0, 0)) * (1 / relativeError) * 2 + // A safe upper bound + (headBufferSize + bufferSize) * 2 + } + + val sizePerInputs = Seq(100, 1000, 10000, 100000, 1000000, 10000000).map { count => + val buffer = new PercentileDigest(relativeError) + // Worst case, data is linear sorted + (0 until count).foreach(buffer.add(_)) + assert(SizeEstimator.estimate(buffer) < memoryFootPrintUpperBound) + } + } + + test("class ApproximatePercentile, high level interface, update, merge, eval...") { + val count = 10000 + val data = (1 until 10000).toSeq + val percentages = Array(0.25D, 0.5D, 0.75D) + val accuracy = 10000 + val expectedPercentiles = percentages.map(count * _) + val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) + val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) + val accuracyExpression = Literal(10000) + val agg = new ApproximatePercentile(childExpression, percentageExpression, accuracyExpression) + + assert(agg.nullable) + val group1 = (0 until data.length / 2) + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => + val input = InternalRow(data(index)) + agg.update(group1Buffer, input) + } + + val group2 = (data.length / 2 until data.length) + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(data(index)) + agg.update(group2Buffer, input) + } + + val mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + + agg.eval(mergeBuffer) match { + case arrayData: ArrayData => + val error = count / accuracy + val percentiles = arrayData.toDoubleArray() + assert(percentiles.zip(expectedPercentiles) + .forall(pair => Math.abs(pair._1 - pair._2) < error)) + } + } + + test("class ApproximatePercentile, low level interface, update, merge, eval...") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val inputAggregationBufferOffset = 1 + val mutableAggregationBufferOffset = 2 + val percentage = 0.5D + + // Phase one, partial mode aggregation + val agg = new ApproximatePercentile(childExpression, Literal(percentage)) + .withNewInputAggBufferOffset(inputAggregationBufferOffset) + .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) + + val mutableAggBuffer = new GenericMutableRow(new Array[Any](mutableAggregationBufferOffset + 1)) + agg.initialize(mutableAggBuffer) + val dataCount = 10 + (1 to dataCount).foreach { data => + agg.update(mutableAggBuffer, InternalRow(data)) + } + agg.serializeAggregateBufferInPlace(mutableAggBuffer) + + // Serialize the aggregation buffer + val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) + val inputAggBuffer = new GenericMutableRow(Array[Any](null, serialized)) + + // Phase 2: final mode aggregation + // Re-initialize the aggregation buffer + agg.initialize(mutableAggBuffer) + agg.merge(mutableAggBuffer, inputAggBuffer) + val expectedPercentile = dataCount * percentage + assert(Math.abs(agg.eval(mutableAggBuffer).asInstanceOf[Double] - expectedPercentile) < 0.1) + } + + test("class ApproximatePercentile, sql string") { + val defaultAccuracy = ApproximatePercentile.DEFAULT_PERCENTILE_ACCURACY + // sql, single percentile + assertEqual( + s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)).sql: String) + + // sql, array of percentile + assertEqual( + s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", + new ApproximatePercentile( + "a".attr, + percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) + ).sql: String) + + // sql(isDistinct = false), single percentile + assertEqual( + s"percentile_approx(`a`, 0.5D, $defaultAccuracy)", + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + .sql(isDistinct = false)) + + // sql(isDistinct = false), array of percentile + assertEqual( + s"percentile_approx(`a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", + new ApproximatePercentile( + "a".attr, + percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) + ).sql(isDistinct = false)) + + // sql(isDistinct = true), single percentile + assertEqual( + s"percentile_approx(DISTINCT `a`, 0.5D, $defaultAccuracy)", + new ApproximatePercentile("a".attr, percentageExpression = Literal(0.5D)) + .sql(isDistinct = true)) + + // sql(isDistinct = true), array of percentile + assertEqual( + s"percentile_approx(DISTINCT `a`, array(0.25D, 0.5D, 0.75D), $defaultAccuracy)", + new ApproximatePercentile( + "a".attr, + percentageExpression = CreateArray(Seq(0.25D, 0.5D, 0.75D).map(Literal(_))) + ).sql(isDistinct = true)) + } + + test("class ApproximatePercentile, fails analysis if percentage or accuracy is not a constant") { + val attribute = AttributeReference("a", DoubleType)() + val wrongAccuracy = new ApproximatePercentile( + attribute, + percentageExpression = Literal(0.5D), + accuracyExpression = AttributeReference("b", IntegerType)()) + + assertEqual( + wrongAccuracy.checkInputDataTypes(), + TypeCheckFailure("The accuracy or percentage provided must be a constant literal") + ) + + val wrongPercentage = new ApproximatePercentile( + attribute, + percentageExpression = attribute, + accuracyExpression = Literal(10000)) + + assertEqual( + wrongPercentage.checkInputDataTypes(), + TypeCheckFailure("The accuracy or percentage provided must be a constant literal") + ) + } + + test("class ApproximatePercentile, fails analysis if parameters are invalid") { + val wrongAccuracy = new ApproximatePercentile( + AttributeReference("a", DoubleType)(), + percentageExpression = Literal(0.5D), + accuracyExpression = Literal(-1)) + assertEqual( + wrongAccuracy.checkInputDataTypes(), + TypeCheckFailure( + "The accuracy provided must be a positive integer literal (current value = -1)")) + + val correctPercentageExpresions = Seq( + Literal(0D), + Literal(1D), + Literal(0.5D), + CreateArray(Seq(0D, 1D, 0.5D).map(Literal(_))) + ) + correctPercentageExpresions.foreach { percentageExpression => + val correctPercentage = new ApproximatePercentile( + AttributeReference("a", DoubleType)(), + percentageExpression = percentageExpression, + accuracyExpression = Literal(100)) + + // no exception should be thrown + correctPercentage.checkInputDataTypes() + } + + val wrongPercentageExpressions = Seq( + Literal(1.1D), + Literal(-0.5D), + CreateArray(Seq(0D, 0.5D, 1.1D).map(Literal(_))) + ) + + wrongPercentageExpressions.foreach { percentageExpression => + val wrongPercentage = new ApproximatePercentile( + AttributeReference("a", DoubleType)(), + percentageExpression = percentageExpression, + accuracyExpression = Literal(100)) + + val result = wrongPercentage.checkInputDataTypes() + assert( + wrongPercentage.checkInputDataTypes() match { + case TypeCheckFailure(msg) if msg.contains("must be between 0.0 and 1.0") => true + case _ => false + }) + } + } + + test("class ApproximatePercentile, automatically add type casting for parameters") { + val testRelation = LocalRelation('a.int) + val analyzer = SimpleAnalyzer + + // Compatible accuracy types: Long type and decimal type + val accuracyExpressions = Seq(Literal(1000L), DecimalLiteral(10000), Literal(123.0D)) + // Compatible percentage types: float, decimal + val percentageExpressions = Seq(Literal(0.3f), DecimalLiteral(0.5), + CreateArray(Seq(Literal(0.3f), Literal(0.5D), DecimalLiteral(0.7)))) + + accuracyExpressions.foreach { accuracyExpression => + percentageExpressions.foreach { percentageExpression => + val agg = new ApproximatePercentile( + UnresolvedAttribute("a"), + percentageExpression, + accuracyExpression) + val analyzed = testRelation.select(agg).analyze.expressions.head + analyzed match { + case Alias(agg: ApproximatePercentile, _) => + assert(agg.resolved) + assert(agg.child.dataType == DoubleType) + assert(agg.percentageExpression.dataType == DoubleType || + agg.percentageExpression.dataType == ArrayType(DoubleType, containsNull = false)) + assert(agg.accuracyExpression.dataType == IntegerType) + case _ => fail() + } + } + } + } + + test("class ApproximatePercentile, null handling") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val agg = new ApproximatePercentile(childExpression, Literal(0.5D)) + val buffer = new GenericMutableRow(new Array[Any](1)) + agg.initialize(buffer) + // Empty aggregation buffer + assert(agg.eval(buffer) == null) + // Empty input row + agg.update(buffer, InternalRow(null)) + assert(agg.eval(buffer) == null) + + // Add some non-empty row + agg.update(buffer, InternalRow(0)) + assert(agg.eval(buffer) != null) + } + + private def compareEquals(left: PercentileDigest, right: PercentileDigest): Boolean = { + val leftSummary = left.quantileSummaries + val rightSummary = right.quantileSummaries + leftSummary.compressThreshold == rightSummary.compressThreshold && + leftSummary.relativeError == rightSummary.relativeError && + leftSummary.count == rightSummary.count && + leftSummary.sampled.sameElements(rightSummary.sampled) + } + + private def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } +} -- cgit v1.2.3