diff options
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala | 87 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala | 2 |
2 files changed, 78 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index f4d1dbaf28..ec659ce789 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi /** * Represents data where tuples have been ordered according to the `ordering` * [[Expression Expressions]]. This is a strictly stronger guarantee than - * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for - * the ordering expressions are contiguous and will never be split across partitions. + * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the + * same value for the ordering expressions are contiguous and will never be split across + * partitions. */ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { require( @@ -86,8 +87,12 @@ sealed trait Partitioning { */ def satisfies(required: Distribution): Boolean - /** Returns the expressions that are used to key the partitioning. */ - def keyExpressions: Seq[Expression] + /** + * Returns true iff we can say that the partitioning scheme of this [[Partitioning]] + * guarantees the same partitioning scheme described by `other`. + */ + // TODO: Add an example once we have the `nullSafe` concept. + def guarantees(other: Partitioning): Boolean } case class UnknownPartitioning(numPartitions: Int) extends Partitioning { @@ -96,7 +101,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning { case _ => false } - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = false } case object SinglePartition extends Partitioning { @@ -104,7 +109,10 @@ case object SinglePartition extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = other match { + case SinglePartition => true + case _ => false + } } case object BroadcastPartitioning extends Partitioning { @@ -112,7 +120,10 @@ case object BroadcastPartitioning extends Partitioning { override def satisfies(required: Distribution): Boolean = true - override def keyExpressions: Seq[Expression] = Nil + override def guarantees(other: Partitioning): Boolean = other match { + case BroadcastPartitioning => true + case _ => false + } } /** @@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) override def nullable: Boolean = false override def dataType: DataType = IntegerType - private[this] lazy val clusteringSet = expressions.toSet + lazy val clusteringSet = expressions.toSet override def satisfies(required: Distribution): Boolean = required match { case UnspecifiedDistribution => true @@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int) case _ => false } - override def keyExpressions: Seq[Expression] = expressions + override def guarantees(other: Partitioning): Boolean = other match { + case o: HashPartitioning => + this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions + case _ => false + } } /** @@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int) case _ => false } - override def keyExpressions: Seq[Expression] = ordering.map(_.child) + override def guarantees(other: Partitioning): Boolean = other match { + case o: RangePartitioning => this == o + case _ => false + } +} + +/** + * A collection of [[Partitioning]]s that can be used to describe the partitioning + * scheme of the output of a physical operator. It is usually used for an operator + * that has multiple children. In this case, a [[Partitioning]] in this collection + * describes how this operator's output is partitioned based on expressions from + * a child. For example, for a Join operator on two tables `A` and `B` + * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema, + * there are two [[Partitioning]]s can be used to describe how the output of + * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and + * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings` + * in this collection do not need to be equivalent, which is useful for + * Outer Join operators. + */ +case class PartitioningCollection(partitionings: Seq[Partitioning]) + extends Expression with Partitioning with Unevaluable { + + require( + partitionings.map(_.numPartitions).distinct.length == 1, + s"PartitioningCollection requires all of its partitionings have the same numPartitions.") + + override def children: Seq[Expression] = partitionings.collect { + case expr: Expression => expr + } + + override def nullable: Boolean = false + + override def dataType: DataType = IntegerType + + override val numPartitions = partitionings.map(_.numPartitions).distinct.head + + /** + * Returns true if any `partitioning` of this collection satisfies the given + * [[Distribution]]. + */ + override def satisfies(required: Distribution): Boolean = + partitionings.exists(_.satisfies(required)) + + /** + * Returns true if any `partitioning` of this collection guarantees + * the given [[Partitioning]]. + */ + override def guarantees(other: Partitioning): Boolean = + partitionings.exists(_.guarantees(other)) + + override def toString: String = { + partitionings.map(_.toString).mkString("(", " or ", ")") + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala index c046dbf4dc..827f7ce692 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala @@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite { } } - test("HashPartitioning is the output partitioning") { + test("HashPartitioning (with nullSafe = true) is the output partitioning") { // Cases which do not need an exchange between two data properties. checkSatisfied( HashPartitioning(Seq('a, 'b, 'c), 10), |