aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorjiangxingbo <jiangxb1987@gmail.com>2016-11-28 11:05:58 -0800
committerHerman van Hovell <hvanhovell@databricks.com>2016-11-28 11:05:58 -0800
commit0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3 (patch)
tree3324b94c8e4275c2bd2ac0bab7804835c7ccab59
parent185642846e25fa812f9c7f398ab20bffc1e25273 (diff)
downloadspark-0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3.tar.gz
spark-0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3.tar.bz2
spark-0f5f52a3d1e5dcf5b970c49e324e322b9deb00f3.zip
[SPARK-16282][SQL] Implement percentile SQL function.
## What changes were proposed in this pull request? Implement percentile SQL function. It computes the exact percentile(s) of expr at pc with range in [0, 1]. ## How was this patch tested? Add a new testsuite `PercentileSuite` to test percentile directly. Updated related testcases in `ExpressionToSQLSuite`. Author: jiangxingbo <jiangxb1987@gmail.com> Author: 蒋星博 <jiangxingbo@meituan.com> Author: jiangxingbo <jiangxingbo@meituan.com> Closes #14136 from jiangxb1987/percentile.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala269
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala245
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala2
5 files changed, 518 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 007cdc1ccb..2636afe620 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -249,6 +249,7 @@ object FunctionRegistry {
expression[Max]("max"),
expression[Average]("mean"),
expression[Min]("min"),
+ expression[Percentile]("percentile"),
expression[Skewness]("skewness"),
expression[ApproximatePercentile]("percentile_approx"),
expression[StddevSamp]("std"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
new file mode 100644
index 0000000000..356e088d1d
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -0,0 +1,269 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.util
+
+import org.apache.spark.sql.AnalysisException
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
+ * the given percentage(s) with value range in [0.0, 1.0].
+ *
+ * The operator is bound to the slower sort based aggregation path because the number of elements
+ * and their partial order cannot be determined in advance. Therefore we have to store all the
+ * elements in memory, and that too many elements can cause GC paused and eventually OutOfMemory
+ * Errors.
+ *
+ * @param child child expression that produce numeric column value with `child.eval(inputRow)`
+ * @param percentageExpression Expression that represents a single percentage value or an array of
+ * percentage values. Each percentage value must be in the range
+ * [0.0, 1.0].
+ */
+@ExpressionDescription(
+ usage =
+ """
+ _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
+ given percentage. The value of percentage must be between 0.0 and 1.0.
+
+ _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
+ of numeric column `col` at the given percentage(s). Each value of the percentage array must
+ be between 0.0 and 1.0.
+ """)
+case class Percentile(
+ child: Expression,
+ percentageExpression: Expression,
+ mutableAggBufferOffset: Int = 0,
+ inputAggBufferOffset: Int = 0) extends TypedImperativeAggregate[OpenHashMap[Number, Long]] {
+
+ def this(child: Expression, percentageExpression: Expression) = {
+ this(child, percentageExpression, 0, 0)
+ }
+
+ override def prettyName: String = "percentile"
+
+ override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): Percentile =
+ copy(mutableAggBufferOffset = newMutableAggBufferOffset)
+
+ override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): Percentile =
+ copy(inputAggBufferOffset = newInputAggBufferOffset)
+
+ // Mark as lazy so that percentageExpression is not evaluated during tree transformation.
+ @transient
+ private lazy val returnPercentileArray = percentageExpression.dataType.isInstanceOf[ArrayType]
+
+ @transient
+ private lazy val percentages =
+ (percentageExpression.dataType, percentageExpression.eval()) match {
+ case (_, num: Double) => Seq(num)
+ case (ArrayType(baseType: NumericType, _), arrayData: ArrayData) =>
+ val numericArray = arrayData.toObjectArray(baseType)
+ numericArray.map { x =>
+ baseType.numeric.toDouble(x.asInstanceOf[baseType.InternalType])}.toSeq
+ case other =>
+ throw new AnalysisException(s"Invalid data type ${other._1} for parameter percentages")
+ }
+
+ override def children: Seq[Expression] = child :: percentageExpression :: Nil
+
+ // Returns null for empty inputs
+ override def nullable: Boolean = true
+
+ override lazy val dataType: DataType = percentageExpression.dataType match {
+ case _: ArrayType => ArrayType(DoubleType, false)
+ case _ => DoubleType
+ }
+
+ override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
+ case _: ArrayType => Seq(NumericType, ArrayType)
+ case _ => Seq(NumericType, DoubleType)
+ }
+
+ // Check the inputTypes are valid, and the percentageExpression satisfies:
+ // 1. percentageExpression must be foldable;
+ // 2. percentages(s) must be in the range [0.0, 1.0].
+ override def checkInputDataTypes(): TypeCheckResult = {
+ // Validate the inputTypes
+ val defaultCheck = super.checkInputDataTypes()
+ if (defaultCheck.isFailure) {
+ defaultCheck
+ } else if (!percentageExpression.foldable) {
+ // percentageExpression must be foldable
+ TypeCheckFailure("The percentage(s) must be a constant literal, " +
+ s"but got $percentageExpression")
+ } else if (percentages.exists(percentage => percentage < 0.0 || percentage > 1.0)) {
+ // percentages(s) must be in the range [0.0, 1.0]
+ TypeCheckFailure("Percentage(s) must be between 0.0 and 1.0, " +
+ s"but got $percentageExpression")
+ } else {
+ TypeCheckSuccess
+ }
+ }
+
+ override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
+ // Initialize new counts map instance here.
+ new OpenHashMap[Number, Long]()
+ }
+
+ override def update(buffer: OpenHashMap[Number, Long], input: InternalRow): Unit = {
+ val key = child.eval(input).asInstanceOf[Number]
+
+ // Null values are ignored in counts map.
+ if (key != null) {
+ buffer.changeValue(key, 1L, _ + 1L)
+ }
+ }
+
+ override def merge(buffer: OpenHashMap[Number, Long], other: OpenHashMap[Number, Long]): Unit = {
+ other.foreach { case (key, count) =>
+ buffer.changeValue(key, count, _ + count)
+ }
+ }
+
+ override def eval(buffer: OpenHashMap[Number, Long]): Any = {
+ generateOutput(getPercentiles(buffer))
+ }
+
+ private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
+ if (buffer.isEmpty) {
+ return Seq.empty
+ }
+
+ val sortedCounts = buffer.toSeq.sortBy(_._1)(
+ child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
+ val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
+ case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
+ }.tail
+ val maxPosition = accumlatedCounts.last._2 - 1
+
+ percentages.map { percentile =>
+ getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
+ }
+ }
+
+ private def generateOutput(results: Seq[Double]): Any = {
+ if (results.isEmpty) {
+ null
+ } else if (returnPercentileArray) {
+ new GenericArrayData(results)
+ } else {
+ results.head
+ }
+ }
+
+ /**
+ * Get the percentile value.
+ *
+ * This function has been based upon similar function from HIVE
+ * `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
+ */
+ private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
+ // We may need to do linear interpolation to get the exact percentile
+ val lower = position.floor.toLong
+ val higher = position.ceil.toLong
+
+ // Use binary search to find the lower and the higher position.
+ val countsArray = aggreCounts.map(_._2).toArray[Long]
+ val lowerIndex = binarySearchCount(countsArray, 0, aggreCounts.size, lower + 1)
+ val higherIndex = binarySearchCount(countsArray, 0, aggreCounts.size, higher + 1)
+
+ val lowerKey = aggreCounts(lowerIndex)._1
+ if (higher == lower) {
+ // no interpolation needed because position does not have a fraction
+ return lowerKey
+ }
+
+ val higherKey = aggreCounts(higherIndex)._1
+ if (higherKey == lowerKey) {
+ // no interpolation needed because lower position and higher position has the same key
+ return lowerKey
+ }
+
+ // Linear interpolation to get the exact percentile
+ return (higher - position) * lowerKey.doubleValue() +
+ (position - lower) * higherKey.doubleValue()
+ }
+
+ /**
+ * use a binary search to find the index of the position closest to the current value.
+ */
+ private def binarySearchCount(
+ countsArray: Array[Long], start: Int, end: Int, value: Long): Int = {
+ util.Arrays.binarySearch(countsArray, 0, end, value) match {
+ case ix if ix < 0 => -(ix + 1)
+ case ix => ix
+ }
+ }
+
+ override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
+ val buffer = new Array[Byte](4 << 10) // 4K
+ val bos = new ByteArrayOutputStream()
+ val out = new DataOutputStream(bos)
+ try {
+ val projection = UnsafeProjection.create(Array[DataType](child.dataType, LongType))
+ // Write pairs in counts map to byte buffer.
+ obj.foreach { case (key, count) =>
+ val row = InternalRow.apply(key, count)
+ val unsafeRow = projection.apply(row)
+ out.writeInt(unsafeRow.getSizeInBytes)
+ unsafeRow.writeToStream(out, buffer)
+ }
+ out.writeInt(-1)
+ out.flush()
+
+ bos.toByteArray
+ } finally {
+ out.close()
+ bos.close()
+ }
+ }
+
+ override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
+ val bis = new ByteArrayInputStream(bytes)
+ val ins = new DataInputStream(bis)
+ try {
+ val counts = new OpenHashMap[Number, Long]
+ // Read unsafeRow size and content in bytes.
+ var sizeOfNextRow = ins.readInt()
+ while (sizeOfNextRow >= 0) {
+ val bs = new Array[Byte](sizeOfNextRow)
+ ins.readFully(bs)
+ val row = new UnsafeRow(2)
+ row.pointTo(bs, sizeOfNextRow)
+ // Insert the pairs into counts map.
+ val key = row.get(0, child.dataType).asInstanceOf[Number]
+ val count = row.get(1, LongType).asInstanceOf[Long]
+ counts.update(key, count)
+ sizeOfNextRow = ins.readInt()
+ }
+
+ counts
+ } finally {
+ ins.close()
+ bis.close()
+ }
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
new file mode 100644
index 0000000000..f060ecc184
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -0,0 +1,245 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.expressions.aggregate
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.util.ArrayData
+import org.apache.spark.sql.types._
+import org.apache.spark.util.collection.OpenHashMap
+
+class PercentileSuite extends SparkFunSuite {
+
+ private val random = new java.util.Random()
+
+ private val data = (0 until 10000).map { _ =>
+ random.nextInt(10000)
+ }
+
+ test("serialize and de-serialize") {
+ val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
+
+ // Check empty serialize and deserialize
+ val buffer = new OpenHashMap[Number, Long]()
+ assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+
+ // Check non-empty buffer serializa and deserialize.
+ data.foreach { key =>
+ buffer.changeValue(key, 1L, _ + 1L)
+ }
+ assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
+ }
+
+ test("class Percentile, high level interface, update, merge, eval...") {
+ val count = 10000
+ val data = (1 to count)
+ val percentages = Seq(0, 0.25, 0.5, 0.75, 1)
+ val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000)
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
+ val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
+ val agg = new Percentile(childExpression, percentageExpression)
+
+ assert(agg.nullable)
+ val group1 = (0 until data.length / 2)
+ val group1Buffer = agg.createAggregationBuffer()
+ group1.foreach { index =>
+ val input = InternalRow(data(index))
+ agg.update(group1Buffer, input)
+ }
+
+ val group2 = (data.length / 2 until data.length)
+ val group2Buffer = agg.createAggregationBuffer()
+ group2.foreach { index =>
+ val input = InternalRow(data(index))
+ agg.update(group2Buffer, input)
+ }
+
+ val mergeBuffer = agg.createAggregationBuffer()
+ agg.merge(mergeBuffer, group1Buffer)
+ agg.merge(mergeBuffer, group2Buffer)
+
+ agg.eval(mergeBuffer) match {
+ case arrayData: ArrayData =>
+ val percentiles = arrayData.toDoubleArray()
+ assert(percentiles.zip(expectedPercentiles)
+ .forall(pair => pair._1 == pair._2))
+ }
+ }
+
+ test("class Percentile, low level interface, update, merge, eval...") {
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+ val inputAggregationBufferOffset = 1
+ val mutableAggregationBufferOffset = 2
+ val percentage = 0.5
+
+ // Phase one, partial mode aggregation
+ val agg = new Percentile(childExpression, Literal(percentage))
+ .withNewInputAggBufferOffset(inputAggregationBufferOffset)
+ .withNewMutableAggBufferOffset(mutableAggregationBufferOffset)
+
+ val mutableAggBuffer = new GenericInternalRow(
+ new Array[Any](mutableAggregationBufferOffset + 1))
+ agg.initialize(mutableAggBuffer)
+ val dataCount = 10
+ (1 to dataCount).foreach { data =>
+ agg.update(mutableAggBuffer, InternalRow(data))
+ }
+ agg.serializeAggregateBufferInPlace(mutableAggBuffer)
+
+ // Serialize the aggregation buffer
+ val serialized = mutableAggBuffer.getBinary(mutableAggregationBufferOffset)
+ val inputAggBuffer = new GenericInternalRow(Array[Any](null, serialized))
+
+ // Phase 2: final mode aggregation
+ // Re-initialize the aggregation buffer
+ agg.initialize(mutableAggBuffer)
+ agg.merge(mutableAggBuffer, inputAggBuffer)
+ val expectedPercentile = 5.5
+ assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile)
+ }
+
+ test("call from sql query") {
+ // sql, single percentile
+ assertEqual(
+ s"percentile(`a`, 0.5D)",
+ new Percentile("a".attr, Literal(0.5)).sql: String)
+
+ // sql, array of percentile
+ assertEqual(
+ s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
+ new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String)
+
+ // sql(isDistinct = false), single percentile
+ assertEqual(
+ s"percentile(`a`, 0.5D)",
+ new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false))
+
+ // sql(isDistinct = false), array of percentile
+ assertEqual(
+ s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
+ new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
+ .sql(isDistinct = false))
+
+ // sql(isDistinct = true), single percentile
+ assertEqual(
+ s"percentile(DISTINCT `a`, 0.5D)",
+ new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true))
+
+ // sql(isDistinct = true), array of percentile
+ assertEqual(
+ s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))",
+ new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
+ .sql(isDistinct = true))
+ }
+
+ test("fail analysis if childExpression is invalid") {
+ val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
+ val percentage = Literal(0.5)
+
+ validDataTypes.foreach { dataType =>
+ val child = AttributeReference("a", dataType)()
+ val percentile = new Percentile(child, percentage)
+ assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
+ }
+
+ val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType,
+ CalendarIntervalType, NullType)
+
+ invalidDataTypes.foreach { dataType =>
+ val child = AttributeReference("a", dataType)()
+ val percentile = new Percentile(child, percentage)
+ assertEqual(percentile.checkInputDataTypes(),
+ TypeCheckFailure(s"argument 1 requires numeric type, however, " +
+ s"'`a`' is of ${dataType.simpleString} type."))
+ }
+ }
+
+ test("fails analysis if percentage(s) are invalid") {
+ val child = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
+ val input = InternalRow(1)
+
+ val validPercentages = Seq(Literal(0D), Literal(0.5), Literal(1D),
+ CreateArray(Seq(0, 0.5, 1).map(Literal(_))))
+
+ validPercentages.foreach { percentage =>
+ val percentile1 = new Percentile(child, percentage)
+ assertEqual(percentile1.checkInputDataTypes(), TypeCheckSuccess)
+ }
+
+ val invalidPercentages = Seq(Literal(-0.5), Literal(1.5), Literal(2D),
+ CreateArray(Seq(-0.5, 0, 2).map(Literal(_))))
+
+ invalidPercentages.foreach { percentage =>
+ val percentile2 = new Percentile(child, percentage)
+ assertEqual(percentile2.checkInputDataTypes(),
+ TypeCheckFailure(s"Percentage(s) must be between 0.0 and 1.0, " +
+ s"but got ${percentage.simpleString}"))
+ }
+
+ val nonFoldablePercentage = Seq(NonFoldableLiteral(0.5),
+ CreateArray(Seq(0, 0.5, 1).map(NonFoldableLiteral(_))))
+
+ nonFoldablePercentage.foreach { percentage =>
+ val percentile3 = new Percentile(child, percentage)
+ assertEqual(percentile3.checkInputDataTypes(),
+ TypeCheckFailure(s"The percentage(s) must be a constant literal, " +
+ s"but got ${percentage}"))
+ }
+
+ val invalidDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType,
+ BooleanType, StringType, DateType, TimestampType, CalendarIntervalType, NullType)
+
+ invalidDataTypes.foreach { dataType =>
+ val percentage = Literal(0.5, dataType)
+ val percentile4 = new Percentile(child, percentage)
+ assertEqual(percentile4.checkInputDataTypes(),
+ TypeCheckFailure(s"argument 2 requires double type, however, " +
+ s"'0.5' is of ${dataType.simpleString} type."))
+ }
+ }
+
+ test("null handling") {
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+ val agg = new Percentile(childExpression, Literal(0.5))
+ val buffer = new GenericInternalRow(new Array[Any](1))
+ agg.initialize(buffer)
+ // Empty aggregation buffer
+ assert(agg.eval(buffer) == null)
+ // Empty input row
+ agg.update(buffer, InternalRow(null))
+ assert(agg.eval(buffer) == null)
+
+ // Add some non-empty row
+ agg.update(buffer, InternalRow(0))
+ assert(agg.eval(buffer) != null)
+ }
+
+ private def compareEquals(
+ left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
+ left.size == right.size && left.forall { case (key, count) =>
+ right.apply(key) == count
+ }
+ }
+
+ private def assertEqual[T](left: T, right: T): Unit = {
+ assert(left == right)
+ }
+}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
index 4a9b28a455..08bf1cd0ef 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala
@@ -234,7 +234,6 @@ private[sql] class HiveSessionCatalog(
// noopwithmapstreaming, parse_url_tuple, reflect2, windowingtablefunction.
// Note: don't forget to update SessionCatalog.isTemporaryFunction
private val hiveFunctions = Seq(
- "histogram_numeric",
- "percentile"
+ "histogram_numeric"
)
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
index fdd02821df..27ea167b90 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
@@ -173,6 +173,8 @@ class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key")
checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key")
checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT percentile(value, 0.25) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT percentile(value, array(0.25, 0.75)) FROM t1 GROUP BY key")
checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key")
checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key")
checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key")