aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala35
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala261
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala)46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala)255
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala75
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala13
27 files changed, 658 insertions, 433 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
new file mode 100644
index 0000000000..c646dcfa11
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.physical
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are
+ * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index).
+ */
+trait BroadcastMode {
+ def transform(rows: Array[InternalRow]): Any
+}
+
+/**
+ * IdentityBroadcastMode requires that rows are broadcasted in their original form.
+ */
+case object IdentityBroadcastMode extends BroadcastMode {
+ override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index d6e10c412c..45e2841ec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -76,6 +77,12 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
}
/**
+ * Represents data where tuples are broadcasted to every node. It is quite common that the
+ * entire set of tuples is transformed into different data structure.
+ */
+case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
+
+/**
* Describes how an operator's output is split across partitions. The `compatibleWith`,
* `guarantees`, and `satisfies` methods describe relationships between child partitionings,
* target partitionings, and [[Distribution]]s. These relations are described more precisely in
@@ -213,7 +220,10 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning {
case object SinglePartition extends Partitioning {
val numPartitions = 1
- override def satisfies(required: Distribution): Boolean = true
+ override def satisfies(required: Distribution): Boolean = required match {
+ case _: BroadcastDistribution => false
+ case _ => true
+ }
override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
@@ -351,3 +361,21 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
partitionings.map(_.toString).mkString("(", " or ", ")")
}
}
+
+/**
+ * Represents a partitioning where rows are collected, transformed and broadcasted to each
+ * node in the cluster.
+ */
+case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
+ override val numPartitions: Int = 1
+
+ override def satisfies(required: Distribution): Boolean = required match {
+ case BroadcastDistribution(m) if m == mode => true
+ case _ => false
+ }
+
+ override def compatibleWith(other: Partitioning): Boolean = other match {
+ case BroadcastPartitioning(m) if m == mode => true
+ case _ => false
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 932df36b85..a2f386850c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan,
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
@@ -59,7 +60,6 @@ import org.apache.spark.util.Utils
* @groupname config Configuration
* @groupname dataframes Custom DataFrame Creation
* @groupname Ungrouped Support functions for language integrated queries
- *
* @since 1.0.0
*/
class SQLContext private[sql](
@@ -313,10 +313,10 @@ class SQLContext private[sql](
}
/**
- * Returns true if the [[Queryable]] is currently cached in-memory.
- * @group cachemgmt
- * @since 1.3.0
- */
+ * Returns true if the [[Queryable]] is currently cached in-memory.
+ * @group cachemgmt
+ * @since 1.3.0
+ */
private[sql] def isCached(qName: Queryable): Boolean = {
cacheManager.lookupCachedData(qName).nonEmpty
}
@@ -364,6 +364,7 @@ class SQLContext private[sql](
/**
* Converts $"col name" into an [[Column]].
+ *
* @since 1.3.0
*/
// This must live here to preserve binary compatibility with Spark < 1.5.
@@ -728,7 +729,6 @@ class SQLContext private[sql](
* cached/persisted before, it's also unpersisted.
*
* @param tableName the name of the table to be unregistered.
- *
* @group basic
* @since 1.3.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 477a9460d7..3be4cce045 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -24,6 +24,7 @@ import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
import org.apache.spark.Logging
+import org.apache.spark.broadcast
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -108,15 +109,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
/**
- * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
- * after adding query plan information to created RDDs for visualization.
- * Concrete implementations of SparkPlan should override doExecute instead.
+ * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after
+ * preparations. Concrete implementations of SparkPlan should override doExecute.
*/
- final def execute(): RDD[InternalRow] = {
+ final def execute(): RDD[InternalRow] = executeQuery {
+ doExecute()
+ }
+
+ /**
+ * Returns the result of this query as a broadcast variable by delegating to doBroadcast after
+ * preparations. Concrete implementations of SparkPlan should override doBroadcast.
+ */
+ final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery {
+ doExecuteBroadcast()
+ }
+
+ /**
+ * Execute a query after preparing the query and adding query plan information to created RDDs
+ * for visualization.
+ */
+ private final def executeQuery[T](query: => T): T = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
waitForSubqueries()
- doExecute()
+ query
}
}
@@ -193,6 +209,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def doExecute(): RDD[InternalRow]
/**
+ * Overridden by concrete implementations of SparkPlan.
+ * Produces the result of the query as a broadcast variable.
+ */
+ protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast")
+ }
+
+ /**
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 382654afac..7347156398 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{execution, Strategy}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -25,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
@@ -328,7 +330,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
- execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
+ ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
} else {
execution.Coalesce(numPartitions, planLater(child)) :: Nil
}
@@ -367,7 +369,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r @ logical.Range(start, end, step, numSlices, output) =>
execution.Range(start, step, numSlices, r.numElements, output) :: Nil
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
- execution.Exchange(HashPartitioning(
+ exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 990eeb22b6..d79b547137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
@@ -172,6 +173,10 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
child.execute()
}
+ override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ child.doExecuteBroadcast()
+ }
+
override def supportCodegen: Boolean = false
override def upstream(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
new file mode 100644
index 0000000000..40cad4b1a7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration._
+
+import org.apache.spark.broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
+import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryNode}
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed
+ * SparkPlan.
+ */
+case class BroadcastExchange(
+ mode: BroadcastMode,
+ child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
+
+ @transient
+ private val timeout: Duration = {
+ val timeoutValue = sqlContext.conf.broadcastTimeout
+ if (timeoutValue < 0) {
+ Duration.Inf
+ } else {
+ timeoutValue.seconds
+ }
+ }
+
+ @transient
+ private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+ // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ Future {
+ // This will run in another thread. Set the execution id so that we can connect these jobs
+ // with the correct execution.
+ SQLExecution.withExecutionId(sparkContext, executionId) {
+ // Note that we use .executeCollect() because we don't want to convert data to Scala types
+ val input: Array[InternalRow] = child.executeCollect()
+
+ // Construct and broadcast the relation.
+ sparkContext.broadcast(mode.transform(input))
+ }
+ }(BroadcastExchange.executionContext)
+ }
+
+ override protected def doPrepare(): Unit = {
+ // Materialize the future.
+ relationFuture
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ "BroadcastExchange does not support the execute() code path.")
+ }
+
+ override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ val result = Await.result(relationFuture, timeout)
+ result.asInstanceOf[broadcast.Broadcast[T]]
+ }
+}
+
+object BroadcastExchange {
+ private[execution] val executionContext = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128))
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
new file mode 100644
index 0000000000..709a424636
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
+
+/**
+ * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
+ * of input data meets the
+ * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
+ * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the
+ * input partition ordering requirements are met.
+ */
+private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
+ private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
+
+ private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
+
+ private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
+
+ private def minNumPostShufflePartitions: Option[Int] = {
+ val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
+ if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
+ }
+
+ /**
+ * Given a required distribution, returns a partitioning that satisfies that distribution.
+ */
+ private def createPartitioning(
+ requiredDistribution: Distribution,
+ numPartitions: Int): Partitioning = {
+ requiredDistribution match {
+ case AllTuples => SinglePartition
+ case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
+ case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
+ case dist => sys.error(s"Do not know how to satisfy distribution $dist")
+ }
+ }
+
+ /**
+ * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled
+ * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]].
+ */
+ private def withExchangeCoordinator(
+ children: Seq[SparkPlan],
+ requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = {
+ val supportsCoordinator =
+ if (children.exists(_.isInstanceOf[ShuffleExchange])) {
+ // Right now, ExchangeCoordinator only support HashPartitionings.
+ children.forall {
+ case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true
+ case child =>
+ child.outputPartitioning match {
+ case hash: HashPartitioning => true
+ case collection: PartitioningCollection =>
+ collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
+ case _ => false
+ }
+ }
+ } else {
+ // In this case, although we do not have Exchange operators, we may still need to
+ // shuffle data when we have more than one children because data generated by
+ // these children may not be partitioned in the same way.
+ // Please see the comment in withCoordinator for more details.
+ val supportsDistribution =
+ requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
+ children.length > 1 && supportsDistribution
+ }
+
+ val withCoordinator =
+ if (adaptiveExecutionEnabled && supportsCoordinator) {
+ val coordinator =
+ new ExchangeCoordinator(
+ children.length,
+ targetPostShuffleInputSize,
+ minNumPostShufflePartitions)
+ children.zip(requiredChildDistributions).map {
+ case (e: ShuffleExchange, _) =>
+ // This child is an Exchange, we need to add the coordinator.
+ e.copy(coordinator = Some(coordinator))
+ case (child, distribution) =>
+ // If this child is not an Exchange, we need to add an Exchange for now.
+ // Ideally, we can try to avoid this Exchange. However, when we reach here,
+ // there are at least two children operators (because if there is a single child
+ // and we can avoid Exchange, supportsCoordinator will be false and we
+ // will not reach here.). Although we can make two children have the same number of
+ // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different.
+ // For example, let's say we have the following plan
+ // Join
+ // / \
+ // Agg Exchange
+ // / \
+ // Exchange t2
+ // /
+ // t1
+ // In this case, because a post-shuffle partition can include multiple pre-shuffle
+ // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes
+ // after shuffle. So, even we can use the child Exchange operator of the Join to
+ // have a number of post-shuffle partitions that matches the number of partitions of
+ // Agg, we cannot say these two children are partitioned in the same way.
+ // Here is another case
+ // Join
+ // / \
+ // Agg1 Agg2
+ // / \
+ // Exchange1 Exchange2
+ // / \
+ // t1 t2
+ // In this case, two Aggs shuffle data with the same column of the join condition.
+ // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same
+ // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2
+ // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle
+ // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its
+ // pre-shuffle partitions by using another partitionStartIndices [0, 4].
+ // So, Agg1 and Agg2 are actually not co-partitioned.
+ //
+ // It will be great to introduce a new Partitioning to represent the post-shuffle
+ // partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
+ val targetPartitioning =
+ createPartitioning(distribution, defaultNumPreShufflePartitions)
+ assert(targetPartitioning.isInstanceOf[HashPartitioning])
+ ShuffleExchange(targetPartitioning, child, Some(coordinator))
+ }
+ } else {
+ // If we do not need ExchangeCoordinator, the original children are returned.
+ children
+ }
+
+ withCoordinator
+ }
+
+ private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
+ val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
+ val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
+ var children: Seq[SparkPlan] = operator.children
+ assert(requiredChildDistributions.length == children.length)
+ assert(requiredChildOrderings.length == children.length)
+
+ // Ensure that the operator's children satisfy their output distribution requirements:
+ children = children.zip(requiredChildDistributions).map {
+ case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
+ child
+ case (child, BroadcastDistribution(mode)) =>
+ BroadcastExchange(mode, child)
+ case (child, distribution) =>
+ ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
+ }
+
+ // If the operator has multiple children and specifies child output distributions (e.g. join),
+ // then the children's output partitionings must be compatible:
+ def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
+ case UnspecifiedDistribution => false
+ case BroadcastDistribution(_) => false
+ case _ => true
+ }
+ if (children.length > 1
+ && requiredChildDistributions.exists(requireCompatiblePartitioning)
+ && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+
+ // First check if the existing partitions of the children all match. This means they are
+ // partitioned by the same partitioning into the same number of partitions. In that case,
+ // don't try to make them match `defaultPartitions`, just use the existing partitioning.
+ val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
+ val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
+ case (child, distribution) =>
+ child.outputPartitioning.guarantees(
+ createPartitioning(distribution, maxChildrenNumPartitions))
+ }
+
+ children = if (useExistingPartitioning) {
+ // We do not need to shuffle any child's output.
+ children
+ } else {
+ // We need to shuffle at least one child's output.
+ // Now, we will determine the number of partitions that will be used by created
+ // partitioning schemes.
+ val numPartitions = {
+ // Let's see if we need to shuffle all child's outputs when we use
+ // maxChildrenNumPartitions.
+ val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
+ case (child, distribution) =>
+ !child.outputPartitioning.guarantees(
+ createPartitioning(distribution, maxChildrenNumPartitions))
+ }
+ // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
+ // number of partitions. Otherwise, we use maxChildrenNumPartitions.
+ if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
+ }
+
+ children.zip(requiredChildDistributions).map {
+ case (child, distribution) =>
+ val targetPartitioning = createPartitioning(distribution, numPartitions)
+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
+ child
+ } else {
+ child match {
+ // If child is an exchange, we replace it with
+ // a new one having targetPartitioning.
+ case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c)
+ case _ => ShuffleExchange(targetPartitioning, child)
+ }
+ }
+ }
+ }
+ }
+
+ // Now, we need to add ExchangeCoordinator if necessary.
+ // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges.
+ // However, with the way that we plan the query, we do not have a place where we have a
+ // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator
+ // at here for now.
+ // Once we finish https://issues.apache.org/jira/browse/SPARK-10665,
+ // we can first add Exchanges and then add coordinator once we have a DAG of query fragments.
+ children = withExchangeCoordinator(children, requiredChildDistributions)
+
+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
+ if (requiredOrdering.nonEmpty) {
+ // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
+ if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
+ Sort(requiredOrdering, global = false, child = child)
+ } else {
+ child
+ }
+ } else {
+ child
+ }
+ }
+
+ operator.withNewChildren(children)
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator @ ShuffleExchange(partitioning, child, _) =>
+ child.children match {
+ case ShuffleExchange(childPartitioning, baseChild, _)::Nil =>
+ if (childPartitioning.guarantees(partitioning)) child else operator
+ case _ => operator
+ }
+ case operator: SparkPlan => ensureDistributionAndOrdering(operator)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index 07015e5a5a..6f3bb0ad2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.execution.exchange
import java.util.{HashMap => JHashMap, Map => JMap}
import javax.annotation.concurrent.GuardedBy
@@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Logging, MapOutputStatistics, ShuffleDependency, SimpleFutureAction}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan}
/**
* A coordinator used to determines how we shuffle data between stages generated by Spark SQL.
@@ -33,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow
*
* A coordinator is constructed with three parameters, `numExchanges`,
* `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`.
- * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to
- * this coordinator. So, when we start to do any actual work, we have a way to make sure that
- * we have got expected number of [[Exchange]]s.
+ * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered
+ * to this coordinator. So, when we start to do any actual work, we have a way to make sure that
+ * we have got expected number of [[ShuffleExchange]]s.
* - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's
* input data size. With this parameter, we can estimate the number of post-shuffle partitions.
* This parameter is configured through
@@ -45,26 +46,27 @@ import org.apache.spark.sql.catalyst.InternalRow
* partitions.
*
* The workflow of this coordinator is described as follows:
- * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator,
+ * - Before the execution of a [[SparkPlan]], for an [[ShuffleExchange]] operator,
* if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator.
* This happens in the `doPrepare` method.
- * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will
- * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]].
- * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will
- * immediately get its corresponding post-shuffle [[ShuffledRowRDD]].
+ * - Once we start to execute a physical plan, an [[ShuffleExchange]] registered to this
+ * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle
+ * [[ShuffledRowRDD]].
+ * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]]
+ * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]].
* - If this coordinator has not made the decision on how to shuffle data, it will ask those
- * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size
- * statistics of pre-shuffle partitions, this coordinator will determine the number of
+ * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the the
+ * size statistics of pre-shuffle partitions, this coordinator will determine the number of
* post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices
* to a single post-shuffle partition whenever necessary.
* - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered
- * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can
- * lookup the corresponding [[RDD]].
+ * [[ShuffleExchange]]s. So, when an [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator
+ * can lookup the corresponding [[RDD]].
*
* The strategy used to determine the number of post-shuffle partitions is described as follows.
* To determine the number of post-shuffle partitions, we have a target input size for a
* post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages
- * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and
+ * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and
* pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until
* the size of a post-shuffle partition is equal or greater than the target size.
* For example, we have two stages with the following pre-shuffle partition size statistics:
@@ -83,11 +85,11 @@ private[sql] class ExchangeCoordinator(
extends Logging {
// The registered Exchange operators.
- private[this] val exchanges = ArrayBuffer[Exchange]()
+ private[this] val exchanges = ArrayBuffer[ShuffleExchange]()
// This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator.
- private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] =
- new JHashMap[Exchange, ShuffledRowRDD](numExchanges)
+ private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] =
+ new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges)
// A boolean that indicates if this coordinator has made decision on how to shuffle data.
// This variable will only be updated by doEstimationIfNecessary, which is protected by
@@ -95,11 +97,11 @@ private[sql] class ExchangeCoordinator(
@volatile private[this] var estimated: Boolean = false
/**
- * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be
- * called in the `doPrepare` method of an [[Exchange]] operator.
+ * Registers an [[ShuffleExchange]] operator to this coordinator. This method is only allowed to
+ * be called in the `doPrepare` method of an [[ShuffleExchange]] operator.
*/
@GuardedBy("this")
- def registerExchange(exchange: Exchange): Unit = synchronized {
+ def registerExchange(exchange: ShuffleExchange): Unit = synchronized {
exchanges += exchange
}
@@ -199,7 +201,7 @@ private[sql] class ExchangeCoordinator(
// Make sure we have the expected number of registered Exchange operators.
assert(exchanges.length == numExchanges)
- val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges)
+ val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges)
// Submit all map stages
val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]()
@@ -254,7 +256,7 @@ private[sql] class ExchangeCoordinator(
}
}
- def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = {
+ def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = {
doEstimationIfNecessary()
if (!postShuffleRDDs.containsKey(exchange)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index e30adefc69..de21d7705e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.execution.exchange
import java.util.Random
@@ -24,19 +24,18 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
import org.apache.spark.util.MutablePair
/**
* Performs a shuffle that will result in the desired `newPartitioning`.
*/
-case class Exchange(
+case class ShuffleExchange(
var newPartitioning: Partitioning,
child: SparkPlan,
@transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode {
@@ -81,7 +80,8 @@ case class Exchange(
* the returned ShuffleDependency will be the input of shuffle.
*/
private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
- Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer)
+ ShuffleExchange.prepareShuffleDependency(
+ child.execute(), child.output, newPartitioning, serializer)
}
/**
@@ -116,9 +116,9 @@ case class Exchange(
}
}
-object Exchange {
- def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
- Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
+object ShuffleExchange {
+ def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = {
+ ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
}
/**
@@ -259,238 +259,3 @@ object Exchange {
dependency
}
}
-
-/**
- * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
- * of input data meets the
- * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
- * each operator by inserting [[Exchange]] Operators where required. Also ensure that the
- * input partition ordering requirements are met.
- */
-private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
- private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
-
- private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
-
- private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
-
- private def minNumPostShufflePartitions: Option[Int] = {
- val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
- if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
- }
-
- /**
- * Given a required distribution, returns a partitioning that satisfies that distribution.
- */
- private def createPartitioning(
- requiredDistribution: Distribution,
- numPartitions: Int): Partitioning = {
- requiredDistribution match {
- case AllTuples => SinglePartition
- case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
- case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
- case dist => sys.error(s"Do not know how to satisfy distribution $dist")
- }
- }
-
- /**
- * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled
- * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]].
- */
- private def withExchangeCoordinator(
- children: Seq[SparkPlan],
- requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = {
- val supportsCoordinator =
- if (children.exists(_.isInstanceOf[Exchange])) {
- // Right now, ExchangeCoordinator only support HashPartitionings.
- children.forall {
- case e @ Exchange(hash: HashPartitioning, _, _) => true
- case child =>
- child.outputPartitioning match {
- case hash: HashPartitioning => true
- case collection: PartitioningCollection =>
- collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
- case _ => false
- }
- }
- } else {
- // In this case, although we do not have Exchange operators, we may still need to
- // shuffle data when we have more than one children because data generated by
- // these children may not be partitioned in the same way.
- // Please see the comment in withCoordinator for more details.
- val supportsDistribution =
- requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
- children.length > 1 && supportsDistribution
- }
-
- val withCoordinator =
- if (adaptiveExecutionEnabled && supportsCoordinator) {
- val coordinator =
- new ExchangeCoordinator(
- children.length,
- targetPostShuffleInputSize,
- minNumPostShufflePartitions)
- children.zip(requiredChildDistributions).map {
- case (e: Exchange, _) =>
- // This child is an Exchange, we need to add the coordinator.
- e.copy(coordinator = Some(coordinator))
- case (child, distribution) =>
- // If this child is not an Exchange, we need to add an Exchange for now.
- // Ideally, we can try to avoid this Exchange. However, when we reach here,
- // there are at least two children operators (because if there is a single child
- // and we can avoid Exchange, supportsCoordinator will be false and we
- // will not reach here.). Although we can make two children have the same number of
- // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different.
- // For example, let's say we have the following plan
- // Join
- // / \
- // Agg Exchange
- // / \
- // Exchange t2
- // /
- // t1
- // In this case, because a post-shuffle partition can include multiple pre-shuffle
- // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes
- // after shuffle. So, even we can use the child Exchange operator of the Join to
- // have a number of post-shuffle partitions that matches the number of partitions of
- // Agg, we cannot say these two children are partitioned in the same way.
- // Here is another case
- // Join
- // / \
- // Agg1 Agg2
- // / \
- // Exchange1 Exchange2
- // / \
- // t1 t2
- // In this case, two Aggs shuffle data with the same column of the join condition.
- // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same
- // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2
- // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle
- // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its
- // pre-shuffle partitions by using another partitionStartIndices [0, 4].
- // So, Agg1 and Agg2 are actually not co-partitioned.
- //
- // It will be great to introduce a new Partitioning to represent the post-shuffle
- // partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
- val targetPartitioning =
- createPartitioning(distribution, defaultNumPreShufflePartitions)
- assert(targetPartitioning.isInstanceOf[HashPartitioning])
- Exchange(targetPartitioning, child, Some(coordinator))
- }
- } else {
- // If we do not need ExchangeCoordinator, the original children are returned.
- children
- }
-
- withCoordinator
- }
-
- private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
- val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
- val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
- var children: Seq[SparkPlan] = operator.children
- assert(requiredChildDistributions.length == children.length)
- assert(requiredChildOrderings.length == children.length)
-
- // Ensure that the operator's children satisfy their output distribution requirements:
- children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
- if (child.outputPartitioning.satisfies(distribution)) {
- child
- } else {
- Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
- }
- }
-
- // If the operator has multiple children and specifies child output distributions (e.g. join),
- // then the children's output partitionings must be compatible:
- if (children.length > 1
- && requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
- && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
-
- // First check if the existing partitions of the children all match. This means they are
- // partitioned by the same partitioning into the same number of partitions. In that case,
- // don't try to make them match `defaultPartitions`, just use the existing partitioning.
- val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
- val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
- case (child, distribution) => {
- child.outputPartitioning.guarantees(
- createPartitioning(distribution, maxChildrenNumPartitions))
- }
- }
-
- children = if (useExistingPartitioning) {
- // We do not need to shuffle any child's output.
- children
- } else {
- // We need to shuffle at least one child's output.
- // Now, we will determine the number of partitions that will be used by created
- // partitioning schemes.
- val numPartitions = {
- // Let's see if we need to shuffle all child's outputs when we use
- // maxChildrenNumPartitions.
- val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
- case (child, distribution) => {
- !child.outputPartitioning.guarantees(
- createPartitioning(distribution, maxChildrenNumPartitions))
- }
- }
- // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
- // number of partitions. Otherwise, we use maxChildrenNumPartitions.
- if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
- }
-
- children.zip(requiredChildDistributions).map {
- case (child, distribution) => {
- val targetPartitioning =
- createPartitioning(distribution, numPartitions)
- if (child.outputPartitioning.guarantees(targetPartitioning)) {
- child
- } else {
- child match {
- // If child is an exchange, we replace it with
- // a new one having targetPartitioning.
- case Exchange(_, c, _) => Exchange(targetPartitioning, c)
- case _ => Exchange(targetPartitioning, child)
- }
- }
- }
- }
- }
- }
-
- // Now, we need to add ExchangeCoordinator if necessary.
- // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges.
- // However, with the way that we plan the query, we do not have a place where we have a
- // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator
- // at here for now.
- // Once we finish https://issues.apache.org/jira/browse/SPARK-10665,
- // we can first add Exchanges and then add coordinator once we have a DAG of query fragments.
- children = withExchangeCoordinator(children, requiredChildDistributions)
-
- // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
- children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
- if (requiredOrdering.nonEmpty) {
- // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
- if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
- Sort(requiredOrdering, global = false, child = child)
- } else {
- child
- }
- } else {
- child
- }
- }
-
- operator.withNewChildren(children)
- }
-
- def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case operator @ Exchange(partitioning, child, _) =>
- child.children match {
- case Exchange(childPartitioning, baseChild, _)::Nil =>
- if (childPartitioning.guarantees(partitioning)) child else operator
- case _ => operator
- }
- case operator: SparkPlan => ensureDistributionAndOrdering(operator)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index a64da22580..ddc08822f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -17,9 +17,6 @@
package org.apache.spark.sql.execution.joins
-import scala.concurrent._
-import scala.concurrent.duration._
-
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
@@ -27,10 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.collection.CompactBuffer
/**
@@ -52,60 +48,25 @@ case class BroadcastHashJoin(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- val timeout: Duration = {
- val timeoutValue = sqlContext.conf.broadcastTimeout
- if (timeoutValue < 0) {
- Duration.Inf
- } else {
- timeoutValue.seconds
- }
- }
-
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
- override def requiredChildDistribution: Seq[Distribution] =
- UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
- // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
- // for the same query.
- @transient
- private lazy val broadcastFuture = {
- // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- Future {
- // This will run in another thread. Set the execution id so that we can connect these jobs
- // with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
- // Note that we use .execute().collect() because we don't want to convert data to Scala
- // types
- val input: Array[InternalRow] = buildPlan.execute().map { row =>
- row.copy()
- }.collect()
- // The following line doesn't run in a job so we cannot track the metric value. However, we
- // have already tracked it in the above lines. So here we can use
- // `SQLMetrics.nullLongMetric` to ignore it.
- // TODO: move this check into HashedRelation
- val hashed = if (canJoinKeyFitWithinLong) {
- LongHashedRelation(
- input.iterator, buildSideKeyGenerator, input.size)
- } else {
- HashedRelation(
- input.iterator, buildSideKeyGenerator, input.size)
- }
- sparkContext.broadcast(hashed)
- }
- }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
- }
-
- protected override def doPrepare(): Unit = {
- broadcastFuture
+ override def requiredChildDistribution: Seq[Distribution] = {
+ val mode = HashedRelationBroadcastMode(
+ canJoinKeyFitWithinLong,
+ rewriteKeyExpr(buildKeys),
+ buildPlan.output)
+ buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val broadcastRelation = Await.result(broadcastFuture, timeout)
-
+ val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
@@ -160,7 +121,7 @@ case class BroadcastHashJoin(
*/
private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
// create a name for HashedRelation
- val broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
@@ -362,9 +323,3 @@ case class BroadcastHashJoin(
}
}
}
-
-object BroadcastHashJoin {
-
- private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 4f1cfd2e81..1f99fbedde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.{InternalAccumulator, TaskContext}
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -38,25 +39,25 @@ case class BroadcastLeftSemiJoinHash(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ override def requiredChildDistribution: Seq[Distribution] = {
+ val mode = if (condition.isEmpty) {
+ HashSetBroadcastMode(rightKeys, right.output)
+ } else {
+ HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
+ }
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val input = right.execute().map { row =>
- row.copy()
- }.collect()
-
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(input.toIterator)
- val broadcastedRelation = sparkContext.broadcast(hashSet)
-
+ val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
left.execute().mapPartitionsInternal { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
}
} else {
- val hashRelation =
- HashedRelation(input.toIterator, rightKeyGenerator, input.size)
- val broadcastedRelation = sparkContext.broadcast(hashRelation)
-
+ val broadcastedRelation = right.executeBroadcast[HashedRelation]()
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4585cbda92..e8bd7f69db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
@@ -33,7 +33,6 @@ case class BroadcastNestedLoopJoin(
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
- // TODO: Override requiredChildDistribution.
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -44,8 +43,15 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}
+ override def requiredChildDistribution: Seq[Distribution] = buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
+ }
+
private[this] def genResultProjection: InternalRow => InternalRow = {
- UnsafeProjection.create(schema)
+ UnsafeProjection.create(schema)
}
override def outputPartitioning: Partitioning = streamed.outputPartitioning
@@ -73,15 +79,14 @@ case class BroadcastNestedLoopJoin(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map { row =>
- row.copy()
- }.collect().toIndexedSeq)
+ val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
/** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
+ val relation = broadcastedRelation.value
+
val matchedRows = new CompactBuffer[InternalRow]
- val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
+ val includedBroadcastTuples = new BitSet(relation.length)
val joinedRow = new JoinedRow
val leftNulls = new GenericMutableRow(left.output.size)
@@ -92,8 +97,8 @@ case class BroadcastNestedLoopJoin(
var i = 0
var streamRowMatched = false
- while (i < broadcastedRelation.value.size) {
- val broadcastedRow = broadcastedRelation.value(i)
+ while (i < relation.length) {
+ val broadcastedRow = relation(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 0220e0b8a7..1cb6a00617 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
@@ -44,22 +45,7 @@ trait HashSemiJoin {
protected def buildKeyHashSet(
buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
- val hashSet = new java.util.HashSet[InternalRow]()
-
- // Create a Hash set of buildKeys
- val rightKey = rightKeyGenerator
- while (buildIter.hasNext) {
- val currentRow = buildIter.next()
- val rowKey = rightKey(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- hashSet.add(rowKey.copy())
- }
- }
- }
-
- hashSet
+ HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
}
protected def hashSemiJoin(
@@ -92,3 +78,36 @@ trait HashSemiJoin {
}
}
}
+
+private[execution] object HashSemiJoin {
+ def buildKeyHashSet(
+ keys: Seq[Expression],
+ attributes: Seq[Attribute],
+ rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
+ val hashSet = new java.util.HashSet[InternalRow]()
+
+ // Create a Hash set of buildKeys
+ val key = UnsafeProjection.create(keys, attributes)
+ while (rows.hasNext) {
+ val currentRow = rows.next()
+ val rowKey = key(currentRow)
+ if (!rowKey.anyNull) {
+ val keyExists = hashSet.contains(rowKey)
+ if (!keyExists) {
+ hashSet.add(rowKey.copy())
+ }
+ }
+ }
+ hashSet
+ }
+}
+
+/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
+private[execution] case class HashSetBroadcastMode(
+ keys: Seq[Expression],
+ attributes: Seq[Attribute]) extends BroadcastMode {
+
+ override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
+ HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0978570d42..606269bf25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -25,12 +25,11 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
import org.apache.spark.sql.execution.local.LocalNode
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.MemoryLocation
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
import org.apache.spark.util.collection.CompactBuffer
@@ -675,3 +674,20 @@ private[joins] object LongHashedRelation {
}
}
}
+
+/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
+private[execution] case class HashedRelationBroadcastMode(
+ canJoinKeyFitWithinLong: Boolean,
+ keys: Seq[Expression],
+ attributes: Seq[Attribute]) extends BroadcastMode {
+
+ def transform(rows: Array[InternalRow]): HashedRelation = {
+ val generator = UnsafeProjection.create(keys, attributes)
+ if (canJoinKeyFitWithinLong) {
+ LongHashedRelation(rows.iterator, generator, rows.length)
+ } else {
+ HashedRelation(rows.iterator, generator, rows.length)
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index ce758d63b3..df6dac8818 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -29,9 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
* for hash join.
*/
case class LeftSemiJoinBNL(
- streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- extends BinaryNode {
- // TODO: Override requiredChildDistribution.
+ streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -46,27 +44,28 @@ case class LeftSemiJoinBNL(
/** The Broadcast relation */
override def right: SparkPlan = broadcast
+ override def requiredChildDistribution: Seq[Distribution] = {
+ UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
+ }
+
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map { row =>
- row.copy()
- }.collect().toIndexedSeq)
+ val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
+ val relation = broadcastedRelation.value
streamedIter.filter(streamedRow => {
var i = 0
var matched = false
- while (i < broadcastedRelation.value.size && !matched) {
- val broadcastedRow = broadcastedRelation.value(i)
- if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+ while (i < relation.length && !matched) {
+ if (boundCondition(joinedRow(streamedRow, relation(i)))) {
matched = true
}
i += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index ef76847bcb..cd543d4195 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
/**
@@ -38,7 +39,8 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
protected override def doExecute(): RDD[InternalRow] = {
val shuffled = new ShuffledRowRDD(
- Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer))
+ ShuffleExchange.prepareShuffleDependency(
+ child.execute(), child.output, SinglePartition, serializer))
shuffled.mapPartitionsInternal(_.take(limit))
}
}
@@ -110,7 +112,8 @@ case class TakeOrderedAndProject(
}
}
val shuffled = new ShuffledRowRDD(
- Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer))
+ ShuffleExchange.prepareShuffleDependency(
+ localTopK, child.output, SinglePartition, serializer))
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
if (projectList.isDefined) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index e8d0678989..83d7953aaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -23,9 +23,9 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
-import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.execution.columnar._
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -357,7 +357,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
* Verifies that the plan for `df` contains `expected` number of Exchange operators.
*/
private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
- assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected)
+ assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected)
}
test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 99ba2e2061..50a246489e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -26,8 +26,8 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
-import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
import org.apache.spark.sql.test.SQLTestData.TestData2
@@ -1119,7 +1119,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
atFirstAgg = true
}
- case e: Exchange => atFirstAgg = false
+ case e: ShuffleExchange => atFirstAgg = false
case _ =>
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 35ff1c40fe..b1c588a63d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql._
+import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
@@ -297,13 +298,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = agg.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 1)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
@@ -311,7 +312,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
case None =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 2)
case o =>
@@ -348,13 +349,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = join.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 2)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
@@ -362,7 +363,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
case None =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 2)
case o =>
@@ -404,13 +405,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = join.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 4)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
@@ -456,13 +457,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = join.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 3)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 87bff3295f..d4f22de90c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.test.SharedSQLContext
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -28,7 +29,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(
input.toDF(),
- plan => Exchange(SinglePartition, plan),
+ plan => ShuffleExchange(SinglePartition, plan),
input.map(Row.fromTuple)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 250ce8f866..4de56783fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal,
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -212,7 +213,7 @@ class PlannerSuite extends SharedSQLContext {
| JOIN tiny ON (small.key = tiny.key)
""".stripMargin
).queryExecution.executedPlan.collect {
- case exchange: Exchange => exchange
+ case exchange: ShuffleExchange => exchange
}.length
assert(numExchanges === 5)
}
@@ -227,7 +228,7 @@ class PlannerSuite extends SharedSQLContext {
| JOIN tiny ON (normal.key = tiny.key)
""".stripMargin
).queryExecution.executedPlan.collect {
- case exchange: Exchange => exchange
+ case exchange: ShuffleExchange => exchange
}.length
assert(numExchanges === 5)
}
@@ -295,7 +296,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
}
}
@@ -333,7 +334,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
}
}
@@ -353,7 +354,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"Exchange should not have been added:\n$outputPlan")
}
}
@@ -376,7 +377,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"No Exchanges should have been added:\n$outputPlan")
}
}
@@ -435,7 +436,7 @@ class PlannerSuite extends SharedSQLContext {
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
assert(!childPartitioning.satisfies(distribution))
- val inputPlan = Exchange(finalPartitioning,
+ val inputPlan = ShuffleExchange(finalPartitioning,
DummySparkPlan(
children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
requiredChildDistribution = Seq(distribution),
@@ -444,7 +445,7 @@ class PlannerSuite extends SharedSQLContext {
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.size == 2) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) {
fail(s"Topmost Exchange should have been eliminated:\n$outputPlan")
}
}
@@ -455,7 +456,7 @@ class PlannerSuite extends SharedSQLContext {
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8)
val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
assert(!childPartitioning.satisfies(distribution))
- val inputPlan = Exchange(finalPartitioning,
+ val inputPlan = ShuffleExchange(finalPartitioning,
DummySparkPlan(
children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
requiredChildDistribution = Seq(distribution),
@@ -464,7 +465,7 @@ class PlannerSuite extends SharedSQLContext {
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.size == 1) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) {
fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index e25b5e0610..a256ee95a1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext}
+import org.apache.spark.sql.{QueryTest, SQLContext}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
/**
@@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
- val plan = df3.queryExecution.sparkPlan
+ val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan)
assert(plan.collect { case p: T => p }.size === 1)
plan.executeCollect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index e22a810a6b..6dfff3770b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -88,7 +89,15 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan)
+ val broadcastJoin = joins.BroadcastHashJoin(
+ leftKeys,
+ rightKeys,
+ Inner,
+ side,
+ boundCondition,
+ leftPlan,
+ rightPlan)
+ EnsureRequirements(sqlContext).apply(broadcastJoin)
}
def makeSortMergeJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index f4b01fbad0..cd6b6fcbb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 9c86084f9b..f3ad8409e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 9ba645626f..a05a57c0f5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -22,8 +22,9 @@ import java.io.File
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
+import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -252,8 +253,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin]
- assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft)
- assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight)
+ assert(joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft)
+ assert(joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight)
}
}
}
@@ -312,7 +313,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
agged.sort("i", "j"),
df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
- assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
}
}
@@ -326,7 +327,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
agged.sort("i", "j"),
df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
- assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
}
}
@@ -339,7 +340,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val agged = hiveContext.table("bucketed_table").groupBy("i").count()
// make sure we fall back to non-bucketing mode and can't avoid shuffle
- assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined)
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isDefined)
checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i"))
}
}