aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-20 15:24:01 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-20 15:24:01 -0800
commitb362239df566bc949283f2ac195ee89af105605a (patch)
tree54f85250f801c12a3102aa4bcfacaee9d5a08c4c
parent10173279305a0e8a62bfbfe7a9d5d1fd558dd8e1 (diff)
downloadspark-b362239df566bc949283f2ac195ee89af105605a.tar.gz
spark-b362239df566bc949283f2ac195ee89af105605a.tar.bz2
spark-b362239df566bc949283f2ac195ee89af105605a.zip
[SPARK-12797] [SQL] Generated TungstenAggregate (without grouping keys)
As discussed in #10786, the generated TungstenAggregate does not support imperative functions. For a query ``` sqlContext.range(10).filter("id > 1").groupBy().count() ``` The generated code will looks like: ``` /* 032 */ if (!initAgg0) { /* 033 */ initAgg0 = true; /* 034 */ /* 035 */ // initialize aggregation buffer /* 037 */ long bufValue2 = 0L; /* 038 */ /* 039 */ /* 040 */ // initialize Range /* 041 */ if (!range_initRange5) { /* 042 */ range_initRange5 = true; ... /* 071 */ } /* 072 */ /* 073 */ while (!range_overflow8 && range_number7 < range_partitionEnd6) { /* 074 */ long range_value9 = range_number7; /* 075 */ range_number7 += 1L; /* 076 */ if (range_number7 < range_value9 ^ 1L < 0) { /* 077 */ range_overflow8 = true; /* 078 */ } /* 079 */ /* 085 */ boolean primitive11 = false; /* 086 */ primitive11 = range_value9 > 1L; /* 087 */ if (!false && primitive11) { /* 092 */ // do aggregate and update aggregation buffer /* 099 */ long primitive17 = -1L; /* 100 */ primitive17 = bufValue2 + 1L; /* 101 */ bufValue2 = primitive17; /* 105 */ } /* 107 */ } /* 109 */ /* 110 */ // output the result /* 112 */ bufferHolder25.reset(); /* 114 */ rowWriter26.initialize(bufferHolder25, 1); /* 118 */ rowWriter26.write(0, bufValue2); /* 120 */ result24.pointTo(bufferHolder25.buffer, bufferHolder25.totalSize()); /* 121 */ currentRow = result24; /* 122 */ return; /* 124 */ } /* 125 */ ``` cc nongli Author: Davies Liu <davies@databricks.com> Closes #10840 from davies/gen_agg.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala87
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala4
5 files changed, 111 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index c15fabab80..57f4945de9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
*/
private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
+ private def supportCodegen(e: Expression): Boolean = e match {
+ case e: LeafExpression => true
+ // CodegenFallback requires the input to be an InternalRow
+ case e: CodegenFallback => false
+ case _ => true
+ }
+
private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
- // Non-leaf with CodegenFallback does not work with whole stage codegen
- val willFallback = plan.expressions.exists(
- _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
- )
+ val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
val haveManyColumns = plan.output.length > 200
!willFallback && !haveManyColumns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 8dcbab4c8c..23e54f344d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
@@ -35,7 +36,7 @@ case class TungstenAggregate(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryNode {
+ extends UnaryNode with CodegenSupport {
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -113,6 +114,86 @@ case class TungstenAggregate(
}
}
+ override def supportCodegen: Boolean = {
+ groupingExpressions.isEmpty &&
+ // ImperativeAggregate is not supported right now
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
+ // final aggregation only have one row, do not need to codegen
+ !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+ }
+
+ // The variables used as aggregation buffer
+ private var bufVars: Seq[ExprCode] = _
+
+ private val modes = aggregateExpressions.map(_.mode).distinct
+
+ protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ val initAgg = ctx.freshName("initAgg")
+ ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+ // generate variables for aggregation buffer
+ val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ val initExpr = functions.flatMap(f => f.initialValues)
+ bufVars = initExpr.map { e =>
+ val isNull = ctx.freshName("bufIsNull")
+ val value = ctx.freshName("bufValue")
+ // The initial expression should not access any column
+ val ev = e.gen(ctx)
+ val initVars = s"""
+ | boolean $isNull = ${ev.isNull};
+ | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+ """.stripMargin
+ ExprCode(ev.code + initVars, isNull, value)
+ }
+
+ val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ val source =
+ s"""
+ | if (!$initAgg) {
+ | $initAgg = true;
+ |
+ | // initialize aggregation buffer
+ | ${bufVars.map(_.code).mkString("\n")}
+ |
+ | $childSource
+ |
+ | // output the result
+ | ${consume(ctx, bufVars)}
+ | }
+ """.stripMargin
+
+ (rdd, source)
+ }
+
+ override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ // only have DeclarativeAggregate
+ val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ // the mode could be only Partial or PartialMerge
+ val updateExpr = if (modes.contains(Partial)) {
+ functions.flatMap(_.updateExpressions)
+ } else {
+ functions.flatMap(_.mergeExpressions)
+ }
+
+ val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
+ val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
+ ctx.currentVars = bufVars ++ input
+ // TODO: support subexpression elimination
+ val codes = boundExpr.zipWithIndex.map { case (e, i) =>
+ val ev = e.gen(ctx)
+ s"""
+ | ${ev.code}
+ | ${bufVars(i).isNull} = ${ev.isNull};
+ | ${bufVars(i).value} = ${ev.value};
+ """.stripMargin
+ }
+
+ s"""
+ | // do aggregate and update aggregation buffer
+ | ${codes.mkString("")}
+ """.stripMargin
+ }
+
override def simpleString: String = {
val allAggregateExpressions = aggregateExpressions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 788b04fcf8..c4aad398bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
- -------------------------------------------------------------------------
- Without whole stage codegen 6725.52 31.18 1.00 X
- With whole stage codegen 2233.05 93.91 3.01 X
+ Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ Without whole stage codegen 7775.53 26.97 1.00 X
+ With whole stage codegen 342.15 612.94 22.73 X
*/
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c54fc6ba2d..300788c88a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.functions.{avg, col, max}
import org.apache.spark.sql.test.SharedSQLContext
class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
@@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = false
)
}
+
+ test("Aggregate should be included in WholeStageCodegen") {
+ val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ assert(df.collect() === Array(Row(9, 4.5)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 4339f7260d..51285431a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
- df.collect()
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ df.collect()
+ }
sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)