aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-02-21 12:32:31 -0800
committerReynold Xin <rxin@databricks.com>2016-02-21 12:32:31 -0800
commitb6a873d6d4682796f55dbafadd0b5cad881f96ea (patch)
treef1d7431df3827d27fdd3eec7bd2503336d08ecc5 /sql
parentaf441ddbd13f48c09b458c451d7bba3965a878d1 (diff)
downloadspark-b6a873d6d4682796f55dbafadd0b5cad881f96ea.tar.gz
spark-b6a873d6d4682796f55dbafadd0b5cad881f96ea.tar.bz2
spark-b6a873d6d4682796f55dbafadd0b5cad881f96ea.zip
[SPARK-13136][SQL] Create a dedicated Broadcast exchange operator
Quite a few Spark SQL join operators broadcast one side of the join to all nodes. The are a few problems with this: - This conflates broadcasting (a data exchange) with joining. Data exchanges should be managed by a different operator. - All these nodes implement their own (duplicate) broadcasting logic. - Re-use of indices is quite hard. This PR defines both a ```BroadcastDistribution``` and ```BroadcastPartitioning```, these contain a `BroadcastMode`. The `BroadcastMode` defines the way in which we transform the Array of `InternalRow`'s into an index. We currently support the following `BroadcastMode`'s: - IdentityBroadcastMode: This broadcasts the rows in their original form. - HashSetBroadcastMode: This applies a projection to the input rows, deduplicates these rows and broadcasts the resulting `Set`. - HashedRelationBroadcastMode: This transforms the input rows into a `HashedRelation`, and broadcasts this index. To match this distribution we implement a ```BroadcastExchange``` operator which will perform the broadcast for us, and have ```EnsureRequirements``` plan this operator. The old Exchange operator has been renamed into ShuffleExchange in order to clearly separate between Shuffled and Broadcasted exchanges. Finally the classes in Exchange.scala have been moved to a dedicated package. cc rxin davies Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #11083 from hvanhovell/SPARK-13136.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala35
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala30
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala261
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala)46
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala)255
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala75
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala25
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala51
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala13
27 files changed, 658 insertions, 433 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
new file mode 100644
index 0000000000..c646dcfa11
--- /dev/null
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.plans.physical
+
+import org.apache.spark.sql.catalyst.InternalRow
+
+/**
+ * Marker trait to identify the shape in which tuples are broadcasted. Typical examples of this are
+ * identity (tuples remain unchanged) or hashed (tuples are converted into some hash index).
+ */
+trait BroadcastMode {
+ def transform(rows: Array[InternalRow]): Any
+}
+
+/**
+ * IdentityBroadcastMode requires that rows are broadcasted in their original form.
+ */
+case object IdentityBroadcastMode extends BroadcastMode {
+ override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index d6e10c412c..45e2841ec9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}
@@ -76,6 +77,12 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
}
/**
+ * Represents data where tuples are broadcasted to every node. It is quite common that the
+ * entire set of tuples is transformed into different data structure.
+ */
+case class BroadcastDistribution(mode: BroadcastMode) extends Distribution
+
+/**
* Describes how an operator's output is split across partitions. The `compatibleWith`,
* `guarantees`, and `satisfies` methods describe relationships between child partitionings,
* target partitionings, and [[Distribution]]s. These relations are described more precisely in
@@ -213,7 +220,10 @@ case class RoundRobinPartitioning(numPartitions: Int) extends Partitioning {
case object SinglePartition extends Partitioning {
val numPartitions = 1
- override def satisfies(required: Distribution): Boolean = true
+ override def satisfies(required: Distribution): Boolean = required match {
+ case _: BroadcastDistribution => false
+ case _ => true
+ }
override def compatibleWith(other: Partitioning): Boolean = other.numPartitions == 1
@@ -351,3 +361,21 @@ case class PartitioningCollection(partitionings: Seq[Partitioning])
partitionings.map(_.toString).mkString("(", " or ", ")")
}
}
+
+/**
+ * Represents a partitioning where rows are collected, transformed and broadcasted to each
+ * node in the cluster.
+ */
+case class BroadcastPartitioning(mode: BroadcastMode) extends Partitioning {
+ override val numPartitions: Int = 1
+
+ override def satisfies(required: Distribution): Boolean = required match {
+ case BroadcastDistribution(m) if m == mode => true
+ case _ => false
+ }
+
+ override def compatibleWith(other: Partitioning): Boolean = other match {
+ case BroadcastPartitioning(m) if m == mode => true
+ case _ => false
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 932df36b85..a2f386850c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -41,6 +41,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan,
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
@@ -59,7 +60,6 @@ import org.apache.spark.util.Utils
* @groupname config Configuration
* @groupname dataframes Custom DataFrame Creation
* @groupname Ungrouped Support functions for language integrated queries
- *
* @since 1.0.0
*/
class SQLContext private[sql](
@@ -313,10 +313,10 @@ class SQLContext private[sql](
}
/**
- * Returns true if the [[Queryable]] is currently cached in-memory.
- * @group cachemgmt
- * @since 1.3.0
- */
+ * Returns true if the [[Queryable]] is currently cached in-memory.
+ * @group cachemgmt
+ * @since 1.3.0
+ */
private[sql] def isCached(qName: Queryable): Boolean = {
cacheManager.lookupCachedData(qName).nonEmpty
}
@@ -364,6 +364,7 @@ class SQLContext private[sql](
/**
* Converts $"col name" into an [[Column]].
+ *
* @since 1.3.0
*/
// This must live here to preserve binary compatibility with Spark < 1.5.
@@ -728,7 +729,6 @@ class SQLContext private[sql](
* cached/persisted before, it's also unpersisted.
*
* @param tableName the name of the table to be unregistered.
- *
* @group basic
* @since 1.3.0
*/
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 477a9460d7..3be4cce045 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -24,6 +24,7 @@ import scala.concurrent.{Await, ExecutionContext, Future}
import scala.concurrent.duration._
import org.apache.spark.Logging
+import org.apache.spark.broadcast
import org.apache.spark.rdd.{RDD, RDDOperationScope}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
@@ -108,15 +109,30 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq.fill(children.size)(Nil)
/**
- * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute
- * after adding query plan information to created RDDs for visualization.
- * Concrete implementations of SparkPlan should override doExecute instead.
+ * Returns the result of this query as an RDD[InternalRow] by delegating to doExecute after
+ * preparations. Concrete implementations of SparkPlan should override doExecute.
*/
- final def execute(): RDD[InternalRow] = {
+ final def execute(): RDD[InternalRow] = executeQuery {
+ doExecute()
+ }
+
+ /**
+ * Returns the result of this query as a broadcast variable by delegating to doBroadcast after
+ * preparations. Concrete implementations of SparkPlan should override doBroadcast.
+ */
+ final def executeBroadcast[T](): broadcast.Broadcast[T] = executeQuery {
+ doExecuteBroadcast()
+ }
+
+ /**
+ * Execute a query after preparing the query and adding query plan information to created RDDs
+ * for visualization.
+ */
+ private final def executeQuery[T](query: => T): T = {
RDDOperationScope.withScope(sparkContext, nodeName, false, true) {
prepare()
waitForSubqueries()
- doExecute()
+ query
}
}
@@ -193,6 +209,14 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
protected def doExecute(): RDD[InternalRow]
/**
+ * Overridden by concrete implementations of SparkPlan.
+ * Produces the result of the query as a broadcast variable.
+ */
+ protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ throw new UnsupportedOperationException(s"$nodeName does not implement doExecuteBroadcast")
+ }
+
+ /**
* Runs this query returning the result as an array.
*/
def executeCollect(): Array[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 382654afac..7347156398 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -17,7 +17,8 @@
package org.apache.spark.sql.execution
-import org.apache.spark.sql.{execution, Strategy}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
+import org.apache.spark.sql.Strategy
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
@@ -25,6 +26,7 @@ import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.{DescribeCommand => RunnableDescribeCommand}
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
import org.apache.spark.sql.execution.datasources.{CreateTableUsing, CreateTempTableUsing, DescribeCommand => LogicalDescribeCommand, _}
@@ -328,7 +330,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
- execution.Exchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
+ ShuffleExchange(RoundRobinPartitioning(numPartitions), planLater(child)) :: Nil
} else {
execution.Coalesce(numPartitions, planLater(child)) :: Nil
}
@@ -367,7 +369,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case r @ logical.Range(start, end, step, numSlices, output) =>
execution.Range(start, step, numSlices, r.numElements, output) :: Nil
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
- execution.Exchange(HashPartitioning(
+ exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
case e @ python.EvaluatePython(udf, child, _) =>
python.BatchPythonEvaluation(udf, e.output, planLater(child)) :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 990eeb22b6..d79b547137 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
@@ -172,6 +173,10 @@ case class InputAdapter(child: SparkPlan) extends LeafNode with CodegenSupport {
child.execute()
}
+ override def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ child.doExecuteBroadcast()
+ }
+
override def supportCodegen: Boolean = false
override def upstream(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
new file mode 100644
index 0000000000..40cad4b1a7
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchange.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import scala.concurrent.{Await, ExecutionContext, Future}
+import scala.concurrent.duration._
+
+import org.apache.spark.broadcast
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastMode, BroadcastPartitioning, Partitioning}
+import org.apache.spark.sql.execution.{SparkPlan, SQLExecution, UnaryNode}
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * A [[BroadcastExchange]] collects, transforms and finally broadcasts the result of a transformed
+ * SparkPlan.
+ */
+case class BroadcastExchange(
+ mode: BroadcastMode,
+ child: SparkPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+
+ override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
+
+ @transient
+ private val timeout: Duration = {
+ val timeoutValue = sqlContext.conf.broadcastTimeout
+ if (timeoutValue < 0) {
+ Duration.Inf
+ } else {
+ timeoutValue.seconds
+ }
+ }
+
+ @transient
+ private lazy val relationFuture: Future[broadcast.Broadcast[Any]] = {
+ // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
+ val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
+ Future {
+ // This will run in another thread. Set the execution id so that we can connect these jobs
+ // with the correct execution.
+ SQLExecution.withExecutionId(sparkContext, executionId) {
+ // Note that we use .executeCollect() because we don't want to convert data to Scala types
+ val input: Array[InternalRow] = child.executeCollect()
+
+ // Construct and broadcast the relation.
+ sparkContext.broadcast(mode.transform(input))
+ }
+ }(BroadcastExchange.executionContext)
+ }
+
+ override protected def doPrepare(): Unit = {
+ // Materialize the future.
+ relationFuture
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ throw new UnsupportedOperationException(
+ "BroadcastExchange does not support the execute() code path.")
+ }
+
+ override protected[sql] def doExecuteBroadcast[T](): broadcast.Broadcast[T] = {
+ val result = Await.result(relationFuture, timeout)
+ result.asInstanceOf[broadcast.Broadcast[T]]
+ }
+}
+
+object BroadcastExchange {
+ private[execution] val executionContext = ExecutionContext.fromExecutorService(
+ ThreadUtils.newDaemonCachedThreadPool("broadcast-exchange", 128))
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
new file mode 100644
index 0000000000..709a424636
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/EnsureRequirements.scala
@@ -0,0 +1,261 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.exchange
+
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
+
+/**
+ * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
+ * of input data meets the
+ * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
+ * each operator by inserting [[ShuffleExchange]] Operators where required. Also ensure that the
+ * input partition ordering requirements are met.
+ */
+private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
+ private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
+
+ private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
+
+ private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
+
+ private def minNumPostShufflePartitions: Option[Int] = {
+ val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
+ if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
+ }
+
+ /**
+ * Given a required distribution, returns a partitioning that satisfies that distribution.
+ */
+ private def createPartitioning(
+ requiredDistribution: Distribution,
+ numPartitions: Int): Partitioning = {
+ requiredDistribution match {
+ case AllTuples => SinglePartition
+ case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
+ case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
+ case dist => sys.error(s"Do not know how to satisfy distribution $dist")
+ }
+ }
+
+ /**
+ * Adds [[ExchangeCoordinator]] to [[ShuffleExchange]]s if adaptive query execution is enabled
+ * and partitioning schemes of these [[ShuffleExchange]]s support [[ExchangeCoordinator]].
+ */
+ private def withExchangeCoordinator(
+ children: Seq[SparkPlan],
+ requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = {
+ val supportsCoordinator =
+ if (children.exists(_.isInstanceOf[ShuffleExchange])) {
+ // Right now, ExchangeCoordinator only support HashPartitionings.
+ children.forall {
+ case e @ ShuffleExchange(hash: HashPartitioning, _, _) => true
+ case child =>
+ child.outputPartitioning match {
+ case hash: HashPartitioning => true
+ case collection: PartitioningCollection =>
+ collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
+ case _ => false
+ }
+ }
+ } else {
+ // In this case, although we do not have Exchange operators, we may still need to
+ // shuffle data when we have more than one children because data generated by
+ // these children may not be partitioned in the same way.
+ // Please see the comment in withCoordinator for more details.
+ val supportsDistribution =
+ requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
+ children.length > 1 && supportsDistribution
+ }
+
+ val withCoordinator =
+ if (adaptiveExecutionEnabled && supportsCoordinator) {
+ val coordinator =
+ new ExchangeCoordinator(
+ children.length,
+ targetPostShuffleInputSize,
+ minNumPostShufflePartitions)
+ children.zip(requiredChildDistributions).map {
+ case (e: ShuffleExchange, _) =>
+ // This child is an Exchange, we need to add the coordinator.
+ e.copy(coordinator = Some(coordinator))
+ case (child, distribution) =>
+ // If this child is not an Exchange, we need to add an Exchange for now.
+ // Ideally, we can try to avoid this Exchange. However, when we reach here,
+ // there are at least two children operators (because if there is a single child
+ // and we can avoid Exchange, supportsCoordinator will be false and we
+ // will not reach here.). Although we can make two children have the same number of
+ // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different.
+ // For example, let's say we have the following plan
+ // Join
+ // / \
+ // Agg Exchange
+ // / \
+ // Exchange t2
+ // /
+ // t1
+ // In this case, because a post-shuffle partition can include multiple pre-shuffle
+ // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes
+ // after shuffle. So, even we can use the child Exchange operator of the Join to
+ // have a number of post-shuffle partitions that matches the number of partitions of
+ // Agg, we cannot say these two children are partitioned in the same way.
+ // Here is another case
+ // Join
+ // / \
+ // Agg1 Agg2
+ // / \
+ // Exchange1 Exchange2
+ // / \
+ // t1 t2
+ // In this case, two Aggs shuffle data with the same column of the join condition.
+ // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same
+ // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2
+ // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle
+ // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its
+ // pre-shuffle partitions by using another partitionStartIndices [0, 4].
+ // So, Agg1 and Agg2 are actually not co-partitioned.
+ //
+ // It will be great to introduce a new Partitioning to represent the post-shuffle
+ // partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
+ val targetPartitioning =
+ createPartitioning(distribution, defaultNumPreShufflePartitions)
+ assert(targetPartitioning.isInstanceOf[HashPartitioning])
+ ShuffleExchange(targetPartitioning, child, Some(coordinator))
+ }
+ } else {
+ // If we do not need ExchangeCoordinator, the original children are returned.
+ children
+ }
+
+ withCoordinator
+ }
+
+ private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
+ val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
+ val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
+ var children: Seq[SparkPlan] = operator.children
+ assert(requiredChildDistributions.length == children.length)
+ assert(requiredChildOrderings.length == children.length)
+
+ // Ensure that the operator's children satisfy their output distribution requirements:
+ children = children.zip(requiredChildDistributions).map {
+ case (child, distribution) if child.outputPartitioning.satisfies(distribution) =>
+ child
+ case (child, BroadcastDistribution(mode)) =>
+ BroadcastExchange(mode, child)
+ case (child, distribution) =>
+ ShuffleExchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
+ }
+
+ // If the operator has multiple children and specifies child output distributions (e.g. join),
+ // then the children's output partitionings must be compatible:
+ def requireCompatiblePartitioning(distribution: Distribution): Boolean = distribution match {
+ case UnspecifiedDistribution => false
+ case BroadcastDistribution(_) => false
+ case _ => true
+ }
+ if (children.length > 1
+ && requiredChildDistributions.exists(requireCompatiblePartitioning)
+ && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
+
+ // First check if the existing partitions of the children all match. This means they are
+ // partitioned by the same partitioning into the same number of partitions. In that case,
+ // don't try to make them match `defaultPartitions`, just use the existing partitioning.
+ val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
+ val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
+ case (child, distribution) =>
+ child.outputPartitioning.guarantees(
+ createPartitioning(distribution, maxChildrenNumPartitions))
+ }
+
+ children = if (useExistingPartitioning) {
+ // We do not need to shuffle any child's output.
+ children
+ } else {
+ // We need to shuffle at least one child's output.
+ // Now, we will determine the number of partitions that will be used by created
+ // partitioning schemes.
+ val numPartitions = {
+ // Let's see if we need to shuffle all child's outputs when we use
+ // maxChildrenNumPartitions.
+ val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
+ case (child, distribution) =>
+ !child.outputPartitioning.guarantees(
+ createPartitioning(distribution, maxChildrenNumPartitions))
+ }
+ // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
+ // number of partitions. Otherwise, we use maxChildrenNumPartitions.
+ if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
+ }
+
+ children.zip(requiredChildDistributions).map {
+ case (child, distribution) =>
+ val targetPartitioning = createPartitioning(distribution, numPartitions)
+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
+ child
+ } else {
+ child match {
+ // If child is an exchange, we replace it with
+ // a new one having targetPartitioning.
+ case ShuffleExchange(_, c, _) => ShuffleExchange(targetPartitioning, c)
+ case _ => ShuffleExchange(targetPartitioning, child)
+ }
+ }
+ }
+ }
+ }
+
+ // Now, we need to add ExchangeCoordinator if necessary.
+ // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges.
+ // However, with the way that we plan the query, we do not have a place where we have a
+ // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator
+ // at here for now.
+ // Once we finish https://issues.apache.org/jira/browse/SPARK-10665,
+ // we can first add Exchanges and then add coordinator once we have a DAG of query fragments.
+ children = withExchangeCoordinator(children, requiredChildDistributions)
+
+ // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
+ children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
+ if (requiredOrdering.nonEmpty) {
+ // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
+ if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
+ Sort(requiredOrdering, global = false, child = child)
+ } else {
+ child
+ }
+ } else {
+ child
+ }
+ }
+
+ operator.withNewChildren(children)
+ }
+
+ def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
+ case operator @ ShuffleExchange(partitioning, child, _) =>
+ child.children match {
+ case ShuffleExchange(childPartitioning, baseChild, _)::Nil =>
+ if (childPartitioning.guarantees(partitioning)) child else operator
+ case _ => operator
+ }
+ case operator: SparkPlan => ensureDistributionAndOrdering(operator)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
index 07015e5a5a..6f3bb0ad2b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ExchangeCoordinator.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.execution.exchange
import java.util.{HashMap => JHashMap, Map => JMap}
import javax.annotation.concurrent.GuardedBy
@@ -25,6 +25,7 @@ import scala.collection.mutable.ArrayBuffer
import org.apache.spark.{Logging, MapOutputStatistics, ShuffleDependency, SimpleFutureAction}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan}
/**
* A coordinator used to determines how we shuffle data between stages generated by Spark SQL.
@@ -33,9 +34,9 @@ import org.apache.spark.sql.catalyst.InternalRow
*
* A coordinator is constructed with three parameters, `numExchanges`,
* `targetPostShuffleInputSize`, and `minNumPostShufflePartitions`.
- * - `numExchanges` is used to indicated that how many [[Exchange]]s that will be registered to
- * this coordinator. So, when we start to do any actual work, we have a way to make sure that
- * we have got expected number of [[Exchange]]s.
+ * - `numExchanges` is used to indicated that how many [[ShuffleExchange]]s that will be registered
+ * to this coordinator. So, when we start to do any actual work, we have a way to make sure that
+ * we have got expected number of [[ShuffleExchange]]s.
* - `targetPostShuffleInputSize` is the targeted size of a post-shuffle partition's
* input data size. With this parameter, we can estimate the number of post-shuffle partitions.
* This parameter is configured through
@@ -45,26 +46,27 @@ import org.apache.spark.sql.catalyst.InternalRow
* partitions.
*
* The workflow of this coordinator is described as follows:
- * - Before the execution of a [[SparkPlan]], for an [[Exchange]] operator,
+ * - Before the execution of a [[SparkPlan]], for an [[ShuffleExchange]] operator,
* if an [[ExchangeCoordinator]] is assigned to it, it registers itself to this coordinator.
* This happens in the `doPrepare` method.
- * - Once we start to execute a physical plan, an [[Exchange]] registered to this coordinator will
- * call `postShuffleRDD` to get its corresponding post-shuffle [[ShuffledRowRDD]].
- * If this coordinator has made the decision on how to shuffle data, this [[Exchange]] will
- * immediately get its corresponding post-shuffle [[ShuffledRowRDD]].
+ * - Once we start to execute a physical plan, an [[ShuffleExchange]] registered to this
+ * coordinator will call `postShuffleRDD` to get its corresponding post-shuffle
+ * [[ShuffledRowRDD]].
+ * If this coordinator has made the decision on how to shuffle data, this [[ShuffleExchange]]
+ * will immediately get its corresponding post-shuffle [[ShuffledRowRDD]].
* - If this coordinator has not made the decision on how to shuffle data, it will ask those
- * registered [[Exchange]]s to submit their pre-shuffle stages. Then, based on the the size
- * statistics of pre-shuffle partitions, this coordinator will determine the number of
+ * registered [[ShuffleExchange]]s to submit their pre-shuffle stages. Then, based on the the
+ * size statistics of pre-shuffle partitions, this coordinator will determine the number of
* post-shuffle partitions and pack multiple pre-shuffle partitions with continuous indices
* to a single post-shuffle partition whenever necessary.
* - Finally, this coordinator will create post-shuffle [[ShuffledRowRDD]]s for all registered
- * [[Exchange]]s. So, when an [[Exchange]] calls `postShuffleRDD`, this coordinator can
- * lookup the corresponding [[RDD]].
+ * [[ShuffleExchange]]s. So, when an [[ShuffleExchange]] calls `postShuffleRDD`, this coordinator
+ * can lookup the corresponding [[RDD]].
*
* The strategy used to determine the number of post-shuffle partitions is described as follows.
* To determine the number of post-shuffle partitions, we have a target input size for a
* post-shuffle partition. Once we have size statistics of pre-shuffle partitions from stages
- * corresponding to the registered [[Exchange]]s, we will do a pass of those statistics and
+ * corresponding to the registered [[ShuffleExchange]]s, we will do a pass of those statistics and
* pack pre-shuffle partitions with continuous indices to a single post-shuffle partition until
* the size of a post-shuffle partition is equal or greater than the target size.
* For example, we have two stages with the following pre-shuffle partition size statistics:
@@ -83,11 +85,11 @@ private[sql] class ExchangeCoordinator(
extends Logging {
// The registered Exchange operators.
- private[this] val exchanges = ArrayBuffer[Exchange]()
+ private[this] val exchanges = ArrayBuffer[ShuffleExchange]()
// This map is used to lookup the post-shuffle ShuffledRowRDD for an Exchange operator.
- private[this] val postShuffleRDDs: JMap[Exchange, ShuffledRowRDD] =
- new JHashMap[Exchange, ShuffledRowRDD](numExchanges)
+ private[this] val postShuffleRDDs: JMap[ShuffleExchange, ShuffledRowRDD] =
+ new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges)
// A boolean that indicates if this coordinator has made decision on how to shuffle data.
// This variable will only be updated by doEstimationIfNecessary, which is protected by
@@ -95,11 +97,11 @@ private[sql] class ExchangeCoordinator(
@volatile private[this] var estimated: Boolean = false
/**
- * Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be
- * called in the `doPrepare` method of an [[Exchange]] operator.
+ * Registers an [[ShuffleExchange]] operator to this coordinator. This method is only allowed to
+ * be called in the `doPrepare` method of an [[ShuffleExchange]] operator.
*/
@GuardedBy("this")
- def registerExchange(exchange: Exchange): Unit = synchronized {
+ def registerExchange(exchange: ShuffleExchange): Unit = synchronized {
exchanges += exchange
}
@@ -199,7 +201,7 @@ private[sql] class ExchangeCoordinator(
// Make sure we have the expected number of registered Exchange operators.
assert(exchanges.length == numExchanges)
- val newPostShuffleRDDs = new JHashMap[Exchange, ShuffledRowRDD](numExchanges)
+ val newPostShuffleRDDs = new JHashMap[ShuffleExchange, ShuffledRowRDD](numExchanges)
// Submit all map stages
val shuffleDependencies = ArrayBuffer[ShuffleDependency[Int, InternalRow, InternalRow]]()
@@ -254,7 +256,7 @@ private[sql] class ExchangeCoordinator(
}
}
- def postShuffleRDD(exchange: Exchange): ShuffledRowRDD = {
+ def postShuffleRDD(exchange: ShuffleExchange): ShuffledRowRDD = {
doEstimationIfNecessary()
if (!postShuffleRDDs.containsKey(exchange)) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
index e30adefc69..de21d7705e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/ShuffleExchange.scala
@@ -15,7 +15,7 @@
* limitations under the License.
*/
-package org.apache.spark.sql.execution
+package org.apache.spark.sql.execution.exchange
import java.util.Random
@@ -24,19 +24,18 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.shuffle.sort.SortShuffleManager
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.errors.attachTree
-import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeProjection}
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution._
import org.apache.spark.util.MutablePair
/**
* Performs a shuffle that will result in the desired `newPartitioning`.
*/
-case class Exchange(
+case class ShuffleExchange(
var newPartitioning: Partitioning,
child: SparkPlan,
@transient coordinator: Option[ExchangeCoordinator]) extends UnaryNode {
@@ -81,7 +80,8 @@ case class Exchange(
* the returned ShuffleDependency will be the input of shuffle.
*/
private[sql] def prepareShuffleDependency(): ShuffleDependency[Int, InternalRow, InternalRow] = {
- Exchange.prepareShuffleDependency(child.execute(), child.output, newPartitioning, serializer)
+ ShuffleExchange.prepareShuffleDependency(
+ child.execute(), child.output, newPartitioning, serializer)
}
/**
@@ -116,9 +116,9 @@ case class Exchange(
}
}
-object Exchange {
- def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
- Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
+object ShuffleExchange {
+ def apply(newPartitioning: Partitioning, child: SparkPlan): ShuffleExchange = {
+ ShuffleExchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
}
/**
@@ -259,238 +259,3 @@ object Exchange {
dependency
}
}
-
-/**
- * Ensures that the [[org.apache.spark.sql.catalyst.plans.physical.Partitioning Partitioning]]
- * of input data meets the
- * [[org.apache.spark.sql.catalyst.plans.physical.Distribution Distribution]] requirements for
- * each operator by inserting [[Exchange]] Operators where required. Also ensure that the
- * input partition ordering requirements are met.
- */
-private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
- private def defaultNumPreShufflePartitions: Int = sqlContext.conf.numShufflePartitions
-
- private def targetPostShuffleInputSize: Long = sqlContext.conf.targetPostShuffleInputSize
-
- private def adaptiveExecutionEnabled: Boolean = sqlContext.conf.adaptiveExecutionEnabled
-
- private def minNumPostShufflePartitions: Option[Int] = {
- val minNumPostShufflePartitions = sqlContext.conf.minNumPostShufflePartitions
- if (minNumPostShufflePartitions > 0) Some(minNumPostShufflePartitions) else None
- }
-
- /**
- * Given a required distribution, returns a partitioning that satisfies that distribution.
- */
- private def createPartitioning(
- requiredDistribution: Distribution,
- numPartitions: Int): Partitioning = {
- requiredDistribution match {
- case AllTuples => SinglePartition
- case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
- case OrderedDistribution(ordering) => RangePartitioning(ordering, numPartitions)
- case dist => sys.error(s"Do not know how to satisfy distribution $dist")
- }
- }
-
- /**
- * Adds [[ExchangeCoordinator]] to [[Exchange]]s if adaptive query execution is enabled
- * and partitioning schemes of these [[Exchange]]s support [[ExchangeCoordinator]].
- */
- private def withExchangeCoordinator(
- children: Seq[SparkPlan],
- requiredChildDistributions: Seq[Distribution]): Seq[SparkPlan] = {
- val supportsCoordinator =
- if (children.exists(_.isInstanceOf[Exchange])) {
- // Right now, ExchangeCoordinator only support HashPartitionings.
- children.forall {
- case e @ Exchange(hash: HashPartitioning, _, _) => true
- case child =>
- child.outputPartitioning match {
- case hash: HashPartitioning => true
- case collection: PartitioningCollection =>
- collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
- case _ => false
- }
- }
- } else {
- // In this case, although we do not have Exchange operators, we may still need to
- // shuffle data when we have more than one children because data generated by
- // these children may not be partitioned in the same way.
- // Please see the comment in withCoordinator for more details.
- val supportsDistribution =
- requiredChildDistributions.forall(_.isInstanceOf[ClusteredDistribution])
- children.length > 1 && supportsDistribution
- }
-
- val withCoordinator =
- if (adaptiveExecutionEnabled && supportsCoordinator) {
- val coordinator =
- new ExchangeCoordinator(
- children.length,
- targetPostShuffleInputSize,
- minNumPostShufflePartitions)
- children.zip(requiredChildDistributions).map {
- case (e: Exchange, _) =>
- // This child is an Exchange, we need to add the coordinator.
- e.copy(coordinator = Some(coordinator))
- case (child, distribution) =>
- // If this child is not an Exchange, we need to add an Exchange for now.
- // Ideally, we can try to avoid this Exchange. However, when we reach here,
- // there are at least two children operators (because if there is a single child
- // and we can avoid Exchange, supportsCoordinator will be false and we
- // will not reach here.). Although we can make two children have the same number of
- // post-shuffle partitions. Their numbers of pre-shuffle partitions may be different.
- // For example, let's say we have the following plan
- // Join
- // / \
- // Agg Exchange
- // / \
- // Exchange t2
- // /
- // t1
- // In this case, because a post-shuffle partition can include multiple pre-shuffle
- // partitions, a HashPartitioning will not be strictly partitioned by the hashcodes
- // after shuffle. So, even we can use the child Exchange operator of the Join to
- // have a number of post-shuffle partitions that matches the number of partitions of
- // Agg, we cannot say these two children are partitioned in the same way.
- // Here is another case
- // Join
- // / \
- // Agg1 Agg2
- // / \
- // Exchange1 Exchange2
- // / \
- // t1 t2
- // In this case, two Aggs shuffle data with the same column of the join condition.
- // After we use ExchangeCoordinator, these two Aggs may not be partitioned in the same
- // way. Let's say that Agg1 and Agg2 both have 5 pre-shuffle partitions and 2
- // post-shuffle partitions. It is possible that Agg1 fetches those pre-shuffle
- // partitions by using a partitionStartIndices [0, 3]. However, Agg2 may fetch its
- // pre-shuffle partitions by using another partitionStartIndices [0, 4].
- // So, Agg1 and Agg2 are actually not co-partitioned.
- //
- // It will be great to introduce a new Partitioning to represent the post-shuffle
- // partitions when one post-shuffle partition includes multiple pre-shuffle partitions.
- val targetPartitioning =
- createPartitioning(distribution, defaultNumPreShufflePartitions)
- assert(targetPartitioning.isInstanceOf[HashPartitioning])
- Exchange(targetPartitioning, child, Some(coordinator))
- }
- } else {
- // If we do not need ExchangeCoordinator, the original children are returned.
- children
- }
-
- withCoordinator
- }
-
- private def ensureDistributionAndOrdering(operator: SparkPlan): SparkPlan = {
- val requiredChildDistributions: Seq[Distribution] = operator.requiredChildDistribution
- val requiredChildOrderings: Seq[Seq[SortOrder]] = operator.requiredChildOrdering
- var children: Seq[SparkPlan] = operator.children
- assert(requiredChildDistributions.length == children.length)
- assert(requiredChildOrderings.length == children.length)
-
- // Ensure that the operator's children satisfy their output distribution requirements:
- children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
- if (child.outputPartitioning.satisfies(distribution)) {
- child
- } else {
- Exchange(createPartitioning(distribution, defaultNumPreShufflePartitions), child)
- }
- }
-
- // If the operator has multiple children and specifies child output distributions (e.g. join),
- // then the children's output partitionings must be compatible:
- if (children.length > 1
- && requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
- && !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
-
- // First check if the existing partitions of the children all match. This means they are
- // partitioned by the same partitioning into the same number of partitions. In that case,
- // don't try to make them match `defaultPartitions`, just use the existing partitioning.
- val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
- val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
- case (child, distribution) => {
- child.outputPartitioning.guarantees(
- createPartitioning(distribution, maxChildrenNumPartitions))
- }
- }
-
- children = if (useExistingPartitioning) {
- // We do not need to shuffle any child's output.
- children
- } else {
- // We need to shuffle at least one child's output.
- // Now, we will determine the number of partitions that will be used by created
- // partitioning schemes.
- val numPartitions = {
- // Let's see if we need to shuffle all child's outputs when we use
- // maxChildrenNumPartitions.
- val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
- case (child, distribution) => {
- !child.outputPartitioning.guarantees(
- createPartitioning(distribution, maxChildrenNumPartitions))
- }
- }
- // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
- // number of partitions. Otherwise, we use maxChildrenNumPartitions.
- if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
- }
-
- children.zip(requiredChildDistributions).map {
- case (child, distribution) => {
- val targetPartitioning =
- createPartitioning(distribution, numPartitions)
- if (child.outputPartitioning.guarantees(targetPartitioning)) {
- child
- } else {
- child match {
- // If child is an exchange, we replace it with
- // a new one having targetPartitioning.
- case Exchange(_, c, _) => Exchange(targetPartitioning, c)
- case _ => Exchange(targetPartitioning, child)
- }
- }
- }
- }
- }
- }
-
- // Now, we need to add ExchangeCoordinator if necessary.
- // Actually, it is not a good idea to add ExchangeCoordinators while we are adding Exchanges.
- // However, with the way that we plan the query, we do not have a place where we have a
- // global picture of all shuffle dependencies of a post-shuffle stage. So, we add coordinator
- // at here for now.
- // Once we finish https://issues.apache.org/jira/browse/SPARK-10665,
- // we can first add Exchanges and then add coordinator once we have a DAG of query fragments.
- children = withExchangeCoordinator(children, requiredChildDistributions)
-
- // Now that we've performed any necessary shuffles, add sorts to guarantee output orderings:
- children = children.zip(requiredChildOrderings).map { case (child, requiredOrdering) =>
- if (requiredOrdering.nonEmpty) {
- // If child.outputOrdering is [a, b] and requiredOrdering is [a], we do not need to sort.
- if (requiredOrdering != child.outputOrdering.take(requiredOrdering.length)) {
- Sort(requiredOrdering, global = false, child = child)
- } else {
- child
- }
- } else {
- child
- }
- }
-
- operator.withNewChildren(children)
- }
-
- def apply(plan: SparkPlan): SparkPlan = plan.transformUp {
- case operator @ Exchange(partitioning, child, _) =>
- child.children match {
- case Exchange(childPartitioning, baseChild, _)::Nil =>
- if (childPartitioning.guarantees(partitioning)) child else operator
- case _ => operator
- }
- case operator: SparkPlan => ensureDistributionAndOrdering(operator)
- }
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index a64da22580..ddc08822f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -17,9 +17,6 @@
package org.apache.spark.sql.execution.joins
-import scala.concurrent._
-import scala.concurrent.duration._
-
import org.apache.spark.TaskContext
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
@@ -27,10 +24,9 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, RightOuter}
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, Partitioning, UnspecifiedDistribution}
-import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan, SQLExecution}
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution}
+import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.util.ThreadUtils
import org.apache.spark.util.collection.CompactBuffer
/**
@@ -52,60 +48,25 @@ case class BroadcastHashJoin(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
- val timeout: Duration = {
- val timeoutValue = sqlContext.conf.broadcastTimeout
- if (timeoutValue < 0) {
- Duration.Inf
- } else {
- timeoutValue.seconds
- }
- }
-
override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
- override def requiredChildDistribution: Seq[Distribution] =
- UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
-
- // Use lazy so that we won't do broadcast when calling explain but still cache the broadcast value
- // for the same query.
- @transient
- private lazy val broadcastFuture = {
- // broadcastFuture is used in "doExecute". Therefore we can get the execution id correctly here.
- val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
- Future {
- // This will run in another thread. Set the execution id so that we can connect these jobs
- // with the correct execution.
- SQLExecution.withExecutionId(sparkContext, executionId) {
- // Note that we use .execute().collect() because we don't want to convert data to Scala
- // types
- val input: Array[InternalRow] = buildPlan.execute().map { row =>
- row.copy()
- }.collect()
- // The following line doesn't run in a job so we cannot track the metric value. However, we
- // have already tracked it in the above lines. So here we can use
- // `SQLMetrics.nullLongMetric` to ignore it.
- // TODO: move this check into HashedRelation
- val hashed = if (canJoinKeyFitWithinLong) {
- LongHashedRelation(
- input.iterator, buildSideKeyGenerator, input.size)
- } else {
- HashedRelation(
- input.iterator, buildSideKeyGenerator, input.size)
- }
- sparkContext.broadcast(hashed)
- }
- }(BroadcastHashJoin.broadcastHashJoinExecutionContext)
- }
-
- protected override def doPrepare(): Unit = {
- broadcastFuture
+ override def requiredChildDistribution: Seq[Distribution] = {
+ val mode = HashedRelationBroadcastMode(
+ canJoinKeyFitWithinLong,
+ rewriteKeyExpr(buildKeys),
+ buildPlan.output)
+ buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
}
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val broadcastRelation = Await.result(broadcastFuture, timeout)
-
+ val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
streamedPlan.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow()
val hashTable = broadcastRelation.value
@@ -160,7 +121,7 @@ case class BroadcastHashJoin(
*/
private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = {
// create a name for HashedRelation
- val broadcastRelation = Await.result(broadcastFuture, timeout)
+ val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]()
val broadcast = ctx.addReferenceObj("broadcast", broadcastRelation)
val relationTerm = ctx.freshName("relation")
val clsName = broadcastRelation.value.getClass.getName
@@ -362,9 +323,3 @@ case class BroadcastHashJoin(
}
}
}
-
-object BroadcastHashJoin {
-
- private[joins] val broadcastHashJoinExecutionContext = ExecutionContext.fromExecutorService(
- ThreadUtils.newDaemonCachedThreadPool("broadcast-hash-join", 128))
-}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
index 4f1cfd2e81..1f99fbedde 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastLeftSemiJoinHash.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.{InternalAccumulator, TaskContext}
+import org.apache.spark.TaskContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -38,25 +39,25 @@ case class BroadcastLeftSemiJoinHash(
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+ override def requiredChildDistribution: Seq[Distribution] = {
+ val mode = if (condition.isEmpty) {
+ HashSetBroadcastMode(rightKeys, right.output)
+ } else {
+ HashedRelationBroadcastMode(canJoinKeyFitWithinLong = false, rightKeys, right.output)
+ }
+ UnspecifiedDistribution :: BroadcastDistribution(mode) :: Nil
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val input = right.execute().map { row =>
- row.copy()
- }.collect()
-
if (condition.isEmpty) {
- val hashSet = buildKeyHashSet(input.toIterator)
- val broadcastedRelation = sparkContext.broadcast(hashSet)
-
+ val broadcastedRelation = right.executeBroadcast[java.util.Set[InternalRow]]()
left.execute().mapPartitionsInternal { streamIter =>
hashSemiJoin(streamIter, broadcastedRelation.value, numOutputRows)
}
} else {
- val hashRelation =
- HashedRelation(input.toIterator, rightKeyGenerator, input.size)
- val broadcastedRelation = sparkContext.broadcast(hashRelation)
-
+ val broadcastedRelation = right.executeBroadcast[HashedRelation]()
left.execute().mapPartitionsInternal { streamIter =>
val hashedRelation = broadcastedRelation.value
TaskContext.get().taskMetrics().incPeakExecutionMemory(hashedRelation.getMemorySize)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
index 4585cbda92..e8bd7f69db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala
@@ -21,7 +21,7 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.util.collection.{BitSet, CompactBuffer}
@@ -33,7 +33,6 @@ case class BroadcastNestedLoopJoin(
buildSide: BuildSide,
joinType: JoinType,
condition: Option[Expression]) extends BinaryNode {
- // TODO: Override requiredChildDistribution.
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -44,8 +43,15 @@ case class BroadcastNestedLoopJoin(
case BuildLeft => (right, left)
}
+ override def requiredChildDistribution: Seq[Distribution] = buildSide match {
+ case BuildLeft =>
+ BroadcastDistribution(IdentityBroadcastMode) :: UnspecifiedDistribution :: Nil
+ case BuildRight =>
+ UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
+ }
+
private[this] def genResultProjection: InternalRow => InternalRow = {
- UnsafeProjection.create(schema)
+ UnsafeProjection.create(schema)
}
override def outputPartitioning: Partitioning = streamed.outputPartitioning
@@ -73,15 +79,14 @@ case class BroadcastNestedLoopJoin(
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map { row =>
- row.copy()
- }.collect().toIndexedSeq)
+ val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
/** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
+ val relation = broadcastedRelation.value
+
val matchedRows = new CompactBuffer[InternalRow]
- val includedBroadcastTuples = new BitSet(broadcastedRelation.value.size)
+ val includedBroadcastTuples = new BitSet(relation.length)
val joinedRow = new JoinedRow
val leftNulls = new GenericMutableRow(left.output.size)
@@ -92,8 +97,8 @@ case class BroadcastNestedLoopJoin(
var i = 0
var streamRowMatched = false
- while (i < broadcastedRelation.value.size) {
- val broadcastedRow = broadcastedRelation.value(i)
+ while (i < relation.length) {
+ val broadcastedRow = relation(i)
buildSide match {
case BuildRight if boundCondition(joinedRow(streamedRow, broadcastedRow)) =>
matchedRows += resultProj(joinedRow(streamedRow, broadcastedRow)).copy()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
index 0220e0b8a7..1cb6a00617 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashSemiJoin.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.metric.LongSQLMetric
@@ -44,22 +45,7 @@ trait HashSemiJoin {
protected def buildKeyHashSet(
buildIter: Iterator[InternalRow]): java.util.Set[InternalRow] = {
- val hashSet = new java.util.HashSet[InternalRow]()
-
- // Create a Hash set of buildKeys
- val rightKey = rightKeyGenerator
- while (buildIter.hasNext) {
- val currentRow = buildIter.next()
- val rowKey = rightKey(currentRow)
- if (!rowKey.anyNull) {
- val keyExists = hashSet.contains(rowKey)
- if (!keyExists) {
- hashSet.add(rowKey.copy())
- }
- }
- }
-
- hashSet
+ HashSemiJoin.buildKeyHashSet(rightKeys, right.output, buildIter)
}
protected def hashSemiJoin(
@@ -92,3 +78,36 @@ trait HashSemiJoin {
}
}
}
+
+private[execution] object HashSemiJoin {
+ def buildKeyHashSet(
+ keys: Seq[Expression],
+ attributes: Seq[Attribute],
+ rows: Iterator[InternalRow]): java.util.HashSet[InternalRow] = {
+ val hashSet = new java.util.HashSet[InternalRow]()
+
+ // Create a Hash set of buildKeys
+ val key = UnsafeProjection.create(keys, attributes)
+ while (rows.hasNext) {
+ val currentRow = rows.next()
+ val rowKey = key(currentRow)
+ if (!rowKey.anyNull) {
+ val keyExists = hashSet.contains(rowKey)
+ if (!keyExists) {
+ hashSet.add(rowKey.copy())
+ }
+ }
+ }
+ hashSet
+ }
+}
+
+/** HashSetBroadcastMode requires that the input rows are broadcasted as a set. */
+private[execution] case class HashSetBroadcastMode(
+ keys: Seq[Expression],
+ attributes: Seq[Attribute]) extends BroadcastMode {
+
+ override def transform(rows: Array[InternalRow]): java.util.HashSet[InternalRow] = {
+ HashSemiJoin.buildKeyHashSet(keys, attributes, rows.iterator)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0978570d42..606269bf25 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -25,12 +25,11 @@ import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.memory.{StaticMemoryManager, TaskMemoryManager}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.execution.SparkSqlSerializer
+import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode
+import org.apache.spark.sql.execution.{SparkPlan, SparkSqlSerializer}
import org.apache.spark.sql.execution.local.LocalNode
-import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
-import org.apache.spark.unsafe.memory.MemoryLocation
import org.apache.spark.util.{KnownSizeEstimation, SizeEstimator, Utils}
import org.apache.spark.util.collection.CompactBuffer
@@ -675,3 +674,20 @@ private[joins] object LongHashedRelation {
}
}
}
+
+/** The HashedRelationBroadcastMode requires that rows are broadcasted as a HashedRelation. */
+private[execution] case class HashedRelationBroadcastMode(
+ canJoinKeyFitWithinLong: Boolean,
+ keys: Seq[Expression],
+ attributes: Seq[Attribute]) extends BroadcastMode {
+
+ def transform(rows: Array[InternalRow]): HashedRelation = {
+ val generator = UnsafeProjection.create(keys, attributes)
+ if (canJoinKeyFitWithinLong) {
+ LongHashedRelation(rows.iterator, generator, rows.length)
+ } else {
+ HashedRelation(rows.iterator, generator, rows.length)
+ }
+ }
+}
+
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
index ce758d63b3..df6dac8818 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinBNL.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.joins
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.Partitioning
-import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
+import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.metric.SQLMetrics
/**
@@ -29,9 +29,7 @@ import org.apache.spark.sql.execution.metric.SQLMetrics
* for hash join.
*/
case class LeftSemiJoinBNL(
- streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
- extends BinaryNode {
- // TODO: Override requiredChildDistribution.
+ streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) extends BinaryNode {
override private[sql] lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
@@ -46,27 +44,28 @@ case class LeftSemiJoinBNL(
/** The Broadcast relation */
override def right: SparkPlan = broadcast
+ override def requiredChildDistribution: Seq[Distribution] = {
+ UnspecifiedDistribution :: BroadcastDistribution(IdentityBroadcastMode) :: Nil
+ }
+
@transient private lazy val boundCondition =
newPredicate(condition.getOrElse(Literal(true)), left.output ++ right.output)
protected override def doExecute(): RDD[InternalRow] = {
val numOutputRows = longMetric("numOutputRows")
- val broadcastedRelation =
- sparkContext.broadcast(broadcast.execute().map { row =>
- row.copy()
- }.collect().toIndexedSeq)
+ val broadcastedRelation = broadcast.executeBroadcast[Array[InternalRow]]()
streamed.execute().mapPartitions { streamedIter =>
val joinedRow = new JoinedRow
+ val relation = broadcastedRelation.value
streamedIter.filter(streamedRow => {
var i = 0
var matched = false
- while (i < broadcastedRelation.value.size && !matched) {
- val broadcastedRow = broadcastedRelation.value(i)
- if (boundCondition(joinedRow(streamedRow, broadcastedRow))) {
+ while (i < relation.length && !matched) {
+ if (boundCondition(joinedRow(streamedRow, relation(i)))) {
matched = true
}
i += 1
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index ef76847bcb..cd543d4195 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.LazilyGeneratedOrdering
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
/**
@@ -38,7 +39,8 @@ case class CollectLimit(limit: Int, child: SparkPlan) extends UnaryNode {
private val serializer: Serializer = new UnsafeRowSerializer(child.output.size)
protected override def doExecute(): RDD[InternalRow] = {
val shuffled = new ShuffledRowRDD(
- Exchange.prepareShuffleDependency(child.execute(), child.output, SinglePartition, serializer))
+ ShuffleExchange.prepareShuffleDependency(
+ child.execute(), child.output, SinglePartition, serializer))
shuffled.mapPartitionsInternal(_.take(limit))
}
}
@@ -110,7 +112,8 @@ case class TakeOrderedAndProject(
}
}
val shuffled = new ShuffledRowRDD(
- Exchange.prepareShuffleDependency(localTopK, child.output, SinglePartition, serializer))
+ ShuffleExchange.prepareShuffleDependency(
+ localTopK, child.output, SinglePartition, serializer))
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
if (projectList.isDefined) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index e8d0678989..83d7953aaf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -23,9 +23,9 @@ import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
-import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.execution.columnar._
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
@@ -357,7 +357,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext
* Verifies that the plan for `df` contains `expected` number of Exchange operators.
*/
private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
- assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected)
+ assert(df.queryExecution.executedPlan.collect { case e: ShuffleExchange => e }.size == expected)
}
test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 99ba2e2061..50a246489e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -26,8 +26,8 @@ import org.scalatest.Matchers._
import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Union}
-import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, SharedSQLContext}
import org.apache.spark.sql.test.SQLTestData.TestData2
@@ -1119,7 +1119,7 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
}
atFirstAgg = true
}
- case e: Exchange => atFirstAgg = false
+ case e: ShuffleExchange => atFirstAgg = false
case _ =>
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
index 35ff1c40fe..b1c588a63d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala
@@ -21,6 +21,7 @@ import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{MapOutputStatistics, SparkConf, SparkContext, SparkFunSuite}
import org.apache.spark.sql._
+import org.apache.spark.sql.execution.exchange.{ExchangeCoordinator, ShuffleExchange}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.TestSQLContext
@@ -297,13 +298,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = agg.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 1)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
@@ -311,7 +312,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
case None =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 2)
case o =>
@@ -348,13 +349,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = join.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 2)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
@@ -362,7 +363,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
case None =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 2)
case o =>
@@ -404,13 +405,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = join.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 4)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
@@ -456,13 +457,13 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll {
// Then, let's look at the number of post-shuffle partitions estimated
// by the ExchangeCoordinator.
val exchanges = join.queryExecution.executedPlan.collect {
- case e: Exchange => e
+ case e: ShuffleExchange => e
}
assert(exchanges.length === 3)
minNumPostShufflePartitions match {
case Some(numPartitions) =>
exchanges.foreach {
- case e: Exchange =>
+ case e: ShuffleExchange =>
assert(e.coordinator.isDefined)
assert(e.outputPartitioning.numPartitions === 3)
case o =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 87bff3295f..d4f22de90c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.test.SharedSQLContext
class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
@@ -28,7 +29,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(
input.toDF(),
- plan => Exchange(SinglePartition, plan),
+ plan => ShuffleExchange(SinglePartition, plan),
input.map(Row.fromTuple)
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 250ce8f866..4de56783fa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal,
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Repartition}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.columnar.{InMemoryColumnarTableScan, InMemoryRelation}
+import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ShuffleExchange}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SharedSQLContext
@@ -212,7 +213,7 @@ class PlannerSuite extends SharedSQLContext {
| JOIN tiny ON (small.key = tiny.key)
""".stripMargin
).queryExecution.executedPlan.collect {
- case exchange: Exchange => exchange
+ case exchange: ShuffleExchange => exchange
}.length
assert(numExchanges === 5)
}
@@ -227,7 +228,7 @@ class PlannerSuite extends SharedSQLContext {
| JOIN tiny ON (normal.key = tiny.key)
""".stripMargin
).queryExecution.executedPlan.collect {
- case exchange: Exchange => exchange
+ case exchange: ShuffleExchange => exchange
}.length
assert(numExchanges === 5)
}
@@ -295,7 +296,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
}
}
@@ -333,7 +334,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.isEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.isEmpty) {
fail(s"Exchange should have been added:\n$outputPlan")
}
}
@@ -353,7 +354,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"Exchange should not have been added:\n$outputPlan")
}
}
@@ -376,7 +377,7 @@ class PlannerSuite extends SharedSQLContext {
)
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.nonEmpty) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.nonEmpty) {
fail(s"No Exchanges should have been added:\n$outputPlan")
}
}
@@ -435,7 +436,7 @@ class PlannerSuite extends SharedSQLContext {
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 5)
val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
assert(!childPartitioning.satisfies(distribution))
- val inputPlan = Exchange(finalPartitioning,
+ val inputPlan = ShuffleExchange(finalPartitioning,
DummySparkPlan(
children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
requiredChildDistribution = Seq(distribution),
@@ -444,7 +445,7 @@ class PlannerSuite extends SharedSQLContext {
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.size == 2) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.size == 2) {
fail(s"Topmost Exchange should have been eliminated:\n$outputPlan")
}
}
@@ -455,7 +456,7 @@ class PlannerSuite extends SharedSQLContext {
val finalPartitioning = HashPartitioning(Literal(1) :: Nil, 8)
val childPartitioning = HashPartitioning(Literal(2) :: Nil, 5)
assert(!childPartitioning.satisfies(distribution))
- val inputPlan = Exchange(finalPartitioning,
+ val inputPlan = ShuffleExchange(finalPartitioning,
DummySparkPlan(
children = DummySparkPlan(outputPartitioning = childPartitioning) :: Nil,
requiredChildDistribution = Seq(distribution),
@@ -464,7 +465,7 @@ class PlannerSuite extends SharedSQLContext {
val outputPlan = EnsureRequirements(sqlContext).apply(inputPlan)
assertDistributionRequirementsAreSatisfied(outputPlan)
- if (outputPlan.collect { case e: Exchange => true }.size == 1) {
+ if (outputPlan.collect { case e: ShuffleExchange => true }.size == 1) {
fail(s"Topmost Exchange should not have been eliminated:\n$outputPlan")
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
index e25b5e0610..a256ee95a1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala
@@ -22,7 +22,8 @@ import scala.reflect.ClassTag
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext}
-import org.apache.spark.sql.{QueryTest, SQLConf, SQLContext}
+import org.apache.spark.sql.{QueryTest, SQLContext}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.functions._
/**
@@ -62,7 +63,7 @@ class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll {
// Comparison at the end is for broadcast left semi join
val joinExpression = df1("key") === df2("key") && df1("value") > df2("value")
val df3 = df1.join(broadcast(df2), joinExpression, joinType)
- val plan = df3.queryExecution.sparkPlan
+ val plan = EnsureRequirements(sqlContext).apply(df3.queryExecution.sparkPlan)
assert(plan.collect { case p: T => p }.size === 1)
plan.executeCollect()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index e22a810a6b..6dfff3770b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
@@ -88,7 +89,15 @@ class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
leftPlan: SparkPlan,
rightPlan: SparkPlan,
side: BuildSide) = {
- joins.BroadcastHashJoin(leftKeys, rightKeys, Inner, side, boundCondition, leftPlan, rightPlan)
+ val broadcastJoin = joins.BroadcastHashJoin(
+ leftKeys,
+ rightKeys,
+ Inner,
+ side,
+ boundCondition,
+ leftPlan,
+ rightPlan)
+ EnsureRequirements(sqlContext).apply(broadcastJoin)
}
def makeSortMergeJoin(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index f4b01fbad0..cd6b6fcbb1 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 9c86084f9b..f3ad8409e5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.execution.exchange.EnsureRequirements
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
index 9ba645626f..a05a57c0f5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedReadSuite.scala
@@ -22,8 +22,9 @@ import java.io.File
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
-import org.apache.spark.sql.execution.{Exchange, PhysicalRDD}
+import org.apache.spark.sql.execution.PhysicalRDD
import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSourceStrategy}
+import org.apache.spark.sql.execution.exchange.ShuffleExchange
import org.apache.spark.sql.execution.joins.SortMergeJoin
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
@@ -252,8 +253,8 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
assert(joined.queryExecution.executedPlan.isInstanceOf[SortMergeJoin])
val joinOperator = joined.queryExecution.executedPlan.asInstanceOf[SortMergeJoin]
- assert(joinOperator.left.find(_.isInstanceOf[Exchange]).isDefined == shuffleLeft)
- assert(joinOperator.right.find(_.isInstanceOf[Exchange]).isDefined == shuffleRight)
+ assert(joinOperator.left.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleLeft)
+ assert(joinOperator.right.find(_.isInstanceOf[ShuffleExchange]).isDefined == shuffleRight)
}
}
}
@@ -312,7 +313,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
agged.sort("i", "j"),
df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
- assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
}
}
@@ -326,7 +327,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
agged.sort("i", "j"),
df1.groupBy("i", "j").agg(max("k")).sort("i", "j"))
- assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isEmpty)
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isEmpty)
}
}
@@ -339,7 +340,7 @@ class BucketedReadSuite extends QueryTest with SQLTestUtils with TestHiveSinglet
val agged = hiveContext.table("bucketed_table").groupBy("i").count()
// make sure we fall back to non-bucketing mode and can't avoid shuffle
- assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[Exchange]).isDefined)
+ assert(agged.queryExecution.executedPlan.find(_.isInstanceOf[ShuffleExchange]).isDefined)
checkAnswer(agged.sort("i"), df1.groupBy("i").count().sort("i"))
}
}