aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala65
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala47
2 files changed, 79 insertions, 33 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 8ddac19bf1..05c5e2f4cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -63,45 +63,16 @@ trait HashJoin {
protected lazy val (buildKeys, streamedKeys) = {
require(leftKeys.map(_.dataType) == rightKeys.map(_.dataType),
"Join keys from two sides should have same types")
- val lkeys = rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
- val rkeys = rewriteKeyExpr(rightKeys).map(BindReferences.bindReference(_, right.output))
+ val lkeys = HashJoin.rewriteKeyExpr(leftKeys).map(BindReferences.bindReference(_, left.output))
+ val rkeys = HashJoin.rewriteKeyExpr(rightKeys)
+ .map(BindReferences.bindReference(_, right.output))
buildSide match {
case BuildLeft => (lkeys, rkeys)
case BuildRight => (rkeys, lkeys)
}
}
- /**
- * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
- *
- * If not, returns the original expressions.
- */
- private def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
- var keyExpr: Expression = null
- var width = 0
- keys.foreach { e =>
- e.dataType match {
- case dt: IntegralType if dt.defaultSize <= 8 - width =>
- if (width == 0) {
- if (e.dataType != LongType) {
- keyExpr = Cast(e, LongType)
- } else {
- keyExpr = e
- }
- width = dt.defaultSize
- } else {
- val bits = dt.defaultSize * 8
- keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
- BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
- width -= bits
- }
- // TODO: support BooleanType, DateType and TimestampType
- case other =>
- return keys
- }
- }
- keyExpr :: Nil
- }
+
protected def buildSideKeyGenerator(): Projection =
UnsafeProjection.create(buildKeys)
@@ -247,3 +218,31 @@ trait HashJoin {
}
}
}
+
+object HashJoin {
+ /**
+ * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long.
+ *
+ * If not, returns the original expressions.
+ */
+ private[joins] def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = {
+ assert(keys.nonEmpty)
+ // TODO: support BooleanType, DateType and TimestampType
+ if (keys.exists(!_.dataType.isInstanceOf[IntegralType])
+ || keys.map(_.dataType.defaultSize).sum > 8) {
+ return keys
+ }
+
+ var keyExpr: Expression = if (keys.head.dataType != LongType) {
+ Cast(keys.head, LongType)
+ } else {
+ keys.head
+ }
+ keys.tail.foreach { e =>
+ val bits = e.dataType.defaultSize * 8
+ keyExpr = BitwiseOr(ShiftLeft(keyExpr, Literal(bits)),
+ BitwiseAnd(Cast(e, LongType), Literal((1L << bits) - 1)))
+ }
+ keyExpr :: Nil
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index 97adffa8ce..83db81ea3f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -21,11 +21,13 @@ import scala.reflect.ClassTag
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession}
+import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft}
import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.types.{LongType, ShortType}
/**
* Test various broadcast join operators.
@@ -153,4 +155,49 @@ class BroadcastJoinSuite extends QueryTest with SQLTestUtils {
cases.foreach(assertBroadcastJoin)
}
}
+
+ test("join key rewritten") {
+ val l = Literal(1L)
+ val i = Literal(2)
+ val s = Literal.create(3, ShortType)
+ val ss = Literal("hello")
+
+ assert(HashJoin.rewriteKeyExpr(l :: Nil) === l :: Nil)
+ assert(HashJoin.rewriteKeyExpr(l :: l :: Nil) === l :: l :: Nil)
+ assert(HashJoin.rewriteKeyExpr(l :: i :: Nil) === l :: i :: Nil)
+
+ assert(HashJoin.rewriteKeyExpr(i :: Nil) === Cast(i, LongType) :: Nil)
+ assert(HashJoin.rewriteKeyExpr(i :: l :: Nil) === i :: l :: Nil)
+ assert(HashJoin.rewriteKeyExpr(i :: i :: Nil) ===
+ BitwiseOr(ShiftLeft(Cast(i, LongType), Literal(32)),
+ BitwiseAnd(Cast(i, LongType), Literal((1L << 32) - 1))) :: Nil)
+ assert(HashJoin.rewriteKeyExpr(i :: i :: i :: Nil) === i :: i :: i :: Nil)
+
+ assert(HashJoin.rewriteKeyExpr(s :: Nil) === Cast(s, LongType) :: Nil)
+ assert(HashJoin.rewriteKeyExpr(s :: l :: Nil) === s :: l :: Nil)
+ assert(HashJoin.rewriteKeyExpr(s :: s :: Nil) ===
+ BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+ BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+ assert(HashJoin.rewriteKeyExpr(s :: s :: s :: Nil) ===
+ BitwiseOr(ShiftLeft(
+ BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+ BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+ Literal(16)),
+ BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+ assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: Nil) ===
+ BitwiseOr(ShiftLeft(
+ BitwiseOr(ShiftLeft(
+ BitwiseOr(ShiftLeft(Cast(s, LongType), Literal(16)),
+ BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+ Literal(16)),
+ BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))),
+ Literal(16)),
+ BitwiseAnd(Cast(s, LongType), Literal((1L << 16) - 1))) :: Nil)
+ assert(HashJoin.rewriteKeyExpr(s :: s :: s :: s :: s :: Nil) ===
+ s :: s :: s :: s :: s :: Nil)
+
+ assert(HashJoin.rewriteKeyExpr(ss :: Nil) === ss :: Nil)
+ assert(HashJoin.rewriteKeyExpr(l :: ss :: Nil) === l :: ss :: Nil)
+ assert(HashJoin.rewriteKeyExpr(i :: ss :: Nil) === i :: ss :: Nil)
+ }
}