aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-01-29 20:16:11 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-29 20:16:11 -0800
commite6a02c66d53f59ba2d5c1548494ae80a385f9f5c (patch)
tree54be709e37f53a6bd7a4768568d4e416e6cc5d7a
parent12252d1da90fa7d2dffa3a7c249ecc8821dee130 (diff)
downloadspark-e6a02c66d53f59ba2d5c1548494ae80a385f9f5c.tar.gz
spark-e6a02c66d53f59ba2d5c1548494ae80a385f9f5c.tar.bz2
spark-e6a02c66d53f59ba2d5c1548494ae80a385f9f5c.zip
[SPARK-12914] [SQL] generate aggregation with grouping keys
This PR add support for grouping keys for generated TungstenAggregate. Spilling and performance improvements for BytesToBytesMap will be done by followup PR. Author: Davies Liu <davies@databricks.com> Closes #10855 from davies/gen_keys.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala47
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala238
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala119
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala9
6 files changed, 393 insertions, 53 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index e6704cf8bb..21f9198073 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -56,6 +56,20 @@ class CodegenContext {
val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]()
/**
+ * Add an object to `references`, create a class member to access it.
+ *
+ * Returns the name of class member.
+ */
+ def addReferenceObj(name: String, obj: Any, className: String = null): String = {
+ val term = freshName(name)
+ val idx = references.length
+ references += obj
+ val clsName = Option(className).getOrElse(obj.getClass.getName)
+ addMutableState(clsName, term, s"this.$term = ($clsName) references[$idx];")
+ term
+ }
+
+ /**
* Holding a list of generated columns as input of current operator, will be used by
* BoundReference to generate code.
*/
@@ -199,6 +213,39 @@ class CodegenContext {
}
/**
+ * Update a column in MutableRow from ExprCode.
+ */
+ def updateColumn(
+ row: String,
+ dataType: DataType,
+ ordinal: Int,
+ ev: ExprCode,
+ nullable: Boolean): String = {
+ if (nullable) {
+ // Can't call setNullAt on DecimalType, because we need to keep the offset
+ if (dataType.isInstanceOf[DecimalType]) {
+ s"""
+ if (!${ev.isNull}) {
+ ${setColumn(row, dataType, ordinal, ev.value)};
+ } else {
+ ${setColumn(row, dataType, ordinal, "null")};
+ }
+ """
+ } else {
+ s"""
+ if (!${ev.isNull}) {
+ ${setColumn(row, dataType, ordinal, ev.value)};
+ } else {
+ $row.setNullAt($ordinal);
+ }
+ """
+ }
+ } else {
+ s"""${setColumn(row, dataType, ordinal, ev.value)};"""
+ }
+ }
+
+ /**
* Returns the name used in accessor and setter for a Java primitive type.
*/
def primitiveTypeName(jt: String): String = jt match {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index ec31db19b9..5b4dc8df86 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -88,31 +88,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
val updates = validExpr.zip(index).map {
case (e, i) =>
- if (e.nullable) {
- if (e.dataType.isInstanceOf[DecimalType]) {
- // Can't call setNullAt on DecimalType, because we need to keep the offset
- s"""
- if (this.isNull_$i) {
- ${ctx.setColumn("mutableRow", e.dataType, i, "null")};
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
- """
- } else {
- s"""
- if (this.isNull_$i) {
- mutableRow.setNullAt($i);
- } else {
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- }
- """
- }
- } else {
- s"""
- ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
- """
- }
-
+ val ev = ExprCode("", s"this.isNull_$i", s"this.value_$i")
+ ctx.updateColumn("mutableRow", e.dataType, i, ev, e.nullable)
}
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index b1bbb1da10..6acf70dbba 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution;
+import java.io.IOException;
+
import scala.collection.Iterator;
import org.apache.spark.sql.catalyst.InternalRow;
@@ -34,7 +36,7 @@ public class BufferedRowIterator {
// used when there is no column in output
protected UnsafeRow unsafeRow = new UnsafeRow(0);
- public boolean hasNext() {
+ public boolean hasNext() throws IOException {
if (currentRow == null) {
processNext();
}
@@ -56,7 +58,7 @@ public class BufferedRowIterator {
*
* After it's called, if currentRow is still null, it means no more rows left.
*/
- protected void processNext() {
+ protected void processNext() throws IOException {
if (input.hasNext()) {
currentRow = input.next();
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index ff2f38bfd9..57db7262fd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -17,16 +17,18 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
+import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{DecimalType, StructType}
+import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
requiredChildDistributionExpressions: Option[Seq[Expression]],
@@ -114,22 +116,38 @@ case class TungstenAggregate(
}
}
+ // all the mode of aggregate expressions
+ private val modes = aggregateExpressions.map(_.mode).distinct
+
override def supportCodegen: Boolean = {
- groupingExpressions.isEmpty &&
- // ImperativeAggregate is not supported right now
- !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
+ // ImperativeAggregate is not supported right now
+ !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate])
}
- // The variables used as aggregation buffer
- private var bufVars: Seq[ExprCode] = _
-
- private val modes = aggregateExpressions.map(_.mode).distinct
-
override def upstream(): RDD[InternalRow] = {
child.asInstanceOf[CodegenSupport].upstream()
}
protected override def doProduce(ctx: CodegenContext): String = {
+ if (groupingExpressions.isEmpty) {
+ doProduceWithoutKeys(ctx)
+ } else {
+ doProduceWithKeys(ctx)
+ }
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ if (groupingExpressions.isEmpty) {
+ doConsumeWithoutKeys(ctx, input)
+ } else {
+ doConsumeWithKeys(ctx, input)
+ }
+ }
+
+ // The variables used as aggregation buffer
+ private var bufVars: Seq[ExprCode] = _
+
+ private def doProduceWithoutKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
@@ -176,10 +194,10 @@ case class TungstenAggregate(
(resultVars, resultVars.map(_.code).mkString("\n"))
}
- val doAgg = ctx.freshName("doAgg")
+ val doAgg = ctx.freshName("doAggregateWithoutKey")
ctx.addNewFunction(doAgg,
s"""
- | private void $doAgg() {
+ | private void $doAgg() throws java.io.IOException {
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
@@ -200,7 +218,7 @@ case class TungstenAggregate(
""".stripMargin
}
- override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output
@@ -212,7 +230,6 @@ case class TungstenAggregate(
e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
}
}
-
ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination
val aggVals = updateExpr.map(BindReferences.bindReference(_, inputAttrs).gen(ctx))
@@ -232,6 +249,199 @@ case class TungstenAggregate(
""".stripMargin
}
+ private val groupingAttributes = groupingExpressions.map(_.toAttribute)
+ private val groupingKeySchema = StructType.fromAttributes(groupingAttributes)
+ private val declFunctions = aggregateExpressions.map(_.aggregateFunction)
+ .filter(_.isInstanceOf[DeclarativeAggregate])
+ .map(_.asInstanceOf[DeclarativeAggregate])
+ private val bufferAttributes = declFunctions.flatMap(_.aggBufferAttributes)
+ private val bufferSchema = StructType.fromAttributes(bufferAttributes)
+
+ // The name for HashMap
+ private var hashMapTerm: String = _
+
+ /**
+ * This is called by generated Java class, should be public.
+ */
+ def createHashMap(): UnsafeFixedWidthAggregationMap = {
+ // create initialized aggregate buffer
+ val initExpr = declFunctions.flatMap(f => f.initialValues)
+ val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
+
+ // create hashMap
+ new UnsafeFixedWidthAggregationMap(
+ initialBuffer,
+ bufferSchema,
+ groupingKeySchema,
+ TaskContext.get().taskMemoryManager(),
+ 1024 * 16, // initial capacity
+ TaskContext.get().taskMemoryManager().pageSizeBytes,
+ false // disable tracking of performance metrics
+ )
+ }
+
+ /**
+ * This is called by generated Java class, should be public.
+ */
+ def createUnsafeJoiner(): UnsafeRowJoiner = {
+ GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
+ }
+
+
+ /**
+ * Update peak execution memory, called in generated Java class.
+ */
+ def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
+ val mapMemory = hashMap.getPeakMemoryUsedBytes
+ val metrics = TaskContext.get().taskMetrics()
+ metrics.incPeakExecutionMemory(mapMemory)
+ }
+
+ private def doProduceWithKeys(ctx: CodegenContext): String = {
+ val initAgg = ctx.freshName("initAgg")
+ ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+ // create hashMap
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ hashMapTerm = ctx.freshName("hashMap")
+ val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
+ ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+
+ // Create a name for iterator from HashMap
+ val iterTerm = ctx.freshName("mapIter")
+ ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
+
+ // generate code for output
+ val keyTerm = ctx.freshName("aggKey")
+ val bufferTerm = ctx.freshName("aggBuffer")
+ val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
+ // generate output using resultExpressions
+ ctx.currentVars = null
+ ctx.INPUT_ROW = keyTerm
+ val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ }
+ ctx.INPUT_ROW = bufferTerm
+ val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ }
+ // evaluate the aggregation result
+ ctx.currentVars = bufferVars
+ val aggResults = declFunctions.map(_.evaluateExpression).map { e =>
+ BindReferences.bindReference(e, bufferAttributes).gen(ctx)
+ }
+ // generate the final result
+ ctx.currentVars = keyVars ++ aggResults
+ val inputAttrs = groupingAttributes ++ aggregateAttributes
+ val resultVars = resultExpressions.map { e =>
+ BindReferences.bindReference(e, inputAttrs).gen(ctx)
+ }
+ s"""
+ ${keyVars.map(_.code).mkString("\n")}
+ ${bufferVars.map(_.code).mkString("\n")}
+ ${aggResults.map(_.code).mkString("\n")}
+ ${resultVars.map(_.code).mkString("\n")}
+
+ ${consume(ctx, resultVars)}
+ """
+
+ } else if (modes.contains(Partial) || modes.contains(PartialMerge)) {
+ // This should be the last operator in a stage, we should output UnsafeRow directly
+ val joinerTerm = ctx.freshName("unsafeRowJoiner")
+ ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
+ s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
+ val resultRow = ctx.freshName("resultRow")
+ s"""
+ UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
+ ${consume(ctx, null, resultRow)}
+ """
+
+ } else {
+ // generate result based on grouping key
+ ctx.INPUT_ROW = keyTerm
+ ctx.currentVars = null
+ val eval = resultExpressions.map{ e =>
+ BindReferences.bindReference(e, groupingAttributes).gen(ctx)
+ }
+ s"""
+ ${eval.map(_.code).mkString("\n")}
+ ${consume(ctx, eval)}
+ """
+ }
+
+ val doAgg = ctx.freshName("doAggregateWithKeys")
+ ctx.addNewFunction(doAgg,
+ s"""
+ private void $doAgg() throws java.io.IOException {
+ ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+
+ $iterTerm = $hashMapTerm.iterator();
+ }
+ """)
+
+ s"""
+ if (!$initAgg) {
+ $initAgg = true;
+ $doAgg();
+ }
+
+ // output the result
+ while ($iterTerm.next()) {
+ UnsafeRow $keyTerm = (UnsafeRow) $iterTerm.getKey();
+ UnsafeRow $bufferTerm = (UnsafeRow) $iterTerm.getValue();
+ $outputCode
+ }
+
+ $thisPlan.updatePeakMemory($hashMapTerm);
+ $hashMapTerm.free();
+ """
+ }
+
+ private def doConsumeWithKeys( ctx: CodegenContext, input: Seq[ExprCode]): String = {
+
+ // create grouping key
+ ctx.currentVars = input
+ val keyCode = GenerateUnsafeProjection.createCode(
+ ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+ val key = keyCode.value
+ val buffer = ctx.freshName("aggBuffer")
+
+ // only have DeclarativeAggregate
+ val updateExpr = aggregateExpressions.flatMap { e =>
+ e.mode match {
+ case Partial | Complete =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions
+ case PartialMerge | Final =>
+ e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions
+ }
+ }
+
+ val inputAttr = bufferAttributes ++ child.output
+ ctx.currentVars = new Array[ExprCode](bufferAttributes.length) ++ input
+ ctx.INPUT_ROW = buffer
+ // TODO: support subexpression elimination
+ val evals = updateExpr.map(BindReferences.bindReference(_, inputAttr).gen(ctx))
+ val updates = evals.zipWithIndex.map { case (ev, i) =>
+ val dt = updateExpr(i).dataType
+ ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
+ }
+
+ s"""
+ // generate grouping key
+ ${keyCode.code}
+ UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ if ($buffer == null) {
+ // failed to allocate the first page
+ throw new OutOfMemoryError("No enough memory for aggregation");
+ }
+
+ // evaluate aggregate function
+ ${evals.map(_.code).mkString("\n")}
+ // update aggregate buffer
+ ${updates.mkString("\n")}
+ """
+ }
+
override def simpleString: String = {
val allAggregateExpressions = aggregateExpressions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index c4aad398bf..2f09c8a114 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -18,7 +18,12 @@
package org.apache.spark.sql.execution
import org.apache.spark.{SparkConf, SparkContext, SparkFunSuite}
+import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.unsafe.Platform
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
+import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.util.Benchmark
/**
@@ -27,34 +32,124 @@ import org.apache.spark.util.Benchmark
* build/sbt "sql/test-only *BenchmarkWholeStageCodegen"
*/
class BenchmarkWholeStageCodegen extends SparkFunSuite {
- def testWholeStage(values: Int): Unit = {
- val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
- val sc = SparkContext.getOrCreate(conf)
- val sqlContext = SQLContext.getOrCreate(sc)
+ lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark")
+ lazy val sc = SparkContext.getOrCreate(conf)
+ lazy val sqlContext = SQLContext.getOrCreate(sc)
- val benchmark = new Benchmark("Single Int Column Scan", values)
+ def testWholeStage(values: Int): Unit = {
+ val benchmark = new Benchmark("rang/filter/aggregate", values)
- benchmark.addCase("Without whole stage codegen") { iter =>
+ benchmark.addCase("Without codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
sqlContext.range(values).filter("(id & 1) = 1").count()
}
- benchmark.addCase("With whole stage codegen") { iter =>
+ benchmark.addCase("With codegen") { iter =>
sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
sqlContext.range(values).filter("(id & 1) = 1").count()
}
/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
- Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ rang/filter/aggregate: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
- Without whole stage codegen 7775.53 26.97 1.00 X
- With whole stage codegen 342.15 612.94 22.73 X
+ Without codegen 7775.53 26.97 1.00 X
+ With codegen 342.15 612.94 22.73 X
*/
benchmark.run()
}
- ignore("benchmark") {
- testWholeStage(1024 * 1024 * 200)
+ def testAggregateWithKey(values: Int): Unit = {
+ val benchmark = new Benchmark("Aggregate with keys", values)
+
+ benchmark.addCase("Aggregate w/o codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "false")
+ sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+ }
+ benchmark.addCase(s"Aggregate w codegen") { iter =>
+ sqlContext.setConf("spark.sql.codegen.wholeStage", "true")
+ sqlContext.range(values).selectExpr("(id & 65535) as k").groupBy("k").sum().collect()
+ }
+
+ /*
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ Aggregate w/o codegen 4254.38 4.93 1.00 X
+ Aggregate w codegen 2661.45 7.88 1.60 X
+ */
+ benchmark.run()
+ }
+
+ def testBytesToBytesMap(values: Int): Unit = {
+ val benchmark = new Benchmark("BytesToBytesMap", values)
+
+ benchmark.addCase("hash") { iter =>
+ var i = 0
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(2)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var s = 0
+ while (i < values) {
+ key.setInt(0, i % 1000)
+ val h = Murmur3_x86_32.hashUnsafeWords(
+ key.getBaseObject, key.getBaseOffset, key.getSizeInBytes, 0)
+ s += h
+ i += 1
+ }
+ }
+
+ Seq("off", "on").foreach { heap =>
+ benchmark.addCase(s"BytesToBytesMap ($heap Heap)") { iter =>
+ val taskMemoryManager = new TaskMemoryManager(
+ new StaticMemoryManager(
+ new SparkConf().set("spark.memory.offHeap.enabled", s"${heap == "off"}")
+ .set("spark.memory.offHeap.size", "102400000"),
+ Long.MaxValue,
+ Long.MaxValue,
+ 1),
+ 0)
+ val map = new BytesToBytesMap(taskMemoryManager, 1024, 64L<<20)
+ val keyBytes = new Array[Byte](16)
+ val valueBytes = new Array[Byte](16)
+ val key = new UnsafeRow(1)
+ key.pointTo(keyBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ val value = new UnsafeRow(2)
+ value.pointTo(valueBytes, Platform.BYTE_ARRAY_OFFSET, 16)
+ var i = 0
+ while (i < values) {
+ key.setInt(0, i % 65536)
+ val loc = map.lookup(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes)
+ if (loc.isDefined) {
+ value.pointTo(loc.getValueAddress.getBaseObject, loc.getValueAddress.getBaseOffset,
+ loc.getValueLength)
+ value.setInt(0, value.getInt(0) + 1)
+ i += 1
+ } else {
+ loc.putNewKey(key.getBaseObject, key.getBaseOffset, key.getSizeInBytes,
+ value.getBaseObject, value.getBaseOffset, value.getSizeInBytes)
+ }
+ }
+ }
+ }
+
+ /**
+ Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
+ Aggregate with keys: Avg Time(ms) Avg Rate(M/s) Relative Rate
+ -------------------------------------------------------------------------------
+ hash 662.06 79.19 1.00 X
+ BytesToBytesMap (off Heap) 2209.42 23.73 0.30 X
+ BytesToBytesMap (on Heap) 2957.68 17.73 0.22 X
+ */
+ benchmark.run()
+ }
+
+ test("benchmark") {
+ // testWholeStage(1024 * 1024 * 200)
+ // testAggregateWithKey(20 << 20)
+ // testBytesToBytesMap(1024 * 1024 * 50)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 300788c88a..c2516509df 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -47,4 +47,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5)))
}
+
+ test("Aggregate with grouping keys should be included in WholeStageCodegen") {
+ val df = sqlContext.range(3).groupBy("id").count().orderBy("id")
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
+ assert(df.collect() === Array(Row(0, 1), Row(1, 1), Row(2, 1)))
+ }
}