aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala117
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala4
2 files changed, 64 insertions, 57 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index 07f4d2946c..8b4cf5bac0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -19,17 +19,13 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfterEach
-import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.catalyst.plans.JoinType
-import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner, LeftSemi}
-import org.apache.spark.sql.execution._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
+import org.apache.spark.sql.catalyst.plans.{FullOuter, Inner, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.test.TestSQLContext
import org.apache.spark.sql.test.TestSQLContext._
class JoinSuite extends QueryTest with BeforeAndAfterEach {
-
// Ensures tables are loaded.
TestData
@@ -41,54 +37,65 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
assert(planned.size === 1)
}
- test("join operator selection") {
- def assertJoin(sqlString: String, c: Class[_]): Any = {
- val rdd = sql(sqlString)
- val physical = rdd.queryExecution.sparkPlan
- val operators = physical.collect {
- case j: ShuffledHashJoin => j
- case j: HashOuterJoin => j
- case j: LeftSemiJoinHash => j
- case j: BroadcastHashJoin => j
- case j: LeftSemiJoinBNL => j
- case j: CartesianProduct => j
- case j: BroadcastNestedLoopJoin => j
- }
-
- assert(operators.size === 1)
- if (operators(0).getClass() != c) {
- fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
- }
+ def assertJoin(sqlString: String, c: Class[_]): Any = {
+ val rdd = sql(sqlString)
+ val physical = rdd.queryExecution.sparkPlan
+ val operators = physical.collect {
+ case j: ShuffledHashJoin => j
+ case j: HashOuterJoin => j
+ case j: LeftSemiJoinHash => j
+ case j: BroadcastHashJoin => j
+ case j: LeftSemiJoinBNL => j
+ case j: CartesianProduct => j
+ case j: BroadcastNestedLoopJoin => j
+ }
+
+ assert(operators.size === 1)
+ if (operators(0).getClass() != c) {
+ fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical")
}
+ }
- val cases1 = Seq(
- ("SELECT * FROM testData left semi join testData2 ON key = a", classOf[LeftSemiJoinHash]),
- ("SELECT * FROM testData left semi join testData2", classOf[LeftSemiJoinBNL]),
- ("SELECT * FROM testData join testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData join testData2 where key=2", classOf[CartesianProduct]),
- ("SELECT * FROM testData left join testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData right join testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData full outer join testData2", classOf[CartesianProduct]),
- ("SELECT * FROM testData left join testData2 where key=2", classOf[CartesianProduct]),
- ("SELECT * FROM testData right join testData2 where key=2", classOf[CartesianProduct]),
- ("SELECT * FROM testData full outer join testData2 where key=2", classOf[CartesianProduct]),
- ("SELECT * FROM testData join testData2 where key>a", classOf[CartesianProduct]),
- ("SELECT * FROM testData full outer join testData2 where key>a", classOf[CartesianProduct]),
- ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData left join testData2 ON key = a", classOf[HashOuterJoin]),
- ("SELECT * FROM testData right join testData2 ON key = a where key=2",
+ test("join operator selection") {
+ clearCache()
+
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]),
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]),
+ ("SELECT * FROM testData JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData LEFT JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData RIGHT JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2", classOf[CartesianProduct]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a", classOf[CartesianProduct]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[ShuffledHashJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 ON key = a", classOf[HashOuterJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 ON key = a where key = 2",
classOf[HashOuterJoin]),
- ("SELECT * FROM testData right join testData2 ON key = a and key=2",
+ ("SELECT * FROM testData right join testData2 ON key = a and key = 2",
classOf[HashOuterJoin]),
- ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin]),
- ("SELECT * FROM testData join testData2 ON key = a", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a and key=2", classOf[ShuffledHashJoin]),
- ("SELECT * FROM testData join testData2 ON key = a where key=2", classOf[ShuffledHashJoin])
- // TODO add BroadcastNestedLoopJoin
- )
- cases1.foreach { c => assertJoin(c._1, c._2) }
+ ("SELECT * FROM testData full outer join testData2 ON key = a", classOf[HashOuterJoin])
+ // TODO add BroadcastNestedLoopJoin
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+ }
+
+ test("broadcasted hash join operator selection") {
+ clearCache()
+ sql("CACHE TABLE testData")
+
+ Seq(
+ ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]),
+ ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+
+ sql("UNCACHE TABLE testData")
}
test("multiple-key equi-join is hash-join") {
@@ -171,7 +178,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)
-
+
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'n > 1)),
(1, "A", null, null) ::
@@ -180,7 +187,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)
-
+
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'N > 1)),
(1, "A", null, null) ::
@@ -189,7 +196,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "d") ::
(5, "E", null, null) ::
(6, "F", null, null) :: Nil)
-
+
checkAnswer(
upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N && 'l > 'L)),
(1, "A", 1, "a") ::
@@ -300,7 +307,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
-
+
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("left.N".attr !== 3))),
(1, "A", null, null) ::
@@ -310,7 +317,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach {
(4, "D", 4, "D") ::
(null, null, 5, "E") ::
(null, null, 6, "F") :: Nil)
-
+
checkAnswer(
left.join(right, FullOuter, Some(("left.N".attr === "right.N".attr) && ("right.N".attr !== 3))),
(1, "A", null, null) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
index 6c38575b13..c4dd3e860f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
@@ -80,7 +80,7 @@ object TestData {
UpperCaseData(3, "C") ::
UpperCaseData(4, "D") ::
UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil)
+ UpperCaseData(6, "F") :: Nil).toSchemaRDD
upperCaseData.registerTempTable("upperCaseData")
case class LowerCaseData(n: Int, l: String)
@@ -89,7 +89,7 @@ object TestData {
LowerCaseData(1, "a") ::
LowerCaseData(2, "b") ::
LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil)
+ LowerCaseData(4, "d") :: Nil).toSchemaRDD
lowerCaseData.registerTempTable("lowerCaseData")
case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])