aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala87
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala4
5 files changed, 111 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index c15fabab80..57f4945de9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
*/
private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {
+ private def supportCodegen(e: Expression): Boolean = e match {
+ case e: LeafExpression => true
+ // CodegenFallback requires the input to be an InternalRow
+ case e: CodegenFallback => false
+ case _ => true
+ }
+
private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
- // Non-leaf with CodegenFallback does not work with whole stage codegen
- val willFallback = plan.expressions.exists(
- _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
- )
+ val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
val haveManyColumns = plan.output.length > 200
!willFallback && !haveManyColumns
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 8dcbab4c8c..23e54f344d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.expressions.aggregate._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType
@@ -35,7 +36,7 @@ case class TungstenAggregate(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
- extends UnaryNode {
+ extends UnaryNode with CodegenSupport {
private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
@@ -113,6 +114,86 @@ case class TungstenAggregate(
}
}
+ override def supportCodegen: Boolean = {
+ groupingExpressions.isEmpty &&
+ // ImperativeAggregate is not supported right now
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
+ // final aggregation only have one row, do not need to codegen
+ !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
+ }
+
+ // The variables used as aggregation buffer
+ private var bufVars: Seq[ExprCode] = _
+
+ private val modes = aggregateExpressions.map(_.mode).distinct
+
+ protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
+ val initAgg = ctx.freshName("initAgg")
+ ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+ // generate variables for aggregation buffer
+ val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ val initExpr = functions.flatMap(f => f.initialValues)
+ bufVars = initExpr.map { e =>
+ val isNull = ctx.freshName("bufIsNull")
+ val value = ctx.freshName("bufValue")
+ // The initial expression should not access any column
+ val ev = e.gen(ctx)
+ val initVars = s"""
+ | boolean $isNull = ${ev.isNull};
+ | ${ctx.javaType(e.dataType)} $value = ${ev.value};
+ """.stripMargin
+ ExprCode(ev.code + initVars, isNull, value)
+ }
+
+ val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ val source =
+ s"""
+ | if (!$initAgg) {
+ | $initAgg = true;
+ |
+ | // initialize aggregation buffer
+ | ${bufVars.map(_.code).mkString("\n")}
+ |
+ | $childSource
+ |
+ | // output the result
+ | ${consume(ctx, bufVars)}
+ | }
+ """.stripMargin
+
+ (rdd, source)
+ }
+
+ override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
+ // only have DeclarativeAggregate
+ val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
+ // the mode could be only Partial or PartialMerge
+ val updateExpr = if (modes.contains(Partial)) {
+ functions.flatMap(_.updateExpressions)
+ } else {
+ functions.flatMap(_.mergeExpressions)
+ }
+
+ val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
+ val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
+ ctx.currentVars = bufVars ++ input
+ // TODO: support subexpression elimination
+ val codes = boundExpr.zipWithIndex.map { case (e, i) =>
+ val ev = e.gen(ctx)
+ s"""
+ | ${ev.code}
+ | ${bufVars(i).isNull} = ${ev.isNull};
+ | ${bufVars(i).value} = ${ev.value};
+ """.stripMargin
+ }
+
+ s"""
+ | // do aggregate and update aggregation buffer
+ | ${codes.mkString("")}
+ """.stripMargin
+ }
+
override def simpleString: String = {
val allAggregateExpressions = aggregateExpressions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 788b04fcf8..c4aad398bf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
- -------------------------------------------------------------------------
- Without whole stage codegen 6725.52 31.18 1.00 X
- With whole stage codegen 2233.05 93.91 3.01 X
+ Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ Without whole stage codegen 7775.53 26.97 1.00 X
+ With whole stage codegen 342.15 612.94 22.73 X
*/
benchmark.run()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index c54fc6ba2d..300788c88a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.execution
+import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.functions.{avg, col, max}
import org.apache.spark.sql.test.SharedSQLContext
class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
@@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = false
)
}
+
+ test("Aggregate should be included in WholeStageCodegen") {
+ val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id")))
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ assert(df.collect() === Array(Row(9, 4.5)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 4339f7260d..51285431a4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
- df.collect()
+ withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
+ df.collect()
+ }
sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)