aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala39
1 files changed, 27 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 7c215d1b96..253592028c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
@@ -64,8 +64,8 @@ case class TungstenAggregate(
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
@@ -266,8 +266,8 @@ case class TungstenAggregate(
private var sorterTerm: String = _
/**
- * This is called by generated Java class, should be public.
- */
+ * This is called by generated Java class, should be public.
+ */
def createHashMap(): UnsafeFixedWidthAggregationMap = {
// create initialized aggregate buffer
val initExpr = declFunctions.flatMap(f => f.initialValues)
@@ -286,15 +286,15 @@ case class TungstenAggregate(
}
/**
- * This is called by generated Java class, should be public.
- */
+ * This is called by generated Java class, should be public.
+ */
def createUnsafeJoiner(): UnsafeRowJoiner = {
GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema)
}
/**
- * Called by generated Java class to finish the aggregate and return a KVIterator.
- */
+ * Called by generated Java class to finish the aggregate and return a KVIterator.
+ */
def finishAggregate(
hashMap: UnsafeFixedWidthAggregationMap,
sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = {
@@ -372,8 +372,8 @@ case class TungstenAggregate(
}
/**
- * Generate the code for output.
- */
+ * Generate the code for output.
+ */
private def generateResultCode(
ctx: CodegenContext,
keyTerm: String,
@@ -437,11 +437,24 @@ case class TungstenAggregate(
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+ // create AggregateHashMap
+ val isAggregateHashMapEnabled: Boolean = false
+ val isAggregateHashMapSupported: Boolean =
+ (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
+ val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
+ val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
+ val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
+ groupingKeySchema, bufferSchema)
+ if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
+ ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
+ s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")
+ }
+
// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
- ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();")
+ ctx.addMutableState(hashMapClassName, hashMapTerm, "")
sorterTerm = ctx.freshName("sorter")
ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "")
@@ -452,7 +465,9 @@ case class TungstenAggregate(
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
+ ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
+ $hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm);