diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-01-29 10:24:23 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-01-29 10:24:23 -0800 |
commit | c5f745ede01831b59c57effa7de88c648b82c13d (patch) | |
tree | b8119933725897c711ac94da671a4dad0522b517 /sql | |
parent | e4c1162b6b3dbc8fc95cfe75c6e0bc2915575fb2 (diff) | |
download | spark-c5f745ede01831b59c57effa7de88c648b82c13d.tar.gz spark-c5f745ede01831b59c57effa7de88c648b82c13d.tar.bz2 spark-c5f745ede01831b59c57effa7de88c648b82c13d.zip |
[SPARK-13072] [SQL] simplify and improve murmur3 hash expression codegen
simplify(remove several unnecessary local variables) the generated code of hash expression, and avoid null check if possible.
generated code comparison for `hash(int, double, string, array<string>)`:
**before:**
```
public UnsafeRow apply(InternalRow i) {
/* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */
int value1 = 42;
/* input[0, int] */
int value3 = i.getInt(0);
if (!false) {
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1);
}
/* input[1, double] */
double value5 = i.getDouble(1);
if (!false) {
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1);
}
/* input[2, string] */
boolean isNull6 = i.isNullAt(2);
UTF8String value7 = isNull6 ? null : (i.getUTF8String(2));
if (!isNull6) {
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1);
}
/* input[3, array<int>] */
boolean isNull8 = i.isNullAt(3);
ArrayData value9 = isNull8 ? null : (i.getArray(3));
if (!isNull8) {
int result10 = value1;
for (int index11 = 0; index11 < value9.numElements(); index11++) {
if (!value9.isNullAt(index11)) {
final int element12 = value9.getInt(index11);
result10 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element12, result10);
}
}
value1 = result10;
}
}
```
**after:**
```
public UnsafeRow apply(InternalRow i) {
/* hash(input[0, int],input[1, double],input[2, string],input[3, array<int>],42) */
int value1 = 42;
/* input[0, int] */
int value3 = i.getInt(0);
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(value3, value1);
/* input[1, double] */
double value5 = i.getDouble(1);
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashLong(Double.doubleToLongBits(value5), value1);
/* input[2, string] */
boolean isNull6 = i.isNullAt(2);
UTF8String value7 = isNull6 ? null : (i.getUTF8String(2));
if (!isNull6) {
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashUnsafeBytes(value7.getBaseObject(), value7.getBaseOffset(), value7.numBytes(), value1);
}
/* input[3, array<int>] */
boolean isNull8 = i.isNullAt(3);
ArrayData value9 = isNull8 ? null : (i.getArray(3));
if (!isNull8) {
for (int index10 = 0; index10 < value9.numElements(); index10++) {
final int element11 = value9.getInt(index10);
value1 = org.apache.spark.unsafe.hash.Murmur3_x86_32.hashInt(element11, value1);
}
}
rowWriter14.write(0, value1);
return result12;
}
```
Author: Wenchen Fan <wenchen@databricks.com>
Closes #10974 from cloud-fan/codegen.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala | 155 |
1 files changed, 69 insertions, 86 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 493e0aae01..8480c3f9a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -325,36 +325,62 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" - val childrenHash = children.zipWithIndex.map { - case (child, dt) => - val childGen = child.gen(ctx) - val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) - s""" - ${childGen.code} - if (!${childGen.isNull}) { - ${childHash.code} - ${ev.value} = ${childHash.value}; - } - """ + val childrenHash = children.map { child => + val childGen = child.gen(ctx) + childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + computeHash(childGen.value, child.dataType, ev.value, ctx) + } }.mkString("\n") + s""" int ${ev.value} = $seed; $childrenHash """ } + private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execution + } + """ + } else { + "\n" + execution + } + } + + private def nullSafeElementHash( + input: String, + index: String, + nullable: Boolean, + elementType: DataType, + result: String, + ctx: CodegenContext): String = { + val element = ctx.freshName("element") + + generateNullCheck(nullable, s"$input.isNullAt($index)") { + s""" + final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; + ${computeHash(element, elementType, result, ctx)} + """ + } + } + private def computeHash( input: String, dataType: DataType, - seed: String, - ctx: CodegenContext): ExprCode = { + result: String, + ctx: CodegenContext): String = { val hasher = classOf[Murmur3_x86_32].getName - def hashInt(i: String): ExprCode = inlineValue(s"$hasher.hashInt($i, $seed)") - def hashLong(l: String): ExprCode = inlineValue(s"$hasher.hashLong($l, $seed)") - def inlineValue(v: String): ExprCode = ExprCode(code = "", isNull = "false", value = v) + + def hashInt(i: String): String = s"$result = $hasher.hashInt($i, $result);" + def hashLong(l: String): String = s"$result = $hasher.hashLong($l, $result);" + def hashBytes(b: String): String = + s"$result = $hasher.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length, $result);" dataType match { - case NullType => inlineValue(seed) + case NullType => "" case BooleanType => hashInt(s"$input ? 1 : 0") case ByteType | ShortType | IntegerType | DateType => hashInt(input) case LongType | TimestampType => hashLong(input) @@ -365,91 +391,48 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression hashLong(s"$input.toUnscaledLong()") } else { val bytes = ctx.freshName("bytes") - val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" - val offset = "Platform.BYTE_ARRAY_OFFSET" - val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" - ExprCode(code, "false", result) + s""" + final byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray(); + ${hashBytes(bytes)} + """ } case CalendarIntervalType => - val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" - val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" - inlineValue(monthsHash) - case BinaryType => - val offset = "Platform.BYTE_ARRAY_OFFSET" - inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $result)" + s"$result = $hasher.hashInt($input.months, $microsecondsHash);" + case BinaryType => hashBytes(input) case StringType => val baseObject = s"$input.getBaseObject()" val baseOffset = s"$input.getBaseOffset()" val numBytes = s"$input.numBytes()" - inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + s"$result = $hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $result);" - case ArrayType(et, _) => - val result = ctx.freshName("result") + case ArrayType(et, containsNull) => val index = ctx.freshName("index") - val element = ctx.freshName("element") - val elementHash = computeHash(element, et, result, ctx) - val code = - s""" - int $result = $seed; - for (int $index = 0; $index < $input.numElements(); $index++) { - if (!$input.isNullAt($index)) { - final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; - ${elementHash.code} - $result = ${elementHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(input, index, containsNull, et, result, ctx)} + } + """ - case MapType(kt, vt, _) => - val result = ctx.freshName("result") + case MapType(kt, vt, valueContainsNull) => val index = ctx.freshName("index") val keys = ctx.freshName("keys") val values = ctx.freshName("values") - val key = ctx.freshName("key") - val value = ctx.freshName("value") - val keyHash = computeHash(key, kt, result, ctx) - val valueHash = computeHash(value, vt, result, ctx) - val code = - s""" - int $result = $seed; - final ArrayData $keys = $input.keyArray(); - final ArrayData $values = $input.valueArray(); - for (int $index = 0; $index < $input.numElements(); $index++) { - final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; - ${keyHash.code} - $result = ${keyHash.value}; - if (!$values.isNullAt($index)) { - final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; - ${valueHash.code} - $result = ${valueHash.value}; - } - } - """ - ExprCode(code, "false", result) + s""" + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + ${nullSafeElementHash(keys, index, false, kt, result, ctx)} + ${nullSafeElementHash(values, index, valueContainsNull, vt, result, ctx)} + } + """ case StructType(fields) => - val result = ctx.freshName("result") - val fieldsHash = fields.map(_.dataType).zipWithIndex.map { - case (dt, index) => - val field = ctx.freshName("field") - val fieldHash = computeHash(field, dt, result, ctx) - s""" - if (!$input.isNullAt($index)) { - final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; - ${fieldHash.code} - $result = ${fieldHash.value}; - } - """ + fields.zipWithIndex.map { case (field, index) => + nullSafeElementHash(input, index.toString, field.nullable, field.dataType, result, ctx) }.mkString("\n") - val code = - s""" - int $result = $seed; - $fieldsHash - """ - ExprCode(code, "false", result) - case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, result, ctx) } } } |