diff options
author | Sameer Agarwal <sameer@databricks.com> | 2016-04-08 13:52:28 -0700 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2016-04-08 13:52:28 -0700 |
commit | f8c9beca38f1f396eb3220b23db6d77112a50293 (patch) | |
tree | aa5a5e8867443c98fb07aa373f2a0900ee2cd0bc | |
parent | 02757535b58069ce8258108d89d8172a53c358e5 (diff) | |
download | spark-f8c9beca38f1f396eb3220b23db6d77112a50293.tar.gz spark-f8c9beca38f1f396eb3220b23db6d77112a50293.tar.bz2 spark-f8c9beca38f1f396eb3220b23db6d77112a50293.zip |
[SPARK-14394][SQL] Generate AggregateHashMap class for LongTypes during TungstenAggregate codegen
## What changes were proposed in this pull request?
This PR adds support for generating the `AggregateHashMap` class in `TungstenAggregate` if the aggregate group by keys/value are of `LongType`. Note that currently this generate aggregate is not actually used.
NB: This currently only supports `LongType` keys/values (please see `isAggregateHashMapSupported` in `TungstenAggregate`) and will be generalized to other data types in a subsequent PR.
## How was this patch tested?
Manually inspected the generated code. This is what the generated map looks like for 2 keys:
```java
/* 068 */ public class agg_GeneratedAggregateHashMap {
/* 069 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
/* 070 */ private int[] buckets;
/* 071 */ private int numBuckets;
/* 072 */ private int maxSteps;
/* 073 */ private int numRows = 0;
/* 074 */ private org.apache.spark.sql.types.StructType schema =
/* 075 */ new org.apache.spark.sql.types.StructType()
/* 076 */ .add("k1", org.apache.spark.sql.types.DataTypes.LongType)
/* 077 */ .add("k2", org.apache.spark.sql.types.DataTypes.LongType)
/* 078 */ .add("sum", org.apache.spark.sql.types.DataTypes.LongType);
/* 079 */
/* 080 */ public agg_GeneratedAggregateHashMap(int capacity, double loadFactor, int maxSteps) {
/* 081 */ assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
/* 082 */ this.maxSteps = maxSteps;
/* 083 */ numBuckets = (int) (capacity / loadFactor);
/* 084 */ batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
/* 085 */ org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
/* 086 */ buckets = new int[numBuckets];
/* 087 */ java.util.Arrays.fill(buckets, -1);
/* 088 */ }
/* 089 */
/* 090 */ public agg_GeneratedAggregateHashMap() {
/* 091 */ new agg_GeneratedAggregateHashMap(1 << 16, 0.25, 5);
/* 092 */ }
/* 093 */
/* 094 */ public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(long agg_key, long agg_key1) {
/* 095 */ long h = hash(agg_key, agg_key1);
/* 096 */ int step = 0;
/* 097 */ int idx = (int) h & (numBuckets - 1);
/* 098 */ while (step < maxSteps) {
/* 099 */ // Return bucket index if it's either an empty slot or already contains the key
/* 100 */ if (buckets[idx] == -1) {
/* 101 */ batch.column(0).putLong(numRows, agg_key);
/* 102 */ batch.column(1).putLong(numRows, agg_key1);
/* 103 */ batch.column(2).putLong(numRows, 0);
/* 104 */ buckets[idx] = numRows++;
/* 105 */ return batch.getRow(buckets[idx]);
/* 106 */ } else if (equals(idx, agg_key, agg_key1)) {
/* 107 */ return batch.getRow(buckets[idx]);
/* 108 */ }
/* 109 */ idx = (idx + 1) & (numBuckets - 1);
/* 110 */ step++;
/* 111 */ }
/* 112 */ // Didn't find it
/* 113 */ return null;
/* 114 */ }
/* 115 */
/* 116 */ private boolean equals(int idx, long agg_key, long agg_key1) {
/* 117 */ return batch.column(0).getLong(buckets[idx]) == agg_key && batch.column(1).getLong(buckets[idx]) == agg_key1;
/* 118 */ }
/* 119 */
/* 120 */ // TODO: Improve this Hash Function
/* 121 */ private long hash(long agg_key, long agg_key1) {
/* 122 */ return agg_key ^ agg_key1;
/* 123 */ }
/* 124 */
/* 125 */ }
```
Author: Sameer Agarwal <sameer@databricks.com>
Closes #12161 from sameeragarwal/tungsten-aggregate.
2 files changed, 210 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala new file mode 100644 index 0000000000..e415dd8e6a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala @@ -0,0 +1,193 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.types.StructType + +/** + * This is a helper object to generate an append-only single-key/single value aggregate hash + * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates + * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in + * TungstenAggregate to speed up aggregates w/ key. + * + * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the + * key-value pairs. The index lookups in the array rely on linear probing (with a small number of + * maximum tries) and use an inexpensive hash function which makes it really efficient for a + * majority of lookups. However, using linear probing and an inexpensive hash function also makes it + * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even + * for certain distribution of keys) and requires us to fall back on the latter for correctness. + */ +class ColumnarAggMapCodeGenerator( + ctx: CodegenContext, + generatedClassName: String, + groupingKeySchema: StructType, + bufferSchema: StructType) { + val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value"))) + val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") + + def generate(): String = { + s""" + |public class $generatedClassName { + |${initializeAggregateHashMap()} + | + |${generateFindOrInsert()} + | + |${generateEquals()} + | + |${generateHashFunction()} + |} + """.stripMargin + } + + private def initializeAggregateHashMap(): String = { + val generatedSchema: String = + s""" + |new org.apache.spark.sql.types.StructType() + |${(groupingKeySchema ++ bufferSchema).map(key => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") + .mkString("\n")}; + """.stripMargin + + s""" + | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; + | private int[] buckets; + | private int numBuckets; + | private int maxSteps; + | private int numRows = 0; + | private org.apache.spark.sql.types.StructType schema = $generatedSchema + | + | public $generatedClassName(int capacity, double loadFactor, int maxSteps) { + | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); + | this.maxSteps = maxSteps; + | numBuckets = (int) (capacity / loadFactor); + | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, + | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); + | buckets = new int[numBuckets]; + | java.util.Arrays.fill(buckets, -1); + | } + | + | public $generatedClassName() { + | new $generatedClassName(1 << 16, 0.25, 5); + | } + """.stripMargin + } + + /** + * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For + * instance, if we have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private long hash(long agg_key, long agg_key1) { + * return agg_key ^ agg_key1; + * } + * }}} + */ + private def generateHashFunction(): String = { + s""" + |// TODO: Improve this hash function + |private long hash($groupingKeySignature) { + | return ${groupingKeys.map(_._2).mkString(" ^ ")}; + |} + """.stripMargin + } + + /** + * Generates a method that returns true if the group-by keys exist at a given index in the + * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * private boolean equals(int idx, long agg_key, long agg_key1) { + * return batch.column(0).getLong(buckets[idx]) == agg_key && + * batch.column(1).getLong(buckets[idx]) == agg_key1; + * } + * }}} + */ + private def generateEquals(): String = { + s""" + |private boolean equals(int idx, $groupingKeySignature) { + | return ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")}; + |} + """.stripMargin + } + + /** + * Generates a method that returns a mutable + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the + * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the + * generated method adds the corresponding row in the associated + * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we + * have 2 long group-by keys, the generated function would be of the form: + * + * {{{ + * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert( + * long agg_key, long agg_key1) { + * long h = hash(agg_key, agg_key1); + * int step = 0; + * int idx = (int) h & (numBuckets - 1); + * while (step < maxSteps) { + * // Return bucket index if it's either an empty slot or already contains the key + * if (buckets[idx] == -1) { + * batch.column(0).putLong(numRows, agg_key); + * batch.column(1).putLong(numRows, agg_key1); + * batch.column(2).putLong(numRows, 0); + * buckets[idx] = numRows++; + * return batch.getRow(buckets[idx]); + * } else if (equals(idx, agg_key, agg_key1)) { + * return batch.getRow(buckets[idx]); + * } + * idx = (idx + 1) & (numBuckets - 1); + * step++; + * } + * // Didn't find it + * return null; + * } + * }}} + */ + private def generateFindOrInsert(): String = { + s""" + |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ + groupingKeySignature}) { + | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); + | int step = 0; + | int idx = (int) h & (numBuckets - 1); + | while (step < maxSteps) { + | // Return bucket index if it's either an empty slot or already contains the key + | if (buckets[idx] == -1) { + | ${groupingKeys.zipWithIndex.map(k => + s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} + | ${bufferValues.zipWithIndex.map(k => + s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);") + .mkString("\n")} + | buckets[idx] = numRows++; + | return batch.getRow(buckets[idx]); + | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { + | return batch.getRow(buckets[idx]); + | } + | idx = (idx + 1) & (numBuckets - 1); + | step++; + | } + | // Didn't find it + | return null; + |} + """.stripMargin + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 60027edc7c..0a5a72c52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{LongType, StructType} import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -64,8 +64,8 @@ case class TungstenAggregate( override def requiredChildDistribution: List[Distribution] = { requiredChildDistributionExpressions match { - case Some(exprs) if exprs.length == 0 => AllTuples :: Nil - case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil + case Some(exprs) if exprs.isEmpty => AllTuples :: Nil + case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil case None => UnspecifiedDistribution :: Nil } } @@ -437,6 +437,19 @@ case class TungstenAggregate( val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + // create AggregateHashMap + val isAggregateHashMapEnabled: Boolean = false + val isAggregateHashMapSupported: Boolean = + (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) + val aggregateHashMapTerm = ctx.freshName("aggregateHashMap") + val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap") + val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName, + groupingKeySchema, bufferSchema) + if (isAggregateHashMapEnabled && isAggregateHashMapSupported) { + ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm, + s"$aggregateHashMapTerm = new $aggregateHashMapClassName();") + } + // create hashMap val thisPlan = ctx.addReferenceObj("plan", this) hashMapTerm = ctx.freshName("hashMap") @@ -452,6 +465,7 @@ case class TungstenAggregate( val doAgg = ctx.freshName("doAggregateWithKeys") ctx.addNewFunction(doAgg, s""" + ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""} private void $doAgg() throws java.io.IOException { ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} |