From e6a02c66d53f59ba2d5c1548494ae80a385f9f5c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 29 Jan 2016 20:16:11 -0800 Subject: [SPARK-12914] [SQL] generate aggregation with grouping keys This PR add support for grouping keys for generated TungstenAggregate. Spilling and performance improvements for BytesToBytesMap will be done by followup PR. Author: Davies Liu Closes #10855 from davies/gen_keys. --- .../expressions/codegen/CodeGenerator.scala | 47 ++++ .../codegen/GenerateMutableProjection.scala | 27 +-- .../spark/sql/execution/BufferedRowIterator.java | 6 +- .../execution/aggregate/TungstenAggregate.scala | 238 +++++++++++++++++++-- .../sql/execution/BenchmarkWholeStageCodegen.scala | 119 +++++++++-- .../sql/execution/WholeStageCodegenSuite.scala | 9 + 6 files changed, 393 insertions(+), 53 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e6704cf8bb..21f9198073 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -55,6 +55,20 @@ class CodegenContext { */ val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() + /** + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ + def addReferenceObj(name: String, obj: Any, className: String = null): String = { + val term = freshName(name) + val idx = references.length + references += obj + val clsName = Option(className).getOrElse(obj.getClass.getName) + addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];") + term + } + /** * Holding a list of generated columns as input of current operator, will be used by * BoundReference to generate code. @@ -198,6 +212,39 @@ class CodegenContext { } } + /** + * Update a column in MutableRow from ExprCode. + */ + def updateColumn( + row: String, + dataType: DataType, + ordinal: Int, + ev: ExprCode, + nullable: Boolean): String = { + if (nullable) { + // Can't call setNullAt on DecimalType, because we need to keep the offset + if (dataType.isInstanceOf[DecimalType]) { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + ${setColumn(row, dataType, ordinal, "null")}; + } + """ + } else { + s""" + if (!${ev.isNull}) { + ${setColumn(row, dataType, ordinal, ev.value)}; + } else { + $row.setNullAt($ordinal); + } + """ + } + } else { + s"""${setColumn(row, dataType, ordinal, ev.value)};""" + } + } + /** * Returns the name used in accessor and setter for a Java primitive type. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index ec31db19b9..5b4dc8df86 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu val updates = validExpr.zip(index).map { case (e, i) => - if (e.nullable) { - if (e.dataType.isInstanceOf[DecimalType]) { - // Can't call setNullAt on DecimalType, because we need to keep the offset - s""" - if (this.isNull_$i) { - ${ctx.setColumn("mutableRow", e.dataType, i, "null")}; - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } else { - s""" - if (this.isNull_$i) { - mutableRow.setNullAt($i); - } else { - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - } - """ - } - } else { - s""" - ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; - """ - } - + val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i") + ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable) } val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index b1bbb1da10..6acf70dbba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution; +import java.io.IOException; + import scala.collection.Iterator; import org.apache.spark.sql.catalyst.InternalRow; @@ -34,7 +36,7 @@ public class BufferedRowIterator { // used when there is no column in output protected UnsafeRow unsafeRow = new UnsafeRow(0); - public boolean hasNext() { + public boolean hasNext() throws IOException { if (currentRow == null) { processNext(); } @@ -56,7 +58,7 @@ public class BufferedRowIterator { * * After it's called, if currentRow is still null, it means no more rows left. */ - protected void processNext() { + protected void processNext() throws IOException { if (input.hasNext()) { currentRow = input.next(); } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index ff2f38bfd9..57db7262fd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -17,16 +17,18 @@ package org.apache.spark.sql.execution.aggregate +import org.apache.spark.TaskContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DecimalType, StructType} +import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], @@ -114,22 +116,38 @@ case class TungstenAggregate( } } + // all the mode of aggregate expressions + private val modes = aggregateExpressions.map(_.mode).distinct + override def supportCodegen: Boolean = { - groupingExpressions.isEmpty && - // ImperativeAggregate is not supported right now - !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) } - // The variables used as aggregation buffer - private var bufVars: Seq[ExprCode] = _ - - private val modes = aggregateExpressions.map(_.mode).distinct - override def upstream(): RDD[InternalRow] = { child.asInstanceOf[CodegenSupport].upstream() } protected override def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + // The variables used as aggregation buffer + private var bufVars: Seq[ExprCode] = _ + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") @@ -176,10 +194,10 @@ case class TungstenAggregate( (resultVars, resultVars.map(_.code).mkString("\n")) } - val doAgg = ctx.freshName("doAgg") + val doAgg = ctx.freshName("doAggregateWithoutKey") ctx.addNewFunction(doAgg, s""" - | private void $doAgg() { + | private void $doAgg() throws java.io.IOException { | // initialize aggregation buffer | ${bufVars.map(_.code).mkString("\n")} | @@ -200,7 +218,7 @@ case class TungstenAggregate( """.stripMargin } - override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { // only have DeclarativeAggregate val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output @@ -212,7 +230,6 @@ case class TungstenAggregate( e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions } } - ctx.currentVars = bufVars ++ input // TODO: support subexpression elimination val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx)) @@ -232,6 +249,199 @@ case class TungstenAggregate( """.stripMargin } + private val groupingAttributes = groupingExpressions.map(_.toAttribute) + private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes) + private val bufferSchema = StructType.fromAttributes(bufferAttributes) + + // The name for HashMap + private var hashMapTerm: String = _ + + /** + * This is called by generated Java class, should be public. + */ + def createHashMap(): UnsafeFixedWidthAggregationMap = { + // create initialized aggregate buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow) + + // create hashMap + new UnsafeFixedWidthAggregationMap( + initialBuffer, + bufferSchema, + groupingKeySchema, + TaskContext.get().taskMemoryManager(), + 1024 * 16, // initial capacity + TaskContext.get().taskMemoryManager().pageSizeBytes, + false // disable tracking of performance metrics + ) + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + + /** + * Update peak execution memory, called in generated Java class. + */ + def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = { + val mapMemory = hashMap.getPeakMemoryUsedBytes + val metrics = TaskContext.get().taskMetrics() + metrics.incPeakExecutionMemory(mapMemory) + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + // create hashMap + val thisPlan = ctx.addReferenceObj("plan", this) + hashMapTerm = ctx.freshName("hashMap") + val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName + ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + + // Create a name for iterator from HashMap + val iterTerm = ctx.freshName("mapIter") + ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "") + + // generate code for output + val keyTerm = ctx.freshName("aggKey") + val bufferTerm = ctx.freshName("aggBuffer") + val outputCode = if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + ctx.INPUT_ROW = bufferTerm + val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, bufferAttributes).gen(ctx) + } + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).gen(ctx) + } + s""" + ${keyVars.map(_.code).mkString("\n")} + ${bufferVars.map(_.code).mkString("\n")} + ${aggResults.map(_.code).mkString("\n")} + ${resultVars.map(_.code).mkString("\n")} + + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $thisPlan.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).gen(ctx) + } + s""" + ${eval.map(_.code).mkString("\n")} + ${consume(ctx, eval)} + """ + } + + val doAgg = ctx.freshName("doAggregateWithKeys") + ctx.addNewFunction(doAgg, + s""" + private void $doAgg() throws java.io.IOException { + ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + + $iterTerm = $hashMapTerm.iterator(); + } + """) + + s""" + if (!$initAgg) { + $initAgg = true; + $doAgg(); + } + + // output the result + while ($iterTerm.next()) { + UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey(); + UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue(); + $outputCode + } + + $thisPlan.updatePeakMemory($hashMapTerm); + $hashMapTerm.free(); + """ + } + + private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = { + + // create grouping key + ctx.currentVars = input + val keyCode = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output))) + val key = keyCode.value + val buffer = ctx.freshName("aggBuffer") + + // only have DeclarativeAggregate + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + + val inputAttr = bufferAttributes ++ child.output + ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input + ctx.INPUT_ROW = buffer + // TODO: support subexpression elimination + val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx)) + val updates = evals.zipWithIndex.map { case (ev, i) => + val dt = updateExpr(i).dataType + ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable) + } + + s""" + // generate grouping key + ${keyCode.code} + UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key); + if ($buffer == null) { + // failed to allocate the first page + throw new OutOfMemoryError("No enough memory for aggregation"); + } + + // evaluate aggregate function + ${evals.map(_.code).mkString("\n")} + // update aggregate buffer + ${updates.mkString("\n")} + """ + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index c4aad398bf..2f09c8a114 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -18,7 +18,12 @@ package org.apache.spark.sql.execution import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager} import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.unsafe.Platform +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** @@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" */ class BenchmarkWholeStageCodegen extends SparkFunSuite { - def testWholeStage(values: Int): Unit = { - val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - val sc = SparkContext.getOrCreate(conf) - val sqlContext = SQLContext.getOrCreate(sc) + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) - val benchmark = new Benchmark("Single Int Column Scan", values) + def testWholeStage(values: Int): Unit = { + val benchmark = new Benchmark("rang/filter/aggregate", values) - benchmark.addCase("Without whole stage codegen") { iter => + benchmark.addCase("Without codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "false") sqlContext.range(values).filter("(id & 1) = 1").count() } - benchmark.addCase("With whole stage codegen") { iter => + benchmark.addCase("With codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "true") sqlContext.range(values).filter("(id & 1) = 1").count() } /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate ------------------------------------------------------------------------------- - Without whole stage codegen 7775.53 26.97 1.00 X - With whole stage codegen 342.15 612.94 22.73 X + Without codegen 7775.53 26.97 1.00 X + With codegen 342.15 612.94 22.73 X */ benchmark.run() } - ignore("benchmark") { - testWholeStage(1024 * 1024 * 200) + def testAggregateWithKey(values: Int): Unit = { + val benchmark = new Benchmark("Aggregate with keys", values) + + benchmark.addCase("Aggregate w/o codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + benchmark.addCase(s"Aggregate w codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect() + } + + /* + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Aggregate w/o codegen 4254.38 4.93 1.00 X + Aggregate w codegen 2661.45 7.88 1.60 X + */ + benchmark.run() + } + + def testBytesToBytesMap(values: Int): Unit = { + val benchmark = new Benchmark("BytesToBytesMap", values) + + benchmark.addCase("hash") { iter => + var i = 0 + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var s = 0 + while (i < values) { + key.setInt(0, i % 1000) + val h = Murmur3_x86_32.hashUnsafeWords( + key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0) + s += h + i += 1 + } + } + + Seq("off", "on").foreach { heap => + benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter => + val taskMemoryManager = new TaskMemoryManager( + new StaticMemoryManager( + new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}") + .set("spark.memory.offHeap.size", "102400000"), + Long.MaxValue, + Long.MaxValue, + 1), + 0) + val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20) + val keyBytes = new Array[Byte](16) + val valueBytes = new Array[Byte](16) + val key = new UnsafeRow(1) + key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16) + val value = new UnsafeRow(2) + value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16) + var i = 0 + while (i < values) { + key.setInt(0, i % 65536) + val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes) + if (loc.isDefined) { + value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset, + loc.getValueLength) + value.setInt(0, value.getInt(0) + 1) + i += 1 + } else { + loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, + value.getBaseObject, value.getBaseOffset, value.getSizeInBytes) + } + } + } + } + + /** + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + hash 662.06 79.19 1.00 X + BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X + BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X + */ + benchmark.run() + } + + test("benchmark") { + // testWholeStage(1024 * 1024 * 200) + // testAggregateWithKey(20 << 20) + // testBytesToBytesMap(1024 * 1024 * 50) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 300788c88a..c2516509df 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) assert(df.collect() === Array(Row(9, 4.5))) } + + test("Aggregate with grouping keys should be included in WholeStageCodegen") { + val df = sqlContext.range(3).groupBy("id").count().orderBy("id") + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined) + assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1))) + } } -- cgit v1.2.3