From c5f745ede01831b59c57effa7de88c648b82c13d Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 29 Jan 2016 10:24:23 -0800 Subject: [SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible. generated code comparison for `hash(int, double, string, array)`: **before:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); } /* input[1, double] */ double value5 = i.getDouble(1); if (!false) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); } /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { int result10 = value1; for (int index11 = 0; index11 < value9.numElements(); index11++) { if (!value9.isNullAt(index11)) { final int element12 = value9.getInt(index11); result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10); } } value1 = result10; } } ``` **after:** ``` public UnsafeRow apply(InternalRow i) { /* hash(input[0, int],input[1, double],input[2, string],input[3, array],42) */ int value1 = 42; /* input[0, int] */ int value3 = i.getInt(0); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1); /* input[1, double] */ double value5 = i.getDouble(1); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1); /* input[2, string] */ boolean isNull6 = i.isNullAt(2); UTF8String value7 = isNull6 ? null : (i.getUTF8String(2)); if (!isNull6) { value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1); } /* input[3, array] */ boolean isNull8 = i.isNullAt(3); ArrayData value9 = isNull8 ? null : (i.getArray(3)); if (!isNull8) { for (int index10 = 0; index10 < value9.numElements(); index10++) { final int element11 = value9.getInt(index10); value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1); } } rowWriter14.write(0, value1); return result12; } ``` Author: Wenchen Fan Closes #10974 from cloud-fan/codegen. --- .../spark/sql/catalyst/expressions/misc.scala | 155 +++++++++------------ 1 file changed, 69 insertions(+), 86 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 493e0aae01..8480c3f9a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" - val childrenHash = children.zipWithIndex.map { - case (child, dt) => - val childGen = child.gen(ctx) - val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) - s""" - ${childGen.code} - if (!${childGen.isNull}) { - ${childHash.code} - ${ev.value} = ${childHash.value}; - } - """ + val childrenHash = children.map { child => + val childGen = child.gen(ctx) + childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } }.mkString("\n") + s""" int ${ev.value} = $seed; $childrenHash """ } + private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execution + } + """ + } else { + "\n" + execution + } + } + + private def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + generateNullCheck(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + private def computeHash( input: String, dataType: DataType, - seed: String, - ctx: CodegenContext): ExprCode = { + result: String, + ctx: CodegenContext): String = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) + + def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" + def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" + def hashBytes(b: String): String = + s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" dataType match { - case NullType => inlineValue(seed) + case NullType => "" case BooleanType => hashInt(s"$input ? 1 : 0") case ByteType | ShortType | IntegerType | DateType => hashInt(input) case LongType | TimestampType => hashLong(input) @@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression hashLong(s"$input.toUnscaledLong()") } else { val bytes = ctx.freshName("bytes") - val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" - val offset = "Platform.BYTE_ARRAY_OFFSET" - val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - ExprCode(code, "false", result) + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ } case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" - val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" - inlineValue(monthsHash) - case BinaryType => - val offset = "Platform.BYTE_ARRAY_OFFSET" - inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" + s"$result = $hasher.hashInt($input.months, $microsecondsHash);" + case BinaryType => hashBytes(input) case StringType => val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" - inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, _) => - val result = ctx.freshName("result") + case ArrayType(et, containsNull) => val index = ctx.freshName("index") - val element = ctx.freshName("element") - val elementHash = computeHash(element, et, result, ctx) - val code = - s""" - int $result = $seed; - for (int $index = 0; $index < $input.numElements(); $index++) { - if (!$input.isNullAt($index)) { - final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; - ${elementHash.code} - $result = ${elementHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} + } + """ - case MapType(kt, vt, _) => - val result = ctx.freshName("result") + case MapType(kt, vt, valueContainsNull) => val index = ctx.freshName("index") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val key = ctx.freshName("key") - val value = ctx.freshName("value") - val keyHash = computeHash(key, kt, result, ctx) - val valueHash = computeHash(value, vt, result, ctx) - val code = - s""" - int $result = $seed; - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; - ${keyHash.code} - $result = ${keyHash.value}; - if (!$values.isNullAt($index)) { - final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; - ${valueHash.code} - $result = ${valueHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, kt, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} + } + """ case StructType(fields) => - val result = ctx.freshName("result") - val fieldsHash = fields.map(_.dataType).zipWithIndex.map { - case (dt, index) => - val field = ctx.freshName("field") - val fieldHash = computeHash(field, dt, result, ctx) - s""" - if (!$input.isNullAt($index)) { - final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; - ${fieldHash.code} - $result = ${fieldHash.value}; - } - """ + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") - val code = - s""" - int $result = $seed; - $fieldsHash - """ - ExprCode(code, "false", result) - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) } } } -- cgit v1.2.3