aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-10-27 20:26:38 -0700
committerYin Huai <yhuai@databricks.com>2015-10-27 20:26:38 -0700
commitd9c6039897236c3f1e4503aa95c5c9b07b32eadd (patch)
tree3f9ddb1f1c7b91ef3cb7073cd7094522e48b340f /sql/core
parentb960a890561eaf3795b93c621bd95be81e56f5b7 (diff)
downloadspark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.tar.gz
spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.tar.bz2
spark-d9c6039897236c3f1e4503aa95c5c9b07b32eadd.zip
[SPARK-10484] [SQL] Optimize the cartesian join with broadcast join for some cases
In some cases, we can broadcast the smaller relation in cartesian join, which improve the performance significantly. Author: Cheng Hao <hao.cheng@intel.com> Closes #8652 from chenghao-intel/cartesian.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala38
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala92
4 files changed, 125 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
index b346f43fae..0f98fe88b2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala
@@ -44,8 +44,9 @@ class SparkPlanner(val sqlContext: SQLContext) extends SparkStrategies {
EquiJoinSelection ::
InMemoryScans ::
BasicOperators ::
+ BroadcastNestedLoop ::
CartesianProduct ::
- BroadcastNestedLoopJoin :: Nil)
+ DefaultJoin :: Nil)
/**
* Used to build table scan operators where complex projection and filtering are done using
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 637deff4e2..ee97162853 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -294,25 +294,24 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
-
- object BroadcastNestedLoopJoin extends Strategy {
+ object BroadcastNestedLoop extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(left, right, joinType, condition) =>
- val buildSide =
- if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
- joins.BuildRight
- } else {
- joins.BuildLeft
- }
- joins.BroadcastNestedLoopJoin(
- planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+ case logical.Join(
+ CanBroadcast(left), right, joinType, condition) if joinType != LeftSemiJoin =>
+ execution.joins.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), joins.BuildLeft, joinType, condition) :: Nil
+ case logical.Join(
+ left, CanBroadcast(right), joinType, condition) if joinType != LeftSemiJoin =>
+ execution.joins.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), joins.BuildRight, joinType, condition) :: Nil
case _ => Nil
}
}
object CartesianProduct extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- case logical.Join(left, right, _, None) =>
+ // TODO CartesianProduct doesn't support the Left Semi Join
+ case logical.Join(left, right, joinType, None) if joinType != LeftSemiJoin =>
execution.joins.CartesianProduct(planLater(left), planLater(right)) :: Nil
case logical.Join(left, right, Inner, Some(condition)) =>
execution.Filter(condition,
@@ -321,6 +320,21 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ object DefaultJoin extends Strategy {
+ def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case logical.Join(left, right, joinType, condition) =>
+ val buildSide =
+ if (right.statistics.sizeInBytes <= left.statistics.sizeInBytes) {
+ joins.BuildRight
+ } else {
+ joins.BuildLeft
+ }
+ joins.BroadcastNestedLoopJoin(
+ planLater(left), planLater(right), buildSide, joinType, condition) :: Nil
+ case _ => Nil
+ }
+ }
+
protected lazy val singleRowRdd = sparkContext.parallelize(Seq(InternalRow()), 1)
object TakeOrderedAndProject extends Strategy {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index efef8c8a8b..05d20f511a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.CompactBuffer
@@ -67,7 +67,10 @@ case class BroadcastNestedLoopJoin(
left.output.map(_.withNullability(true)) ++ right.output
case FullOuter =>
left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true))
- case x =>
+ case Inner =>
+ // TODO we can avoid breaking the lineage, since we union an empty RDD for Inner Join case
+ left.output ++ right.output
+ case x => // TODO support the Left Semi Join
throw new IllegalArgumentException(
s"BroadcastNestedLoopJoin should not take $x as the JoinType")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index b1fb068158..a9ca46cab0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -28,6 +28,10 @@ class JoinSuite extends QueryTest with SharedSQLContext {
setupTestData()
+ def statisticSizeInByte(df: DataFrame): BigInt = {
+ df.queryExecution.optimizedPlan.statistics.sizeInBytes
+ }
+
test("equi-join is hash-join") {
val x = testData2.as("x")
val y = testData2.as("y")
@@ -466,6 +470,94 @@ class JoinSuite extends QueryTest with SharedSQLContext {
sql("UNCACHE TABLE testData")
}
+ test("cross join with broadcast") {
+ sql("CACHE TABLE testData")
+
+ val sizeInByteOfTestData = statisticSizeInByte(sqlContext.table("testData"))
+
+ // we set the threshold is greater than statistic of the cached table testData
+ withSQLConf(
+ SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> (sizeInByteOfTestData + 1).toString()) {
+
+ assert(statisticSizeInByte(sqlContext.table("testData2")) >
+ sqlContext.conf.autoBroadcastJoinThreshold)
+
+ assert(statisticSizeInByte(sqlContext.table("testData")) <
+ sqlContext.conf.autoBroadcastJoinThreshold)
+
+ Seq(
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a",
+ classOf[LeftSemiJoinHash]),
+ ("SELECT * FROM testData LEFT SEMI JOIN testData2",
+ classOf[LeftSemiJoinBNL]),
+ ("SELECT * FROM testData JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData LEFT JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData RIGHT JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key = 2",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData JOIN testData2 WHERE key > a",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData FULL OUTER JOIN testData2 WHERE key > a",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData left JOIN testData2 WHERE (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData right JOIN testData2 WHERE (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin]),
+ ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)",
+ classOf[BroadcastNestedLoopJoin])
+ ).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
+
+ checkAnswer(
+ sql(
+ """
+ SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key = 2
+ """.stripMargin),
+ Row("2", 1, 1) ::
+ Row("2", 1, 2) ::
+ Row("2", 2, 1) ::
+ Row("2", 2, 2) ::
+ Row("2", 3, 1) ::
+ Row("2", 3, 2) :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y WHERE x.key < y.a
+ """.stripMargin),
+ Row("1", 2, 1) ::
+ Row("1", 2, 2) ::
+ Row("1", 3, 1) ::
+ Row("1", 3, 2) ::
+ Row("2", 3, 1) ::
+ Row("2", 3, 2) :: Nil)
+
+ checkAnswer(
+ sql(
+ """
+ SELECT x.value, y.a, y.b FROM testData x JOIN testData2 y ON x.key < y.a
+ """.stripMargin),
+ Row("1", 2, 1) ::
+ Row("1", 2, 2) ::
+ Row("1", 3, 1) ::
+ Row("1", 3, 2) ::
+ Row("2", 3, 1) ::
+ Row("2", 3, 2) :: Nil)
+ }
+
+ sql("UNCACHE TABLE testData")
+ }
+
test("left semi join") {
val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,