aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala247
3 files changed, 252 insertions, 8 deletions
diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
index c7ea9085eb..73577437ac 100644
--- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
+++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/expressions/HiveHasher.java
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.expressions;
import org.apache.spark.unsafe.Platform;
/**
- * Simulates Hive's hashing function at
+ * Simulates Hive's hashing function from Hive v1.2.1
* org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*/
public class HiveHasher {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index e14f0544c2..2d9c2e4206 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction {
}
}
-
/**
- * Simulates Hive's hashing function at
- * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive
+ * Simulates Hive's hashing function from Hive v1.2.1 at
+ * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode()
*
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
@@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override protected def hasherClassName: String = classOf[HiveHasher].getName
override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = {
- HiveHashFunction.hash(value, dataType, seed).toInt
+ HiveHashFunction.hash(value, dataType, this.seed).toInt
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
@@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction {
var i = 0
val length = struct.numFields
while (i < length) {
- result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt
+ result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt
i += 1
}
result
- case _ => super.hash(value, dataType, seed)
+ case _ => super.hash(value, dataType, 0)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 0326292652..0cb3a79eee 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions
import java.nio.charset.StandardCharsets
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
-import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
+import org.apache.spark.sql.types.{ArrayType, StructType, _}
import org.apache.spark.unsafe.types.UTF8String
class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
+ val random = new scala.util.Random
test("md5") {
checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))),
@@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
+
+ def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
+ // Note : All expected hashes need to be computed using Hive 1.2.1
+ val actual = HiveHashFunction.hash(input, dataType, seed = 0)
+
+ withClue(s"hash mismatch for input = `$input` of type `$dataType`.") {
+ assert(actual == expected)
+ }
+ }
+
+ def checkHiveHashForIntegralType(dataType: DataType): Unit = {
+ // corner cases
+ checkHiveHash(null, dataType, 0)
+ checkHiveHash(1, dataType, 1)
+ checkHiveHash(0, dataType, 0)
+ checkHiveHash(-1, dataType, -1)
+ checkHiveHash(Int.MaxValue, dataType, Int.MaxValue)
+ checkHiveHash(Int.MinValue, dataType, Int.MinValue)
+
+ // random values
+ for (_ <- 0 until 10) {
+ val input = random.nextInt()
+ checkHiveHash(input, dataType, input)
+ }
+ }
+
+ test("hive-hash for null") {
+ checkHiveHash(null, NullType, 0)
+ }
+
+ test("hive-hash for boolean") {
+ checkHiveHash(true, BooleanType, 1)
+ checkHiveHash(false, BooleanType, 0)
+ }
+
+ test("hive-hash for byte") {
+ checkHiveHashForIntegralType(ByteType)
+ }
+
+ test("hive-hash for short") {
+ checkHiveHashForIntegralType(ShortType)
+ }
+
+ test("hive-hash for int") {
+ checkHiveHashForIntegralType(IntegerType)
+ }
+
+ test("hive-hash for long") {
+ checkHiveHash(1L, LongType, 1L)
+ checkHiveHash(0L, LongType, 0L)
+ checkHiveHash(-1L, LongType, 0L)
+ checkHiveHash(Long.MaxValue, LongType, -2147483648)
+ // Hive's fails to parse this.. but the hashing function itself can handle this input
+ checkHiveHash(Long.MinValue, LongType, -2147483648)
+
+ for (_ <- 0 until 10) {
+ val input = random.nextLong()
+ checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt)
+ }
+ }
+
+ test("hive-hash for float") {
+ checkHiveHash(0F, FloatType, 0)
+ checkHiveHash(0.0F, FloatType, 0)
+ checkHiveHash(1.1F, FloatType, 1066192077L)
+ checkHiveHash(-1.1F, FloatType, -1081291571)
+ checkHiveHash(99999999.99999999999F, FloatType, 1287568416L)
+ checkHiveHash(Float.MaxValue, FloatType, 2139095039)
+ checkHiveHash(Float.MinValue, FloatType, -8388609)
+ }
+
+ test("hive-hash for double") {
+ checkHiveHash(0, DoubleType, 0)
+ checkHiveHash(0.0, DoubleType, 0)
+ checkHiveHash(1.1, DoubleType, -1503133693)
+ checkHiveHash(-1.1, DoubleType, 644349955)
+ checkHiveHash(1000000000.000001, DoubleType, 1104006509)
+ checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501)
+ checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676)
+ checkHiveHash(Double.MaxValue, DoubleType, -2146435072)
+ checkHiveHash(Double.MinValue, DoubleType, 1048576)
+ }
+
+ test("hive-hash for string") {
+ checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L)
+ checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L)
+ checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L)
+ checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L)
+ // scalastyle:off nonascii
+ checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L)
+ checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L)
+ // scalastyle:on nonascii
+ }
+
+ test("hive-hash for array") {
+ // empty array
+ checkHiveHash(
+ input = new GenericArrayData(Array[Int]()),
+ dataType = ArrayType(IntegerType, containsNull = false),
+ expected = 0)
+
+ // basic case
+ checkHiveHash(
+ input = new GenericArrayData(Array(1, 10000, Int.MaxValue)),
+ dataType = ArrayType(IntegerType, containsNull = false),
+ expected = -2147172688L)
+
+ // with negative values
+ checkHiveHash(
+ input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)),
+ dataType = ArrayType(LongType, containsNull = false),
+ expected = -2147452680L)
+
+ // with nulls only
+ val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true)
+ checkHiveHash(
+ input = new GenericArrayData(Array(null, null)),
+ dataType = arrayTypeWithNull,
+ expected = 0)
+
+ // mix with null
+ checkHiveHash(
+ input = new GenericArrayData(Array(-12221, 89, null, 767)),
+ dataType = arrayTypeWithNull,
+ expected = -363989515)
+
+ // nested with array
+ checkHiveHash(
+ input = new GenericArrayData(
+ Array(
+ new GenericArrayData(Array(1234L, -9L, 67L)),
+ new GenericArrayData(Array(null, null)),
+ new GenericArrayData(Array(55L, -100L, -2147452680L))
+ )),
+ dataType = ArrayType(ArrayType(LongType)),
+ expected = -1007531064)
+
+ // nested with map
+ checkHiveHash(
+ input = new GenericArrayData(
+ Array(
+ new ArrayBasedMapData(
+ new GenericArrayData(Array(-99, 1234)),
+ new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
+ new ArrayBasedMapData(
+ new GenericArrayData(Array(67)),
+ new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
+ )),
+ dataType = ArrayType(MapType(IntegerType, StringType)),
+ expected = 1139205955)
+ }
+
+ test("hive-hash for map") {
+ val mapType = MapType(IntegerType, StringType)
+
+ // empty map
+ checkHiveHash(
+ input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())),
+ dataType = mapType,
+ expected = 0)
+
+ // basic case
+ checkHiveHash(
+ input = new ArrayBasedMapData(
+ new GenericArrayData(Array(1, 2)),
+ new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))),
+ dataType = mapType,
+ expected = 198872)
+
+ // with null value
+ checkHiveHash(
+ input = new ArrayBasedMapData(
+ new GenericArrayData(Array(55, -99)),
+ new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))),
+ dataType = mapType,
+ expected = 1142704473)
+
+ // nesting (only values can be nested as keys have to be primitive datatype)
+ val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType))
+ checkHiveHash(
+ input = new ArrayBasedMapData(
+ new GenericArrayData(Array(1, -100)),
+ new GenericArrayData(
+ Array(
+ new ArrayBasedMapData(
+ new GenericArrayData(Array(-99, 1234)),
+ new GenericArrayData(Array(UTF8String.fromString("sql"), null))),
+ new ArrayBasedMapData(
+ new GenericArrayData(Array(67)),
+ new GenericArrayData(Array(UTF8String.fromString("apache spark"))))
+ ))),
+ dataType = nestedMapType,
+ expected = -1142817416)
+ }
+
+ test("hive-hash for struct") {
+ // basic
+ val row = new GenericInternalRow(Array[Any](1, 2, 3))
+ checkHiveHash(
+ input = row,
+ dataType =
+ new StructType()
+ .add("col1", IntegerType)
+ .add("col2", IntegerType)
+ .add("col3", IntegerType),
+ expected = 1026)
+
+ // mix of several datatypes
+ val structType = new StructType()
+ .add("null", NullType)
+ .add("boolean", BooleanType)
+ .add("byte", ByteType)
+ .add("short", ShortType)
+ .add("int", IntegerType)
+ .add("long", LongType)
+ .add("arrayOfString", arrayOfString)
+ .add("mapOfString", mapOfString)
+
+ val rowValues = new ArrayBuffer[Any]()
+ rowValues += null
+ rowValues += true
+ rowValues += 1
+ rowValues += 2
+ rowValues += Int.MaxValue
+ rowValues += Long.MinValue
+ rowValues += new GenericArrayData(Array(
+ UTF8String.fromString("apache spark"),
+ UTF8String.fromString("hello world")
+ ))
+ rowValues += new ArrayBasedMapData(
+ new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))),
+ new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))
+ )
+
+ val row2 = new GenericInternalRow(rowValues.toArray)
+ checkHiveHash(
+ input = row2,
+ dataType = structType,
+ expected = -2119012447)
+ }
+
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)