diff options
Diffstat (limited to 'sql')
10 files changed, 148 insertions, 31 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f4d1dbaf28..ec659ce789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for - * the ordering expressions are contiguous and will never be split across partitions. + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the + * same value for the ordering expressions are contiguous and will never be split across + * partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -86,8 +87,12 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** Returns the expressions that are used to key the partitioning. */ - def keyExpressions: Seq[Expression] + /** + * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] + * guarantees the same partitioning scheme described by `other`. + */ + // TODO: Add an example once we have the `nullSafe` concept. + def guarantees(other: Partitioning): Boolean } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -96,7 +101,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = false } case object SinglePartition extends Partitioning { @@ -104,7 +109,10 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = other match { + case SinglePartition => true + case _ => false + } } case object BroadcastPartitioning extends Partitioning { @@ -112,7 +120,10 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = other match { + case BroadcastPartitioning => true + case _ => false + } } /** @@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = expressions.toSet + lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def keyExpressions: Seq[Expression] = expressions + override def guarantees(other: Partitioning): Boolean = other match { + case o: HashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } } /** @@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def guarantees(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o + case _ => false + } +} + +/** + * A collection of [[Partitioning]]s that can be used to describe the partitioning + * scheme of the output of a physical operator. It is usually used for an operator + * that has multiple children. In this case, a [[Partitioning]] in this collection + * describes how this operator's output is partitioned based on expressions from + * a child. For example, for a Join operator on two tables `A` and `B` + * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema, + * there are two [[Partitioning]]s can be used to describe how the output of + * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and + * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings` + * in this collection do not need to be equivalent, which is useful for + * Outer Join operators. + */ +case class PartitioningCollection(partitionings: Seq[Partitioning]) + extends Expression with Partitioning with Unevaluable { + + require( + partitionings.map(_.numPartitions).distinct.length == 1, + s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + + override def children: Seq[Expression] = partitionings.collect { + case expr: Expression => expr + } + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override val numPartitions = partitionings.map(_.numPartitions).distinct.head + + /** + * Returns true if any `partitioning` of this collection satisfies the given + * [[Distribution]]. + */ + override def satisfies(required: Distribution): Boolean = + partitionings.exists(_.satisfies(required)) + + /** + * Returns true if any `partitioning` of this collection guarantees + * the given [[Partitioning]]. + */ + override def guarantees(other: Partitioning): Boolean = + partitionings.exists(_.guarantees(other)) + + override def toString: String = { + partitionings.map(_.toString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c046dbf4dc..827f7ce692 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning is the output partitioning") { + test("HashPartitioning (with nullSafe = true) is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala index 6bd57f010a..05b009d193 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala @@ -209,7 +209,7 @@ private[sql] case class EnsureRequirements(sqlContext: SQLContext) extends Rule[ child: SparkPlan): SparkPlan = { def addShuffleIfNecessary(child: SparkPlan): SparkPlan = { - if (child.outputPartitioning != partitioning) { + if (!child.outputPartitioning.guarantees(partitioning)) { Exchange(partitioning, child) } else { child diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala index 77e7fe7100..309716a0ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashOuterJoin.scala @@ -24,7 +24,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, UnspecifiedDistribution}
+import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, UnspecifiedDistribution}
import org.apache.spark.sql.catalyst.plans.{JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
import org.apache.spark.util.ThreadUtils
@@ -57,6 +57,8 @@ case class BroadcastHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] =
UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+ override def outputPartitioning: Partitioning = streamedPlan.outputPartitioning
+
@transient
private val broadcastFuture = future {
// Note that we use .execute().collect() because we don't want to convert data to Scala types
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index 7e671e7914..a323aea4ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -22,7 +22,6 @@ import java.util.{HashMap => JavaHashMap} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -38,14 +37,6 @@ trait HashOuterJoin { val left: SparkPlan val right: SparkPlan - override def outputPartitioning: Partitioning = joinType match { - case LeftOuter => left.outputPartitioning - case RightOuter => right.outputPartitioning - case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions) - case x => - throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType") - } - override def output: Seq[Attribute] = { joinType match { case LeftOuter => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala index 26a664104d..68ccd34d8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/LeftSemiJoinHash.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.physical.ClusteredDistribution +import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, Distribution, ClusteredDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -37,7 +37,9 @@ case class LeftSemiJoinHash( right: SparkPlan, condition: Option[Expression]) extends BinaryNode with HashSemiJoin { - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def outputPartitioning: Partitioning = left.outputPartitioning + + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 5439e10a60..fc6efe87bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} +import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} /** @@ -38,9 +38,10 @@ case class ShuffledHashJoin( right: SparkPlan) extends BinaryNode with HashJoin { - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) - override def requiredChildDistribution: Seq[ClusteredDistribution] = + override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil protected override def doExecute(): RDD[InternalRow] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala index d29b593207..eee8ad800f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashOuterJoin.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.physical.{Distribution, ClusteredDistribution}
+import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -44,6 +44,14 @@ case class ShuffledHashOuterJoin( override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+ override def outputPartitioning: Partitioning = joinType match {
+ case LeftOuter => left.outputPartitioning
+ case RightOuter => right.outputPartitioning
+ case FullOuter => UnknownPartitioning(left.outputPartitioning.numPartitions)
+ case x =>
+ throw new IllegalArgumentException(s"HashOuterJoin should not take $x as the JoinType")
+ }
+
protected override def doExecute(): RDD[InternalRow] = {
val joinedRow = new JoinedRow()
left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index bb18b5403f..41be78afd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -40,7 +40,8 @@ case class SortMergeJoin( override def output: Seq[Attribute] = left.output ++ right.output - override def outputPartitioning: Partitioning = left.outputPartitioning + override def outputPartitioning: Partitioning = + PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) override def requiredChildDistribution: Seq[Distribution] = ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 845ce669f0..18b0e54dc7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -23,14 +23,18 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin} import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.test.TestSQLContext.planner._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.{Row, SQLConf, execution} +import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution} -class PlannerSuite extends SparkFunSuite { +class PlannerSuite extends SparkFunSuite with SQLTestUtils { + + override def sqlContext: SQLContext = TestSQLContext + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption) val planned = @@ -157,4 +161,45 @@ class PlannerSuite extends SparkFunSuite { val planned = planner.TakeOrderedAndProject(query) assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject]) } + + test("PartitioningCollection") { + withTempTable("normal", "small", "tiny") { + testData.registerTempTable("normal") + testData.limit(10).registerTempTable("small") + testData.limit(3).registerTempTable("tiny") + + // Disable broadcast join + withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { + { + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (small.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + { + // This second query joins on different keys: + val numExchanges = sql( + """ + |SELECT * + |FROM + | normal JOIN small ON (normal.key = small.key) + | JOIN tiny ON (normal.key = tiny.key) + """.stripMargin + ).queryExecution.executedPlan.collect { + case exchange: Exchange => exchange + }.length + assert(numExchanges === 3) + } + + } + } + } } |