aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala48
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala5
3 files changed, 45 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 6b7cf7991d..8433a93ea3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util
-import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
import org.apache.spark.sql.catalyst.expressions._
@@ -61,7 +61,7 @@ case class Percentile(
frequencyExpression : Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
- extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
+ extends TypedImperativeAggregate[OpenHashMap[AnyRef, Long]] with ImplicitCastInputTypes {
def this(child: Expression, percentageExpression: Expression) = {
this(child, percentageExpression, Literal(1L), 0, 0)
@@ -130,15 +130,20 @@ case class Percentile(
}
}
- override def createAggregationBuffer(): OpenHashMap[Number, Long] = {
+ private def toDoubleValue(d: Any): Double = d match {
+ case d: Decimal => d.toDouble
+ case n: Number => n.doubleValue
+ }
+
+ override def createAggregationBuffer(): OpenHashMap[AnyRef, Long] = {
// Initialize new counts map instance here.
- new OpenHashMap[Number, Long]()
+ new OpenHashMap[AnyRef, Long]()
}
override def update(
- buffer: OpenHashMap[Number, Long],
- input: InternalRow): OpenHashMap[Number, Long] = {
- val key = child.eval(input).asInstanceOf[Number]
+ buffer: OpenHashMap[AnyRef, Long],
+ input: InternalRow): OpenHashMap[AnyRef, Long] = {
+ val key = child.eval(input).asInstanceOf[AnyRef]
val frqValue = frequencyExpression.eval(input)
// Null values are ignored in counts map.
@@ -155,32 +160,32 @@ case class Percentile(
}
override def merge(
- buffer: OpenHashMap[Number, Long],
- other: OpenHashMap[Number, Long]): OpenHashMap[Number, Long] = {
+ buffer: OpenHashMap[AnyRef, Long],
+ other: OpenHashMap[AnyRef, Long]): OpenHashMap[AnyRef, Long] = {
other.foreach { case (key, count) =>
buffer.changeValue(key, count, _ + count)
}
buffer
}
- override def eval(buffer: OpenHashMap[Number, Long]): Any = {
+ override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
generateOutput(getPercentiles(buffer))
}
- private def getPercentiles(buffer: OpenHashMap[Number, Long]): Seq[Double] = {
+ private def getPercentiles(buffer: OpenHashMap[AnyRef, Long]): Seq[Double] = {
if (buffer.isEmpty) {
return Seq.empty
}
val sortedCounts = buffer.toSeq.sortBy(_._1)(
- child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[Number]])
+ child.dataType.asInstanceOf[NumericType].ordering.asInstanceOf[Ordering[AnyRef]])
val accumlatedCounts = sortedCounts.scanLeft(sortedCounts.head._1, 0L) {
case ((key1, count1), (key2, count2)) => (key2, count1 + count2)
}.tail
val maxPosition = accumlatedCounts.last._2 - 1
percentages.map { percentile =>
- getPercentile(accumlatedCounts, maxPosition * percentile).doubleValue()
+ getPercentile(accumlatedCounts, maxPosition * percentile)
}
}
@@ -200,7 +205,7 @@ case class Percentile(
* This function has been based upon similar function from HIVE
* `org.apache.hadoop.hive.ql.udf.UDAFPercentile.getPercentile()`.
*/
- private def getPercentile(aggreCounts: Seq[(Number, Long)], position: Double): Number = {
+ private def getPercentile(aggreCounts: Seq[(AnyRef, Long)], position: Double): Double = {
// We may need to do linear interpolation to get the exact percentile
val lower = position.floor.toLong
val higher = position.ceil.toLong
@@ -213,18 +218,17 @@ case class Percentile(
val lowerKey = aggreCounts(lowerIndex)._1
if (higher == lower) {
// no interpolation needed because position does not have a fraction
- return lowerKey
+ return toDoubleValue(lowerKey)
}
val higherKey = aggreCounts(higherIndex)._1
if (higherKey == lowerKey) {
// no interpolation needed because lower position and higher position has the same key
- return lowerKey
+ return toDoubleValue(lowerKey)
}
// Linear interpolation to get the exact percentile
- return (higher - position) * lowerKey.doubleValue() +
- (position - lower) * higherKey.doubleValue()
+ (higher - position) * toDoubleValue(lowerKey) + (position - lower) * toDoubleValue(higherKey)
}
/**
@@ -238,7 +242,7 @@ case class Percentile(
}
}
- override def serialize(obj: OpenHashMap[Number, Long]): Array[Byte] = {
+ override def serialize(obj: OpenHashMap[AnyRef, Long]): Array[Byte] = {
val buffer = new Array[Byte](4 << 10) // 4K
val bos = new ByteArrayOutputStream()
val out = new DataOutputStream(bos)
@@ -261,11 +265,11 @@ case class Percentile(
}
}
- override def deserialize(bytes: Array[Byte]): OpenHashMap[Number, Long] = {
+ override def deserialize(bytes: Array[Byte]): OpenHashMap[AnyRef, Long] = {
val bis = new ByteArrayInputStream(bytes)
val ins = new DataInputStream(bis)
try {
- val counts = new OpenHashMap[Number, Long]
+ val counts = new OpenHashMap[AnyRef, Long]
// Read unsafeRow size and content in bytes.
var sizeOfNextRow = ins.readInt()
while (sizeOfNextRow >= 0) {
@@ -274,7 +278,7 @@ case class Percentile(
val row = new UnsafeRow(2)
row.pointTo(bs, sizeOfNextRow)
// Insert the pairs into counts map.
- val key = row.get(0, child.dataType).asInstanceOf[Number]
+ val key = row.get(0, child.dataType)
val count = row.get(1, LongType).asInstanceOf[Long]
counts.update(key, count)
sizeOfNextRow = ins.readInt()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index 1533fe5f90..2420ba513f 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -21,7 +21,6 @@ import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
-import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util.ArrayData
import org.apache.spark.sql.types._
@@ -39,12 +38,12 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(BoundReference(0, IntegerType, true), Literal(0.5))
// Check empty serialize and deserialize
- val buffer = new OpenHashMap[Number, Long]()
+ val buffer = new OpenHashMap[AnyRef, Long]()
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
// Check non-empty buffer serializa and deserialize.
data.foreach { key =>
- buffer.changeValue(key, 1L, _ + 1L)
+ buffer.changeValue(new Integer(key), 1L, _ + 1L)
}
assert(compareEquals(agg.deserialize(agg.serialize(buffer)), buffer))
}
@@ -58,25 +57,25 @@ class PercentileSuite extends SparkFunSuite {
val agg = new Percentile(childExpression, percentageExpression)
// Test with rows without frequency
- val rows = (1 to count).map( x => Seq(x))
- runTest( agg, rows, expectedPercentiles)
+ val rows = (1 to count).map(x => Seq(x))
+ runTest(agg, rows, expectedPercentiles)
// Test with row with frequency. Second and third columns are frequency in Int and Long
val countForFrequencyTest = 1000
- val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
+ val rowsWithFrequency = (1 to countForFrequencyTest).map(x => Seq(x, x):+ x.toLong)
val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
- runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
+ runTest(aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
- runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
+ runTest(aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
// Run test with Flatten data
- val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
- (1 to current).map( y => current )).map( Seq(_))
+ val flattenRows = (1 to countForFrequencyTest).flatMap(current =>
+ (1 to current).map(y => current )).map(Seq(_))
runTest(agg, flattenRows, expectedPercentilesWithFrquency)
}
@@ -153,7 +152,7 @@ class PercentileSuite extends SparkFunSuite {
}
val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
- for ( dataType <- validDataTypes;
+ for (dataType <- validDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
@@ -176,7 +175,7 @@ class PercentileSuite extends SparkFunSuite {
StringType, DateType, TimestampType,
CalendarIntervalType, NullType)
- for( dataType <- invalidDataTypes;
+ for(dataType <- invalidDataTypes;
frequencyType <- validFrequencyTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
@@ -186,7 +185,7 @@ class PercentileSuite extends SparkFunSuite {
s"'`a`' is of ${dataType.simpleString} type."))
}
- for( dataType <- validDataTypes;
+ for(dataType <- validDataTypes;
frequencyType <- invalidFrequencyDataTypes) {
val child = AttributeReference("a", dataType)()
val frq = AttributeReference("frq", frequencyType)()
@@ -294,11 +293,11 @@ class PercentileSuite extends SparkFunSuite {
agg.update(buffer, InternalRow(1, -5))
agg.eval(buffer)
}
- assert( caught.getMessage.startsWith("Negative values found in "))
+ assert(caught.getMessage.startsWith("Negative values found in "))
}
private def compareEquals(
- left: OpenHashMap[Number, Long], right: OpenHashMap[Number, Long]): Boolean = {
+ left: OpenHashMap[AnyRef, Long], right: OpenHashMap[AnyRef, Long]): Boolean = {
left.size == right.size && left.forall { case (key, count) =>
right.apply(key) == count
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index e6338ab7cd..5e65436079 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1702,4 +1702,9 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
val df = Seq(123L -> "123", 19157170390056973L -> "19157170390056971").toDF("i", "j")
checkAnswer(df.select($"i" === $"j"), Row(true) :: Row(false) :: Nil)
}
+
+ test("SPARK-19691 Calculating percentile of decimal column fails with ClassCastException") {
+ val df = spark.range(1).selectExpr("CAST(id as DECIMAL) as x").selectExpr("percentile(x, 0.5)")
+ checkAnswer(df, Row(BigDecimal(0.0)) :: Nil)
+ }
}