aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-29 10:24:23 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-29 10:24:23 -0800
commitc5f745ede01831b59c57effa7de88c648b82c13d (patch)
treeb8119933725897c711ac94da671a4dad0522b517
parente4c1162b6b3dbc8fc95cfe75c6e0bc2915575fb2 (diff)
downloadspark-c5f745ede01831b59c57effa7de88c648b82c13d.tar.gz
spark-c5f745ede01831b59c57effa7de88c648b82c13d.tar.bz2
spark-c5f745ede01831b59c57effa7de88c648b82c13d.zip
[SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen
simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible. generated code comparison for `hash(int, double, string, array<string>)`: **before:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); } /* input[1, double] */ double value5 = i.getDouble(1); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); } /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array<int>] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { int result10 = value1; for (int index11 = 0; index11 < value9.numElements(); index11++) { if (!value9.isNullAt(index11)) { final int element12 = value9.getInt(index11); result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10); } } value1 = result10; } } ``` **after:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); /* input[1, double] */ double value5 = i.getDouble(1); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array<int>] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { for (int index10 = 0; index10 < value9.numElements(); index10++) { final int element11 = value9.getInt(index10); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1); } } rowWriter14.write(0, value1); return result12; } ``` Author: Wenchen Fan <wenchen@databricks.com> Closes #10974 from cloud-fan/codegen.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala155
1 files changed, 69 insertions, 86 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 493e0aae01..8480c3f9a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
ev.isNull = "false"
- val childrenHash = children.zipWithIndex.map {
- case (child, dt) =>
- val childGen = child.gen(ctx)
- val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
- s"""
- ${childGen.code}
- if (!${childGen.isNull}) {
- ${childHash.code}
- ${ev.value} = ${childHash.value};
- }
- """
+ val childrenHash = children.map { child =>
+ val childGen = child.gen(ctx)
+ childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
+ computeHash(childGen.value, child.dataType, ev.value, ctx)
+ }
}.mkString("\n")
+
s"""
int ${ev.value} = $seed;
$childrenHash
"""
}
+ private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
+ if (nullable) {
+ s"""
+ if (!$isNull) {
+ $execution
+ }
+ """
+ } else {
+ "\n" + execution
+ }
+ }
+
+ private def nullSafeElementHash(
+ input: String,
+ index: String,
+ nullable: Boolean,
+ elementType: DataType,
+ result: String,
+ ctx: CodegenContext): String = {
+ val element = ctx.freshName("element")
+
+ generateNullCheck(nullable, s"$input.isNullAt($index)") {
+ s"""
+ final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
+ ${computeHash(element, elementType, result, ctx)}
+ """
+ }
+ }
+
private def computeHash(
input: String,
dataType: DataType,
- seed: String,
- ctx: CodegenContext): ExprCode = {
+ result: String,
+ ctx: CodegenContext): String = {
val hasher = classOf[Murmur3_x86_32].getName
- def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)")
- def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)")
- def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v)
+
+ def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);"
+ def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);"
+ def hashBytes(b: String): String =
+ s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);"
dataType match {
- case NullType => inlineValue(seed)
+ case NullType => ""
case BooleanType => hashInt(s"$input ? 1 : 0")
case ByteType | ShortType | IntegerType | DateType => hashInt(input)
case LongType | TimestampType => hashLong(input)
@@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
hashLong(s"$input.toUnscaledLong()")
} else {
val bytes = ctx.freshName("bytes")
- val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
- val offset = "Platform.BYTE_ARRAY_OFFSET"
- val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
- ExprCode(code, "false", result)
+ s"""
+ final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();
+ ${hashBytes(bytes)}
+ """
}
case CalendarIntervalType =>
- val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
- val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
- inlineValue(monthsHash)
- case BinaryType =>
- val offset = "Platform.BYTE_ARRAY_OFFSET"
- inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
+ val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)"
+ s"$result = $hasher.hashInt($input.months, $microsecondsHash);"
+ case BinaryType => hashBytes(input)
case StringType =>
val baseObject = s"$input.getBaseObject()"
val baseOffset = s"$input.getBaseOffset()"
val numBytes = s"$input.numBytes()"
- inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
+ s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);"
- case ArrayType(et, _) =>
- val result = ctx.freshName("result")
+ case ArrayType(et, containsNull) =>
val index = ctx.freshName("index")
- val element = ctx.freshName("element")
- val elementHash = computeHash(element, et, result, ctx)
- val code =
- s"""
- int $result = $seed;
- for (int $index = 0; $index < $input.numElements(); $index++) {
- if (!$input.isNullAt($index)) {
- final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
- ${elementHash.code}
- $result = ${elementHash.value};
- }
- }
- """
- ExprCode(code, "false", result)
+ s"""
+ for (int $index = 0; $index < $input.numElements(); $index++) {
+ ${nullSafeElementHash(input, index, containsNull, et, result, ctx)}
+ }
+ """
- case MapType(kt, vt, _) =>
- val result = ctx.freshName("result")
+ case MapType(kt, vt, valueContainsNull) =>
val index = ctx.freshName("index")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
- val key = ctx.freshName("key")
- val value = ctx.freshName("value")
- val keyHash = computeHash(key, kt, result, ctx)
- val valueHash = computeHash(value, vt, result, ctx)
- val code =
- s"""
- int $result = $seed;
- final ArrayData $keys = $input.keyArray();
- final ArrayData $values = $input.valueArray();
- for (int $index = 0; $index < $input.numElements(); $index++) {
- final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
- ${keyHash.code}
- $result = ${keyHash.value};
- if (!$values.isNullAt($index)) {
- final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
- ${valueHash.code}
- $result = ${valueHash.value};
- }
- }
- """
- ExprCode(code, "false", result)
+ s"""
+ final ArrayData $keys = $input.keyArray();
+ final ArrayData $values = $input.valueArray();
+ for (int $index = 0; $index < $input.numElements(); $index++) {
+ ${nullSafeElementHash(keys, index, false, kt, result, ctx)}
+ ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)}
+ }
+ """
case StructType(fields) =>
- val result = ctx.freshName("result")
- val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
- case (dt, index) =>
- val field = ctx.freshName("field")
- val fieldHash = computeHash(field, dt, result, ctx)
- s"""
- if (!$input.isNullAt($index)) {
- final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
- ${fieldHash.code}
- $result = ${fieldHash.value};
- }
- """
+ fields.zipWithIndex.map { case (field, index) =>
+ nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx)
}.mkString("\n")
- val code =
- s"""
- int $result = $seed;
- $fieldsHash
- """
- ExprCode(code, "false", result)
- case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
+ case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx)
}
}
}