diff options
author | jiangxingbo <jiangxb1987@gmail.com> | 2016-11-28 11:05:58 -0800 |
---|---|---|
committer | Herman van Hovell <hvanhovell@databricks.com> | 2016-11-28 11:05:58 -0800 |
commit | 0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3 (patch) | |
tree | 3324b94c8e4275c2bd2ac0bab7804835c7ccab59 | |
parent | 185642846e25fa812f9c7f398ab20bffc1e25273 (diff) | |
download | spark-0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3.tar.gz spark-0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3.tar.bz2 spark-0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3.zip |
[SPARK-16282][SQL] Implement percentile SQL function.
## What changes were proposed in this pull request?
Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1].
## How was this patch tested?
Add a new testsuite `PercentileSuite` to test percentile directly.
Updated related testcases in `ExpressionToSQLSuite`.
Author: jiangxingbo <jiangxb1987@gmail.com>
Author: 蒋星博 <jiangxingbo@meituan.com>
Author: jiangxingbo <jiangxingbo@meituan.com>
Closes #14136 from jiangxb1987/percentile.
5 files changed, 518 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 007cdc1ccb..2636afe620 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -249,6 +249,7 @@ object FunctionRegistry { expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), + expression[Percentile]("percentile"), expression[Skewness]("skewness"), expression[ApproximatePercentile]("percentile_approx"), expression[StddevSamp]("std"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala new file mode 100644 index 0000000000..356e088d1d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala @@ -0,0 +1,269 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream} +import java.util + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashMap + +/** + * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at + * the given percentage(s) with value range in [0.0, 1.0]. + * + * The operator is bound to the slower sort based aggregation path because the number of elements + * and their partial order cannot be determined in advance. Therefore we have to store all the + * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory + * Errors. + * + * @param child child expression that produce numeric column value with `child.eval(inputRow)` + * @param percentageExpression Expression that represents a single percentage value or an array of + * percentage values. Each percentage value must be in the range + * [0.0, 1.0]. + */ +@ExpressionDescription( + usage = + """ + _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the + given percentage. The value of percentage must be between 0.0 and 1.0. + + _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array + of numeric column `col` at the given percentage(s). Each value of the percentage array must + be between 0.0 and 1.0. + """) +case class Percentile( + child: Expression, + percentageExpression: Expression, + mutableAggBufferOffset: Int = 0, + inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] { + + def this(child: Expression, percentageExpression: Expression) = { + this(child, percentageExpression, 0, 0) + } + + override def prettyName: String = "percentile" + + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile = + copy(mutableAggBufferOffset = newMutableAggBufferOffset) + + override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile = + copy(inputAggBufferOffset = newInputAggBufferOffset) + + // Mark as lazy so that percentageExpression is not evaluated during tree transformation. + @transient + private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType] + + @transient + private lazy val percentages = + (percentageExpression.dataType, percentageExpression.eval()) match { + case (_, num: Double) => Seq(num) + case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) => + val numericArray = arrayData.toObjectArray(baseType) + numericArray.map { x => + baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq + case other => + throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages") + } + + override def children: Seq[Expression] = child :: percentageExpression :: Nil + + // Returns null for empty inputs + override def nullable: Boolean = true + + override lazy val dataType: DataType = percentageExpression.dataType match { + case _: ArrayType => ArrayType(DoubleType, false) + case _ => DoubleType + } + + override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match { + case _: ArrayType => Seq(NumericType, ArrayType) + case _ => Seq(NumericType, DoubleType) + } + + // Check the inputTypes are valid, and the percentageExpression satisfies: + // 1. percentageExpression must be foldable; + // 2. percentages(s) must be in the range [0.0, 1.0]. + override def checkInputDataTypes(): TypeCheckResult = { + // Validate the inputTypes + val defaultCheck = super.checkInputDataTypes() + if (defaultCheck.isFailure) { + defaultCheck + } else if (!percentageExpression.foldable) { + // percentageExpression must be foldable + TypeCheckFailure("The percentage(s) must be a constant literal, " + + s"but got $percentageExpression") + } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) { + // percentages(s) must be in the range [0.0, 1.0] + TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " + + s"but got $percentageExpression") + } else { + TypeCheckSuccess + } + } + + override def createAggregationBuffer(): OpenHashMap[Number, Long] = { + // Initialize new counts map instance here. + new OpenHashMap[Number, Long]() + } + + override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = { + val key = child.eval(input).asInstanceOf[Number] + + // Null values are ignored in counts map. + if (key != null) { + buffer.changeValue(key, 1L, _ + 1L) + } + } + + override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = { + other.foreach { case (key, count) => + buffer.changeValue(key, count, _ + count) + } + } + + override def eval(buffer: OpenHashMap[Number, Long]): Any = { + generateOutput(getPercentiles(buffer)) + } + + private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = { + if (buffer.isEmpty) { + return Seq.empty + } + + val sortedCounts = buffer.toSeq.sortBy(_._1)( + child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]]) + val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) { + case ((key1, count1), (key2, count2)) => (key2, count1 + count2) + }.tail + val maxPosition = accumlatedCounts.last._2 - 1 + + percentages.map { percentile => + getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue() + } + } + + private def generateOutput(results: Seq[Double]): Any = { + if (results.isEmpty) { + null + } else if (returnPercentileArray) { + new GenericArrayData(results) + } else { + results.head + } + } + + /** + * Get the percentile value. + * + * This function has been based upon similar function from HIVE + * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`. + */ + private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = { + // We may need to do linear interpolation to get the exact percentile + val lower = position.floor.toLong + val higher = position.ceil.toLong + + // Use binary search to find the lower and the higher position. + val countsArray = aggreCounts.map(_._2).toArray[Long] + val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1) + val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1) + + val lowerKey = aggreCounts(lowerIndex)._1 + if (higher == lower) { + // no interpolation needed because position does not have a fraction + return lowerKey + } + + val higherKey = aggreCounts(higherIndex)._1 + if (higherKey == lowerKey) { + // no interpolation needed because lower position and higher position has the same key + return lowerKey + } + + // Linear interpolation to get the exact percentile + return (higher - position) * lowerKey.doubleValue() + + (position - lower) * higherKey.doubleValue() + } + + /** + * use a binary search to find the index of the position closest to the current value. + */ + private def binarySearchCount( + countsArray: Array[Long], start: Int, end: Int, value: Long): Int = { + util.Arrays.binarySearch(countsArray, 0, end, value) match { + case ix if ix < 0 => -(ix + 1) + case ix => ix + } + } + + override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = { + val buffer = new Array[Byte](4 << 10) // 4K + val bos = new ByteArrayOutputStream() + val out = new DataOutputStream(bos) + try { + val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType)) + // Write pairs in counts map to byte buffer. + obj.foreach { case (key, count) => + val row = InternalRow.apply(key, count) + val unsafeRow = projection.apply(row) + out.writeInt(unsafeRow.getSizeInBytes) + unsafeRow.writeToStream(out, buffer) + } + out.writeInt(-1) + out.flush() + + bos.toByteArray + } finally { + out.close() + bos.close() + } + } + + override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = { + val bis = new ByteArrayInputStream(bytes) + val ins = new DataInputStream(bis) + try { + val counts = new OpenHashMap[Number, Long] + // Read unsafeRow size and content in bytes. + var sizeOfNextRow = ins.readInt() + while (sizeOfNextRow >= 0) { + val bs = new Array[Byte](sizeOfNextRow) + ins.readFully(bs) + val row = new UnsafeRow(2) + row.pointTo(bs, sizeOfNextRow) + // Insert the pairs into counts map. + val key = row.get(0, child.dataType).asInstanceOf[Number] + val count = row.get(1, LongType).asInstanceOf[Long] + counts.update(key, count) + sizeOfNextRow = ins.readInt() + } + + counts + } finally { + ins.close() + bis.close() + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala new file mode 100644 index 0000000000..f060ecc184 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala @@ -0,0 +1,245 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.aggregate + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.types._ +import org.apache.spark.util.collection.OpenHashMap + +class PercentileSuite extends SparkFunSuite { + + private val random = new java.util.Random() + + private val data = (0 until 10000).map { _ => + random.nextInt(10000) + } + + test("serialize and de-serialize") { + val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5)) + + // Check empty serialize and deserialize + val buffer = new OpenHashMap[Number, Long]() + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) + + // Check non-empty buffer serializa and deserialize. + data.foreach { key => + buffer.changeValue(key, 1L, _ + 1L) + } + assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer)) + } + + test("class Percentile, high level interface, update, merge, eval...") { + val count = 10000 + val data = (1 to count) + val percentages = Seq(0, 0.25, 0.5, 0.75, 1) + val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000) + val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) + val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_))) + val agg = new Percentile(childExpression, percentageExpression) + + assert(agg.nullable) + val group1 = (0 until data.length / 2) + val group1Buffer = agg.createAggregationBuffer() + group1.foreach { index => + val input = InternalRow(data(index)) + agg.update(group1Buffer, input) + } + + val group2 = (data.length / 2 until data.length) + val group2Buffer = agg.createAggregationBuffer() + group2.foreach { index => + val input = InternalRow(data(index)) + agg.update(group2Buffer, input) + } + + val mergeBuffer = agg.createAggregationBuffer() + agg.merge(mergeBuffer, group1Buffer) + agg.merge(mergeBuffer, group2Buffer) + + agg.eval(mergeBuffer) match { + case arrayData: ArrayData => + val percentiles = arrayData.toDoubleArray() + assert(percentiles.zip(expectedPercentiles) + .forall(pair => pair._1 == pair._2)) + } + } + + test("class Percentile, low level interface, update, merge, eval...") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val inputAggregationBufferOffset = 1 + val mutableAggregationBufferOffset = 2 + val percentage = 0.5 + + // Phase one, partial mode aggregation + val agg = new Percentile(childExpression, Literal(percentage)) + .withNewInputAggBufferOffset(inputAggregationBufferOffset) + .withNewMutableAggBufferOffset(mutableAggregationBufferOffset) + + val mutableAggBuffer = new GenericInternalRow( + new Array[Any](mutableAggregationBufferOffset + 1)) + agg.initialize(mutableAggBuffer) + val dataCount = 10 + (1 to dataCount).foreach { data => + agg.update(mutableAggBuffer, InternalRow(data)) + } + agg.serializeAggregateBufferInPlace(mutableAggBuffer) + + // Serialize the aggregation buffer + val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset) + val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized)) + + // Phase 2: final mode aggregation + // Re-initialize the aggregation buffer + agg.initialize(mutableAggBuffer) + agg.merge(mutableAggBuffer, inputAggBuffer) + val expectedPercentile = 5.5 + assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile) + } + + test("call from sql query") { + // sql, single percentile + assertEqual( + s"percentile(`a`, 0.5D)", + new Percentile("a".attr, Literal(0.5)).sql: String) + + // sql, array of percentile + assertEqual( + s"percentile(`a`, array(0.25D, 0.5D, 0.75D))", + new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String) + + // sql(isDistinct = false), single percentile + assertEqual( + s"percentile(`a`, 0.5D)", + new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false)) + + // sql(isDistinct = false), array of percentile + assertEqual( + s"percentile(`a`, array(0.25D, 0.5D, 0.75D))", + new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))) + .sql(isDistinct = false)) + + // sql(isDistinct = true), single percentile + assertEqual( + s"percentile(DISTINCT `a`, 0.5D)", + new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true)) + + // sql(isDistinct = true), array of percentile + assertEqual( + s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))", + new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))) + .sql(isDistinct = true)) + } + + test("fail analysis if childExpression is invalid") { + val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType) + val percentage = Literal(0.5) + + validDataTypes.foreach { dataType => + val child = AttributeReference("a", dataType)() + val percentile = new Percentile(child, percentage) + assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess) + } + + val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType, + CalendarIntervalType, NullType) + + invalidDataTypes.foreach { dataType => + val child = AttributeReference("a", dataType)() + val percentile = new Percentile(child, percentage) + assertEqual(percentile.checkInputDataTypes(), + TypeCheckFailure(s"argument 1 requires numeric type, however, " + + s"'`a`' is of ${dataType.simpleString} type.")) + } + } + + test("fails analysis if percentage(s) are invalid") { + val child = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType) + val input = InternalRow(1) + + val validPercentages = Seq(Literal(0D), Literal(0.5), Literal(1D), + CreateArray(Seq(0, 0.5, 1).map(Literal(_)))) + + validPercentages.foreach { percentage => + val percentile1 = new Percentile(child, percentage) + assertEqual(percentile1.checkInputDataTypes(), TypeCheckSuccess) + } + + val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2D), + CreateArray(Seq(-0.5, 0, 2).map(Literal(_)))) + + invalidPercentages.foreach { percentage => + val percentile2 = new Percentile(child, percentage) + assertEqual(percentile2.checkInputDataTypes(), + TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " + + s"but got ${percentage.simpleString}")) + } + + val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5), + CreateArray(Seq(0, 0.5, 1).map(NonFoldableLiteral(_)))) + + nonFoldablePercentage.foreach { percentage => + val percentile3 = new Percentile(child, percentage) + assertEqual(percentile3.checkInputDataTypes(), + TypeCheckFailure(s"The percentage(s) must be a constant literal, " + + s"but got ${percentage}")) + } + + val invalidDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, + BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType) + + invalidDataTypes.foreach { dataType => + val percentage = Literal(0.5, dataType) + val percentile4 = new Percentile(child, percentage) + assertEqual(percentile4.checkInputDataTypes(), + TypeCheckFailure(s"argument 2 requires double type, however, " + + s"'0.5' is of ${dataType.simpleString} type.")) + } + } + + test("null handling") { + val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType) + val agg = new Percentile(childExpression, Literal(0.5)) + val buffer = new GenericInternalRow(new Array[Any](1)) + agg.initialize(buffer) + // Empty aggregation buffer + assert(agg.eval(buffer) == null) + // Empty input row + agg.update(buffer, InternalRow(null)) + assert(agg.eval(buffer) == null) + + // Add some non-empty row + agg.update(buffer, InternalRow(0)) + assert(agg.eval(buffer) != null) + } + + private def compareEquals( + left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = { + left.size == right.size && left.forall { case (key, count) => + right.apply(key) == count + } + } + + private def assertEqual[T](left: T, right: T): Unit = { + assert(left == right) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 4a9b28a455..08bf1cd0ef 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -234,7 +234,6 @@ private[sql] class HiveSessionCatalog( // noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction. // Note: don't forget to update SessionCatalog.isTemporaryFunction private val hiveFunctions = Seq( - "histogram_numeric", - "percentile" + "histogram_numeric" ) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala index fdd02821df..27ea167b90 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala @@ -173,6 +173,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile(value, 0.25) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT percentile(value, array(0.25, 0.75)) FROM t1 GROUP BY key") checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key") checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key") |