From c2ea79f96acd076351b48162644ed1cff4c8e090 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Jan 2016 12:29:02 -0800 Subject: [SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row https://issues.apache.org/jira/browse/SPARK-12642 Author: Wenchen Fan Closes #10694 from cloud-fan/hash-expr. --- python/pyspark/sql/functions.py | 2 +- .../spark/sql/catalyst/expressions/UnsafeRow.java | 4 - .../spark/sql/catalyst/expressions/misc.scala | 251 ++++++++++++++++++++- .../catalyst/expressions/MiscFunctionsSuite.scala | 6 +- .../spark/sql/sources/BucketedWriteSuite.scala | 26 ++- .../apache/spark/unsafe/hash/Murmur3_x86_32.java | 28 ++- 6 files changed, 288 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index b0390cb994..719eca8f55 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1023,7 +1023,7 @@ def hash(*cols): """Calculates the hash code of given columns, and returns the result as a int column. >>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect() - [Row(hash=1358996357)] + [Row(hash=-757602832)] """ sc = SparkContext._active_spark_context jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column)) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index b8d3c49100..1a351933a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -566,10 +566,6 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); } - public int hashCode(int seed) { - return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed); - } - @Override public boolean equals(Object other) { if (other instanceof UnsafeRow) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index cc406a39f0..4751fbe414 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -25,8 +25,11 @@ import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData} import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.unsafe.hash.Murmur3_x86_32 +import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.unsafe.Platform /** * A function that calculates an MD5 128-bit checksum and returns it as a hex string @@ -184,8 +187,31 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp * A function that calculates hash value for a group of expressions. Note that the `seed` argument * is not exposed to users and should only be set inside spark SQL. * - * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of - * the unsafe row using murmur3 hasher with a seed. + * The hash value for an expression depends on its type and seed: + * - null: seed + * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to + * hash this int with seed. + * - byte, short, int: use murmur3 to hash the input as int with seed. + * - long: use murmur3 to hash the long input with seed. + * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it. + * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it. + * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash + * it. Else, turn it into bytes and hash it. + * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`. + * - binary: use murmur3 to hash the bytes with seed. + * - string: get the bytes of string and hash it. + * - array: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each element, and assign the element hash value + * to `result`. + * - map: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each key-value, and assign the key-value hash + * value to `result`. + * - struct: The `result` starts with seed, then use `result` as seed, recursively + * calculate hash value for each field, and assign the field hash value to + * `result`. + * + * Finally we aggregate the hash values for each expression by the same way of struct. + * * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle * and bucketing have same data distribution. */ @@ -206,22 +232,225 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression } } - private lazy val unsafeProjection = UnsafeProjection.create(children) + override def prettyName: String = "hash" + + override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" override def eval(input: InternalRow): Any = { - unsafeProjection(input).hashCode(seed) + var hash = seed + var i = 0 + val len = children.length + while (i < len) { + hash = computeHash(children(i).eval(input), children(i).dataType, hash) + i += 1 + } + hash } + private def computeHash(value: Any, dataType: DataType, seed: Int): Int = { + def hashInt(i: Int): Int = Murmur3_x86_32.hashInt(i, seed) + def hashLong(l: Long): Int = Murmur3_x86_32.hashLong(l, seed) + + value match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0) + case b: Byte => hashInt(b) + case s: Short => hashInt(s) + case i: Int => hashInt(i) + case l: Long => hashLong(l) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f)) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d)) + case d: Decimal => + val precision = dataType.asInstanceOf[DecimalType].precision + if (precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(d.toUnscaledLong) + } else { + val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray + Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed) + } + case c: CalendarInterval => Murmur3_x86_32.hashInt(c.months, hashLong(c.microseconds)) + case a: Array[Byte] => + Murmur3_x86_32.hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed) + case s: UTF8String => + Murmur3_x86_32.hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed) + + case array: ArrayData => + val elementType = dataType match { + case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType + case ArrayType(et, _) => et + } + var result = seed + var i = 0 + while (i < array.numElements()) { + result = computeHash(array.get(i, elementType), elementType, result) + i += 1 + } + result + + case map: MapData => + val (kt, vt) = dataType match { + case udt: UserDefinedType[_] => + val mapType = udt.sqlType.asInstanceOf[MapType] + mapType.keyType -> mapType.valueType + case MapType(kt, vt, _) => kt -> vt + } + val keys = map.keyArray() + val values = map.valueArray() + var result = seed + var i = 0 + while (i < map.numElements()) { + result = computeHash(keys.get(i, kt), kt, result) + result = computeHash(values.get(i, vt), vt, result) + i += 1 + } + result + + case struct: InternalRow => + val types: Array[DataType] = dataType match { + case udt: UserDefinedType[_] => + udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray + case StructType(fields) => fields.map(_.dataType) + } + var result = seed + var i = 0 + val len = struct.numFields + while (i < len) { + result = computeHash(struct.get(i, types(i)), types(i), result) + i += 1 + } + result + } + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children) ev.isNull = "false" + val childrenHash = children.zipWithIndex.map { + case (child, dt) => + val childGen = child.gen(ctx) + val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx) + s""" + ${childGen.code} + if (!${childGen.isNull}) { + ${childHash.code} + ${ev.value} = ${childHash.value}; + } + """ + }.mkString("\n") s""" - ${unsafeRow.code} - final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); + int ${ev.value} = $seed; + $childrenHash """ } - override def prettyName: String = "hash" - - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" + private def computeHash( + input: String, + dataType: DataType, + seed: String, + ctx: CodeGenContext): GeneratedExpressionCode = { + val hasher = classOf[Murmur3_x86_32].getName + def hashInt(i: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashInt($i, $seed)") + def hashLong(l: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashLong($l, $seed)") + def inlineValue(v: String): GeneratedExpressionCode = + GeneratedExpressionCode(code = "", isNull = "false", value = v) + + dataType match { + case NullType => inlineValue(seed) + case BooleanType => hashInt(s"$input ? 1 : 0") + case ByteType | ShortType | IntegerType | DateType => hashInt(input) + case LongType | TimestampType => hashLong(input) + case FloatType => hashInt(s"Float.floatToIntBits($input)") + case DoubleType => hashLong(s"Double.doubleToLongBits($input)") + case d: DecimalType => + if (d.precision <= Decimal.MAX_LONG_DIGITS) { + hashLong(s"$input.toUnscaledLong()") + } else { + val bytes = ctx.freshName("bytes") + val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();" + val offset = "Platform.BYTE_ARRAY_OFFSET" + val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)" + GeneratedExpressionCode(code, "false", result) + } + case CalendarIntervalType => + val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)" + val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)" + inlineValue(monthsHash) + case BinaryType => + val offset = "Platform.BYTE_ARRAY_OFFSET" + inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)") + case StringType => + val baseObject = s"$input.getBaseObject()" + val baseOffset = s"$input.getBaseOffset()" + val numBytes = s"$input.numBytes()" + inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)") + + case ArrayType(et, _) => + val result = ctx.freshName("result") + val index = ctx.freshName("index") + val element = ctx.freshName("element") + val elementHash = computeHash(element, et, result, ctx) + val code = + s""" + int $result = $seed; + for (int $index = 0; $index < $input.numElements(); $index++) { + if (!$input.isNullAt($index)) { + final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)}; + ${elementHash.code} + $result = ${elementHash.value}; + } + } + """ + GeneratedExpressionCode(code, "false", result) + + case MapType(kt, vt, _) => + val result = ctx.freshName("result") + val index = ctx.freshName("index") + val keys = ctx.freshName("keys") + val values = ctx.freshName("values") + val key = ctx.freshName("key") + val value = ctx.freshName("value") + val keyHash = computeHash(key, kt, result, ctx) + val valueHash = computeHash(value, vt, result, ctx) + val code = + s""" + int $result = $seed; + final ArrayData $keys = $input.keyArray(); + final ArrayData $values = $input.valueArray(); + for (int $index = 0; $index < $input.numElements(); $index++) { + final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)}; + ${keyHash.code} + $result = ${keyHash.value}; + if (!$values.isNullAt($index)) { + final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)}; + ${valueHash.code} + $result = ${valueHash.value}; + } + } + """ + GeneratedExpressionCode(code, "false", result) + + case StructType(fields) => + val result = ctx.freshName("result") + val fieldsHash = fields.map(_.dataType).zipWithIndex.map { + case (dt, index) => + val field = ctx.freshName("field") + val fieldHash = computeHash(field, dt, result, ctx) + s""" + if (!$input.isNullAt($index)) { + final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)}; + ${fieldHash.code} + $result = ${fieldHash.value}; + } + """ + }.mkString("\n") + val code = + s""" + int $result = $seed; + $fieldsHash + """ + GeneratedExpressionCode(code, "false", result) + + case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx) + } + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 64161bebdc..75131a6170 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -79,7 +79,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { .add("long", LongType) .add("float", FloatType) .add("double", DoubleType) - .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("bigDecimal", DecimalType.SYSTEM_DEFAULT) + .add("smallDecimal", DecimalType.USER_DEFAULT) .add("string", StringType) .add("binary", BinaryType) .add("date", DateType) @@ -126,7 +127,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { case (value, dt) => Literal.create(value, dt) } - checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed)) + // Only test the interpreted version has same result with codegen version. + checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval()) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 7f1745705a..b718b7cefb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -70,6 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private def testBucketing( dataDir: File, source: String, @@ -82,27 +85,30 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(groupedBucketFiles.size <= 8) for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) - .select((bucketCols ++ sortCols).map(col): _*) + for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) { + val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) + val columns = (bucketCols ++ sortCols).zip(types).map { + case (colName, dt) => col(colName).cast(dt) + } + val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*) if (sortCols.nonEmpty) { - checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) } - val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + val qe = readBack.select(bucketCols.map(col): _*).queryExecution + val rows = qe.toRdd.map(_.copy()).collect() + val getHashCode = + UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output) for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 + val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8) assert(actualBucketId == bucketId) } } } } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - test("write bucketed data") { for (source <- Seq("parquet", "json", "orc")) { withTable("bucketed_table") { diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java index 4276f25c21..5e7ee480ca 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java @@ -38,6 +38,10 @@ public final class Murmur3_x86_32 { } public int hashInt(int input) { + return hashInt(input, seed); + } + + public static int hashInt(int input, int seed) { int k1 = mixK1(input); int h1 = mixH1(seed, k1); @@ -51,16 +55,38 @@ public final class Murmur3_x86_32 { public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) { // This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method. assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)"; + int h1 = hashBytesByInt(base, offset, lengthInBytes, seed); + return fmix(h1, lengthInBytes); + } + + public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes >= 0): "lengthInBytes cannot be negative"; + int lengthAligned = lengthInBytes - lengthInBytes % 4; + int h1 = hashBytesByInt(base, offset, lengthAligned, seed); + for (int i = lengthAligned; i < lengthInBytes; i++) { + int halfWord = Platform.getByte(base, offset + i); + int k1 = mixK1(halfWord); + h1 = mixH1(h1, k1); + } + return fmix(h1, lengthInBytes); + } + + private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) { + assert (lengthInBytes % 4 == 0); int h1 = seed; for (int i = 0; i < lengthInBytes; i += 4) { int halfWord = Platform.getInt(base, offset + i); int k1 = mixK1(halfWord); h1 = mixH1(h1, k1); } - return fmix(h1, lengthInBytes); + return h1; } public int hashLong(long input) { + return hashLong(input, seed); + } + + public static int hashLong(long input, int seed) { int low = (int) input; int high = (int) (input >>> 32); -- cgit v1.2.3