diff options
Diffstat (limited to 'sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala | 39 |
1 files changed, 27 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 7c215d1b96..253592028c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -64,8 +64,8 @@ case class TungstenAggregate( override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -266,8 +266,8 @@ case class TungstenAggregate( private var sorterTerm: String = _ /** - * This is called by generated Java class, should be public. - */ + * This is called by generated Java class, should be public. + */ def createHashMap(): UnsafeFixedWidthAggregationMap = { // create initialized aggregate buffer val initExpr = declFunctions.flatMap(f => f.initialValues) @@ -286,15 +286,15 @@ case class TungstenAggregate( } /** - * This is called by generated Java class, should be public. - */ + * This is called by generated Java class, should be public. + */ def createUnsafeJoiner(): UnsafeRowJoiner = { GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } /** - * Called by generated Java class to finish the aggregate and return a KVIterator. - */ + * Called by generated Java class to finish the aggregate and return a KVIterator. + */ def finishAggregate( hashMap: UnsafeFixedWidthAggregationMap, sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { @@ -372,8 +372,8 @@ case class TungstenAggregate( } /** - * Generate the code for output. - */ + * Generate the code for output. + */ private def generateResultCode( ctx: CodegenContext, keyTerm: String, @@ -437,11 +437,24 @@ case class TungstenAggregate( val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + // create AggregateHashMap + val isAggregateHashMapEnabled: Boolean = false + val isAggregateHashMapSupported: Boolean = + (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) + val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") + val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") + val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName, + groupingKeySchema, bufferSchema) + if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { + ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, + s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") + } + // create hashMap val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName - ctx.addMutableState(hashMapClassName, hashMapTerm, s"$hashMapTerm = $thisPlan.createHashMap();") + ctx.addMutableState(hashMapClassName, hashMapTerm, "") sorterTerm = ctx.freshName("sorter") ctx.addMutableState(classOf[UnsafeKVExternalSorter].getName, sorterTerm, "") @@ -452,7 +465,9 @@ case class TungstenAggregate( val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, s""" + ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { + $hashMapTerm = $thisPlan.createHashMap(); ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} $iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm); |