aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-04-08 13:52:28 -0700
committerYin Huai <yhuai@databricks.com>2016-04-08 13:52:28 -0700
commitf8c9beca38f1f396eb3220b23db6d77112a50293 (patch)
treeaa5a5e8867443c98fb07aa373f2a0900ee2cd0bc
parent02757535b58069ce8258108d89d8172a53c358e5 (diff)
downloadspark-f8c9beca38f1f396eb3220b23db6d77112a50293.tar.gz
spark-f8c9beca38f1f396eb3220b23db6d77112a50293.tar.bz2
spark-f8c9beca38f1f396eb3220b23db6d77112a50293.zip
[SPARK-14394][SQL] Generate AggregateHashMap class for LongTypes during TungstenAggregate codegen
## What changes were proposed in this pull request? This PR adds support for generating the `AggregateHashMap` class in `TungstenAggregate` if the aggregate group by keys/value are of `LongType`. Note that currently this generate aggregate is not actually used. NB: This currently only supports `LongType` keys/values (please see `isAggregateHashMapSupported` in `TungstenAggregate`) and will be generalized to other data types in a subsequent PR. ## How was this patch tested? Manually inspected the generated code. This is what the generated map looks like for 2 keys: ```java /* 068 */ public class agg_GeneratedAggregateHashMap { /* 069 */ private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; /* 070 */ private int[] buckets; /* 071 */ private int numBuckets; /* 072 */ private int maxSteps; /* 073 */ private int numRows = 0; /* 074 */ private org.apache.spark.sql.types.StructType schema = /* 075 */ new org.apache.spark.sql.types.StructType() /* 076 */ .add("k1", org.apache.spark.sql.types.DataTypes.LongType) /* 077 */ .add("k2", org.apache.spark.sql.types.DataTypes.LongType) /* 078 */ .add("sum", org.apache.spark.sql.types.DataTypes.LongType); /* 079 */ /* 080 */ public agg_GeneratedAggregateHashMap(int capacity, double loadFactor, int maxSteps) { /* 081 */ assert (capacity > 0 && ((capacity & (capacity - 1)) == 0)); /* 082 */ this.maxSteps = maxSteps; /* 083 */ numBuckets = (int) (capacity / loadFactor); /* 084 */ batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema, /* 085 */ org.apache.spark.memory.MemoryMode.ON_HEAP, capacity); /* 086 */ buckets = new int[numBuckets]; /* 087 */ java.util.Arrays.fill(buckets, -1); /* 088 */ } /* 089 */ /* 090 */ public agg_GeneratedAggregateHashMap() { /* 091 */ new agg_GeneratedAggregateHashMap(1 << 16, 0.25, 5); /* 092 */ } /* 093 */ /* 094 */ public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(long agg_key, long agg_key1) { /* 095 */ long h = hash(agg_key, agg_key1); /* 096 */ int step = 0; /* 097 */ int idx = (int) h & (numBuckets - 1); /* 098 */ while (step < maxSteps) { /* 099 */ // Return bucket index if it's either an empty slot or already contains the key /* 100 */ if (buckets[idx] == -1) { /* 101 */ batch.column(0).putLong(numRows, agg_key); /* 102 */ batch.column(1).putLong(numRows, agg_key1); /* 103 */ batch.column(2).putLong(numRows, 0); /* 104 */ buckets[idx] = numRows++; /* 105 */ return batch.getRow(buckets[idx]); /* 106 */ } else if (equals(idx, agg_key, agg_key1)) { /* 107 */ return batch.getRow(buckets[idx]); /* 108 */ } /* 109 */ idx = (idx + 1) & (numBuckets - 1); /* 110 */ step++; /* 111 */ } /* 112 */ // Didn't find it /* 113 */ return null; /* 114 */ } /* 115 */ /* 116 */ private boolean equals(int idx, long agg_key, long agg_key1) { /* 117 */ return batch.column(0).getLong(buckets[idx]) == agg_key && batch.column(1).getLong(buckets[idx]) == agg_key1; /* 118 */ } /* 119 */ /* 120 */ // TODO: Improve this Hash Function /* 121 */ private long hash(long agg_key, long agg_key1) { /* 122 */ return agg_key ^ agg_key1; /* 123 */ } /* 124 */ /* 125 */ } ``` Author: Sameer Agarwal <sameer@databricks.com> Closes #12161 from sameeragarwal/tungsten-aggregate.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala193
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala20
2 files changed, 210 insertions, 3 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
new file mode 100644
index 0000000000..e415dd8e6a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/ColumnarAggMapCodeGenerator.scala
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.aggregate
+
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
+import org.apache.spark.sql.types.StructType
+
+/**
+ * This is a helper object to generate an append-only single-key/single value aggregate hash
+ * map that can act as a 'cache' for extremely fast key-value lookups while evaluating aggregates
+ * (and fall back to the `BytesToBytesMap` if a given key isn't found). This is 'codegened' in
+ * TungstenAggregate to speed up aggregates w/ key.
+ *
+ * It is backed by a power-of-2-sized array for index lookups and a columnar batch that stores the
+ * key-value pairs. The index lookups in the array rely on linear probing (with a small number of
+ * maximum tries) and use an inexpensive hash function which makes it really efficient for a
+ * majority of lookups. However, using linear probing and an inexpensive hash function also makes it
+ * less robust as compared to the `BytesToBytesMap` (especially for a large number of keys or even
+ * for certain distribution of keys) and requires us to fall back on the latter for correctness.
+ */
+class ColumnarAggMapCodeGenerator(
+ ctx: CodegenContext,
+ generatedClassName: String,
+ groupingKeySchema: StructType,
+ bufferSchema: StructType) {
+ val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key")))
+ val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value")))
+ val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ")
+
+ def generate(): String = {
+ s"""
+ |public class $generatedClassName {
+ |${initializeAggregateHashMap()}
+ |
+ |${generateFindOrInsert()}
+ |
+ |${generateEquals()}
+ |
+ |${generateHashFunction()}
+ |}
+ """.stripMargin
+ }
+
+ private def initializeAggregateHashMap(): String = {
+ val generatedSchema: String =
+ s"""
+ |new org.apache.spark.sql.types.StructType()
+ |${(groupingKeySchema ++ bufferSchema).map(key =>
+ s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""")
+ .mkString("\n")};
+ """.stripMargin
+
+ s"""
+ | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch;
+ | private int[] buckets;
+ | private int numBuckets;
+ | private int maxSteps;
+ | private int numRows = 0;
+ | private org.apache.spark.sql.types.StructType schema = $generatedSchema
+ |
+ | public $generatedClassName(int capacity, double loadFactor, int maxSteps) {
+ | assert (capacity > 0 && ((capacity & (capacity - 1)) == 0));
+ | this.maxSteps = maxSteps;
+ | numBuckets = (int) (capacity / loadFactor);
+ | batch = org.apache.spark.sql.execution.vectorized.ColumnarBatch.allocate(schema,
+ | org.apache.spark.memory.MemoryMode.ON_HEAP, capacity);
+ | buckets = new int[numBuckets];
+ | java.util.Arrays.fill(buckets, -1);
+ | }
+ |
+ | public $generatedClassName() {
+ | new $generatedClassName(1 << 16, 0.25, 5);
+ | }
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that computes a hash by currently xor-ing all individual group-by keys. For
+ * instance, if we have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * private long hash(long agg_key, long agg_key1) {
+ * return agg_key ^ agg_key1;
+ * }
+ * }}}
+ */
+ private def generateHashFunction(): String = {
+ s"""
+ |// TODO: Improve this hash function
+ |private long hash($groupingKeySignature) {
+ | return ${groupingKeys.map(_._2).mkString(" ^ ")};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns true if the group-by keys exist at a given index in the
+ * associated [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+ * have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * private boolean equals(int idx, long agg_key, long agg_key1) {
+ * return batch.column(0).getLong(buckets[idx]) == agg_key &&
+ * batch.column(1).getLong(buckets[idx]) == agg_key1;
+ * }
+ * }}}
+ */
+ private def generateEquals(): String = {
+ s"""
+ |private boolean equals(int idx, $groupingKeySignature) {
+ | return ${groupingKeys.zipWithIndex.map(k =>
+ s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")};
+ |}
+ """.stripMargin
+ }
+
+ /**
+ * Generates a method that returns a mutable
+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row]] which keeps track of the
+ * aggregate value(s) for a given set of keys. If the corresponding row doesn't exist, the
+ * generated method adds the corresponding row in the associated
+ * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]. For instance, if we
+ * have 2 long group-by keys, the generated function would be of the form:
+ *
+ * {{{
+ * public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(
+ * long agg_key, long agg_key1) {
+ * long h = hash(agg_key, agg_key1);
+ * int step = 0;
+ * int idx = (int) h & (numBuckets - 1);
+ * while (step < maxSteps) {
+ * // Return bucket index if it's either an empty slot or already contains the key
+ * if (buckets[idx] == -1) {
+ * batch.column(0).putLong(numRows, agg_key);
+ * batch.column(1).putLong(numRows, agg_key1);
+ * batch.column(2).putLong(numRows, 0);
+ * buckets[idx] = numRows++;
+ * return batch.getRow(buckets[idx]);
+ * } else if (equals(idx, agg_key, agg_key1)) {
+ * return batch.getRow(buckets[idx]);
+ * }
+ * idx = (idx + 1) & (numBuckets - 1);
+ * step++;
+ * }
+ * // Didn't find it
+ * return null;
+ * }
+ * }}}
+ */
+ private def generateFindOrInsert(): String = {
+ s"""
+ |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${
+ groupingKeySignature}) {
+ | long h = hash(${groupingKeys.map(_._2).mkString(", ")});
+ | int step = 0;
+ | int idx = (int) h & (numBuckets - 1);
+ | while (step < maxSteps) {
+ | // Return bucket index if it's either an empty slot or already contains the key
+ | if (buckets[idx] == -1) {
+ | ${groupingKeys.zipWithIndex.map(k =>
+ s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")}
+ | ${bufferValues.zipWithIndex.map(k =>
+ s"batch.column(${groupingKeys.length + k._2}).putLong(numRows, 0);")
+ .mkString("\n")}
+ | buckets[idx] = numRows++;
+ | return batch.getRow(buckets[idx]);
+ | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) {
+ | return batch.getRow(buckets[idx]);
+ | }
+ | idx = (idx + 1) & (numBuckets - 1);
+ | step++;
+ | }
+ | // Didn't find it
+ | return null;
+ |}
+ """.stripMargin
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index 60027edc7c..0a5a72c52a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.types.{LongType, StructType}
import org.apache.spark.unsafe.KVIterator
case class TungstenAggregate(
@@ -64,8 +64,8 @@ case class TungstenAggregate(
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
- case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
- case Some(exprs) if exprs.length > 0 => ClusteredDistribution(exprs) :: Nil
+ case Some(exprs) if exprs.isEmpty => AllTuples :: Nil
+ case Some(exprs) if exprs.nonEmpty => ClusteredDistribution(exprs) :: Nil
case None => UnspecifiedDistribution :: Nil
}
}
@@ -437,6 +437,19 @@ case class TungstenAggregate(
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
+ // create AggregateHashMap
+ val isAggregateHashMapEnabled: Boolean = false
+ val isAggregateHashMapSupported: Boolean =
+ (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType)
+ val aggregateHashMapTerm = ctx.freshName("aggregateHashMap")
+ val aggregateHashMapClassName = ctx.freshName("GeneratedAggregateHashMap")
+ val aggregateHashMapGenerator = new ColumnarAggMapCodeGenerator(ctx, aggregateHashMapClassName,
+ groupingKeySchema, bufferSchema)
+ if (isAggregateHashMapEnabled && isAggregateHashMapSupported) {
+ ctx.addMutableState(aggregateHashMapClassName, aggregateHashMapTerm,
+ s"$aggregateHashMapTerm = new $aggregateHashMapClassName();")
+ }
+
// create hashMap
val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
@@ -452,6 +465,7 @@ case class TungstenAggregate(
val doAgg = ctx.freshName("doAggregateWithKeys")
ctx.addNewFunction(doAgg,
s"""
+ ${if (isAggregateHashMapSupported) aggregateHashMapGenerator.generate() else ""}
private void $doAgg() throws java.io.IOException {
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}