aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-02 19:47:44 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-02 19:47:44 -0800
commit99a6e3c1e8d580ce1cc497bd9362eaf16c597f77 (patch)
tree7c2ea61d3e26b4d9c518426e23de7d5ef059a17e
parentff71261b651a7b289ea2312abd6075da8b838ed9 (diff)
downloadspark-99a6e3c1e8d580ce1cc497bd9362eaf16c597f77.tar.gz
spark-99a6e3c1e8d580ce1cc497bd9362eaf16c597f77.tar.bz2
spark-99a6e3c1e8d580ce1cc497bd9362eaf16c597f77.zip
[SPARK-12951] [SQL] support spilling in generated aggregate
This PR add spilling support for generated TungstenAggregate. If spilling happened, it's not that bad to do the iterator based sort-merge-aggregate (not generated). The changes will be covered by TungstenAggregationQueryWithControlledFallbackSuite Author: Davies Liu <davies@databricks.com> Closes #10998 from davies/gen_spilling.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala172
1 files changed, 142 insertions, 30 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index a8a81d6d65..f61db8594d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -25,9 +25,9 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.{DecimalType, StructType}
+import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
@@ -258,6 +258,7 @@ case class TungstenAggregate(
// The name for HashMap
private var hashMapTerm: String = _
+ private var sorterTerm: String = _
/**
* This is called by generated Java class, should be public.
@@ -286,39 +287,98 @@ case class TungstenAggregate(
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}
-
/**
- * Update peak execution memory, called in generated Java class.
+ * Called by generated Java class to finish the aggregate and return a KVIterator.
*/
- def updatePeakMemory(hashMap: UnsafeFixedWidthAggregationMap): Unit = {
+ def finishAggregate(
+ hashMap: UnsafeFixedWidthAggregationMap,
+ sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
+
+ // update peak execution memory
val mapMemory = hashMap.getPeakMemoryUsedBytes
+ val sorterMemory = Option(sorter).map(_.getPeakMemoryUsedBytes).getOrElse(0L)
+ val peakMemory = Math.max(mapMemory, sorterMemory)
val metrics = TaskContext.get().taskMetrics()
- metrics.incPeakExecutionMemory(mapMemory)
- }
+ metrics.incPeakExecutionMemory(peakMemory)
- private def doProduceWithKeys(ctx: CodegenContext): String = {
- val initAgg = ctx.freshName("initAgg")
- ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+ if (sorter == null) {
+ // not spilled
+ return hashMap.iterator()
+ }
- // create hashMap
- val thisPlan = ctx.addReferenceObj("plan", this)
- hashMapTerm = ctx.freshName("hashMap")
- val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
- ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+ // merge the final hashMap into sorter
+ sorter.merge(hashMap.destructAndCreateExternalSorter())
+ hashMap.free()
+ val sortedIter = sorter.sortedIterator()
+
+ // Create a KVIterator based on the sorted iterator.
+ new KVIterator[UnsafeRow, UnsafeRow] {
+
+ // Create a MutableProjection to merge the rows of same key together
+ val mergeExpr = declFunctions.flatMap(_.mergeExpressions)
+ val mergeProjection = newMutableProjection(
+ mergeExpr,
+ bufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes),
+ subexpressionEliminationEnabled)()
+ val joinedRow = new JoinedRow()
+
+ var currentKey: UnsafeRow = null
+ var currentRow: UnsafeRow = null
+ var nextKey: UnsafeRow = if (sortedIter.next()) {
+ sortedIter.getKey
+ } else {
+ null
+ }
- // Create a name for iterator from HashMap
- val iterTerm = ctx.freshName("mapIter")
- ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
+ override def next(): Boolean = {
+ if (nextKey != null) {
+ currentKey = nextKey.copy()
+ currentRow = sortedIter.getValue.copy()
+ nextKey = null
+ // use the first row as aggregate buffer
+ mergeProjection.target(currentRow)
+
+ // merge the following rows with same key together
+ var findNextGroup = false
+ while (!findNextGroup && sortedIter.next()) {
+ val key = sortedIter.getKey
+ if (currentKey.equals(key)) {
+ mergeProjection(joinedRow(currentRow, sortedIter.getValue))
+ } else {
+ // We find a new group.
+ findNextGroup = true
+ nextKey = key
+ }
+ }
+
+ true
+ } else {
+ false
+ }
+ }
- // generate code for output
- val keyTerm = ctx.freshName("aggKey")
- val bufferTerm = ctx.freshName("aggBuffer")
- val outputCode = if (modes.contains(Final) || modes.contains(Complete)) {
+ override def getKey: UnsafeRow = currentKey
+ override def getValue: UnsafeRow = currentRow
+ override def close(): Unit = {
+ sortedIter.close()
+ }
+ }
+ }
+
+ /**
+ * Generate the code for output.
+ */
+ private def generateResultCode(
+ ctx: CodegenContext,
+ keyTerm: String,
+ bufferTerm: String,
+ plan: String): String = {
+ if (modes.contains(Final) || modes.contains(Complete)) {
// generate output using resultExpressions
ctx.currentVars = null
ctx.INPUT_ROW = keyTerm
val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) =>
- BoundReference(i, e.dataType, e.nullable).gen(ctx)
+ BoundReference(i, e.dataType, e.nullable).gen(ctx)
}
ctx.INPUT_ROW = bufferTerm
val bufferVars = bufferAttributes.zipWithIndex.map { case (e, i) =>
@@ -348,7 +408,7 @@ case class TungstenAggregate(
// This should be the last operator in a stage, we should output UnsafeRow directly
val joinerTerm = ctx.freshName("unsafeRowJoiner")
ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm,
- s"$joinerTerm = $thisPlan.createUnsafeJoiner();")
+ s"$joinerTerm = $plan.createUnsafeJoiner();")
val resultRow = ctx.freshName("resultRow")
s"""
UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm);
@@ -367,6 +427,23 @@ case class TungstenAggregate(
${consume(ctx, eval)}
"""
}
+ }
+
+ private def doProduceWithKeys(ctx: CodegenContext): String = {
+ val initAgg = ctx.freshName("initAgg")
+ ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+
+ // create hashMap
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ hashMapTerm = ctx.freshName("hashMap")
+ val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
+ ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+ sorterTerm = ctx.freshName("sorter")
+ ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
+
+ // Create a name for iterator from HashMap
+ val iterTerm = ctx.freshName("mapIter")
+ ctx.addMutableState(classOf[KVIterator[UnsafeRow, UnsafeRow]].getName, iterTerm, "")
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
@@ -374,10 +451,15 @@ case class TungstenAggregate(
private void $doAgg() throws java.io.IOException {
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
- $iterTerm = $hashMapTerm.iterator();
+ $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);
}
""")
+ // generate code for output
+ val keyTerm = ctx.freshName("aggKey")
+ val bufferTerm = ctx.freshName("aggBuffer")
+ val outputCode = generateResultCode(ctx, keyTerm, bufferTerm, thisPlan)
+
s"""
if (!$initAgg) {
$initAgg = true;
@@ -391,8 +473,10 @@ case class TungstenAggregate(
$outputCode
}
- $thisPlan.updatePeakMemory($hashMapTerm);
- $hashMapTerm.free();
+ $iterTerm.close();
+ if ($sorterTerm == null) {
+ $hashMapTerm.free();
+ }
"""
}
@@ -425,14 +509,42 @@ case class TungstenAggregate(
ctx.updateColumn(buffer, dt, i, ev, updateExpr(i).nullable)
}
+ val (checkFallback, resetCoulter, incCounter) = if (testFallbackStartsAt.isDefined) {
+ val countTerm = ctx.freshName("fallbackCounter")
+ ctx.addMutableState("int", countTerm, s"$countTerm = 0;")
+ (s"$countTerm < ${testFallbackStartsAt.get}", s"$countTerm = 0;", s"$countTerm += 1;")
+ } else {
+ ("true", "", "")
+ }
+
+ // We try to do hash map based in-memory aggregation first. If there is not enough memory (the
+ // hash map will return null for new key), we spill the hash map to disk to free memory, then
+ // continue to do in-memory aggregation and spilling until all the rows had been processed.
+ // Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
// generate grouping key
${keyCode.code}
- UnsafeRow $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ UnsafeRow $buffer = null;
+ if ($checkFallback) {
+ // try to get the buffer from hash map
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ }
if ($buffer == null) {
- // failed to allocate the first page
- throw new OutOfMemoryError("No enough memory for aggregation");
+ if ($sorterTerm == null) {
+ $sorterTerm = $hashMapTerm.destructAndCreateExternalSorter();
+ } else {
+ $sorterTerm.merge($hashMapTerm.destructAndCreateExternalSorter());
+ }
+ $resetCoulter
+ // the hash map had be spilled, it should have enough memory now,
+ // try to allocate buffer again.
+ $buffer = $hashMapTerm.getAggregationBufferFromUnsafeRow($key);
+ if ($buffer == null) {
+ // failed to allocate the first page
+ throw new OutOfMemoryError("No enough memory for aggregation");
+ }
}
+ $incCounter
// evaluate aggregate function
${evals.map(_.code).mkString("\n")}