aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorNong Li <nongli@gmail.com>2015-11-02 19:18:45 -0800
committerYin Huai <yhuai@databricks.com>2015-11-02 19:18:45 -0800
commit2cef1bb0b560a03aa7308f694b0c66347b90c9ea (patch)
tree32c9ecc3aeedc9cf3c02e6be920c768e861080e3 /sql
parent21ad846238a9a79564e2e99a1def89fd31a0870d (diff)
downloadspark-2cef1bb0b560a03aa7308f694b0c66347b90c9ea.tar.gz
spark-2cef1bb0b560a03aa7308f694b0c66347b90c9ea.tar.bz2
spark-2cef1bb0b560a03aa7308f694b0c66347b90c9ea.zip
[SPARK-5354][SQL] Cached tables should preserve partitioning and ord…
…ering. For cached tables, we can just maintain the partitioning and ordering from the source relation. Author: Nong Li <nongli@gmail.com> Closes #9404 from nongli/spark-5354.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala59
3 files changed, 97 insertions, 9 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
index b4607b12fc..7eb1ad7cd8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/columnar/InMemoryColumnarTableScan.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{ConvertToUnsafe, LeafNode, SparkPlan}
import org.apache.spark.sql.types.UserDefinedType
import org.apache.spark.storage.StorageLevel
@@ -209,6 +210,12 @@ private[sql] case class InMemoryColumnarTableScan(
override def output: Seq[Attribute] = attributes
+ // The cached version does not change the outputPartitioning of the original SparkPlan.
+ override def outputPartitioning: Partitioning = relation.child.outputPartitioning
+
+ // The cached version does not change the outputOrdering of the original SparkPlan.
+ override def outputOrdering: Seq[SortOrder] = relation.child.outputOrdering
+
override def outputsUnsafeRows: Boolean = true
private def statsFor(a: Attribute) = relation.partitionStatistics.forAttribute(a)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 7f60c8f5ea..e81108b788 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -194,12 +194,13 @@ case class Exchange(newPartitioning: Partitioning, child: SparkPlan) extends Una
*/
private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[SparkPlan] {
// TODO: Determine the number of partitions.
- private def numPartitions: Int = sqlContext.conf.numShufflePartitions
+ private def defaultPartitions: Int = sqlContext.conf.numShufflePartitions
/**
* Given a required distribution, returns a partitioning that satisfies that distribution.
*/
- private def canonicalPartitioning(requiredDistribution: Distribution): Partitioning = {
+ private def createPartitioning(requiredDistribution: Distribution,
+ numPartitions: Int): Partitioning = {
requiredDistribution match {
case AllTuples => SinglePartition
case ClusteredDistribution(clustering) => HashPartitioning(clustering, numPartitions)
@@ -220,7 +221,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
if (child.outputPartitioning.satisfies(distribution)) {
child
} else {
- Exchange(canonicalPartitioning(distribution), child)
+ Exchange(createPartitioning(distribution, defaultPartitions), child)
}
}
@@ -229,12 +230,33 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[
if (children.length > 1
&& requiredChildDistributions.toSet != Set(UnspecifiedDistribution)
&& !Partitioning.allCompatible(children.map(_.outputPartitioning))) {
- children = children.zip(requiredChildDistributions).map { case (child, distribution) =>
- val targetPartitioning = canonicalPartitioning(distribution)
- if (child.outputPartitioning.guarantees(targetPartitioning)) {
- child
- } else {
- Exchange(targetPartitioning, child)
+
+ // First check if the existing partitions of the children all match. This means they are
+ // partitioned by the same partitioning into the same number of partitions. In that case,
+ // don't try to make them match `defaultPartitions`, just use the existing partitioning.
+ // TODO: this should be a cost based descision. For example, a big relation should probably
+ // maintain its existing number of partitions and smaller partitions should be shuffled.
+ // defaultPartitions is arbitrary.
+ val numPartitions = children.head.outputPartitioning.numPartitions
+ val useExistingPartitioning = children.zip(requiredChildDistributions).forall {
+ case (child, distribution) => {
+ child.outputPartitioning.guarantees(
+ createPartitioning(distribution, numPartitions))
+ }
+ }
+
+ children = if (useExistingPartitioning) {
+ children
+ } else {
+ children.zip(requiredChildDistributions).map {
+ case (child, distribution) => {
+ val targetPartitioning = createPartitioning(distribution, defaultPartitions)
+ if (child.outputPartitioning.guarantees(targetPartitioning)) {
+ child
+ } else {
+ Exchange(targetPartitioning, child)
+ }
+ }
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index fd566c8276..605954b105 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException
+import org.apache.spark.sql.execution.Exchange
import org.apache.spark.sql.execution.PhysicalRDD
import scala.concurrent.duration._
@@ -353,4 +354,62 @@ class CachedTableSuite extends QueryTest with SharedSQLContext {
assert(sparkPlan.collect { case e: InMemoryColumnarTableScan => e }.size === 3)
assert(sparkPlan.collect { case e: PhysicalRDD => e }.size === 0)
}
+
+ /**
+ * Verifies that the plan for `df` contains `expected` number of Exchange operators.
+ */
+ private def verifyNumExchanges(df: DataFrame, expected: Int): Unit = {
+ assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.size == expected)
+ }
+
+ test("A cached table preserves the partitioning and ordering of its cached SparkPlan") {
+ val table3x = testData.unionAll(testData).unionAll(testData)
+ table3x.registerTempTable("testData3x")
+
+ sql("SELECT key, value FROM testData3x ORDER BY key").registerTempTable("orderedTable")
+ sqlContext.cacheTable("orderedTable")
+ assertCached(sqlContext.table("orderedTable"))
+ // Should not have an exchange as the query is already sorted on the group by key.
+ verifyNumExchanges(sql("SELECT key, count(*) FROM orderedTable GROUP BY key"), 0)
+ checkAnswer(
+ sql("SELECT key, count(*) FROM orderedTable GROUP BY key ORDER BY key"),
+ sql("SELECT key, count(*) FROM testData3x GROUP BY key ORDER BY key").collect())
+ sqlContext.uncacheTable("orderedTable")
+
+ // Set up two tables distributed in the same way. Try this with the data distributed into
+ // different number of partitions.
+ for (numPartitions <- 1 until 10 by 4) {
+ testData.distributeBy(Column("key") :: Nil, numPartitions).registerTempTable("t1")
+ testData2.distributeBy(Column("a") :: Nil, numPartitions).registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
+
+ // Joining them should result in no exchanges.
+ verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 0)
+ checkAnswer(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"),
+ sql("SELECT * FROM testData t1 JOIN testData2 t2 ON t1.key = t2.a"))
+
+ // Grouping on the partition key should result in no exchanges
+ verifyNumExchanges(sql("SELECT count(*) FROM t1 GROUP BY key"), 0)
+ checkAnswer(sql("SELECT count(*) FROM t1 GROUP BY key"),
+ sql("SELECT count(*) FROM testData GROUP BY key"))
+
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ sqlContext.dropTempTable("t1")
+ sqlContext.dropTempTable("t2")
+ }
+
+ // Distribute the tables into non-matching number of partitions. Need to shuffle.
+ testData.distributeBy(Column("key") :: Nil, 6).registerTempTable("t1")
+ testData2.distributeBy(Column("a") :: Nil, 3).registerTempTable("t2")
+ sqlContext.cacheTable("t1")
+ sqlContext.cacheTable("t2")
+
+ verifyNumExchanges(sql("SELECT * FROM t1 t1 JOIN t2 t2 ON t1.key = t2.a"), 2)
+ sqlContext.uncacheTable("t1")
+ sqlContext.uncacheTable("t2")
+ sqlContext.dropTempTable("t1")
+ sqlContext.dropTempTable("t2")
+ }
}