aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-04 18:49:41 -0800
committerReynold Xin <rxin@databricks.com>2016-01-04 18:49:41 -0800
commitb1a771231e20df157fb3e780287390a883c0cc6f (patch)
tree980b90de0e7b173df0fb48a42f0faac251a9ff36 /sql
parent77ab49b8575d2ebd678065fa70b0343d532ab9c2 (diff)
downloadspark-b1a771231e20df157fb3e780287390a883c0cc6f.tar.gz
spark-b1a771231e20df157fb3e780287390a883c0cc6f.tar.bz2
spark-b1a771231e20df157fb3e780287390a883c0cc6f.zip
[SPARK-12480][SQL] add Hash expression that can calculate hash value for a group of expressions
just write the arguments into unsafe row and use murmur3 to calculate hash code Author: Wenchen Fan <wenchen@databricks.com> Closes #10435 from cloud-fan/hash-expr.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala44
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala73
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala10
-rw-r--r--sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala24
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala3
10 files changed, 171 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
index 1a351933a3..b8d3c49100 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeRow.java
@@ -566,6 +566,10 @@ public final class UnsafeRow extends MutableRow implements Externalizable, KryoS
return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, 42);
}
+ public int hashCode(int seed) {
+ return Murmur3_x86_32.hashUnsafeWords(baseObject, baseOffset, sizeInBytes, seed);
+ }
+
@Override
public boolean equals(Object other) {
if (other instanceof UnsafeRow) {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 57d1a1107e..5c2aa3c06b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -49,7 +49,7 @@ trait FunctionRegistry {
class SimpleFunctionRegistry extends FunctionRegistry {
- private val functionBuilders =
+ private[sql] val functionBuilders =
StringKeyHashMap[(ExpressionInfo, FunctionBuilder)](caseSensitive = false)
override def registerFunction(
@@ -278,6 +278,7 @@ object FunctionRegistry {
// misc functions
expression[Crc32]("crc32"),
expression[Md5]("md5"),
+ expression[Murmur3Hash]("hash"),
expression[Sha1]("sha"),
expression[Sha1]("sha1"),
expression[Sha2]("sha2"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index d0ec99b232..8834924687 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -22,6 +22,8 @@ import java.util.zip.CRC32
import org.apache.commons.codec.digest.DigestUtils
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -177,3 +179,45 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp
})
}
}
+
+/**
+ * A function that calculates hash value for a group of expressions. Note that the `seed` argument
+ * is not exposed to users and should only be set inside spark SQL.
+ *
+ * Internally this function will write arguments into an [[UnsafeRow]], and calculate hash code of
+ * the unsafe row using murmur3 hasher with a seed.
+ * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle
+ * and bucketing have same data distribution.
+ */
+case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression {
+ def this(arguments: Seq[Expression]) = this(arguments, 42)
+
+ override def dataType: DataType = IntegerType
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def nullable: Boolean = false
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.isEmpty) {
+ TypeCheckResult.TypeCheckFailure("arguments of function hash cannot be empty")
+ } else {
+ TypeCheckResult.TypeCheckSuccess
+ }
+ }
+
+ private lazy val unsafeProjection = UnsafeProjection.create(children)
+
+ override def eval(input: InternalRow): Any = {
+ unsafeProjection(input).hashCode(seed)
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val unsafeRow = GenerateUnsafeProjection.createCode(ctx, children)
+ ev.isNull = "false"
+ s"""
+ ${unsafeRow.code}
+ final int ${ev.value} = ${unsafeRow.value}.hashCode($seed);
+ """
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
index 8f4faab7ba..b17f8d5ec7 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/RowEncoderSuite.scala
@@ -99,7 +99,7 @@ class RowEncoderSuite extends SparkFunSuite {
.add("binary", BinaryType)
.add("date", DateType)
.add("timestamp", TimestampType)
- .add("udt", new ExamplePointUDT, false))
+ .add("udt", new ExamplePointUDT))
encodeDecodeTest(
new StructType()
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index 75d17417e5..9175568f43 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -20,7 +20,9 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.commons.codec.digest.DigestUtils
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.types.{IntegerType, StringType, BinaryType}
+import org.apache.spark.sql.{Row, RandomDataGenerator}
+import org.apache.spark.sql.catalyst.encoders.{ExamplePointUDT, RowEncoder}
+import org.apache.spark.sql.types._
class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
@@ -59,4 +61,73 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Crc32(Literal.create(null, BinaryType)), null)
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
+
+ private val structOfString = new StructType().add("str", StringType)
+ private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
+ private val arrayOfString = ArrayType(StringType)
+ private val arrayOfNull = ArrayType(NullType)
+ private val mapOfString = MapType(StringType, StringType)
+ private val arrayOfUDT = ArrayType(new ExamplePointUDT, false)
+
+ testMurmur3Hash(
+ new StructType()
+ .add("null", NullType)
+ .add("boolean", BooleanType)
+ .add("byte", ByteType)
+ .add("short", ShortType)
+ .add("int", IntegerType)
+ .add("long", LongType)
+ .add("float", FloatType)
+ .add("double", DoubleType)
+ .add("decimal", DecimalType.SYSTEM_DEFAULT)
+ .add("string", StringType)
+ .add("binary", BinaryType)
+ .add("date", DateType)
+ .add("timestamp", TimestampType)
+ .add("udt", new ExamplePointUDT))
+
+ testMurmur3Hash(
+ new StructType()
+ .add("arrayOfNull", arrayOfNull)
+ .add("arrayOfString", arrayOfString)
+ .add("arrayOfArrayOfString", ArrayType(arrayOfString))
+ .add("arrayOfArrayOfInt", ArrayType(ArrayType(IntegerType)))
+ .add("arrayOfMap", ArrayType(mapOfString))
+ .add("arrayOfStruct", ArrayType(structOfString))
+ .add("arrayOfUDT", arrayOfUDT))
+
+ testMurmur3Hash(
+ new StructType()
+ .add("mapOfIntAndString", MapType(IntegerType, StringType))
+ .add("mapOfStringAndArray", MapType(StringType, arrayOfString))
+ .add("mapOfArrayAndInt", MapType(arrayOfString, IntegerType))
+ .add("mapOfArray", MapType(arrayOfString, arrayOfString))
+ .add("mapOfStringAndStruct", MapType(StringType, structOfString))
+ .add("mapOfStructAndString", MapType(structOfString, StringType))
+ .add("mapOfStruct", MapType(structOfString, structOfString)))
+
+ testMurmur3Hash(
+ new StructType()
+ .add("structOfString", structOfString)
+ .add("structOfStructOfString", new StructType().add("struct", structOfString))
+ .add("structOfArray", new StructType().add("array", arrayOfString))
+ .add("structOfMap", new StructType().add("map", mapOfString))
+ .add("structOfArrayAndMap",
+ new StructType().add("array", arrayOfString).add("map", mapOfString))
+ .add("structOfUDT", structOfUDT))
+
+ private def testMurmur3Hash(inputSchema: StructType): Unit = {
+ val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
+ val encoder = RowEncoder(inputSchema)
+ val seed = scala.util.Random.nextInt()
+ test(s"murmur3 hash: ${inputSchema.simpleString}") {
+ for (_ <- 1 to 10) {
+ val input = encoder.toRow(inputGenerator.apply().asInstanceOf[Row]).asInstanceOf[UnsafeRow]
+ val literals = input.toSeq(inputSchema).zip(inputSchema.map(_.dataType)).map {
+ case (value, dt) => Literal.create(value, dt)
+ }
+ checkEvaluation(Murmur3Hash(literals, seed), input.hashCode(seed))
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 2b3db398aa..e223e32fd7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1813,6 +1813,17 @@ object functions extends LegacyFunctions {
*/
def crc32(e: Column): Column = withExpr { Crc32(e.expr) }
+ /**
+ * Calculates the hash code of given columns, and returns the result as a int column.
+ *
+ * @group misc_funcs
+ * @since 2.0
+ */
+ @scala.annotation.varargs
+ def hash(col: Column, cols: Column*): Column = withExpr {
+ new Murmur3Hash((col +: cols).map(_.expr))
+ }
+
//////////////////////////////////////////////////////////////////////////////////////////////
// String functions
//////////////////////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 115b617c21..72845711ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -2057,4 +2057,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
)
}
+ test("hash function") {
+ val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j")
+ withTempTable("tbl") {
+ df.registerTempTable("tbl")
+ checkAnswer(
+ df.select(hash($"i", $"j")),
+ sql("SELECT hash(i, j) from tbl")
+ )
+ }
+ }
}
diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
index 2b0e48dbfc..bd1a52e5f3 100644
--- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
+++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala
@@ -53,6 +53,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, 5)
// Enable in-memory partition pruning for testing purposes
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
+ // Use Hive hash expression instead of the native one
+ TestHive.functionRegistry.unregisterFunction("hash")
RuleExecutor.resetTime()
}
@@ -62,6 +64,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter {
Locale.setDefault(originalLocale)
TestHive.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
TestHive.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+ TestHive.functionRegistry.restore()
// For debugging dump some statistics about how much time was spent in various optimizer rules.
logWarning(RuleExecutor.dumpTimeSpent())
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
index 013fbab0a8..66d5f20d88 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala
@@ -31,10 +31,13 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe
import org.apache.spark.sql.{SQLContext, SQLConf}
import org.apache.spark.sql.catalyst.analysis._
+import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
+import org.apache.spark.sql.catalyst.expressions.ExpressionInfo
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.CacheTableCommand
import org.apache.spark.sql.hive._
import org.apache.spark.sql.hive.execution.HiveNativeCommand
+import org.apache.spark.sql.hive.client.ClientWrapper
import org.apache.spark.util.{ShutdownHookManager, Utils}
import org.apache.spark.{SparkConf, SparkContext}
@@ -451,6 +454,27 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) {
logError("FATAL ERROR: Failed to reset TestDB state.", e)
}
}
+
+ @transient
+ override protected[sql] lazy val functionRegistry = new TestHiveFunctionRegistry(
+ org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), this.executionHive)
+}
+
+private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: ClientWrapper)
+ extends HiveFunctionRegistry(fr, client) {
+
+ private val removedFunctions =
+ collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))]
+
+ def unregisterFunction(name: String): Unit = {
+ fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f)
+ }
+
+ def restore(): Unit = {
+ removedFunctions.foreach {
+ case (name, (info, builder)) => fr.registerFunction(name, info, builder)
+ }
+ }
}
private[hive] object TestHiveContext {
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 8a5acaf3e1..acd1130f27 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -387,9 +387,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
createQueryTest("partitioned table scan",
"SELECT ds, hr, key, value FROM srcpart")
- createQueryTest("hash",
- "SELECT hash('test') FROM src LIMIT 1")
-
createQueryTest("create table as",
"""
|CREATE TABLE createdtable AS SELECT * FROM src;