diff options
10 files changed, 171 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java index 1a351933a3..b8d3c49100 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java @@ -566,6 +566,10 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42); } + public int hashCode(int seed) { + return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed); + } + @Override public boolean equals(Object other) { if (other instanceof UnsafeRow) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 57d1a1107e..5c2aa3c06b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -49,7 +49,7 @@ trait FunctionRegistry { class SimpleFunctionRegistry extends FunctionRegistry { - private val functionBuilders = + private[sql] val functionBuilders = StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false) override def registerFunction( @@ -278,6 +278,7 @@ object FunctionRegistry { // misc functions expression[Crc32]("crc32"), expression[Md5]("md5"), + expression[Murmur3Hash]("hash"), expression[Sha1]("sha"), expression[Sha1]("sha1"), expression[Sha2]("sha2"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index d0ec99b232..8834924687 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -22,6 +22,8 @@ import java.util.zip.CRC32 import org.apache.commons.codec.digest.DigestUtils +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -177,3 +179,45 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp }) } } + +/** + * A function that calculates hash value for a group of expressions. Note that the `seed` argument + * is not exposed to users and should only be set inside spark SQL. + * + * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of + * the unsafe row using murmur3 hasher with a seed. + * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle + * and bucketing have same data distribution. + */ +case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression { + def this(arguments: Seq[Expression]) = this(arguments, 42) + + override def dataType: DataType = IntegerType + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = false + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.isEmpty) { + TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be empty") + } else { + TypeCheckResult.TypeCheckSuccess + } + } + + private lazy val unsafeProjection = UnsafeProjection.create(children) + + override def eval(input: InternalRow): Any = { + unsafeProjection(input).hashCode(seed) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children) + ev.isNull = "false" + s""" + ${unsafeRow.code} + final int ${ev.value} = ${unsafeRow.value}.hashCode($seed); + """ + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala index 8f4faab7ba..b17f8d5ec7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala @@ -99,7 +99,7 @@ class RowEncoderSuite extends SparkFunSuite { .add("binary", BinaryType) .add("date", DateType) .add("timestamp", TimestampType) - .add("udt", new ExamplePointUDT, false)) + .add("udt", new ExamplePointUDT)) encodeDecodeTest( new StructType() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index 75d17417e5..9175568f43 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.codec.digest.DigestUtils import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType} +import org.apache.spark.sql.{Row, RandomDataGenerator} +import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder} +import org.apache.spark.sql.types._ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { @@ -59,4 +61,73 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Crc32(Literal.create(null, BinaryType)), null) checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + + private val structOfString = new StructType().add("str", StringType) + private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) + private val arrayOfString = ArrayType(StringType) + private val arrayOfNull = ArrayType(NullType) + private val mapOfString = MapType(StringType, StringType) + private val arrayOfUDT = ArrayType(new ExamplePointUDT, false) + + testMurmur3Hash( + new StructType() + .add("null", NullType) + .add("boolean", BooleanType) + .add("byte", ByteType) + .add("short", ShortType) + .add("int", IntegerType) + .add("long", LongType) + .add("float", FloatType) + .add("double", DoubleType) + .add("decimal", DecimalType.SYSTEM_DEFAULT) + .add("string", StringType) + .add("binary", BinaryType) + .add("date", DateType) + .add("timestamp", TimestampType) + .add("udt", new ExamplePointUDT)) + + testMurmur3Hash( + new StructType() + .add("arrayOfNull", arrayOfNull) + .add("arrayOfString", arrayOfString) + .add("arrayOfArrayOfString", ArrayType(arrayOfString)) + .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType))) + .add("arrayOfMap", ArrayType(mapOfString)) + .add("arrayOfStruct", ArrayType(structOfString)) + .add("arrayOfUDT", arrayOfUDT)) + + testMurmur3Hash( + new StructType() + .add("mapOfIntAndString", MapType(IntegerType, StringType)) + .add("mapOfStringAndArray", MapType(StringType, arrayOfString)) + .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType)) + .add("mapOfArray", MapType(arrayOfString, arrayOfString)) + .add("mapOfStringAndStruct", MapType(StringType, structOfString)) + .add("mapOfStructAndString", MapType(structOfString, StringType)) + .add("mapOfStruct", MapType(structOfString, structOfString))) + + testMurmur3Hash( + new StructType() + .add("structOfString", structOfString) + .add("structOfStructOfString", new StructType().add("struct", structOfString)) + .add("structOfArray", new StructType().add("array", arrayOfString)) + .add("structOfMap", new StructType().add("map", mapOfString)) + .add("structOfArrayAndMap", + new StructType().add("array", arrayOfString).add("map", mapOfString)) + .add("structOfUDT", structOfUDT)) + + private def testMurmur3Hash(inputSchema: StructType): Unit = { + val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get + val encoder = RowEncoder(inputSchema) + val seed = scala.util.Random.nextInt() + test(s"murmur3 hash: ${inputSchema.simpleString}") { + for (_ <- 1 to 10) { + val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow] + val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map { + case (value, dt) => Literal.create(value, dt) + } + checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed)) + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 2b3db398aa..e223e32fd7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1813,6 +1813,17 @@ object functions extends LegacyFunctions { */ def crc32(e: Column): Column = withExpr { Crc32(e.expr) } + /** + * Calculates the hash code of given columns, and returns the result as a int column. + * + * @group misc_funcs + * @since 2.0 + */ + @scala.annotation.varargs + def hash(col: Column, cols: Column*): Column = withExpr { + new Murmur3Hash((col +: cols).map(_.expr)) + } + ////////////////////////////////////////////////////////////////////////////////////////////// // String functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 115b617c21..72845711ad 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2057,4 +2057,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { ) } + test("hash function") { + val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") + withTempTable("tbl") { + df.registerTempTable("tbl") + checkAnswer( + df.select(hash($"i", $"j")), + sql("SELECT hash(i, j) from tbl") + ) + } + } } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 2b0e48dbfc..bd1a52e5f3 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -53,6 +53,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5) // Enable in-memory partition pruning for testing purposes TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true) + // Use Hive hash expression instead of the native one + TestHive.functionRegistry.unregisterFunction("hash") RuleExecutor.resetTime() } @@ -62,6 +64,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { Locale.setDefault(originalLocale) TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize) TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning) + TestHive.functionRegistry.restore() // For debugging dump some statistics about how much time was spent in various optimizer rules. logWarning(RuleExecutor.dumpTimeSpent()) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 013fbab0a8..66d5f20d88 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -31,10 +31,13 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.sql.{SQLContext, SQLConf} import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.CacheTableCommand import org.apache.spark.sql.hive._ import org.apache.spark.sql.hive.execution.HiveNativeCommand +import org.apache.spark.sql.hive.client.ClientWrapper import org.apache.spark.util.{ShutdownHookManager, Utils} import org.apache.spark.{SparkConf, SparkContext} @@ -451,6 +454,27 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { logError("FATAL ERROR: Failed to reset TestDB state.", e) } } + + @transient + override protected[sql] lazy val functionRegistry = new TestHiveFunctionRegistry( + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive) +} + +private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper) + extends HiveFunctionRegistry(fr, client) { + + private val removedFunctions = + collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] + + def unregisterFunction(name: String): Unit = { + fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) + } + + def restore(): Unit = { + removedFunctions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + } } private[hive] object TestHiveContext { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 8a5acaf3e1..acd1130f27 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -387,9 +387,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { createQueryTest("partitioned table scan", "SELECT ds, hr, key, value FROM srcpart") - createQueryTest("hash", - "SELECT hash('test') FROM src LIMIT 1") - createQueryTest("create table as", """ |CREATE TABLE createdtable AS SELECT * FROM src; |