From 3e40f6c3d6fc0bcd828d09031fa3994925394889 Mon Sep 17 00:00:00 2001 From: Tejas Patil Date: Fri, 24 Feb 2017 09:46:42 -0800 Subject: [SPARK-17495][SQL] Add more tests for hive hash ## What changes were proposed in this pull request? This PR adds tests hive-hash by comparing the outputs generated against Hive 1.2.1. Following datatypes are covered by this PR: - null - boolean - byte - short - int - long - float - double - string - array - map - struct Datatypes that I have _NOT_ covered but I will work on separately are: - Decimal (handled separately in https://github.com/apache/spark/pull/17056) - TimestampType - DateType - CalendarIntervalType ## How was this patch tested? NA Author: Tejas Patil Closes #17049 from tejasapatil/SPARK-17495_remaining_types. --- .../spark/sql/catalyst/expressions/hash.scala | 11 +- .../expressions/HashExpressionsSuite.scala | 247 ++++++++++++++++++++- 2 files changed, 251 insertions(+), 7 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala index e14f0544c2..2d9c2e4206 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala @@ -573,10 +573,9 @@ object XxHash64Function extends InterpretedHashFunction { } } - /** - * Simulates Hive's hashing function at - * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() in Hive + * Simulates Hive's hashing function from Hive v1.2.1 at + * org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils#hashcode() * * We should use this hash function for both shuffle and bucket of Hive tables, so that * we can guarantee shuffle and bucketing have same data distribution @@ -595,7 +594,7 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] { override protected def hasherClassName: String = classOf[HiveHasher].getName override protected def computeHash(value: Any, dataType: DataType, seed: Int): Int = { - HiveHashFunction.hash(value, dataType, seed).toInt + HiveHashFunction.hash(value, dataType, this.seed).toInt } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -781,12 +780,12 @@ object HiveHashFunction extends InterpretedHashFunction { var i = 0 val length = struct.numFields while (i < length) { - result = (31 * result) + hash(struct.get(i, types(i)), types(i), seed + 1).toInt + result = (31 * result) + hash(struct.get(i, types(i)), types(i), 0).toInt i += 1 } result - case _ => super.hash(value, dataType, seed) + case _ => super.hash(value, dataType, 0) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala index 0326292652..0cb3a79eee 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala @@ -19,16 +19,20 @@ package org.apache.spark.sql.catalyst.expressions import java.nio.charset.StandardCharsets +import scala.collection.mutable.ArrayBuffer + import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection -import org.apache.spark.sql.types._ +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} +import org.apache.spark.sql.types.{ArrayType, StructType, _} import org.apache.spark.unsafe.types.UTF8String class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { + val random = new scala.util.Random test("md5") { checkEvaluation(Md5(Literal("ABC".getBytes(StandardCharsets.UTF_8))), @@ -71,6 +75,247 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = { + // Note : All expected hashes need to be computed using Hive 1.2.1 + val actual = HiveHashFunction.hash(input, dataType, seed = 0) + + withClue(s"hash mismatch for input = `$input` of type `$dataType`.") { + assert(actual == expected) + } + } + + def checkHiveHashForIntegralType(dataType: DataType): Unit = { + // corner cases + checkHiveHash(null, dataType, 0) + checkHiveHash(1, dataType, 1) + checkHiveHash(0, dataType, 0) + checkHiveHash(-1, dataType, -1) + checkHiveHash(Int.MaxValue, dataType, Int.MaxValue) + checkHiveHash(Int.MinValue, dataType, Int.MinValue) + + // random values + for (_ <- 0 until 10) { + val input = random.nextInt() + checkHiveHash(input, dataType, input) + } + } + + test("hive-hash for null") { + checkHiveHash(null, NullType, 0) + } + + test("hive-hash for boolean") { + checkHiveHash(true, BooleanType, 1) + checkHiveHash(false, BooleanType, 0) + } + + test("hive-hash for byte") { + checkHiveHashForIntegralType(ByteType) + } + + test("hive-hash for short") { + checkHiveHashForIntegralType(ShortType) + } + + test("hive-hash for int") { + checkHiveHashForIntegralType(IntegerType) + } + + test("hive-hash for long") { + checkHiveHash(1L, LongType, 1L) + checkHiveHash(0L, LongType, 0L) + checkHiveHash(-1L, LongType, 0L) + checkHiveHash(Long.MaxValue, LongType, -2147483648) + // Hive's fails to parse this.. but the hashing function itself can handle this input + checkHiveHash(Long.MinValue, LongType, -2147483648) + + for (_ <- 0 until 10) { + val input = random.nextLong() + checkHiveHash(input, LongType, ((input >>> 32) ^ input).toInt) + } + } + + test("hive-hash for float") { + checkHiveHash(0F, FloatType, 0) + checkHiveHash(0.0F, FloatType, 0) + checkHiveHash(1.1F, FloatType, 1066192077L) + checkHiveHash(-1.1F, FloatType, -1081291571) + checkHiveHash(99999999.99999999999F, FloatType, 1287568416L) + checkHiveHash(Float.MaxValue, FloatType, 2139095039) + checkHiveHash(Float.MinValue, FloatType, -8388609) + } + + test("hive-hash for double") { + checkHiveHash(0, DoubleType, 0) + checkHiveHash(0.0, DoubleType, 0) + checkHiveHash(1.1, DoubleType, -1503133693) + checkHiveHash(-1.1, DoubleType, 644349955) + checkHiveHash(1000000000.000001, DoubleType, 1104006509) + checkHiveHash(1000000000.0000000000000000000000001, DoubleType, 1104006501) + checkHiveHash(9999999999999999999.9999999999999999999, DoubleType, 594568676) + checkHiveHash(Double.MaxValue, DoubleType, -2146435072) + checkHiveHash(Double.MinValue, DoubleType, 1048576) + } + + test("hive-hash for string") { + checkHiveHash(UTF8String.fromString("apache spark"), StringType, 1142704523L) + checkHiveHash(UTF8String.fromString("!@#$%^&*()_+=-"), StringType, -613724358L) + checkHiveHash(UTF8String.fromString("abcdefghijklmnopqrstuvwxyz"), StringType, 958031277L) + checkHiveHash(UTF8String.fromString("AbCdEfGhIjKlMnOpQrStUvWxYz012"), StringType, -648013852L) + // scalastyle:off nonascii + checkHiveHash(UTF8String.fromString("数据砖头"), StringType, -898686242L) + checkHiveHash(UTF8String.fromString("नमस्ते"), StringType, 2006045948L) + // scalastyle:on nonascii + } + + test("hive-hash for array") { + // empty array + checkHiveHash( + input = new GenericArrayData(Array[Int]()), + dataType = ArrayType(IntegerType, containsNull = false), + expected = 0) + + // basic case + checkHiveHash( + input = new GenericArrayData(Array(1, 10000, Int.MaxValue)), + dataType = ArrayType(IntegerType, containsNull = false), + expected = -2147172688L) + + // with negative values + checkHiveHash( + input = new GenericArrayData(Array(-1L, 0L, 999L, Int.MinValue.toLong)), + dataType = ArrayType(LongType, containsNull = false), + expected = -2147452680L) + + // with nulls only + val arrayTypeWithNull = ArrayType(IntegerType, containsNull = true) + checkHiveHash( + input = new GenericArrayData(Array(null, null)), + dataType = arrayTypeWithNull, + expected = 0) + + // mix with null + checkHiveHash( + input = new GenericArrayData(Array(-12221, 89, null, 767)), + dataType = arrayTypeWithNull, + expected = -363989515) + + // nested with array + checkHiveHash( + input = new GenericArrayData( + Array( + new GenericArrayData(Array(1234L, -9L, 67L)), + new GenericArrayData(Array(null, null)), + new GenericArrayData(Array(55L, -100L, -2147452680L)) + )), + dataType = ArrayType(ArrayType(LongType)), + expected = -1007531064) + + // nested with map + checkHiveHash( + input = new GenericArrayData( + Array( + new ArrayBasedMapData( + new GenericArrayData(Array(-99, 1234)), + new GenericArrayData(Array(UTF8String.fromString("sql"), null))), + new ArrayBasedMapData( + new GenericArrayData(Array(67)), + new GenericArrayData(Array(UTF8String.fromString("apache spark")))) + )), + dataType = ArrayType(MapType(IntegerType, StringType)), + expected = 1139205955) + } + + test("hive-hash for map") { + val mapType = MapType(IntegerType, StringType) + + // empty map + checkHiveHash( + input = new ArrayBasedMapData(new GenericArrayData(Array()), new GenericArrayData(Array())), + dataType = mapType, + expected = 0) + + // basic case + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(1, 2)), + new GenericArrayData(Array(UTF8String.fromString("foo"), UTF8String.fromString("bar")))), + dataType = mapType, + expected = 198872) + + // with null value + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(55, -99)), + new GenericArrayData(Array(UTF8String.fromString("apache spark"), null))), + dataType = mapType, + expected = 1142704473) + + // nesting (only values can be nested as keys have to be primitive datatype) + val nestedMapType = MapType(IntegerType, MapType(IntegerType, StringType)) + checkHiveHash( + input = new ArrayBasedMapData( + new GenericArrayData(Array(1, -100)), + new GenericArrayData( + Array( + new ArrayBasedMapData( + new GenericArrayData(Array(-99, 1234)), + new GenericArrayData(Array(UTF8String.fromString("sql"), null))), + new ArrayBasedMapData( + new GenericArrayData(Array(67)), + new GenericArrayData(Array(UTF8String.fromString("apache spark")))) + ))), + dataType = nestedMapType, + expected = -1142817416) + } + + test("hive-hash for struct") { + // basic + val row = new GenericInternalRow(Array[Any](1, 2, 3)) + checkHiveHash( + input = row, + dataType = + new StructType() + .add("col1", IntegerType) + .add("col2", IntegerType) + .add("col3", IntegerType), + expected = 1026) + + // mix of several datatypes + val structType = new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("arrayOfString", arrayOfString) + .add("mapOfString", mapOfString) + + val rowValues = new ArrayBuffer[Any]() + rowValues += null + rowValues += true + rowValues += 1 + rowValues += 2 + rowValues += Int.MaxValue + rowValues += Long.MinValue + rowValues += new GenericArrayData(Array( + UTF8String.fromString("apache spark"), + UTF8String.fromString("hello world") + )) + rowValues += new ArrayBasedMapData( + new GenericArrayData(Array(UTF8String.fromString("project"), UTF8String.fromString("meta"))), + new GenericArrayData(Array(UTF8String.fromString("apache spark"), null)) + ) + + val row2 = new GenericInternalRow(rowValues.toArray) + checkHiveHash( + input = row2, + dataType = structType, + expected = -2119012447) + } + private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) -- cgit v1.2.3