aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTejas Patil <tejasp@fb.com>2017-03-06 10:16:20 -0800
committerWenchen Fan <wenchen@databricks.com>2017-03-06 10:16:20 -0800
commit2a0bc867a4a1dad4ecac47701199e540d345ff4f (patch)
tree649c56d710dfa9c4b5cf6ba7cedc8bcba4c7d268
parent207067ead6db6dc87b0d144a658e2564e3280a89 (diff)
downloadspark-2a0bc867a4a1dad4ecac47701199e540d345ff4f.tar.gz
spark-2a0bc867a4a1dad4ecac47701199e540d345ff4f.tar.bz2
spark-2a0bc867a4a1dad4ecac47701199e540d345ff4f.zip
[SPARK-17495][SQL] Support Decimal type in Hive-hash
## What changes were proposed in this pull request? Hive hash to support Decimal datatype. [Hive internally normalises decimals](https://github.com/apache/hive/blob/4ba713ccd85c3706d195aeef9476e6e6363f1c21/storage-api/src/java/org/apache/hadoop/hive/common/type/HiveDecimalV1.java#L307) and I have ported that logic as-is to HiveHash. ## How was this patch tested? Added unit tests Author: Tejas Patil <tejasp@fb.com> Closes #17056 from tejasapatil/SPARK-17495_decimal.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala56
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala46
2 files changed, 99 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
index 2d9c2e4206..03101b4bfc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/hash.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.expressions
+import java.math.{BigDecimal, RoundingMode}
import java.security.{MessageDigest, NoSuchAlgorithmException}
import java.util.zip.CRC32
@@ -580,7 +581,7 @@ object XxHash64Function extends InterpretedHashFunction {
* We should use this hash function for both shuffle and bucket of Hive tables, so that
* we can guarantee shuffle and bucketing have same data distribution
*
- * TODO: Support Decimal and date related types
+ * TODO: Support date related types
*/
@ExpressionDescription(
usage = "_FUNC_(expr1, expr2, ...) - Returns a hash value of the arguments.")
@@ -635,6 +636,16 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
override protected def genHashBytes(b: String, result: String): String =
s"$result = $hasherClassName.hashUnsafeBytes($b, Platform.BYTE_ARRAY_OFFSET, $b.length);"
+ override protected def genHashDecimal(
+ ctx: CodegenContext,
+ d: DecimalType,
+ input: String,
+ result: String): String = {
+ s"""
+ $result = ${HiveHashFunction.getClass.getName.stripSuffix("$")}.normalizeDecimal(
+ $input.toJavaBigDecimal()).hashCode();"""
+ }
+
override protected def genHashCalendarInterval(input: String, result: String): String = {
s"""
$result = (31 * $hasherClassName.hashInt($input.months)) +
@@ -732,6 +743,44 @@ object HiveHashFunction extends InterpretedHashFunction {
HiveHasher.hashUnsafeBytes(base, offset, len)
}
+ private val HIVE_DECIMAL_MAX_PRECISION = 38
+ private val HIVE_DECIMAL_MAX_SCALE = 38
+
+ // Mimics normalization done for decimals in Hive at HiveDecimalV1.normalize()
+ def normalizeDecimal(input: BigDecimal): BigDecimal = {
+ if (input == null) return null
+
+ def trimDecimal(input: BigDecimal) = {
+ var result = input
+ if (result.compareTo(BigDecimal.ZERO) == 0) {
+ // Special case for 0, because java doesn't strip zeros correctly on that number.
+ result = BigDecimal.ZERO
+ } else {
+ result = result.stripTrailingZeros
+ if (result.scale < 0) {
+ // no negative scale decimals
+ result = result.setScale(0)
+ }
+ }
+ result
+ }
+
+ var result = trimDecimal(input)
+ val intDigits = result.precision - result.scale
+ if (intDigits > HIVE_DECIMAL_MAX_PRECISION) {
+ return null
+ }
+
+ val maxScale = Math.min(HIVE_DECIMAL_MAX_SCALE,
+ Math.min(HIVE_DECIMAL_MAX_PRECISION - intDigits, result.scale))
+ if (result.scale > maxScale) {
+ result = result.setScale(maxScale, RoundingMode.HALF_UP)
+ // Trimming is again necessary, because rounding may introduce new trailing 0's.
+ result = trimDecimal(result)
+ }
+ result
+ }
+
override def hash(value: Any, dataType: DataType, seed: Long): Long = {
value match {
case null => 0
@@ -785,7 +834,10 @@ object HiveHashFunction extends InterpretedHashFunction {
}
result
- case _ => super.hash(value, dataType, 0)
+ case d: Decimal =>
+ normalizeDecimal(d.toJavaBigDecimal).hashCode()
+
+ case _ => super.hash(value, dataType, seed)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
index 0cb3a79eee..0c77dc2709 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HashExpressionsSuite.scala
@@ -75,7 +75,6 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
-
def checkHiveHash(input: Any, dataType: DataType, expected: Long): Unit = {
// Note : All expected hashes need to be computed using Hive 1.2.1
val actual = HiveHashFunction.hash(input, dataType, seed = 0)
@@ -371,6 +370,51 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
new StructType().add("array", arrayOfString).add("map", mapOfString))
.add("structOfUDT", structOfUDT))
+ test("hive-hash for decimal") {
+ def checkHiveHashForDecimal(
+ input: String,
+ precision: Int,
+ scale: Int,
+ expected: Long): Unit = {
+ val decimalType = DataTypes.createDecimalType(precision, scale)
+ val decimal = {
+ val value = Decimal.apply(new java.math.BigDecimal(input))
+ if (value.changePrecision(precision, scale)) value else null
+ }
+
+ checkHiveHash(decimal, decimalType, expected)
+ }
+
+ checkHiveHashForDecimal("18", 38, 0, 558)
+ checkHiveHashForDecimal("-18", 38, 0, -558)
+ checkHiveHashForDecimal("-18", 38, 12, -558)
+ checkHiveHashForDecimal("18446744073709001000", 38, 19, 0)
+ checkHiveHashForDecimal("-18446744073709001000", 38, 22, 0)
+ checkHiveHashForDecimal("-18446744073709001000", 38, 3, 17070057)
+ checkHiveHashForDecimal("18446744073709001000", 38, 4, -17070057)
+ checkHiveHashForDecimal("9223372036854775807", 38, 4, 2147482656)
+ checkHiveHashForDecimal("-9223372036854775807", 38, 5, -2147482656)
+ checkHiveHashForDecimal("00000.00000000000", 38, 34, 0)
+ checkHiveHashForDecimal("-00000.00000000000", 38, 11, 0)
+ checkHiveHashForDecimal("123456.1234567890", 38, 2, 382713974)
+ checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
+ checkHiveHashForDecimal("123456.1234567890", 38, 10, 1871500252)
+ checkHiveHashForDecimal("-123456.1234567890", 38, 10, -1871500234)
+ checkHiveHashForDecimal("123456.1234567890", 38, 0, 3827136)
+ checkHiveHashForDecimal("-123456.1234567890", 38, 0, -3827136)
+ checkHiveHashForDecimal("123456.1234567890", 38, 20, 1871500252)
+ checkHiveHashForDecimal("-123456.1234567890", 38, 20, -1871500234)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 0, 3827136)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 0, -3827136)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 10, 1871500252)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 10, -1871500234)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 20, 236317582)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 20, -236317544)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 30, 1728235666)
+ checkHiveHashForDecimal("-123456.123456789012345678901234567890", 38, 30, -1728235608)
+ checkHiveHashForDecimal("123456.123456789012345678901234567890", 38, 31, 1728235666)
+ }
+
test("SPARK-18207: Compute hash for a lot of expressions") {
val N = 1000
val wideRow = new GenericInternalRow(