aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorVenkata Ramana Gollamudi <ramana.gollamudi@huawei.com>2015-04-08 18:42:34 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-08 18:42:34 -0700
commit7d7384c781ea72e1eabab3daca2e237e3b0fc666 (patch)
tree5b0494d92c48592baca16cc53066174f6a4590ce /sql
parent9418280547f962eaf309bfff9986cdd848409643 (diff)
downloadspark-7d7384c781ea72e1eabab3daca2e237e3b0fc666.tar.gz
spark-7d7384c781ea72e1eabab3daca2e237e3b0fc666.tar.bz2
spark-7d7384c781ea72e1eabab3daca2e237e3b0fc666.zip
[SPARK-6451][SQL] supported code generation for CombineSum
Author: Venkata Ramana Gollamudi <ramana.gollamudi@huawei.com> Closes #5138 from gvramana/sum_fix_codegen and squashes the following commits: 95f5fe4 [Venkata Ramana Gollamudi] rebase merge changes 12f45a5 [Venkata Ramana Gollamudi] Combined and added code generations tests as per comment d6a76ac [Venkata Ramana Gollamudi] added support for codegeneration for CombineSum and tests
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala44
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala92
3 files changed, 133 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index a8018b9213..861a2c21ad 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -99,7 +99,10 @@ case class GeneratedAggregate(
// but really, common sub expression elimination would be better....
val zero = Cast(Literal(0), calcType)
val updateFunction = Coalesce(
- Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil)
+ Add(
+ Coalesce(currentSum :: zero :: Nil),
+ Cast(expr, calcType)
+ ) :: currentSum :: zero :: Nil)
val result =
expr.dataType match {
case DecimalType.Fixed(_, _) =>
@@ -109,6 +112,45 @@ case class GeneratedAggregate(
AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+ case cs @ CombineSum(expr) =>
+ val calcType = expr.dataType
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ DecimalType.Unlimited
+ case _ =>
+ expr.dataType
+ }
+
+ val currentSum = AttributeReference("currentSum", calcType, nullable = true)()
+ val initialValue = Literal.create(null, calcType)
+
+ // Coalasce avoids double calculation...
+ // but really, common sub expression elimination would be better....
+ val zero = Cast(Literal(0), calcType)
+ // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its
+ // UnscaledValue will be null if and only if x is null; helps with Average on decimals
+ val actualExpr = expr match {
+ case UnscaledValue(e) => e
+ case _ => expr
+ }
+ // partial sum result can be null only when no input rows present
+ val updateFunction = If(
+ IsNotNull(actualExpr),
+ Coalesce(
+ Add(
+ Coalesce(currentSum :: zero :: Nil),
+ Cast(expr, calcType)) :: currentSum :: zero :: Nil),
+ currentSum)
+
+ val result =
+ expr.dataType match {
+ case DecimalType.Fixed(_, _) =>
+ Cast(currentSum, cs.dataType)
+ case _ => currentSum
+ }
+
+ AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result)
+
case a @ Average(expr) =>
val calcType =
expr.dataType match {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f754fa770d..23f7e56094 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
- case _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
+ case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 87e7cf8c8a..1ad92a3941 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import org.apache.spark.sql.test.TestSQLContext
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.sql.execution.GeneratedAggregate
import org.apache.spark.sql.functions._
import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
@@ -102,14 +103,99 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
sql("SELECT ABS(2.5)"),
Row(2.5))
}
-
+
test("aggregation with codegen") {
val originalValue = conf.codegenEnabled
setConf(SQLConf.CODEGEN_ENABLED, "true")
- sql("SELECT key FROM testData GROUP BY key").collect()
+ // Prepare a table that we can group some rows.
+ table("testData")
+ .unionAll(table("testData"))
+ .unionAll(table("testData"))
+ .registerTempTable("testData3x")
+
+ def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = {
+ val df = sql(sqlText)
+ // First, check if we have GeneratedAggregate.
+ var hasGeneratedAgg = false
+ df.queryExecution.executedPlan.foreach {
+ case generatedAgg: GeneratedAggregate => hasGeneratedAgg = true
+ case _ =>
+ }
+ if (!hasGeneratedAgg) {
+ fail(
+ s"""
+ |Codegen is enabled, but query $sqlText does not have GeneratedAggregate in the plan.
+ |${df.queryExecution.simpleString}
+ """.stripMargin)
+ }
+ // Then, check results.
+ checkAnswer(df, expectedResults)
+ }
+
+ // Just to group rows.
+ testCodeGen(
+ "SELECT key FROM testData3x GROUP BY key",
+ (1 to 100).map(Row(_)))
+ // COUNT
+ testCodeGen(
+ "SELECT key, count(value) FROM testData3x GROUP BY key",
+ (1 to 100).map(i => Row(i, 3)))
+ testCodeGen(
+ "SELECT count(key) FROM testData3x",
+ Row(300) :: Nil)
+ // COUNT DISTINCT ON int
+ testCodeGen(
+ "SELECT value, count(distinct key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 1)))
+ testCodeGen(
+ "SELECT count(distinct key) FROM testData3x",
+ Row(100) :: Nil)
+ // SUM
+ testCodeGen(
+ "SELECT value, sum(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, 3 * i)))
+ testCodeGen(
+ "SELECT sum(key), SUM(CAST(key as Double)) FROM testData3x",
+ Row(5050 * 3, 5050 * 3.0) :: Nil)
+ // AVERAGE
+ testCodeGen(
+ "SELECT value, avg(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT avg(key) FROM testData3x",
+ Row(50.5) :: Nil)
+ // MAX
+ testCodeGen(
+ "SELECT value, max(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT max(key) FROM testData3x",
+ Row(100) :: Nil)
+ // Some combinations.
+ testCodeGen(
+ """
+ |SELECT
+ | value,
+ | sum(key),
+ | max(key),
+ | avg(key),
+ | count(key),
+ | count(distinct key)
+ |FROM testData3x
+ |GROUP BY value
+ """.stripMargin,
+ (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1)))
+ testCodeGen(
+ "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x",
+ Row(100, 50.5, 300, 100) :: Nil)
+ // Aggregate with Code generation handling all null values
+ testCodeGen(
+ "SELECT sum('a'), avg('a'), count(null) FROM testData",
+ Row(0, null, 0) :: Nil)
+
+ dropTempTable("testData3x")
setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
}
-
test("Add Parser of SQL COALESCE()") {
checkAnswer(
sql("""SELECT COALESCE(1, 2)"""),