aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala63
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala55
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala92
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala29
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala38
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala49
16 files changed, 403 insertions, 90 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index c62d5ead86..371d72ef5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -21,7 +21,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.{DataType, StructType}
-abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanType] {
+abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanType] {
self: PlanType =>
def output: Seq[Attribute]
@@ -237,4 +237,65 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}
override def innerChildren: Seq[PlanType] = subqueries
+
+ /**
+ * Canonicalized copy of this query plan.
+ */
+ protected lazy val canonicalized: PlanType = this
+
+ /**
+ * Returns true when the given query plan will return the same results as this query plan.
+ *
+ * Since its likely undecidable to generally determine if two given plans will produce the same
+ * results, it is okay for this function to return false, even if the results are actually
+ * the same. Such behavior will not affect correctness, only the application of performance
+ * enhancements like caching. However, it is not acceptable to return true if the results could
+ * possibly be different.
+ *
+ * By default this function performs a modified version of equality that is tolerant of cosmetic
+ * differences like attribute naming and or expression id differences. Operators that
+ * can do better should override this function.
+ */
+ def sameResult(plan: PlanType): Boolean = {
+ val canonicalizedLeft = this.canonicalized
+ val canonicalizedRight = plan.canonicalized
+ canonicalizedLeft.getClass == canonicalizedRight.getClass &&
+ canonicalizedLeft.children.size == canonicalizedRight.children.size &&
+ canonicalizedLeft.cleanArgs == canonicalizedRight.cleanArgs &&
+ (canonicalizedLeft.children, canonicalizedRight.children).zipped.forall(_ sameResult _)
+ }
+
+ /**
+ * All the attributes that are used for this plan.
+ */
+ lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output)
+
+ private def cleanExpression(e: Expression): Expression = e match {
+ case a: Alias =>
+ // As the root of the expression, Alias will always take an arbitrary exprId, we need
+ // to erase that for equality testing.
+ val cleanedExprId =
+ Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated)
+ BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true)
+ case other =>
+ BindReferences.bindReference(other, allAttributes, allowFailures = true)
+ }
+
+ /** Args that have cleaned such that differences in expression id should not affect equality */
+ protected lazy val cleanArgs: Seq[Any] = {
+ def cleanArg(arg: Any): Any = arg match {
+ case e: Expression => cleanExpression(e).canonicalized
+ case other => other
+ }
+
+ productIterator.map {
+ // Children are checked using sameResult above.
+ case tn: TreeNode[_] if containsChild(tn) => null
+ case e: Expression => cleanArg(e)
+ case s: Option[_] => s.map(cleanArg)
+ case s: Seq[_] => s.map(cleanArg)
+ case m: Map[_, _] => m.mapValues(cleanArg)
+ case other => other
+ }.toSeq
+ }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 31e775d60f..b32c7d0fcb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -114,60 +114,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def childrenResolved: Boolean = children.forall(_.resolved)
- /**
- * Returns true when the given logical plan will return the same results as this logical plan.
- *
- * Since its likely undecidable to generally determine if two given plans will produce the same
- * results, it is okay for this function to return false, even if the results are actually
- * the same. Such behavior will not affect correctness, only the application of performance
- * enhancements like caching. However, it is not acceptable to return true if the results could
- * possibly be different.
- *
- * By default this function performs a modified version of equality that is tolerant of cosmetic
- * differences like attribute naming and or expression id differences. Logical operators that
- * can do better should override this function.
- */
- def sameResult(plan: LogicalPlan): Boolean = {
- val cleanLeft = EliminateSubqueryAliases(this)
- val cleanRight = EliminateSubqueryAliases(plan)
-
- cleanLeft.getClass == cleanRight.getClass &&
- cleanLeft.children.size == cleanRight.children.size && {
- logDebug(
- s"[${cleanRight.cleanArgs.mkString(", ")}] == [${cleanLeft.cleanArgs.mkString(", ")}]")
- cleanRight.cleanArgs == cleanLeft.cleanArgs
- } &&
- (cleanLeft.children, cleanRight.children).zipped.forall(_ sameResult _)
- }
-
- /** Args that have cleaned such that differences in expression id should not affect equality */
- protected lazy val cleanArgs: Seq[Any] = {
- val input = children.flatMap(_.output)
- def cleanExpression(e: Expression) = e match {
- case a: Alias =>
- // As the root of the expression, Alias will always take an arbitrary exprId, we need
- // to erase that for equality testing.
- val cleanedExprId =
- Alias(a.child, a.name)(ExprId(-1), a.qualifiers, isGenerated = a.isGenerated)
- BindReferences.bindReference(cleanedExprId, input, allowFailures = true)
- case other => BindReferences.bindReference(other, input, allowFailures = true)
- }
-
- productIterator.map {
- // Children are checked using sameResult above.
- case tn: TreeNode[_] if containsChild(tn) => null
- case e: Expression => cleanExpression(e)
- case s: Option[_] => s.map {
- case e: Expression => cleanExpression(e)
- case other => other
- }
- case s: Seq[_] => s.map {
- case e: Expression => cleanExpression(e)
- case other => other
- }
- case other => other
- }.toSeq
- }
+ override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this)
/**
* Optionally resolves the given strings to a [[NamedExpression]] using the input from all child
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
index e01f69f813..9dfdf4da78 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -25,6 +25,11 @@ import org.apache.spark.sql.catalyst.InternalRow
*/
trait BroadcastMode {
def transform(rows: Array[InternalRow]): Any
+
+ /**
+ * Returns true iff this [[BroadcastMode]] generates the same result as `other`.
+ */
+ def compatibleWith(other: BroadcastMode): Boolean
}
/**
@@ -33,4 +38,8 @@ trait BroadcastMode {
case object IdentityBroadcastMode extends BroadcastMode {
// TODO: pack the UnsafeRows into single bytes array.
override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
+
+ override def compatibleWith(other: BroadcastMode): Boolean = {
+ this eq other
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
index 9019e5dfd6..247f55da1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.execution.exchange.ReusedExchange
import org.apache.spark.sql.execution.metric.SQLMetricInfo
import org.apache.spark.util.Utils
@@ -31,13 +32,28 @@ class SparkPlanInfo(
val simpleString: String,
val children: Seq[SparkPlanInfo],
val metadata: Map[String, String],
- val metrics: Seq[SQLMetricInfo])
+ val metrics: Seq[SQLMetricInfo]) {
+
+ override def hashCode(): Int = {
+ // hashCode of simpleString should be good enough to distinguish the plans from each other
+ // within a plan
+ simpleString.hashCode
+ }
+
+ override def equals(other: Any): Boolean = other match {
+ case o: SparkPlanInfo =>
+ nodeName == o.nodeName && simpleString == o.simpleString && children == o.children
+ case _ => false
+ }
+}
private[sql] object SparkPlanInfo {
def fromSparkPlan(plan: SparkPlan): SparkPlanInfo = {
-
- val children = plan.children ++ plan.subqueries
+ val children = plan match {
+ case ReusedExchange(_, child) => child :: Nil
+ case _ => plan.children ++ plan.subqueries
+ }
val metrics = plan.metrics.toSeq.map { case (key, metric) =>
new SQLMetricInfo(metric.name.getOrElse(key), metric.id,
Utils.getFormattedClassName(metric.param))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index f07add83d5..f856634cf7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -46,6 +46,10 @@ case class TungstenAggregate(
require(TungstenAggregate.supportsAggregate(aggregateBufferAttributes))
+ override lazy val allAttributes: Seq[Attribute] =
+ child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++
+ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes)
+
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"),
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 4a9e736f7a..4901298227 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -166,6 +166,9 @@ case class Range(
private[sql] override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ // output attributes should not affect the results
+ override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
+
override def upstreams(): Seq[RDD[InternalRow]] = {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
.map(i => InternalRow(i)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
index 40cad4b1a7..1a5c6a66c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
@@ -34,12 +34,16 @@ import org.apache.spark.util.ThreadUtils
*/
case class BroadcastExchange(
mode: BroadcastMode,
- child: SparkPlan) extends UnaryNode {
-
- override def output: Seq[Attribute] = child.output
+ child: SparkPlan) extends Exchange {
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
+ override def sameResult(plan: SparkPlan): Boolean = plan match {
+ case p: BroadcastExchange =>
+ mode.compatibleWith(p.mode) && child.sameResult(p.child)
+ case _ => false
+ }
+
@transient
private val timeout: Duration = {
val timeoutValue = sqlContext.conf.broadcastTimeout
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
new file mode 100644
index 0000000000..12513e9106
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
+import org.apache.spark.sql.types.StructType
+
+/**
+ * An interface for exchanges.
+ */
+abstract class Exchange extends UnaryNode {
+ override def output: Seq[Attribute] = child.output
+}
+
+/**
+ * A wrapper for reused exchange to have different output, because two exchanges which produce
+ * logically identical output will have distinct sets of output attribute ids, so we need to
+ * preserve the original ids because they're what downstream operators are expecting.
+ */
+case class ReusedExchange(override val output: Seq[Attribute], child: Exchange) extends LeafNode {
+
+ override def sameResult(plan: SparkPlan): Boolean = {
+ // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here.
+ plan.sameResult(child)
+ }
+
+ def doExecute(): RDD[InternalRow] = {
+ child.execute()
+ }
+
+ override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ child.executeBroadcast()
+ }
+
+ // Do not repeat the same tree in explain.
+ override def treeChildren: Seq[SparkPlan] = Nil
+}
+
+/**
+ * Find out duplicated exchanges in the spark plan, then use the same exchange for all the
+ * references.
+ */
+private[sql] case class ReuseExchange(sqlContext: SQLContext) extends Rule[SparkPlan] {
+
+ def apply(plan: SparkPlan): SparkPlan = {
+ if (!sqlContext.conf.exchangeReuseEnabled) {
+ return plan
+ }
+ // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
+ val exchanges = mutable.HashMap[StructType, ArrayBuffer[Exchange]]()
+ plan.transformUp {
+ case exchange: Exchange =>
+ // the exchanges that have same results usually also have same schemas (same column names).
+ val sameSchema = exchanges.getOrElseUpdate(exchange.schema, ArrayBuffer[Exchange]())
+ val samePlan = sameSchema.find { e =>
+ exchange.sameResult(e)
+ }
+ if (samePlan.isDefined) {
+ // Keep the output of this exchange, the following plans require that to resolve
+ // attributes.
+ ReusedExchange(exchange.output, samePlan.get)
+ } else {
+ sameSchema += exchange
+ exchange
+ }
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index de21d7705e..4eb4d9adbd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -38,7 +38,7 @@ import org.apache.spark.util.MutablePair
case class ShuffleExchange(
var newPartitioning: Partitioning,
child: SparkPlan,
- @transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode {
+ @transient coordinator: Option[ExchangeCoordinator]) extends Exchange {
override def nodeName: String = {
val extraInfo = coordinator match {
@@ -55,8 +55,6 @@ case class ShuffleExchange(
override def outputPartitioning: Partitioning = newPartitioning
- override def output: Seq[Attribute] = child.output
-
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
override protected def doPrepare(): Unit = {
@@ -103,16 +101,25 @@ case class ShuffleExchange(
new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
}
+ /**
+ * Caches the created ShuffleRowRDD so we can reuse that.
+ */
+ private var cachedShuffleRDD: ShuffledRowRDD = null
+
protected override def doExecute(): RDD[InternalRow] = attachTree(this, "execute") {
- coordinator match {
- case Some(exchangeCoordinator) =>
- val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
- assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
- shuffleRDD
- case None =>
- val shuffleDependency = prepareShuffleDependency()
- preparePostShuffleRDD(shuffleDependency)
+ // Returns the same ShuffleRowRDD if this plan is used by multiple plans.
+ if (cachedShuffleRDD == null) {
+ cachedShuffleRDD = coordinator match {
+ case Some(exchangeCoordinator) =>
+ val shuffleRDD = exchangeCoordinator.postShuffleRDD(this)
+ assert(shuffleRDD.partitions.length == newPartitioning.numPartitions)
+ shuffleRDD
+ case None =>
+ val shuffleDependency = prepareShuffleDependency()
+ preparePostShuffleRDD(shuffleDependency)
+ }
}
+ cachedShuffleRDD
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 9a3cdaf697..99f8841c87 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -681,7 +681,7 @@ private[execution] case class HashedRelationBroadcastMode(
keys: Seq[Expression],
attributes: Seq[Attribute]) extends BroadcastMode {
- def transform(rows: Array[InternalRow]): HashedRelation = {
+ override def transform(rows: Array[InternalRow]): HashedRelation = {
val generator = UnsafeProjection.create(keys, attributes)
if (canJoinKeyFitWithinLong) {
LongHashedRelation(rows.iterator, generator, rows.length)
@@ -689,5 +689,18 @@ private[execution] case class HashedRelationBroadcastMode(
HashedRelation(rows.iterator, generator, rows.length)
}
}
+
+ private lazy val canonicalizedKeys: Seq[Expression] = {
+ keys.map { e =>
+ BindReferences.bindReference(e.canonicalized, attributes)
+ }
+ }
+
+ override def compatibleWith(other: BroadcastMode): Boolean = other match {
+ case m: HashedRelationBroadcastMode =>
+ canJoinKeyFitWithinLong == m.canJoinKeyFitWithinLong &&
+ canonicalizedKeys == m.canonicalizedKeys
+ case _ => false
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
index 83372aa2e9..94d318e702 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala
@@ -64,7 +64,8 @@ private[sql] object SparkPlanGraph {
val nodeIdGenerator = new AtomicLong(0)
val nodes = mutable.ArrayBuffer[SparkPlanGraphNode]()
val edges = mutable.ArrayBuffer[SparkPlanGraphEdge]()
- buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null)
+ val exchanges = mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]()
+ buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, null, null, exchanges)
new SparkPlanGraph(nodes, edges)
}
@@ -74,7 +75,8 @@ private[sql] object SparkPlanGraph {
nodes: mutable.ArrayBuffer[SparkPlanGraphNode],
edges: mutable.ArrayBuffer[SparkPlanGraphEdge],
parent: SparkPlanGraphNode,
- subgraph: SparkPlanGraphCluster): Unit = {
+ subgraph: SparkPlanGraphCluster,
+ exchanges: mutable.HashMap[SparkPlanInfo, SparkPlanGraphNode]): Unit = {
planInfo.nodeName match {
case "WholeStageCodegen" =>
val cluster = new SparkPlanGraphCluster(
@@ -84,13 +86,14 @@ private[sql] object SparkPlanGraph {
mutable.ArrayBuffer[SparkPlanGraphNode]())
nodes += cluster
buildSparkPlanGraphNode(
- planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster)
+ planInfo.children.head, nodeIdGenerator, nodes, edges, parent, cluster, exchanges)
case "InputAdapter" =>
- buildSparkPlanGraphNode(planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null)
+ buildSparkPlanGraphNode(
+ planInfo.children.head, nodeIdGenerator, nodes, edges, parent, null, exchanges)
case "Subquery" if subgraph != null =>
// Subquery should not be included in WholeStageCodegen
- buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null)
- case _ =>
+ buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
+ case name =>
val metrics = planInfo.metrics.map { metric =>
SQLPlanMetric(metric.name, metric.accumulatorId,
SQLMetrics.getMetricParam(metric.metricParam))
@@ -103,12 +106,15 @@ private[sql] object SparkPlanGraph {
} else {
subgraph.nodes += node
}
+ if (name == "ShuffleExchange" || name == "BroadcastExchange") {
+ exchanges += planInfo -> node
+ }
if (parent != null) {
edges += SparkPlanGraphEdge(node.id, parent.id)
}
planInfo.children.foreach(
- buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph))
+ buildSparkPlanGraphNode(_, nodeIdGenerator, nodes, edges, node, subgraph, exchanges))
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 1d1e288441..384102e5ea 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -504,6 +504,10 @@ object SQLConf {
" method",
isPublic = false)
+ val EXCHANGE_REUSE_ENABLED = booleanConf("spark.sql.exchange.reuse",
+ defaultValue = Some(true),
+ doc = "When true, the planner will try to find out duplicated exchanges and re-use them",
+ isPublic = false)
object Deprecated {
val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks"
@@ -564,6 +568,8 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin
def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED)
+ def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
+
def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)
def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 6f81794b29..98ada4d58a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -24,10 +24,9 @@ import org.apache.spark.sql.catalyst.parser.ParserInterface
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources.{DataSourceAnalysis, PreInsertCastAndRename, ResolveDataSource}
-import org.apache.spark.sql.execution.exchange.EnsureRequirements
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange}
import org.apache.spark.sql.util.ExecutionListenerManager
-
/**
* A class that holds all session-specific state in a given [[SQLContext]].
*/
@@ -94,7 +93,8 @@ private[sql] class SessionState(ctx: SQLContext) {
override val batches: Seq[Batch] = Seq(
Batch("Subquery", Once, PlanSubqueries(ctx)),
Batch("Add exchange", Once, EnsureRequirements(ctx)),
- Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx))
+ Batch("Whole stage codegen", Once, CollapseCodegenStages(ctx)),
+ Batch("Reuse duplicated exchanges", Once, ReuseExchange(ctx))
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 55153cda31..26775c3700 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -25,9 +25,9 @@ import scala.util.Random
import org.scalatest.Matchers._
import org.apache.spark.SparkException
-import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
+import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, OneRowRelation, Union}
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
-import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
@@ -1316,6 +1316,40 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
+ test("reuse exchange") {
+ withSQLConf("spark.sql.autoBroadcastJoinThreshold" -> "2") {
+ val df = sqlContext.range(100)
+ val join = df.join(df, "id")
+ val plan = join.queryExecution.executedPlan
+ checkAnswer(join, df)
+ assert(
+ join.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1)
+ assert(join.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 1)
+ val broadcasted = broadcast(join)
+ val join2 = join.join(broadcasted, "id").join(broadcasted, "id")
+ checkAnswer(join2, df)
+ assert(
+ join2.queryExecution.executedPlan.collect { case e: ShuffleExchange => true }.size === 1)
+ assert(
+ join2.queryExecution.executedPlan.collect { case e: BroadcastExchange => true }.size === 1)
+ assert(
+ join2.queryExecution.executedPlan.collect { case e: ReusedExchange => true }.size === 4)
+ }
+ }
+
+ test("sameResult() on aggregate") {
+ val df = sqlContext.range(100)
+ val agg1 = df.groupBy().count()
+ val agg2 = df.groupBy().count()
+ // two aggregates with different ExprId within them should have same result
+ assert(agg1.queryExecution.executedPlan.sameResult(agg2.queryExecution.executedPlan))
+ val agg3 = df.groupBy().sum()
+ assert(!agg1.queryExecution.executedPlan.sameResult(agg3.queryExecution.executedPlan))
+ val df2 = sqlContext.range(101)
+ val agg4 = df2.groupBy().count()
+ assert(!agg1.queryExecution.executedPlan.sameResult(agg4.queryExecution.executedPlan))
+ }
+
test("SPARK-12512: support `.` in column name for withColumn()") {
val df = Seq("a" -> "b").toDF("col.a", "col.b")
checkAnswer(df.select(df("*")), Row("a", "b"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index d4f22de90c..9f159d1e1e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -18,8 +18,10 @@
package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
-import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.catalyst.expressions.{Alias, Literal}
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, IdentityBroadcastMode, SinglePartition}
+import org.apache.spark.sql.execution.exchange.{BroadcastExchange, ReusedExchange, ShuffleExchange}
+import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode
import org.apache.spark.sql.test.SharedSQLContext
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -33,4 +35,70 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
input.map(Row.fromTuple)
)
}
+
+ test("compatible BroadcastMode") {
+ val mode1 = IdentityBroadcastMode
+ val mode2 = HashedRelationBroadcastMode(true, Literal(1) :: Nil, Seq())
+ val mode3 = HashedRelationBroadcastMode(false, Literal("s") :: Nil, Seq())
+
+ assert(mode1.compatibleWith(mode1))
+ assert(!mode1.compatibleWith(mode2))
+ assert(!mode2.compatibleWith(mode1))
+ assert(mode2.compatibleWith(mode2))
+ assert(!mode2.compatibleWith(mode3))
+ assert(mode3.compatibleWith(mode3))
+ }
+
+ test("BroadcastExchange same result") {
+ val df = sqlContext.range(10)
+ val plan = df.queryExecution.executedPlan
+ val output = plan.output
+ assert(plan sameResult plan)
+
+ val exchange1 = BroadcastExchange(IdentityBroadcastMode, plan)
+ val hashMode = HashedRelationBroadcastMode(true, output, plan.output)
+ val exchange2 = BroadcastExchange(hashMode, plan)
+ val hashMode2 =
+ HashedRelationBroadcastMode(true, Alias(output.head, "id2")() :: Nil, plan.output)
+ val exchange3 = BroadcastExchange(hashMode2, plan)
+ val exchange4 = ReusedExchange(output, exchange3)
+
+ assert(exchange1 sameResult exchange1)
+ assert(exchange2 sameResult exchange2)
+ assert(exchange3 sameResult exchange3)
+ assert(exchange4 sameResult exchange4)
+
+ assert(!exchange1.sameResult(exchange2))
+ assert(!exchange2.sameResult(exchange3))
+ assert(!exchange3.sameResult(exchange4))
+ assert(exchange4 sameResult exchange3)
+ }
+
+ test("ShuffleExchange same result") {
+ val df = sqlContext.range(10)
+ val plan = df.queryExecution.executedPlan
+ val output = plan.output
+ assert(plan sameResult plan)
+
+ val part1 = HashPartitioning(output, 1)
+ val exchange1 = ShuffleExchange(part1, plan)
+ val exchange2 = ShuffleExchange(part1, plan)
+ val part2 = HashPartitioning(output, 2)
+ val exchange3 = ShuffleExchange(part2, plan)
+ val part3 = HashPartitioning(output ++ output, 2)
+ val exchange4 = ShuffleExchange(part3, plan)
+ val exchange5 = ReusedExchange(output, exchange4)
+
+ assert(exchange1 sameResult exchange1)
+ assert(exchange2 sameResult exchange2)
+ assert(exchange3 sameResult exchange3)
+ assert(exchange4 sameResult exchange4)
+ assert(exchange5 sameResult exchange5)
+
+ assert(exchange1 sameResult exchange2)
+ assert(!exchange2.sameResult(exchange3))
+ assert(!exchange3.sameResult(exchange4))
+ assert(!exchange4.sameResult(exchange5))
+ assert(exchange5 sameResult exchange4)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index a733237a5e..ab0a7ff628 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -23,15 +23,14 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
-import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange}
+import org.apache.spark.sql.execution.columnar.InMemoryRelation
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReusedExchange, ReuseExchange, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-
class PlannerSuite extends SharedSQLContext {
import testImplicits._
@@ -472,6 +471,50 @@ class PlannerSuite extends SharedSQLContext {
}
// ---------------------------------------------------------------------------------------------
+
+ test("Reuse exchanges") {
+ val distribution = ClusteredDistribution(Literal(1) :: Nil)
+ val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
+ val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
+ assert(!childPartitioning.satisfies(distribution))
+ val shuffle = ShuffleExchange(finalPartitioning,
+ DummySparkPlan(
+ children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
+ requiredChildDistribution = Seq(distribution),
+ requiredChildOrdering = Seq(Seq.empty)),
+ None)
+
+ val inputPlan = SortMergeJoin(
+ Literal(1) :: Nil,
+ Literal(1) :: Nil,
+ None,
+ shuffle,
+ shuffle)
+
+ val outputPlan = ReuseExchange(sqlContext).apply(inputPlan)
+ if (outputPlan.collect { case e: ReusedExchange => true }.size != 1) {
+ fail(s"Should re-use the shuffle:\n$outputPlan")
+ }
+ if (outputPlan.collect { case e: ShuffleExchange => true }.size != 1) {
+ fail(s"Should have only one shuffle:\n$outputPlan")
+ }
+
+ // nested exchanges
+ val inputPlan2 = SortMergeJoin(
+ Literal(1) :: Nil,
+ Literal(1) :: Nil,
+ None,
+ ShuffleExchange(finalPartitioning, inputPlan),
+ ShuffleExchange(finalPartitioning, inputPlan))
+
+ val outputPlan2 = ReuseExchange(sqlContext).apply(inputPlan2)
+ if (outputPlan2.collect { case e: ReusedExchange => true }.size != 2) {
+ fail(s"Should re-use the two shuffles:\n$outputPlan2")
+ }
+ if (outputPlan2.collect { case e: ShuffleExchange => true }.size != 2) {
+ fail(s"Should have only two shuffles:\n$outputPlan")
+ }
+ }
}
// Used for unit-testing EnsureRequirements