aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-11-06 11:13:51 -0800
committerYin Huai <yhuai@databricks.com>2015-11-06 11:13:51 -0800
commit8211aab0793cf64202b99be4f31bb8a9ae77050d (patch)
tree52472ec5354ce3ede17f0060fac3750f9f7cacf0 /sql
parentc048929c6a9f7ce57f384037cd6c0bf5751c447a (diff)
downloadspark-8211aab0793cf64202b99be4f31bb8a9ae77050d.tar.gz
spark-8211aab0793cf64202b99be4f31bb8a9ae77050d.tar.bz2
spark-8211aab0793cf64202b99be4f31bb8a9ae77050d.zip
[SPARK-9858][SQL] Add an ExchangeCoordinator to estimate the number of post-shuffle partitions for aggregates and joins (follow-up)
https://issues.apache.org/jira/browse/SPARK-9858 This PR is the follow-up work of https://github.com/apache/spark/pull/9276. It addresses JoshRosen's comments. Author: Yin Huai <yhuai@databricks.com> Closes #9453 from yhuai/numReducer-followUp.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala150
4 files changed, 167 insertions, 62 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 9312c8123e..86b9417477 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -165,11 +165,6 @@ sealed trait Partitioning {
* produced by `A` could have also been produced by `B`.
*/
def guarantees(other: Partitioning): Boolean = this == other
-
- def withNumPartitions(newNumPartitions: Int): Partitioning = {
- throw new IllegalStateException(
- s"It is not allowed to call withNumPartitions method of a ${this.getClass.getSimpleName}")
- }
}
object Partitioning {
@@ -254,9 +249,6 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
- override def withNumPartitions(newNumPartitions: Int): HashPartitioning = {
- HashPartitioning(expressions, newNumPartitions)
- }
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 0f72ec6cc1..a4ce328c1a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -242,7 +242,7 @@ case class Exchange(
// update the number of post-shuffle partitions.
specifiedPartitionStartIndices.foreach { indices =>
assert(newPartitioning.isInstanceOf[HashPartitioning])
- newPartitioning = newPartitioning.withNumPartitions(indices.length)
+ newPartitioning = UnknownPartitioning(indices.length)
}
new ShuffledRowRDD(shuffleDependency, specifiedPartitionStartIndices)
}
@@ -262,7 +262,7 @@ case class Exchange(
object Exchange {
def apply(newPartitioning: Partitioning, child: SparkPlan): Exchange = {
- Exchange(newPartitioning, child, None: Option[ExchangeCoordinator])
+ Exchange(newPartitioning, child, coordinator = None: Option[ExchangeCoordinator])
}
}
@@ -315,7 +315,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
child.outputPartitioning match {
case hash: HashPartitioning => true
case collection: PartitioningCollection =>
- collection.partitionings.exists(_.isInstanceOf[HashPartitioning])
+ collection.partitionings.forall(_.isInstanceOf[HashPartitioning])
case _ => false
}
}
@@ -416,28 +416,48 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
// First check if the existing partitions of the children all match. This means they are
// partitioned by the same partitioning into the same number of partitions. In that case,
// don't try to make them match `defaultPartitions`, just use the existing partitioning.
- // TODO: this should be a cost based decision. For example, a big relation should probably
- // maintain its existing number of partitions and smaller partitions should be shuffled.
- // defaultPartitions is arbitrary.
- val numPartitions = children.head.outputPartitioning.numPartitions
+ val maxChildrenNumPartitions = children.map(_.outputPartitioning.numPartitions).max
val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
case (child, distribution) => {
child.outputPartitioning.guarantees(
- createPartitioning(distribution, numPartitions))
+ createPartitioning(distribution, maxChildrenNumPartitions))
}
}
children = if (useExistingPartitioning) {
+ // We do not need to shuffle any child's output.
children
} else {
+ // We need to shuffle at least one child's output.
+ // Now, we will determine the number of partitions that will be used by created
+ // partitioning schemes.
+ val numPartitions = {
+ // Let's see if we need to shuffle all child's outputs when we use
+ // maxChildrenNumPartitions.
+ val shufflesAllChildren = children.zip(requiredChildDistributions).forall {
+ case (child, distribution) => {
+ !child.outputPartitioning.guarantees(
+ createPartitioning(distribution, maxChildrenNumPartitions))
+ }
+ }
+ // If we need to shuffle all children, we use defaultNumPreShufflePartitions as the
+ // number of partitions. Otherwise, we use maxChildrenNumPartitions.
+ if (shufflesAllChildren) defaultNumPreShufflePartitions else maxChildrenNumPartitions
+ }
+
children.zip(requiredChildDistributions).map {
case (child, distribution) => {
val targetPartitioning =
- createPartitioning(distribution, defaultNumPreShufflePartitions)
+ createPartitioning(distribution, numPartitions)
if (child.outputPartitioning.guarantees(targetPartitioning)) {
child
} else {
- Exchange(targetPartitioning, child)
+ child match {
+ // If child is an exchange, we replace it with
+ // a new one having targetPartitioning.
+ case Exchange(_, c, _) => Exchange(targetPartitioning, c)
+ case _ => Exchange(targetPartitioning, child)
+ }
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
index 8dbd69e1f4..827fdd2784 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExchangeCoordinator.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution
import java.util.{Map => JMap, HashMap => JHashMap}
+import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ArrayBuffer
@@ -97,6 +98,7 @@ private[sql] class ExchangeCoordinator(
* Registers an [[Exchange]] operator to this coordinator. This method is only allowed to be
* called in the `doPrepare` method of an [[Exchange]] operator.
*/
+ @GuardedBy("this")
def registerExchange(exchange: Exchange): Unit = synchronized {
exchanges += exchange
}
@@ -109,7 +111,7 @@ private[sql] class ExchangeCoordinator(
*/
private[sql] def estimatePartitionStartIndices(
mapOutputStatistics: Array[MapOutputStatistics]): Array[Int] = {
- // If we have mapOutputStatistics.length <= numExchange, it is because we do not submit
+ // If we have mapOutputStatistics.length < numExchange, it is because we do not submit
// a stage when the number of partitions of this dependency is 0.
assert(mapOutputStatistics.length <= numExchanges)
@@ -121,6 +123,8 @@ private[sql] class ExchangeCoordinator(
val totalPostShuffleInputSize = mapOutputStatistics.map(_.bytesByPartitionId.sum).sum
// The max at here is to make sure that when we have an empty table, we
// only have a single post-shuffle partition.
+ // There is no particular reason that we pick 16. We just need a number to
+ // prevent maxPostShuffleInputSize from being set to 0.
val maxPostShuffleInputSize =
math.max(math.ceil(totalPostShuffleInputSize / numPartitions.toDouble).toLong, 16)
math.min(maxPostShuffleInputSize, advisoryTargetPostShuffleInputSize)
@@ -135,6 +139,12 @@ private[sql] class ExchangeCoordinator(
// Make sure we do get the same number of pre-shuffle partitions for those stages.
val distinctNumPreShufflePartitions =
mapOutputStatistics.map(stats => stats.bytesByPartitionId.length).distinct
+ // The reason that we are expecting a single value of the number of pre-shuffle partitions
+ // is that when we add Exchanges, we set the number of pre-shuffle partitions
+ // (i.e. map output partitions) using a static setting, which is the value of
+ // spark.sql.shuffle.partitions. Even if two input RDDs are having different
+ // number of partitions, they will have the same number of pre-shuffle partitions
+ // (i.e. map output partitions).
assert(
distinctNumPreShufflePartitions.length == 1,
"There should be only one distinct value of the number pre-shuffle partitions " +
@@ -177,6 +187,7 @@ private[sql] class ExchangeCoordinator(
partitionStartIndices.toArray
}
+ @GuardedBy("this")
private def doEstimationIfNecessary(): Unit = synchronized {
// It is unlikely that this method will be called from multiple threads
// (when multiple threads trigger the execution of THIS physical)
@@ -209,11 +220,11 @@ private[sql] class ExchangeCoordinator(
// Wait for the finishes of those submitted map stages.
val mapOutputStatistics = new Array[MapOutputStatistics](submittedStageFutures.length)
- i = 0
- while (i < submittedStageFutures.length) {
+ var j = 0
+ while (j < submittedStageFutures.length) {
// This call is a blocking call. If the stage has not finished, we will wait at here.
- mapOutputStatistics(i) = submittedStageFutures(i).get()
- i += 1
+ mapOutputStatistics(j) = submittedStageFutures(j).get()
+ j += 1
}
// Now, we estimate partitionStartIndices. partitionStartIndices.length will be the
@@ -225,14 +236,14 @@ private[sql] class ExchangeCoordinator(
Some(estimatePartitionStartIndices(mapOutputStatistics))
}
- i = 0
- while (i < numExchanges) {
- val exchange = exchanges(i)
+ var k = 0
+ while (k < numExchanges) {
+ val exchange = exchanges(k)
val rdd =
- exchange.preparePostShuffleRDD(shuffleDependencies(i), partitionStartIndices)
+ exchange.preparePostShuffleRDD(shuffleDependencies(k), partitionStartIndices)
newPostShuffleRDDs.put(exchange, rdd)
- i += 1
+ k += 1
}
// Finally, we set postShuffleRDDs and estimated.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index dbcb011f60..bce94dafad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -29,12 +29,12 @@ import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.{SQLTestUtils, SharedSQLContext}
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
private case class BigData(s: String)
-class CachedTableSuite extends QueryTest with SharedSQLContext {
+class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext {
import testImplicits._
def rddIdOf(tableName: String): Int = {
@@ -375,53 +375,135 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect())
sqlContext.uncacheTable("orderedTable")
+ sqlContext.dropTempTable("orderedTable")
// Set up two tables distributed in the same way. Try this with the data distributed into
// different number of partitions.
for (numPartitions <- 1 until 10 by 4) {
- testData.repartition(numPartitions, $"key").registerTempTable("t1")
- testData2.repartition(numPartitions, $"a").registerTempTable("t2")
+ withTempTable("t1", "t2") {
+ testData.repartition(numPartitions, $"key").registerTempTable("t1")
+ testData2.repartition(numPartitions, $"a").registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
+
+ // Joining them should result in no exchanges.
+ verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
+ checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
+ sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
+
+ // Grouping on the partition key should result in no exchanges
+ verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
+ checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
+ sql("SELECT count(*) FROM testData GROUP BY key"))
+
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ }
+ }
+
+ // Distribute the tables into non-matching number of partitions. Need to shuffle one side.
+ withTempTable("t1", "t2") {
+ testData.repartition(6, $"key").registerTempTable("t1")
+ testData2.repartition(3, $"a").registerTempTable("t2")
sqlContext.cacheTable("t1")
sqlContext.cacheTable("t2")
- // Joining them should result in no exchanges.
- verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
- checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
- sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
+ val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
+ verifyNumExchanges(query, 1)
+ assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
+ checkAnswer(
+ query,
+ testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ }
- // Grouping on the partition key should result in no exchanges
- verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
- checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
- sql("SELECT count(*) FROM testData GROUP BY key"))
+ // One side of join is not partitioned in the desired way. Need to shuffle one side.
+ withTempTable("t1", "t2") {
+ testData.repartition(6, $"value").registerTempTable("t1")
+ testData2.repartition(6, $"a").registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
+ val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
+ verifyNumExchanges(query, 1)
+ assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
+ checkAnswer(
+ query,
+ testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
sqlContext.uncacheTable("t1")
sqlContext.uncacheTable("t2")
- sqlContext.dropTempTable("t1")
- sqlContext.dropTempTable("t2")
}
- // Distribute the tables into non-matching number of partitions. Need to shuffle.
- testData.repartition(6, $"key").registerTempTable("t1")
- testData2.repartition(3, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ withTempTable("t1", "t2") {
+ testData.repartition(6, $"value").registerTempTable("t1")
+ testData2.repartition(12, $"a").registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
- verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
- sqlContext.dropTempTable("t1")
- sqlContext.dropTempTable("t2")
+ val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
+ verifyNumExchanges(query, 1)
+ assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 12)
+ checkAnswer(
+ query,
+ testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ }
- // One side of join is not partitioned in the desired way. Need to shuffle.
- testData.repartition(6, $"value").registerTempTable("t1")
- testData2.repartition(6, $"a").registerTempTable("t2")
- sqlContext.cacheTable("t1")
- sqlContext.cacheTable("t2")
+ // One side of join is not partitioned in the desired way. Since the number of partitions of
+ // the side that has already partitioned is smaller than the side that is not partitioned,
+ // we shuffle both side.
+ withTempTable("t1", "t2") {
+ testData.repartition(6, $"value").registerTempTable("t1")
+ testData2.repartition(3, $"a").registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
- verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
- sqlContext.uncacheTable("t1")
- sqlContext.uncacheTable("t2")
- sqlContext.dropTempTable("t1")
- sqlContext.dropTempTable("t2")
+ val query = sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a")
+ verifyNumExchanges(query, 2)
+ checkAnswer(
+ query,
+ testData.join(testData2, $"key" === $"a").select($"key", $"value", $"a", $"b"))
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ }
+
+ // repartition's column ordering is different from group by column ordering.
+ // But they use the same set of columns.
+ withTempTable("t1") {
+ testData.repartition(6, $"value", $"key").registerTempTable("t1")
+ sqlContext.cacheTable("t1")
+
+ val query = sql("SELECT value, key from t1 group by key, value")
+ verifyNumExchanges(query, 0)
+ checkAnswer(
+ query,
+ testData.distinct().select($"value", $"key"))
+ sqlContext.uncacheTable("t1")
+ }
+
+ // repartition's column ordering is different from join condition's column ordering.
+ // We will still shuffle because hashcodes of a row depend on the column ordering.
+ // If we do not shuffle, we may actually partition two tables in totally two different way.
+ // See PartitioningSuite for more details.
+ withTempTable("t1", "t2") {
+ val df1 = testData
+ df1.repartition(6, $"value", $"key").registerTempTable("t1")
+ val df2 = testData2.select($"a", $"b".cast("string"))
+ df2.repartition(6, $"a", $"b").registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
+
+ val query =
+ sql("SELECT key, value, a, b FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a and t1.value = t2.b")
+ verifyNumExchanges(query, 1)
+ assert(query.queryExecution.executedPlan.outputPartitioning.numPartitions === 6)
+ checkAnswer(
+ query,
+ df1.join(df2, $"key" === $"a" && $"value" === $"b").select($"key", $"value", $"a", $"b"))
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ }
}
}