From b362239df566bc949283f2ac195ee89af105605a Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 20 Jan 2016 15:24:01 -0800 Subject: [SPARK-12797] [SQL] Generated TungstenAggregate (without grouping keys) As discussed in #10786, the generated TungstenAggregate does not support imperative functions. For a query ``` sqlContext.range(10).filter("id > 1").groupBy().count() ``` The generated code will looks like: ``` /* 032 */ if (!initAgg0) { /* 033 */ initAgg0 = true; /* 034 */ /* 035 */ // initialize aggregation buffer /* 037 */ long bufValue2 = 0L; /* 038 */ /* 039 */ /* 040 */ // initialize Range /* 041 */ if (!range_initRange5) { /* 042 */ range_initRange5 = true; ... /* 071 */ } /* 072 */ /* 073 */ while (!range_overflow8 && range_number7 < range_partitionEnd6) { /* 074 */ long range_value9 = range_number7; /* 075 */ range_number7 += 1L; /* 076 */ if (range_number7 < range_value9 ^ 1L < 0) { /* 077 */ range_overflow8 = true; /* 078 */ } /* 079 */ /* 085 */ boolean primitive11 = false; /* 086 */ primitive11 = range_value9 > 1L; /* 087 */ if (!false && primitive11) { /* 092 */ // do aggregate and update aggregation buffer /* 099 */ long primitive17 = -1L; /* 100 */ primitive17 = bufValue2 + 1L; /* 101 */ bufValue2 = primitive17; /* 105 */ } /* 107 */ } /* 109 */ /* 110 */ // output the result /* 112 */ bufferHolder25.reset(); /* 114 */ rowWriter26.initialize(bufferHolder25, 1); /* 118 */ rowWriter26.write(0, bufValue2); /* 120 */ result24.pointTo(bufferHolder25.buffer, bufferHolder25.totalSize()); /* 121 */ currentRow = result24; /* 122 */ return; /* 124 */ } /* 125 */ ``` cc nongli Author: Davies Liu Closes #10840 from davies/gen_agg. --- .../spark/sql/execution/WholeStageCodegen.scala | 12 ++- .../execution/aggregate/TungstenAggregate.scala | 87 +++++++++++++++++++++- .../sql/execution/BenchmarkWholeStageCodegen.scala | 8 +- .../sql/execution/WholeStageCodegenSuite.scala | 12 +++ .../sql/execution/metric/SQLMetricsSuite.scala | 4 +- 5 files changed, 111 insertions(+), 12 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c15fabab80..57f4945de9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) */ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { + private def supportCodegen(e: Expression): Boolean = e match { + case e: LeafExpression => true + // CodegenFallback requires the input to be an InternalRow + case e: CodegenFallback => false + case _ => true + } + private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - // Non-leaf with CodegenFallback does not work with whole stage codegen - val willFallback = plan.expressions.exists( - _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined - ) + val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val haveManyColumns = plan.output.length > 200 !willFallback && !haveManyColumns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 8dcbab4c8c..23e54f344d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType @@ -35,7 +36,7 @@ case class TungstenAggregate( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -113,6 +114,86 @@ case class TungstenAggregate( } } + override def supportCodegen: Boolean = { + groupingExpressions.isEmpty && + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private val modes = aggregateExpressions.map(_.mode).distinct + + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // generate variables for aggregation buffer + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + // The initial expression should not access any column + val ev = e.gen(ctx) + val initVars = s""" + | boolean $isNull = ${ev.isNull}; + | ${ctx.javaType(e.dataType)} $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + val source = + s""" + | if (!$initAgg) { + | $initAgg = true; + | + | // initialize aggregation buffer + | ${bufVars.map(_.code).mkString("\n")} + | + | $childSource + | + | // output the result + | ${consume(ctx, bufVars)} + | } + """.stripMargin + + (rdd, source) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + // the mode could be only Partial or PartialMerge + val updateExpr = if (modes.contains(Partial)) { + functions.flatMap(_.updateExpressions) + } else { + functions.flatMap(_.mergeExpressions) + } + + val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr)) + ctx.currentVars = bufVars ++ input + // TODO: support subexpression elimination + val codes = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) + s""" + | ${ev.code} + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + + s""" + | // do aggregate and update aggregation buffer + | ${codes.mkString("")} + """.stripMargin + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 788b04fcf8..c4aad398bf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Without whole stage codegen 6725.52 31.18 1.00 X - With whole stage codegen 2233.05 93.91 3.01 X + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Without whole stage codegen 7775.53 26.97 1.00 X + With whole stage codegen 342.15 612.94 22.73 X */ benchmark.run() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index c54fc6ba2d..300788c88a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.execution.aggregate.TungstenAggregate +import org.apache.spark.sql.functions.{avg, col, max} import org.apache.spark.sql.test.SharedSQLContext class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { @@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { sortAnswers = false ) } + + test("Aggregate should be included in WholeStageCodegen") { + val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id"))) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(9, 4.5))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index 4339f7260d..51285431a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { expectedNumOfJobs: Int, expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = { val previousExecutionIds = sqlContext.listener.executionIdToData.keySet - df.collect() + withSQLConf("spark.sql.codegen.wholeStage" -> "false") { + df.collect() + } sparkContext.listenerBus.waitUntilEmpty(10000) val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds) assert(executionIds.size === 1) -- cgit v1.2.3