From 2cef1bb0b560a03aa7308f694b0c66347b90c9ea Mon Sep 17 00:00:00 2001 From: Nong Li Date: Mon, 2 Nov 2015 19:18:45 -0800 Subject: [SPARK-5354][SQL] Cached tables should preserve partitioning and ord… MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …ering. For cached tables, we can just maintain the partitioning and ordering from the source relation. Author: Nong Li Closes #9404 from nongli/spark-5354. --- .../sql/columnar/InMemoryColumnarTableScan.scala | 7 +++ .../org/apache/spark/sql/execution/Exchange.scala | 40 +++++++++++---- .../org/apache/spark/sql/CachedTableSuite.scala | 59 ++++++++++++++++++++++ 3 files changed, 97 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala index b4607b12fc..7eb1ad7cd8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} +import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan} import org.apache.spark.sql.types.UserDefinedType import org.apache.spark.storage.StorageLevel @@ -209,6 +210,12 @@ private[sql] case class InMemoryColumnarTableScan( override def output: Seq[Attribute] = attributes + // The cached version does not change the outputPartitioning of the original SparkPlan. + override def outputPartitioning: Partitioning = relation.child.outputPartitioning + + // The cached version does not change the outputOrdering of the original SparkPlan. + override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering + override def outputsUnsafeRows: Boolean = true private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 7f60c8f5ea..e81108b788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -194,12 +194,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una */ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] { // TODO: Determine the number of partitions. - private def numPartitions: Int = sqlContext.conf.numShufflePartitions + private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions /** * Given a required distribution, returns a partitioning that satisfies that distribution. */ - private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = { + private def createPartitioning(requiredDistribution: Distribution, + numPartitions: Int): Partitioning = { requiredDistribution match { case AllTuples => SinglePartition case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions) @@ -220,7 +221,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (child.outputPartitioning.satisfies(distribution)) { child } else { - Exchange(canonicalPartitioning(distribution), child) + Exchange(createPartitioning(distribution, defaultPartitions), child) } } @@ -229,12 +230,33 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ if (children.length > 1 && requiredChildDistributions.toSet != Set(UnspecifiedDistribution) && !Partitioning.allCompatible(children.map(_.outputPartitioning))) { - children = children.zip(requiredChildDistributions).map { case (child, distribution) => - val targetPartitioning = canonicalPartitioning(distribution) - if (child.outputPartitioning.guarantees(targetPartitioning)) { - child - } else { - Exchange(targetPartitioning, child) + + // First check if the existing partitions of the children all match. This means they are + // partitioned by the same partitioning into the same number of partitions. In that case, + // don't try to make them match `defaultPartitions`, just use the existing partitioning. + // TODO: this should be a cost based descision. For example, a big relation should probably + // maintain its existing number of partitions and smaller partitions should be shuffled. + // defaultPartitions is arbitrary. + val numPartitions = children.head.outputPartitioning.numPartitions + val useExistingPartitioning = children.zip(requiredChildDistributions).forall { + case (child, distribution) => { + child.outputPartitioning.guarantees( + createPartitioning(distribution, numPartitions)) + } + } + + children = if (useExistingPartitioning) { + children + } else { + children.zip(requiredChildDistributions).map { + case (child, distribution) => { + val targetPartitioning = createPartitioning(distribution, defaultPartitions) + if (child.outputPartitioning.guarantees(targetPartitioning)) { + child + } else { + Exchange(targetPartitioning, child) + } + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index fd566c8276..605954b105 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.execution.Exchange import org.apache.spark.sql.execution.PhysicalRDD import scala.concurrent.duration._ @@ -353,4 +354,62 @@ class CachedTableSuite extends QueryTest with SharedSQLContext { assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3) assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0) } + + /** + * Verifies that the plan for `df` contains `expected` number of Exchange operators. + */ + private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = { + assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected) + } + + test("A cached table preserves the partitioning and ordering of its cached SparkPlan") { + val table3x = testData.unionAll(testData).unionAll(testData) + table3x.registerTempTable("testData3x") + + sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable") + sqlContext.cacheTable("orderedTable") + assertCached(sqlContext.table("orderedTable")) + // Should not have an exchange as the query is already sorted on the group by key. + verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0) + checkAnswer( + sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"), + sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect()) + sqlContext.uncacheTable("orderedTable") + + // Set up two tables distributed in the same way. Try this with the data distributed into + // different number of partitions. + for (numPartitions <- 1 until 10 by 4) { + testData.distributeBy(Column("key") :: Nil, numPartitions).registerTempTable("t1") + testData2.distributeBy(Column("a") :: Nil, numPartitions).registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + // Joining them should result in no exchanges. + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0) + checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), + sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a")) + + // Grouping on the partition key should result in no exchanges + verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0) + checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"), + sql("SELECT count(*) FROM testData GROUP BY key")) + + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + sqlContext.dropTempTable("t1") + sqlContext.dropTempTable("t2") + } + + // Distribute the tables into non-matching number of partitions. Need to shuffle. + testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1") + testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2") + sqlContext.cacheTable("t1") + sqlContext.cacheTable("t2") + + verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2) + sqlContext.uncacheTable("t1") + sqlContext.uncacheTable("t2") + sqlContext.dropTempTable("t1") + sqlContext.dropTempTable("t2") + } } -- cgit v1.2.3