From b29bc3f51518806ef7827b35df7c8aada329f961 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Thu, 21 Apr 2016 21:31:01 -0700 Subject: [SPARK-14680] [SQL] Support all datatypes to use VectorizedHashmap in TungstenAggregate ## What changes were proposed in this pull request? This PR adds support for all primitive datatypes, decimal types and stringtypes in the VectorizedHashmap during aggregation. ## How was this patch tested? Existing tests for group-by aggregates should already test for all these datatypes. Additionally, manually inspected the generated code for all supported datatypes (details below). Author: Sameer Agarwal Closes #12440 from sameeragarwal/all-datatypes. --- .../execution/aggregate/TungstenAggregate.scala | 18 ++- .../aggregate/VectorizedHashMapGenerator.scala | 163 +++++++++++++++++---- .../org/apache/spark/sql/internal/SQLConf.scala | 2 +- 3 files changed, 144 insertions(+), 39 deletions(-) (limited to 'sql/core/src/main/scala/org/apache') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index d4cef8f310..5c0fc02861 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.types.{LongType, StructType} +import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator case class TungstenAggregate( @@ -265,11 +265,7 @@ case class TungstenAggregate( // The name for Vectorized HashMap private var vectorizedHashMapTerm: String = _ - - // We currently only enable vectorized hashmap for long key/value types and partial aggregates - private val isVectorizedHashMapEnabled: Boolean = sqlContext.conf.columnarAggregateMapEnabled && - (groupingKeySchema ++ bufferSchema).forall(_.dataType == LongType) && - modes.forall(mode => mode == Partial || mode == PartialMerge) + private var isVectorizedHashMapEnabled: Boolean = _ // The name for UnsafeRow HashMap private var hashMapTerm: String = _ @@ -447,10 +443,16 @@ case class TungstenAggregate( val initAgg = ctx.freshName("initAgg") ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + // Enable vectorized hash map for all primitive data types during partial aggregation + isVectorizedHashMapEnabled = sqlContext.conf.columnarAggregateMapEnabled && + (groupingKeySchema ++ bufferSchema).forall(f => ctx.isPrimitiveType(f.dataType) || + f.dataType.isInstanceOf[DecimalType] || f.dataType.isInstanceOf[StringType]) && + bufferSchema.forall(!_.dataType.isInstanceOf[StringType]) && bufferSchema.nonEmpty && + modes.forall(mode => mode == Partial || mode == PartialMerge) vectorizedHashMapTerm = ctx.freshName("vectorizedHashMap") val vectorizedHashMapClassName = ctx.freshName("VectorizedHashMap") - val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, vectorizedHashMapClassName, - groupingKeySchema, bufferSchema) + val vectorizedHashMapGenerator = new VectorizedHashMapGenerator(ctx, aggregateExpressions, + vectorizedHashMapClassName, groupingKeySchema, bufferSchema) // Create a name for iterator from vectorized HashMap val iterTermForVectorizedHashMap = ctx.freshName("vectorizedHashMapIter") if (isVectorizedHashMapEnabled) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index dd9b2f097e..61bd6eb3cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -17,8 +17,9 @@ package org.apache.spark.sql.execution.aggregate -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, DeclarativeAggregate} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types._ /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' @@ -40,12 +41,32 @@ import org.apache.spark.sql.types.StructType */ class VectorizedHashMapGenerator( ctx: CodegenContext, + aggregateExpressions: Seq[AggregateExpression], generatedClassName: String, groupingKeySchema: StructType, bufferSchema: StructType) { - val groupingKeys = groupingKeySchema.map(k => (k.dataType.typeName, ctx.freshName("key"))) - val bufferValues = bufferSchema.map(k => (k.dataType.typeName, ctx.freshName("value"))) - val groupingKeySignature = groupingKeys.map(_.productIterator.toList.mkString(" ")).mkString(", ") + case class Buffer(dataType: DataType, name: String) + val groupingKeys = groupingKeySchema.map(k => Buffer(k.dataType, ctx.freshName("key"))) + val bufferValues = bufferSchema.map(k => Buffer(k.dataType, ctx.freshName("value"))) + val groupingKeySignature = + groupingKeys.map(key => s"${ctx.javaType(key.dataType)} ${key.name}").mkString(", ") + val buffVars: Seq[ExprCode] = { + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val initExpr = functions.flatMap(f => f.initialValues) + initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + val ev = e.genCode(ctx) + val initVars = + s""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + } def generate(): String = { s""" @@ -67,20 +88,28 @@ class VectorizedHashMapGenerator( private def initializeAggregateHashMap(): String = { val generatedSchema: String = - s""" - |new org.apache.spark.sql.types.StructType() - |${(groupingKeySchema ++ bufferSchema).map(key => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") - .mkString("\n")}; - """.stripMargin + s"new org.apache.spark.sql.types.StructType()" + + (groupingKeySchema ++ bufferSchema).map { key => + key.dataType match { + case d: DecimalType => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + |${d.precision}, ${d.scale}))""".stripMargin + case _ => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + } + }.mkString("\n").concat(";") val generatedAggBufferSchema: String = - s""" - |new org.apache.spark.sql.types.StructType() - |${bufferSchema.map(key => - s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""") - .mkString("\n")}; - """.stripMargin + s"new org.apache.spark.sql.types.StructType()" + + bufferSchema.map { key => + key.dataType match { + case d: DecimalType => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType( + |${d.precision}, ${d.scale}))""".stripMargin + case _ => + s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})""" + } + }.mkString("\n").concat(";") s""" | private org.apache.spark.sql.execution.vectorized.ColumnarBatch batch; @@ -122,12 +151,23 @@ class VectorizedHashMapGenerator( * }}} */ private def generateHashFunction(): String = { + val hash = ctx.freshName("hash") + + def genHashForKeys(groupingKeys: Seq[Buffer]): String = { + groupingKeys.map { key => + val result = ctx.freshName("result") + s""" + |${genComputeHash(ctx, key.name, key.dataType, result)} + |$hash = ($hash ^ (0x9e3779b9)) + $result + ($hash << 6) + ($hash >>> 2); + """.stripMargin + }.mkString("\n") + } + s""" |private long hash($groupingKeySignature) { - | long h = 0; - | ${groupingKeys.map(key => s"h = (h ^ (0x9e3779b9)) + ${key._2} + (h << 6) + (h >>> 2);") - .mkString("\n")} - | return h; + | long $hash = 0; + | ${genHashForKeys(groupingKeys)} + | return $hash; |} """.stripMargin } @@ -145,10 +185,17 @@ class VectorizedHashMapGenerator( * }}} */ private def generateEquals(): String = { + + def genEqualsForKeys(groupingKeys: Seq[Buffer]): String = { + groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => + s"""(${ctx.genEqual(key.dataType, ctx.getValue("batch", "buckets[idx]", + key.dataType, ordinal), key.name)})""" + }.mkString(" && ") + } + s""" |private boolean equals(int idx, $groupingKeySignature) { - | return ${groupingKeys.zipWithIndex.map(k => - s"batch.column(${k._2}).getLong(buckets[idx]) == ${k._1._2}").mkString(" && ")}; + | return ${genEqualsForKeys(groupingKeys)}; |} """.stripMargin } @@ -187,21 +234,39 @@ class VectorizedHashMapGenerator( * }}} */ private def generateFindOrInsert(): String = { + + def genCodeToSetKeys(groupingKeys: Seq[Buffer]): Seq[String] = { + groupingKeys.zipWithIndex.map { case (key: Buffer, ordinal: Int) => + ctx.setValue("batch", "numRows", key.dataType, ordinal, key.name) + } + } + + def genCodeToSetAggBuffers(bufferValues: Seq[Buffer]): Seq[String] = { + bufferValues.zipWithIndex.map { case (key: Buffer, ordinal: Int) => + ctx.updateColumn("batch", "numRows", key.dataType, groupingKeys.length + ordinal, + buffVars(ordinal), nullable = true) + } + } + s""" |public org.apache.spark.sql.execution.vectorized.ColumnarBatch.Row findOrInsert(${ groupingKeySignature}) { - | long h = hash(${groupingKeys.map(_._2).mkString(", ")}); + | long h = hash(${groupingKeys.map(_.name).mkString(", ")}); | int step = 0; | int idx = (int) h & (numBuckets - 1); | while (step < maxSteps) { | // Return bucket index if it's either an empty slot or already contains the key | if (buckets[idx] == -1) { | if (numRows < capacity) { - | ${groupingKeys.zipWithIndex.map(k => - s"batch.column(${k._2}).putLong(numRows, ${k._1._2});").mkString("\n")} - | ${bufferValues.zipWithIndex.map(k => - s"batch.column(${groupingKeys.length + k._2}).putNull(numRows);") - .mkString("\n")} + | + | // Initialize aggregate keys + | ${genCodeToSetKeys(groupingKeys).mkString("\n")} + | + | ${buffVars.map(_.code).mkString("\n")} + | + | // Initialize aggregate values + | ${genCodeToSetAggBuffers(bufferValues).mkString("\n")} + | | buckets[idx] = numRows++; | batch.setNumRows(numRows); | aggregateBufferBatch.setNumRows(numRows); @@ -210,7 +275,7 @@ class VectorizedHashMapGenerator( | // No more space | return null; | } - | } else if (equals(idx, ${groupingKeys.map(_._2).mkString(", ")})) { + | } else if (equals(idx, ${groupingKeys.map(_.name).mkString(", ")})) { | return aggregateBufferBatch.getRow(buckets[idx]); | } | idx = (idx + 1) & (numBuckets - 1); @@ -238,4 +303,42 @@ class VectorizedHashMapGenerator( |} """.stripMargin } + + private def genComputeHash( + ctx: CodegenContext, + input: String, + dataType: DataType, + result: String): String = { + def hashInt(i: String): String = s"int $result = $i;" + def hashLong(l: String): String = s"long $result = $l;" + def hashBytes(b: String): String = { + val hash = ctx.freshName("hash") + s""" + |int $result = 0; + |for (int i = 0; i < $b.length; i++) { + | ${genComputeHash(ctx, s"$b[i]", ByteType, hash)} + | $result = ($result ^ (0x9e3779b9)) + $hash + ($result << 6) + ($result >>> 2); + |} + """.stripMargin + } + + dataType match { + case BooleanType => hashInt(s"$input ? 1 : 0") + case ByteType | ShortType | IntegerType | DateType => hashInt(input) + case LongType | TimestampType => hashLong(input) + case FloatType => hashInt(s"Float.floatToIntBits($input)") + case DoubleType => hashLong(s"Double.doubleToLongBits($input)") + case d: DecimalType => + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(s"$input.toUnscaledLong()") + } else { + val bytes = ctx.freshName("bytes") + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ + } + case StringType => hashBytes(s"$input.getBytes()") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index a4e82d80f5..eb976fbaad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -482,7 +482,7 @@ object SQLConf { .internal() .doc("When true, aggregate with keys use an in-memory columnar map to speed up execution.") .booleanConf - .createWithDefault(false) + .createWithDefault(true) val FILE_SINK_LOG_DELETION = SQLConfigBuilder("spark.sql.streaming.fileSink.log.deletion") .internal() -- cgit v1.2.3