aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala290
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala79
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala25
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala6
8 files changed, 326 insertions, 119 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java
index 85529f6a0a..a88a315bf4 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/FixedLengthRowBasedKeyValueBatch.java
@@ -165,10 +165,10 @@ public final class FixedLengthRowBasedKeyValueBatch extends RowBasedKeyValueBatc
protected FixedLengthRowBasedKeyValueBatch(StructType keySchema, StructType valueSchema,
int maxRows, TaskMemoryManager manager) {
super(keySchema, valueSchema, maxRows, manager);
- klen = keySchema.defaultSize()
- + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length());
- vlen = valueSchema.defaultSize()
- + UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length());
+ int keySize = keySchema.size() * 8; // each fixed-length field is stored in a 8-byte word
+ int valueSize = valueSchema.size() * 8;
+ klen = keySize + UnsafeRow.calculateBitSetWidthInBytes(keySchema.length());
+ vlen = valueSize + UnsafeRow.calculateBitSetWidthInBytes(valueSchema.length());
recordLength = klen + vlen + 8;
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
index bd7efa606e..59e132dfb2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.execution.aggregate
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
@@ -279,9 +280,14 @@ case class HashAggregateExec(
.map(_.asInstanceOf[DeclarativeAggregate])
private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes)
- // The name for Vectorized HashMap
- private var vectorizedHashMapTerm: String = _
- private var isVectorizedHashMapEnabled: Boolean = _
+ // The name for Fast HashMap
+ private var fastHashMapTerm: String = _
+ private var isFastHashMapEnabled: Boolean = false
+
+ // whether a vectorized hashmap is used instead
+ // we have decided to always use the row-based hashmap,
+ // but the vectorized hashmap can still be switched on for testing and benchmarking purposes.
+ private var isVectorizedHashMapEnabled: Boolean = false
// The name for UnsafeRow HashMap
private var hashMapTerm: String = _
@@ -307,6 +313,16 @@ case class HashAggregateExec(
)
}
+ def getTaskMemoryManager(): TaskMemoryManager = {
+ TaskContext.get().taskMemoryManager()
+ }
+
+ def getEmptyAggregationBuffer(): InternalRow = {
+ val initExpr = declFunctions.flatMap(f => f.initialValues)
+ val initialBuffer = UnsafeProjection.create(initExpr)(EmptyRow)
+ initialBuffer
+ }
+
/**
* This is called by generated Java class, should be public.
*/
@@ -459,52 +475,91 @@ case class HashAggregateExec(
}
/**
- * Using the vectorized hash map in HashAggregate is currently supported for all primitive
- * data types during partial aggregation. However, we currently only enable the hash map for a
- * subset of cases that've been verified to show performance improvements on our benchmarks
- * subject to an internal conf that sets an upper limit on the maximum length of the aggregate
- * key/value schema.
- *
+ * A required check for any fast hash map implementation (basically the common requirements
+ * for row-based and vectorized).
+ * Currently fast hash map is supported for primitive data types during partial aggregation.
* This list of supported use-cases should be expanded over time.
*/
- private def enableVectorizedHashMap(ctx: CodegenContext): Boolean = {
- val schemaLength = (groupingKeySchema ++ bufferSchema).length
+ private def checkIfFastHashMapSupported(ctx: CodegenContext): Boolean = {
val isSupported =
(groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) ||
f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) &&
bufferSchema.nonEmpty && modes.forall(mode => mode == Partial || mode == PartialMerge)
- // We do not support byte array based decimal type for aggregate values as
- // ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
+ // For vectorized hash map, We do not support byte array based decimal type for aggregate values
+ // as ColumnVector.putDecimal for high-precision decimals doesn't currently support in-place
// updates. Due to this, appending the byte array in the vectorized hash map can turn out to be
// quite inefficient and can potentially OOM the executor.
+ // For row-based hash map, while decimal update is supported in UnsafeRow, we will just act
+ // conservative here, due to lack of testing and benchmarking.
val isNotByteArrayDecimalType = bufferSchema.map(_.dataType).filter(_.isInstanceOf[DecimalType])
.forall(!DecimalType.isByteArrayDecimalType(_))
- isSupported && isNotByteArrayDecimalType &&
- schemaLength <= sqlContext.conf.vectorizedAggregateMapMaxColumns
+ isSupported && isNotByteArrayDecimalType
+ }
+
+ private def enableTwoLevelHashMap(ctx: CodegenContext) = {
+ if (!checkIfFastHashMapSupported(ctx)) {
+ if (modes.forall(mode => mode == Partial || mode == PartialMerge) && !Utils.isTesting) {
+ logInfo("spark.sql.codegen.aggregate.map.twolevel.enable is set to true, but"
+ + " current version of codegened fast hashmap does not support this aggregate.")
+ }
+ } else {
+ isFastHashMapEnabled = true
+
+ // This is for testing/benchmarking only.
+ // We enforce to first level to be a vectorized hashmap, instead of the default row-based one.
+ sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match {
+ case "true" => isVectorizedHashMapEnabled = true
+ case null | "" | "false" => None }
+ }
}
private def doProduceWithKeys(ctx: CodegenContext): String = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")
- isVectorizedHashMapEnabled = enableVectorizedHashMap(ctx)
- vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap")
- val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap")
- val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, aggregateExpressions,
- vectorizedHashMapClassName, groupingKeySchema, bufferSchema)
+ if (sqlContext.conf.enableTwoLevelAggMap) {
+ enableTwoLevelHashMap(ctx)
+ } else {
+ sqlContext.getConf("spark.sql.codegen.aggregate.map.vectorized.enable", null) match {
+ case "true" => logWarning("Two level hashmap is disabled but vectorized hashmap is " +
+ "enabled.")
+ case null | "" | "false" => None
+ }
+ }
+ fastHashMapTerm = ctx.freshName("fastHashMap")
+ val fastHashMapClassName = ctx.freshName("FastHashMap")
+ val fastHashMapGenerator =
+ if (isVectorizedHashMapEnabled) {
+ new VectorizedHashMapGenerator(ctx, aggregateExpressions,
+ fastHashMapClassName, groupingKeySchema, bufferSchema)
+ } else {
+ new RowBasedHashMapGenerator(ctx, aggregateExpressions,
+ fastHashMapClassName, groupingKeySchema, bufferSchema)
+ }
+
+ val thisPlan = ctx.addReferenceObj("plan", this)
+
// Create a name for iterator from vectorized HashMap
- val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter")
- if (isVectorizedHashMapEnabled) {
- ctx.addMutableState(vectorizedHashMapClassName, vectorizedHashMapTerm,
- s"$vectorizedHashMapTerm = new $vectorizedHashMapClassName();")
- ctx.addMutableState(
- "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>",
- iterTermForVectorizedHashMap, "")
+ val iterTermForFastHashMap = ctx.freshName("fastHashMapIter")
+ if (isFastHashMapEnabled) {
+ if (isVectorizedHashMapEnabled) {
+ ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
+ s"$fastHashMapTerm = new $fastHashMapClassName();")
+ ctx.addMutableState(
+ "java.util.Iterator<org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row>",
+ iterTermForFastHashMap, "")
+ } else {
+ ctx.addMutableState(fastHashMapClassName, fastHashMapTerm,
+ s"$fastHashMapTerm = new $fastHashMapClassName(" +
+ s"agg_plan.getTaskMemoryManager(), agg_plan.getEmptyAggregationBuffer());")
+ ctx.addMutableState(
+ "org.apache.spark.unsafe.KVIterator",
+ iterTermForFastHashMap, "")
+ }
}
// create hashMap
- val thisPlan = ctx.addReferenceObj("plan", this)
hashMapTerm = ctx.freshName("hashMap")
val hashMapClassName = classOf[UnsafeFixedWidthAggregationMap].getName
ctx.addMutableState(hashMapClassName, hashMapTerm, "")
@@ -518,15 +573,30 @@ case class HashAggregateExec(
val doAgg = ctx.freshName("doAggregateWithKeys")
val peakMemory = metricTerm(ctx, "peakMemory")
val spillSize = metricTerm(ctx, "spillSize")
+
+ def generateGenerateCode(): String = {
+ if (isFastHashMapEnabled) {
+ if (isVectorizedHashMapEnabled) {
+ s"""
+ | ${fastHashMapGenerator.asInstanceOf[VectorizedHashMapGenerator].generate()}
+ """.stripMargin
+ } else {
+ s"""
+ | ${fastHashMapGenerator.asInstanceOf[RowBasedHashMapGenerator].generate()}
+ """.stripMargin
+ }
+ } else ""
+ }
+
ctx.addNewFunction(doAgg,
s"""
- ${if (isVectorizedHashMapEnabled) vectorizedHashMapGenerator.generate() else ""}
+ ${generateGenerateCode}
private void $doAgg() throws java.io.IOException {
$hashMapTerm = $thisPlan.createHashMap();
${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
- ${if (isVectorizedHashMapEnabled) {
- s"$iterTermForVectorizedHashMap = $vectorizedHashMapTerm.rowIterator();"} else ""}
+ ${if (isFastHashMapEnabled) {
+ s"$iterTermForFastHashMap = $fastHashMapTerm.rowIterator();"} else ""}
$iterTerm = $thisPlan.finishAggregate($hashMapTerm, $sorterTerm, $peakMemory, $spillSize);
}
@@ -542,34 +612,56 @@ case class HashAggregateExec(
// so `copyResult` should be reset to `false`.
ctx.copyResult = false
+ def outputFromGeneratedMap: String = {
+ if (isFastHashMapEnabled) {
+ if (isVectorizedHashMapEnabled) {
+ outputFromVectorizedMap
+ } else {
+ outputFromRowBasedMap
+ }
+ } else ""
+ }
+
+ def outputFromRowBasedMap: String = {
+ s"""
+ while ($iterTermForFastHashMap.next()) {
+ $numOutput.add(1);
+ UnsafeRow $keyTerm = (UnsafeRow) $iterTermForFastHashMap.getKey();
+ UnsafeRow $bufferTerm = (UnsafeRow) $iterTermForFastHashMap.getValue();
+ $outputCode
+
+ if (shouldStop()) return;
+ }
+ $fastHashMapTerm.close();
+ """
+ }
+
// Iterate over the aggregate rows and convert them from ColumnarBatch.Row to UnsafeRow
- def outputFromGeneratedMap: Option[String] = {
- if (isVectorizedHashMapEnabled) {
- val row = ctx.freshName("vectorizedHashMapRow")
+ def outputFromVectorizedMap: String = {
+ val row = ctx.freshName("fastHashMapRow")
ctx.currentVars = null
ctx.INPUT_ROW = row
var schema: StructType = groupingKeySchema
bufferSchema.foreach(i => schema = schema.add(i))
val generateRow = GenerateUnsafeProjection.createCode(ctx, schema.toAttributes.zipWithIndex
.map { case (attr, i) => BoundReference(i, attr.dataType, attr.nullable) })
- Option(
- s"""
- | while ($iterTermForVectorizedHashMap.hasNext()) {
- | $numOutput.add(1);
- | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
- | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
- | $iterTermForVectorizedHashMap.next();
- | ${generateRow.code}
- | ${consume(ctx, Seq.empty, {generateRow.value})}
- |
- | if (shouldStop()) return;
- | }
- |
- | $vectorizedHashMapTerm.close();
- """.stripMargin)
- } else None
+ s"""
+ | while ($iterTermForFastHashMap.hasNext()) {
+ | $numOutput.add(1);
+ | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $row =
+ | (org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row)
+ | $iterTermForFastHashMap.next();
+ | ${generateRow.code}
+ | ${consume(ctx, Seq.empty, {generateRow.value})}
+ |
+ | if (shouldStop()) return;
+ | }
+ |
+ | $fastHashMapTerm.close();
+ """.stripMargin
}
+
val aggTime = metricTerm(ctx, "aggTime")
val beforeAgg = ctx.freshName("beforeAgg")
s"""
@@ -581,7 +673,7 @@ case class HashAggregateExec(
}
// output the result
- ${outputFromGeneratedMap.getOrElse("")}
+ ${outputFromGeneratedMap}
while ($iterTerm.next()) {
$numOutput.add(1);
@@ -605,11 +697,11 @@ case class HashAggregateExec(
ctx.currentVars = input
val unsafeRowKeyCode = GenerateUnsafeProjection.createCode(
ctx, groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
- val vectorizedRowKeys = ctx.generateExpressions(
- groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
+ val fastRowKeys = ctx.generateExpressions(
+ groupingExpressions.map(e => BindReferences.bindReference[Expression](e, child.output)))
val unsafeRowKeys = unsafeRowKeyCode.value
val unsafeRowBuffer = ctx.freshName("unsafeRowAggBuffer")
- val vectorizedRowBuffer = ctx.freshName("vectorizedAggBuffer")
+ val fastRowBuffer = ctx.freshName("fastAggBuffer")
// only have DeclarativeAggregate
val updateExpr = aggregateExpressions.flatMap { e =>
@@ -639,17 +731,18 @@ case class HashAggregateExec(
("true", "true", "", "")
}
- // We first generate code to probe and update the vectorized hash map. If the probe is
- // successful the corresponding vectorized row buffer will hold the mutable row
- val findOrInsertInVectorizedHashMap: Option[String] = {
- if (isVectorizedHashMapEnabled) {
+ // We first generate code to probe and update the fast hash map. If the probe is
+ // successful the corresponding fast row buffer will hold the mutable row
+ val findOrInsertFastHashMap: Option[String] = {
+ if (isFastHashMapEnabled) {
Option(
s"""
+ |
|if ($checkFallbackForGeneratedHashMap) {
- | ${vectorizedRowKeys.map(_.code).mkString("\n")}
- | if (${vectorizedRowKeys.map("!" + _.isNull).mkString(" && ")}) {
- | $vectorizedRowBuffer = $vectorizedHashMapTerm.findOrInsert(
- | ${vectorizedRowKeys.map(_.value).mkString(", ")});
+ | ${fastRowKeys.map(_.code).mkString("\n")}
+ | if (${fastRowKeys.map("!" + _.isNull).mkString(" && ")}) {
+ | $fastRowBuffer = $fastHashMapTerm.findOrInsert(
+ | ${fastRowKeys.map(_.value).mkString(", ")});
| }
|}
""".stripMargin)
@@ -658,36 +751,35 @@ case class HashAggregateExec(
}
}
- val updateRowInVectorizedHashMap: Option[String] = {
- if (isVectorizedHashMapEnabled) {
- ctx.INPUT_ROW = vectorizedRowBuffer
- val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
- val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
- val effectiveCodes = subExprs.codes.mkString("\n")
- val vectorizedRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
- boundUpdateExpr.map(_.genCode(ctx))
- }
- val updateVectorizedRow = vectorizedRowEvals.zipWithIndex.map { case (ev, i) =>
- val dt = updateExpr(i).dataType
- ctx.updateColumn(vectorizedRowBuffer, dt, i, ev, updateExpr(i).nullable,
- isVectorized = true)
- }
- Option(
- s"""
- |// common sub-expressions
- |$effectiveCodes
- |// evaluate aggregate function
- |${evaluateVariables(vectorizedRowEvals)}
- |// update vectorized row
- |${updateVectorizedRow.mkString("\n").trim}
- """.stripMargin)
- } else None
+
+ def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = {
+ ctx.INPUT_ROW = fastRowBuffer
+ val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr))
+ val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr)
+ val effectiveCodes = subExprs.codes.mkString("\n")
+ val fastRowEvals = ctx.withSubExprEliminationExprs(subExprs.states) {
+ boundUpdateExpr.map(_.genCode(ctx))
+ }
+ val updateFastRow = fastRowEvals.zipWithIndex.map { case (ev, i) =>
+ val dt = updateExpr(i).dataType
+ ctx.updateColumn(fastRowBuffer, dt, i, ev, updateExpr(i).nullable, isVectorized)
+ }
+ Option(
+ s"""
+ |// common sub-expressions
+ |$effectiveCodes
+ |// evaluate aggregate function
+ |${evaluateVariables(fastRowEvals)}
+ |// update fast row
+ |${updateFastRow.mkString("\n").trim}
+ |
+ """.stripMargin)
}
// Next, we generate code to probe and update the unsafe row hash map.
val findOrInsertInUnsafeRowMap: String = {
s"""
- | if ($vectorizedRowBuffer == null) {
+ | if ($fastRowBuffer == null) {
| // generate grouping key
| ${unsafeRowKeyCode.code.trim}
| ${hashEval.code.trim}
@@ -745,17 +837,31 @@ case class HashAggregateExec(
// Finally, sort the spilled aggregate buffers by key, and merge them together for same key.
s"""
UnsafeRow $unsafeRowBuffer = null;
- org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $vectorizedRowBuffer = null;
+ ${
+ if (isVectorizedHashMapEnabled) {
+ s"""
+ | org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row $fastRowBuffer = null;
+ """.stripMargin
+ } else {
+ s"""
+ | UnsafeRow $fastRowBuffer = null;
+ """.stripMargin
+ }
+ }
- ${findOrInsertInVectorizedHashMap.getOrElse("")}
+ ${findOrInsertFastHashMap.getOrElse("")}
$findOrInsertInUnsafeRowMap
$incCounter
- if ($vectorizedRowBuffer != null) {
- // update vectorized row
- ${updateRowInVectorizedHashMap.getOrElse("")}
+ if ($fastRowBuffer != null) {
+ // update fast row
+ ${
+ if (isFastHashMapEnabled) {
+ updateRowInFastHashMap(isVectorizedHashMapEnabled).getOrElse("")
+ } else ""
+ }
} else {
// update unsafe row
$updateRowInUnsafeRowMap
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index 1dea33037c..a77e178546 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -141,8 +141,16 @@ class RowBasedHashMapGenerator(
}
val createUnsafeRowForKey = groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) =>
- s"agg_rowWriter.write(${ordinal}, ${key.name})"}
- .mkString(";\n")
+ key.dataType match {
+ case t: DecimalType =>
+ s"agg_rowWriter.write(${ordinal}, ${key.name}, ${t.precision}, ${t.scale})"
+ case t: DataType =>
+ if (!t.isInstanceOf[StringType] && !ctx.isPrimitiveType(t)) {
+ throw new IllegalArgumentException(s"cannot generate code for unsupported type: $t")
+ }
+ s"agg_rowWriter.write(${ordinal}, ${key.name})"
+ }
+ }.mkString(";\n")
s"""
|public org.apache.spark.sql.catalyst.expressions.UnsafeRow findOrInsert(${
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 91988270ad..d3440a2644 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -509,14 +509,15 @@ object SQLConf {
.intConf
.createWithDefault(40)
- val VECTORIZED_AGG_MAP_MAX_COLUMNS =
- SQLConfigBuilder("spark.sql.codegen.aggregate.map.columns.max")
+ val ENABLE_TWOLEVEL_AGG_MAP =
+ SQLConfigBuilder("spark.sql.codegen.aggregate.map.twolevel.enable")
.internal()
- .doc("Sets the maximum width of schema (aggregate keys + values) for which aggregate with" +
- "keys uses an in-memory columnar map to speed up execution. Setting this to 0 effectively" +
- "disables the columnar map")
- .intConf
- .createWithDefault(3)
+ .doc("Enable two-level aggregate hash map. When enabled, records will first be " +
+ "inserted/looked-up at a 1st-level, small, fast map, and then fallback to a " +
+ "2nd-level, larger, slower map when 1st level is full or keys cannot be found. " +
+ "When disabled, records go directly to the 2nd level. Defaults to true.")
+ .booleanConf
+ .createWithDefault(true)
val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion")
.internal()
@@ -687,7 +688,7 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
override def runSQLonFile: Boolean = getConf(RUN_SQL_ON_FILES)
- def vectorizedAggregateMapMaxColumns: Int = getConf(VECTORIZED_AGG_MAP_MAX_COLUMNS)
+ def enableTwoLevelAggMap: Boolean = getConf(ENABLE_TWOLEVEL_AGG_MAP)
def variableSubstituteEnabled: Boolean = getConf(VARIABLE_SUBSTITUTE_ENABLED)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala
new file mode 100644
index 0000000000..3e85d95523
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/AggregateHashMapSuite.scala
@@ -0,0 +1,79 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.scalatest.BeforeAndAfter
+
+class SingleLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
+
+ protected override def beforeAll(): Unit = {
+ sparkConf.set("spark.sql.codegen.fallback", "false")
+ sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
+ super.beforeAll()
+ }
+
+ // adding some checking after each test is run, assuring that the configs are not changed
+ // in test code
+ after {
+ assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
+ "configuration parameter changed in test body")
+ assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "false",
+ "configuration parameter changed in test body")
+ }
+}
+
+class TwoLevelAggregateHashMapSuite extends DataFrameAggregateSuite with BeforeAndAfter {
+
+ protected override def beforeAll(): Unit = {
+ sparkConf.set("spark.sql.codegen.fallback", "false")
+ sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ super.beforeAll()
+ }
+
+ // adding some checking after each test is run, assuring that the configs are not changed
+ // in test code
+ after {
+ assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
+ "configuration parameter changed in test body")
+ assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
+ "configuration parameter changed in test body")
+ }
+}
+
+class TwoLevelAggregateHashMapWithVectorizedMapSuite extends DataFrameAggregateSuite with
+BeforeAndAfter {
+
+ protected override def beforeAll(): Unit = {
+ sparkConf.set("spark.sql.codegen.fallback", "false")
+ sparkConf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ sparkConf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
+ super.beforeAll()
+ }
+
+ // adding some checking after each test is run, assuring that the configs are not changed
+ // in test code
+ after {
+ assert(sparkConf.get("spark.sql.codegen.fallback") == "false",
+ "configuration parameter changed in test body")
+ assert(sparkConf.get("spark.sql.codegen.aggregate.map.twolevel.enable") == "true",
+ "configuration parameter changed in test body")
+ assert(sparkConf.get("spark.sql.codegen.aggregate.map.vectorized.enable") == "true",
+ "configuration parameter changed in test body")
+ }
+}
+
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 69a3b5f278..427390a90f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -485,4 +485,12 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
spark.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"),
Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil)
}
+
+ test("SQL decimal test (used for catching certain demical handling bugs in aggregates)") {
+ checkAnswer(
+ decimalData.groupBy('a cast DecimalType(10, 2)).agg(avg('b cast DecimalType(10, 2))),
+ Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.5)),
+ Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(1.5)),
+ Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5))))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
index bf3a39c84b..8a2993bdf4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/AggregateBenchmark.scala
@@ -106,13 +106,14 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
f()
}
benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
f()
}
@@ -146,13 +147,14 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true)
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", 0)
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
f()
}
benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", value = true)
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", 3)
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
f()
}
@@ -184,13 +186,14 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F", numIters = 3) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
f()
}
benchmark.addCase(s"codegen = T hashmap = T", numIters = 5) { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
f()
}
@@ -221,13 +224,14 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
f()
}
benchmark.addCase(s"codegen = T hashmap = T") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "3")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
f()
}
@@ -268,13 +272,14 @@ class AggregateBenchmark extends BenchmarkBase {
benchmark.addCase(s"codegen = T hashmap = F") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "0")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "false")
f()
}
benchmark.addCase(s"codegen = T hashmap = T") { iter =>
sparkSession.conf.set("spark.sql.codegen.wholeStage", "true")
- sparkSession.conf.set("spark.sql.codegen.aggregate.map.columns.max", "10")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.twolevel.enable", "true")
+ sparkSession.conf.set("spark.sql.codegen.aggregate.map.vectorized.enable", "true")
f()
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 2dcf13c02a..4a8086d7e5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -998,9 +998,9 @@ class HashAggregationQuerySuite extends AggregationQuerySuite
class HashAggregationQueryWithControlledFallbackSuite extends AggregationQuerySuite {
override protected def checkAnswer(actual: => DataFrame, expectedAnswer: Seq[Row]): Unit = {
- Seq(0, 10).foreach { maxColumnarHashMapColumns =>
- withSQLConf("spark.sql.codegen.aggregate.map.columns.max" ->
- maxColumnarHashMapColumns.toString) {
+ Seq("true", "false").foreach { enableTwoLevelMaps =>
+ withSQLConf("spark.sql.codegen.aggregate.map.twolevel.enable" ->
+ enableTwoLevelMaps) {
(1 to 3).foreach { fallbackStartsAt =>
withSQLConf("spark.sql.TungstenAggregate.testFallbackStartsAt" ->
s"${(fallbackStartsAt - 1).toString}, ${fallbackStartsAt.toString}") {