aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorgagan taneja <tanejagagan@gagans-MacBook-Pro.local>2017-02-07 14:05:22 +0100
committerHerman van Hovell <hvanhovell@databricks.com>2017-02-07 14:05:22 +0100
commite99e34d0f370211a7c7b96d144cc932b2fc71d10 (patch)
tree06cd312cf7437f0b221937664ea34c983a0faf3b /sql/catalyst/src
parent3d314d08c9420e74b4bb687603cdd11394eccab5 (diff)
downloadspark-e99e34d0f370211a7c7b96d144cc932b2fc71d10.tar.gz
spark-e99e34d0f370211a7c7b96d144cc932b2fc71d10.tar.bz2
spark-e99e34d0f370211a7c7b96d144cc932b2fc71d10.zip
[SPARK-19118][SQL] Percentile support for frequency distribution table
## What changes were proposed in this pull request? I have a frequency distribution table with following entries Age, No of person 21, 10 22, 15 23, 18 .. .. 30, 14 Moreover it is common to have data in frequency distribution format to further calculate Percentile, Median. With current implementation It would be very difficult and complex to find the percentile. Therefore i am proposing enhancement to current Percentile and Approx Percentile implementation to take frequency distribution column into consideration ## How was this patch tested? 1) Enhanced /sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala to cover the additional functionality 2) Run some performance benchmark test with 20 million row in local environment and did not see any performance degradation Please review http://spark.apache.org/contributing.html before opening a pull request. Author: gagan taneja <tanejagagan@gagans-MacBook-Pro.local> Closes #16497 from tanejagagan/branch-18940.
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala47
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala149
2 files changed, 141 insertions, 55 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
index 5b4ce47fd5..6b7cf7991d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Percentile.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
+import org.apache.spark.SparkException
/**
* The Percentile aggregate function returns the exact percentile(s) of numeric column `expr` at
@@ -44,22 +45,30 @@ import org.apache.spark.util.collection.OpenHashMap
@ExpressionDescription(
usage =
"""
- _FUNC_(col, percentage) - Returns the exact percentile value of numeric column `col` at the
- given percentage. The value of percentage must be between 0.0 and 1.0.
+ _FUNC_(col, percentage [, frequency]) - Returns the exact percentile value of numeric column
+ `col` at the given percentage. The value of percentage must be between 0.0 and 1.0. The
+ value of frequency should be positive integral
- _FUNC_(col, array(percentage1 [, percentage2]...)) - Returns the exact percentile value array
- of numeric column `col` at the given percentage(s). Each value of the percentage array must
- be between 0.0 and 1.0.
- """)
+ _FUNC_(col, array(percentage1 [, percentage2]...) [, frequency]) - Returns the exact
+ percentile value array of numeric column `col` at the given percentage(s). Each value
+ of the percentage array must be between 0.0 and 1.0. The value of frequency should be
+ positive integral
+
+ """)
case class Percentile(
child: Expression,
percentageExpression: Expression,
+ frequencyExpression : Expression,
mutableAggBufferOffset: Int = 0,
inputAggBufferOffset: Int = 0)
extends TypedImperativeAggregate[OpenHashMap[Number, Long]] with ImplicitCastInputTypes {
def this(child: Expression, percentageExpression: Expression) = {
- this(child, percentageExpression, 0, 0)
+ this(child, percentageExpression, Literal(1L), 0, 0)
+ }
+
+ def this(child: Expression, percentageExpression: Expression, frequency: Expression) = {
+ this(child, percentageExpression, frequency, 0, 0)
}
override def prettyName: String = "percentile"
@@ -80,7 +89,9 @@ case class Percentile(
case arrayData: ArrayData => arrayData.toDoubleArray().toSeq
}
- override def children: Seq[Expression] = child :: percentageExpression :: Nil
+ override def children: Seq[Expression] = {
+ child :: percentageExpression ::frequencyExpression :: Nil
+ }
// Returns null for empty inputs
override def nullable: Boolean = true
@@ -90,9 +101,12 @@ case class Percentile(
case _ => DoubleType
}
- override def inputTypes: Seq[AbstractDataType] = percentageExpression.dataType match {
- case _: ArrayType => Seq(NumericType, ArrayType(DoubleType))
- case _ => Seq(NumericType, DoubleType)
+ override def inputTypes: Seq[AbstractDataType] = {
+ val percentageExpType = percentageExpression.dataType match {
+ case _: ArrayType => ArrayType(DoubleType)
+ case _ => DoubleType
+ }
+ Seq(NumericType, percentageExpType, IntegralType)
}
// Check the inputTypes are valid, and the percentageExpression satisfies:
@@ -125,10 +139,17 @@ case class Percentile(
buffer: OpenHashMap[Number, Long],
input: InternalRow): OpenHashMap[Number, Long] = {
val key = child.eval(input).asInstanceOf[Number]
+ val frqValue = frequencyExpression.eval(input)
// Null values are ignored in counts map.
- if (key != null) {
- buffer.changeValue(key, 1L, _ + 1L)
+ if (key != null && frqValue != null) {
+ val frqLong = frqValue.asInstanceOf[Number].longValue()
+ // add only when frequency is positive
+ if (frqLong > 0) {
+ buffer.changeValue(key, frqLong, _ + frqLong)
+ } else if (frqLong < 0) {
+ throw new SparkException(s"Negative values found in ${frequencyExpression.sql}")
+ }
}
buffer
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
index f060ecc184..1533fe5f90 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/aggregate/PercentileSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions.aggregate
+import org.apache.spark.SparkException
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult._
@@ -50,25 +51,50 @@ class PercentileSuite extends SparkFunSuite {
test("class Percentile, high level interface, update, merge, eval...") {
val count = 10000
- val data = (1 to count)
val percentages = Seq(0, 0.25, 0.5, 0.75, 1)
val expectedPercentiles = Seq(1, 2500.75, 5000.5, 7500.25, 10000)
val childExpression = Cast(BoundReference(0, IntegerType, nullable = false), DoubleType)
val percentageExpression = CreateArray(percentages.toSeq.map(Literal(_)))
val agg = new Percentile(childExpression, percentageExpression)
+ // Test with rows without frequency
+ val rows = (1 to count).map( x => Seq(x))
+ runTest( agg, rows, expectedPercentiles)
+
+ // Test with row with frequency. Second and third columns are frequency in Int and Long
+ val countForFrequencyTest = 1000
+ val rowsWithFrequency = (1 to countForFrequencyTest).map( x => Seq(x, x):+ x.toLong)
+ val expectedPercentilesWithFrquency = Seq(1.0, 500.0, 707.0, 866.0, 1000.0)
+
+ val frequencyExpressionInt = BoundReference(1, IntegerType, nullable = false)
+ val aggInt = new Percentile(childExpression, percentageExpression, frequencyExpressionInt)
+ runTest( aggInt, rowsWithFrequency, expectedPercentilesWithFrquency)
+
+ val frequencyExpressionLong = BoundReference(2, LongType, nullable = false)
+ val aggLong = new Percentile(childExpression, percentageExpression, frequencyExpressionLong)
+ runTest( aggLong, rowsWithFrequency, expectedPercentilesWithFrquency)
+
+ // Run test with Flatten data
+ val flattenRows = (1 to countForFrequencyTest).flatMap( current =>
+ (1 to current).map( y => current )).map( Seq(_))
+ runTest(agg, flattenRows, expectedPercentilesWithFrquency)
+ }
+
+ private def runTest(agg: Percentile,
+ rows : Seq[Seq[Any]],
+ expectedPercentiles : Seq[Double]) {
assert(agg.nullable)
- val group1 = (0 until data.length / 2)
+ val group1 = (0 until rows.length / 2)
val group1Buffer = agg.createAggregationBuffer()
group1.foreach { index =>
- val input = InternalRow(data(index))
+ val input = InternalRow(rows(index): _*)
agg.update(group1Buffer, input)
}
- val group2 = (data.length / 2 until data.length)
+ val group2 = (rows.length / 2 until rows.length)
val group2Buffer = agg.createAggregationBuffer()
group2.foreach { index =>
- val input = InternalRow(data(index))
+ val input = InternalRow(rows(index): _*)
agg.update(group2Buffer, input)
}
@@ -116,40 +142,6 @@ class PercentileSuite extends SparkFunSuite {
assert(agg.eval(mutableAggBuffer).asInstanceOf[Double] == expectedPercentile)
}
- test("call from sql query") {
- // sql, single percentile
- assertEqual(
- s"percentile(`a`, 0.5D)",
- new Percentile("a".attr, Literal(0.5)).sql: String)
-
- // sql, array of percentile
- assertEqual(
- s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
- new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_)))).sql: String)
-
- // sql(isDistinct = false), single percentile
- assertEqual(
- s"percentile(`a`, 0.5D)",
- new Percentile("a".attr, Literal(0.5)).sql(isDistinct = false))
-
- // sql(isDistinct = false), array of percentile
- assertEqual(
- s"percentile(`a`, array(0.25D, 0.5D, 0.75D))",
- new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
- .sql(isDistinct = false))
-
- // sql(isDistinct = true), single percentile
- assertEqual(
- s"percentile(DISTINCT `a`, 0.5D)",
- new Percentile("a".attr, Literal(0.5)).sql(isDistinct = true))
-
- // sql(isDistinct = true), array of percentile
- assertEqual(
- s"percentile(DISTINCT `a`, array(0.25D, 0.5D, 0.75D))",
- new Percentile("a".attr, CreateArray(Seq(0.25, 0.5, 0.75).map(Literal(_))))
- .sql(isDistinct = true))
- }
-
test("fail analysis if childExpression is invalid") {
val validDataTypes = Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType)
val percentage = Literal(0.5)
@@ -160,6 +152,15 @@ class PercentileSuite extends SparkFunSuite {
assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
}
+ val validFrequencyTypes = Seq(ByteType, ShortType, IntegerType, LongType)
+ for ( dataType <- validDataTypes;
+ frequencyType <- validFrequencyTypes) {
+ val child = AttributeReference("a", dataType)()
+ val frq = AttributeReference("frq", frequencyType)()
+ val percentile = new Percentile(child, percentage, frq)
+ assertEqual(percentile.checkInputDataTypes(), TypeCheckSuccess)
+ }
+
val invalidDataTypes = Seq(BooleanType, StringType, DateType, TimestampType,
CalendarIntervalType, NullType)
@@ -170,6 +171,30 @@ class PercentileSuite extends SparkFunSuite {
TypeCheckFailure(s"argument 1 requires numeric type, however, " +
s"'`a`' is of ${dataType.simpleString} type."))
}
+
+ val invalidFrequencyDataTypes = Seq(FloatType, DoubleType, BooleanType,
+ StringType, DateType, TimestampType,
+ CalendarIntervalType, NullType)
+
+ for( dataType <- invalidDataTypes;
+ frequencyType <- validFrequencyTypes) {
+ val child = AttributeReference("a", dataType)()
+ val frq = AttributeReference("frq", frequencyType)()
+ val percentile = new Percentile(child, percentage, frq)
+ assertEqual(percentile.checkInputDataTypes(),
+ TypeCheckFailure(s"argument 1 requires numeric type, however, " +
+ s"'`a`' is of ${dataType.simpleString} type."))
+ }
+
+ for( dataType <- validDataTypes;
+ frequencyType <- invalidFrequencyDataTypes) {
+ val child = AttributeReference("a", dataType)()
+ val frq = AttributeReference("frq", frequencyType)()
+ val percentile = new Percentile(child, percentage, frq)
+ assertEqual(percentile.checkInputDataTypes(),
+ TypeCheckFailure(s"argument 3 requires integral type, however, " +
+ s"'`frq`' is of ${frequencyType.simpleString} type."))
+ }
}
test("fails analysis if percentage(s) are invalid") {
@@ -217,19 +242,59 @@ class PercentileSuite extends SparkFunSuite {
}
test("null handling") {
+
+ // Percentile without frequency column
val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
val agg = new Percentile(childExpression, Literal(0.5))
val buffer = new GenericInternalRow(new Array[Any](1))
agg.initialize(buffer)
+
// Empty aggregation buffer
assert(agg.eval(buffer) == null)
+
// Empty input row
agg.update(buffer, InternalRow(null))
assert(agg.eval(buffer) == null)
- // Add some non-empty row
- agg.update(buffer, InternalRow(0))
- assert(agg.eval(buffer) != null)
+ // Percentile with Frequency column
+ val frequencyExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
+ val aggWithFrequency = new Percentile(childExpression, Literal(0.5), frequencyExpression)
+ val bufferWithFrequency = new GenericInternalRow(new Array[Any](2))
+ aggWithFrequency.initialize(bufferWithFrequency)
+
+ // Empty aggregation buffer
+ assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+ // Empty input row
+ aggWithFrequency.update(bufferWithFrequency, InternalRow(null, null))
+ assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+
+ // Add some non-empty row with empty frequency column
+ aggWithFrequency.update(bufferWithFrequency, InternalRow(0, null))
+ assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+
+ // Add some non-empty row with zero frequency
+ aggWithFrequency.update(bufferWithFrequency, InternalRow(1, 0))
+ assert(aggWithFrequency.eval(bufferWithFrequency) == null)
+
+ // Add some non-empty row with positive frequency
+ aggWithFrequency.update(bufferWithFrequency, InternalRow(0, 1))
+ assert(aggWithFrequency.eval(bufferWithFrequency) != null)
+ }
+
+ test("negatives frequency column handling") {
+ val childExpression = Cast(BoundReference(0, IntegerType, nullable = true), DoubleType)
+ val freqExpression = Cast(BoundReference(1, IntegerType, nullable = true), IntegerType)
+ val agg = new Percentile(childExpression, Literal(0.5), freqExpression)
+ val buffer = new GenericInternalRow(new Array[Any](2))
+ agg.initialize(buffer)
+
+ val caught =
+ intercept[SparkException]{
+ // Add some non-empty row with negative frequency
+ agg.update(buffer, InternalRow(1, -5))
+ agg.eval(buffer)
+ }
+ assert( caught.getMessage.startsWith("Negative values found in "))
}
private def compareEquals(