aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-13 12:29:02 -0800
committerReynold Xin <rxin@databricks.com>2016-01-13 12:29:02 -0800
commitc2ea79f96acd076351b48162644ed1cff4c8e090 (patch)
tree55ca22bdd84dac3cb225cd2b9bddaf0c11c93d19
parente4e0b3f7b2945aae5ec7c3d68296010bbc5160cf (diff)
downloadspark-c2ea79f96acd076351b48162644ed1cff4c8e090.tar.gz
spark-c2ea79f96acd076351b48162644ed1cff4c8e090.tar.bz2
spark-c2ea79f96acd076351b48162644ed1cff4c8e090.zip
[SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row
https://issues.apache.org/jira/browse/SPARK-12642 Author: Wenchen Fan <wenchen@databricks.com> Closes #10694 from cloud-fan/hash-expr.
-rw-r--r--python/pyspark/sql/functions.py2
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala251
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala26
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java28
6 files changed, 288 insertions, 29 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index b0390cb994..719eca8f55 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1023,7 +1023,7 @@ def hash(*cols):
"""Calculates the hash code of given columns, and returns the result as a int column.
>>> sqlContext.createDataFrame([('ABC',)], ['a']).select(hash('a').alias('hash')).collect()
- [Row(hash=1358996357)]
+ [Row(hash=-757602832)]
"""
sc = SparkContext._active_spark_context
jc = sc._jvm.functions.hash(_to_seq(sc, cols, _to_java_column))
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index b8d3c49100..1a351933a3 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -566,10 +566,6 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}
- public int hashCode(int seed) {
- return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed);
- }
-
@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index cc406a39f0..4751fbe414 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -25,8 +25,11 @@ import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
+import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
-import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.unsafe.hash.Murmur3_x86_32
+import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.unsafe.Platform
/**
* A function that calculates an MD5 128-bit checksum and returns it as a hex string
@@ -184,8 +187,31 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
* A function that calculates hash value for a group of expressions. Note that the `seed` argument
* is not exposed to users and should only be set inside spark SQL.
*
- * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of
- * the unsafe row using murmur3 hasher with a seed.
+ * The hash value for an expression depends on its type and seed:
+ * - null: seed
+ * - boolean: turn boolean into int, 1 for true, 0 for false, and then use murmur3 to
+ * hash this int with seed.
+ * - byte, short, int: use murmur3 to hash the input as int with seed.
+ * - long: use murmur3 to hash the long input with seed.
+ * - float: turn it into int: java.lang.Float.floatToIntBits(input), and hash it.
+ * - double: turn it into long: java.lang.Double.doubleToLongBits(input), and hash it.
+ * - decimal: if it's a small decimal, i.e. precision <= 18, turn it into long and hash
+ * it. Else, turn it into bytes and hash it.
+ * - calendar interval: hash `microseconds` first, and use the result as seed to hash `months`.
+ * - binary: use murmur3 to hash the bytes with seed.
+ * - string: get the bytes of string and hash it.
+ * - array: The `result` starts with seed, then use `result` as seed, recursively
+ * calculate hash value for each element, and assign the element hash value
+ * to `result`.
+ * - map: The `result` starts with seed, then use `result` as seed, recursively
+ * calculate hash value for each key-value, and assign the key-value hash
+ * value to `result`.
+ * - struct: The `result` starts with seed, then use `result` as seed, recursively
+ * calculate hash value for each field, and assign the field hash value to
+ * `result`.
+ *
+ * Finally we aggregate the hash values for each expression by the same way of struct.
+ *
* We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
* and bucketing have same data distribution.
*/
@@ -206,22 +232,225 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
}
}
- private lazy val unsafeProjection = UnsafeProjection.create(children)
+ override def prettyName: String = "hash"
+
+ override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
override def eval(input: InternalRow): Any = {
- unsafeProjection(input).hashCode(seed)
+ var hash = seed
+ var i = 0
+ val len = children.length
+ while (i < len) {
+ hash = computeHash(children(i).eval(input), children(i).dataType, hash)
+ i += 1
+ }
+ hash
}
+ private def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
+ def hashInt(i: Int): Int = Murmur3_x86_32.hashInt(i, seed)
+ def hashLong(l: Long): Int = Murmur3_x86_32.hashLong(l, seed)
+
+ value match {
+ case null => seed
+ case b: Boolean => hashInt(if (b) 1 else 0)
+ case b: Byte => hashInt(b)
+ case s: Short => hashInt(s)
+ case i: Int => hashInt(i)
+ case l: Long => hashLong(l)
+ case f: Float => hashInt(java.lang.Float.floatToIntBits(f))
+ case d: Double => hashLong(java.lang.Double.doubleToLongBits(d))
+ case d: Decimal =>
+ val precision = dataType.asInstanceOf[DecimalType].precision
+ if (precision <= Decimal.MAX_LONG_DIGITS) {
+ hashLong(d.toUnscaledLong)
+ } else {
+ val bytes = d.toJavaBigDecimal.unscaledValue().toByteArray
+ Murmur3_x86_32.hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed)
+ }
+ case c: CalendarInterval => Murmur3_x86_32.hashInt(c.months, hashLong(c.microseconds))
+ case a: Array[Byte] =>
+ Murmur3_x86_32.hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
+ case s: UTF8String =>
+ Murmur3_x86_32.hashUnsafeBytes(s.getBaseObject, s.getBaseOffset, s.numBytes(), seed)
+
+ case array: ArrayData =>
+ val elementType = dataType match {
+ case udt: UserDefinedType[_] => udt.sqlType.asInstanceOf[ArrayType].elementType
+ case ArrayType(et, _) => et
+ }
+ var result = seed
+ var i = 0
+ while (i < array.numElements()) {
+ result = computeHash(array.get(i, elementType), elementType, result)
+ i += 1
+ }
+ result
+
+ case map: MapData =>
+ val (kt, vt) = dataType match {
+ case udt: UserDefinedType[_] =>
+ val mapType = udt.sqlType.asInstanceOf[MapType]
+ mapType.keyType -> mapType.valueType
+ case MapType(kt, vt, _) => kt -> vt
+ }
+ val keys = map.keyArray()
+ val values = map.valueArray()
+ var result = seed
+ var i = 0
+ while (i < map.numElements()) {
+ result = computeHash(keys.get(i, kt), kt, result)
+ result = computeHash(values.get(i, vt), vt, result)
+ i += 1
+ }
+ result
+
+ case struct: InternalRow =>
+ val types: Array[DataType] = dataType match {
+ case udt: UserDefinedType[_] =>
+ udt.sqlType.asInstanceOf[StructType].map(_.dataType).toArray
+ case StructType(fields) => fields.map(_.dataType)
+ }
+ var result = seed
+ var i = 0
+ val len = struct.numFields
+ while (i < len) {
+ result = computeHash(struct.get(i, types(i)), types(i), result)
+ i += 1
+ }
+ result
+ }
+ }
+
+
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children)
ev.isNull = "false"
+ val childrenHash = children.zipWithIndex.map {
+ case (child, dt) =>
+ val childGen = child.gen(ctx)
+ val childHash = computeHash(childGen.value, child.dataType, ev.value, ctx)
+ s"""
+ ${childGen.code}
+ if (!${childGen.isNull}) {
+ ${childHash.code}
+ ${ev.value} = ${childHash.value};
+ }
+ """
+ }.mkString("\n")
s"""
- ${unsafeRow.code}
- final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
+ int ${ev.value} = $seed;
+ $childrenHash
"""
}
- override def prettyName: String = "hash"
-
- override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
+ private def computeHash(
+ input: String,
+ dataType: DataType,
+ seed: String,
+ ctx: CodeGenContext): GeneratedExpressionCode = {
+ val hasher = classOf[Murmur3_x86_32].getName
+ def hashInt(i: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashInt($i, $seed)")
+ def hashLong(l: String): GeneratedExpressionCode = inlineValue(s"$hasher.hashLong($l, $seed)")
+ def inlineValue(v: String): GeneratedExpressionCode =
+ GeneratedExpressionCode(code = "", isNull = "false", value = v)
+
+ dataType match {
+ case NullType => inlineValue(seed)
+ case BooleanType => hashInt(s"$input ? 1 : 0")
+ case ByteType | ShortType | IntegerType | DateType => hashInt(input)
+ case LongType | TimestampType => hashLong(input)
+ case FloatType => hashInt(s"Float.floatToIntBits($input)")
+ case DoubleType => hashLong(s"Double.doubleToLongBits($input)")
+ case d: DecimalType =>
+ if (d.precision <= Decimal.MAX_LONG_DIGITS) {
+ hashLong(s"$input.toUnscaledLong()")
+ } else {
+ val bytes = ctx.freshName("bytes")
+ val code = s"byte[] $bytes = $input.toJavaBigDecimal().unscaledValue().toByteArray();"
+ val offset = "Platform.BYTE_ARRAY_OFFSET"
+ val result = s"$hasher.hashUnsafeBytes($bytes, $offset, $bytes.length, $seed)"
+ GeneratedExpressionCode(code, "false", result)
+ }
+ case CalendarIntervalType =>
+ val microsecondsHash = s"$hasher.hashLong($input.microseconds, $seed)"
+ val monthsHash = s"$hasher.hashInt($input.months, $microsecondsHash)"
+ inlineValue(monthsHash)
+ case BinaryType =>
+ val offset = "Platform.BYTE_ARRAY_OFFSET"
+ inlineValue(s"$hasher.hashUnsafeBytes($input, $offset, $input.length, $seed)")
+ case StringType =>
+ val baseObject = s"$input.getBaseObject()"
+ val baseOffset = s"$input.getBaseOffset()"
+ val numBytes = s"$input.numBytes()"
+ inlineValue(s"$hasher.hashUnsafeBytes($baseObject, $baseOffset, $numBytes, $seed)")
+
+ case ArrayType(et, _) =>
+ val result = ctx.freshName("result")
+ val index = ctx.freshName("index")
+ val element = ctx.freshName("element")
+ val elementHash = computeHash(element, et, result, ctx)
+ val code =
+ s"""
+ int $result = $seed;
+ for (int $index = 0; $index < $input.numElements(); $index++) {
+ if (!$input.isNullAt($index)) {
+ final ${ctx.javaType(et)} $element = ${ctx.getValue(input, et, index)};
+ ${elementHash.code}
+ $result = ${elementHash.value};
+ }
+ }
+ """
+ GeneratedExpressionCode(code, "false", result)
+
+ case MapType(kt, vt, _) =>
+ val result = ctx.freshName("result")
+ val index = ctx.freshName("index")
+ val keys = ctx.freshName("keys")
+ val values = ctx.freshName("values")
+ val key = ctx.freshName("key")
+ val value = ctx.freshName("value")
+ val keyHash = computeHash(key, kt, result, ctx)
+ val valueHash = computeHash(value, vt, result, ctx)
+ val code =
+ s"""
+ int $result = $seed;
+ final ArrayData $keys = $input.keyArray();
+ final ArrayData $values = $input.valueArray();
+ for (int $index = 0; $index < $input.numElements(); $index++) {
+ final ${ctx.javaType(kt)} $key = ${ctx.getValue(keys, kt, index)};
+ ${keyHash.code}
+ $result = ${keyHash.value};
+ if (!$values.isNullAt($index)) {
+ final ${ctx.javaType(vt)} $value = ${ctx.getValue(values, vt, index)};
+ ${valueHash.code}
+ $result = ${valueHash.value};
+ }
+ }
+ """
+ GeneratedExpressionCode(code, "false", result)
+
+ case StructType(fields) =>
+ val result = ctx.freshName("result")
+ val fieldsHash = fields.map(_.dataType).zipWithIndex.map {
+ case (dt, index) =>
+ val field = ctx.freshName("field")
+ val fieldHash = computeHash(field, dt, result, ctx)
+ s"""
+ if (!$input.isNullAt($index)) {
+ final ${ctx.javaType(dt)} $field = ${ctx.getValue(input, dt, index.toString)};
+ ${fieldHash.code}
+ $result = ${fieldHash.value};
+ }
+ """
+ }.mkString("\n")
+ val code =
+ s"""
+ int $result = $seed;
+ $fieldsHash
+ """
+ GeneratedExpressionCode(code, "false", result)
+
+ case udt: UserDefinedType[_] => computeHash(input, udt.sqlType, seed, ctx)
+ }
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 64161bebdc..75131a6170 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -79,7 +79,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
.add("long", LongType)
.add("float", FloatType)
.add("double", DoubleType)
- .add("decimal", DecimalType.SYSTEM_DEFAULT)
+ .add("bigDecimal", DecimalType.SYSTEM_DEFAULT)
+ .add("smallDecimal", DecimalType.USER_DEFAULT)
.add("string", StringType)
.add("binary", BinaryType)
.add("date", DateType)
@@ -126,7 +127,8 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
case (value, dt) => Literal.create(value, dt)
}
- checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed))
+ // Only test the interpreted version has same result with codegen version.
+ checkEvaluation(Murmur3Hash(literals, seed), Murmur3Hash(literals, seed).eval())
}
}
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 7f1745705a..b718b7cefb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.sources
import java.io.File
import org.apache.spark.sql.{AnalysisException, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
@@ -70,6 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
}
}
+ private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+
private def testBucketing(
dataDir: File,
source: String,
@@ -82,27 +85,30 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
assert(groupedBucketFiles.size <= 8)
for ((bucketId, bucketFiles) <- groupedBucketFiles) {
- for (bucketFile <- bucketFiles) {
- val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath)
- .select((bucketCols ++ sortCols).map(col): _*)
+ for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
+ val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
+ val columns = (bucketCols ++ sortCols).zip(types).map {
+ case (colName, dt) => col(colName).cast(dt)
+ }
+ val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
if (sortCols.nonEmpty) {
- checkAnswer(df.sort(sortCols.map(col): _*), df.collect())
+ checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
}
- val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect()
+ val qe = readBack.select(bucketCols.map(col): _*).queryExecution
+ val rows = qe.toRdd.map(_.copy()).collect()
+ val getHashCode =
+ UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output)
for (row <- rows) {
- assert(row.isInstanceOf[UnsafeRow])
- val actualBucketId = (row.hashCode() % 8 + 8) % 8
+ val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8)
assert(actualBucketId == bucketId)
}
}
}
}
- private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
-
test("write bucketed data") {
for (source <- Seq("parquet", "json", "orc")) {
withTable("bucketed_table") {
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
index 4276f25c21..5e7ee480ca 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/hash/Murmur3_x86_32.java
@@ -38,6 +38,10 @@ public final class Murmur3_x86_32 {
}
public int hashInt(int input) {
+ return hashInt(input, seed);
+ }
+
+ public static int hashInt(int input, int seed) {
int k1 = mixK1(input);
int h1 = mixH1(seed, k1);
@@ -51,16 +55,38 @@ public final class Murmur3_x86_32 {
public static int hashUnsafeWords(Object base, long offset, int lengthInBytes, int seed) {
// This is based on Guava's `Murmur32_Hasher.processRemaining(ByteBuffer)` method.
assert (lengthInBytes % 8 == 0): "lengthInBytes must be a multiple of 8 (word-aligned)";
+ int h1 = hashBytesByInt(base, offset, lengthInBytes, seed);
+ return fmix(h1, lengthInBytes);
+ }
+
+ public static int hashUnsafeBytes(Object base, long offset, int lengthInBytes, int seed) {
+ assert (lengthInBytes >= 0): "lengthInBytes cannot be negative";
+ int lengthAligned = lengthInBytes - lengthInBytes % 4;
+ int h1 = hashBytesByInt(base, offset, lengthAligned, seed);
+ for (int i = lengthAligned; i < lengthInBytes; i++) {
+ int halfWord = Platform.getByte(base, offset + i);
+ int k1 = mixK1(halfWord);
+ h1 = mixH1(h1, k1);
+ }
+ return fmix(h1, lengthInBytes);
+ }
+
+ private static int hashBytesByInt(Object base, long offset, int lengthInBytes, int seed) {
+ assert (lengthInBytes % 4 == 0);
int h1 = seed;
for (int i = 0; i < lengthInBytes; i += 4) {
int halfWord = Platform.getInt(base, offset + i);
int k1 = mixK1(halfWord);
h1 = mixH1(h1, k1);
}
- return fmix(h1, lengthInBytes);
+ return h1;
}
public int hashLong(long input) {
+ return hashLong(input, seed);
+ }
+
+ public static int hashLong(long input, int seed) {
int low = (int) input;
int high = (int) (input >>> 32);