aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala87
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala2
2 files changed, 78 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index f4d1dbaf28..ec659ce789 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -60,8 +60,9 @@ case class ClusteredDistribution(clustering: Seq[Expression]) extends Distributi
/**
* Represents data where tuples have been ordered according to the `ordering`
* [[Expression Expressions]]. This is a strictly stronger guarantee than
- * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the same value for
- * the ordering expressions are contiguous and will never be split across partitions.
+ * [[ClusteredDistribution]] as an ordering will ensure that tuples that share the
+ * same value for the ordering expressions are contiguous and will never be split across
+ * partitions.
*/
case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution {
require(
@@ -86,8 +87,12 @@ sealed trait Partitioning {
*/
def satisfies(required: Distribution): Boolean
- /** Returns the expressions that are used to key the partitioning. */
- def keyExpressions: Seq[Expression]
+ /**
+ * Returns true iff we can say that the partitioning scheme of this [[Partitioning]]
+ * guarantees the same partitioning scheme described by `other`.
+ */
+ // TODO: Add an example once we have the `nullSafe` concept.
+ def guarantees(other: Partitioning): Boolean
}
case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
@@ -96,7 +101,7 @@ case class UnknownPartitioning(numPartitions: Int) extends Partitioning {
case _ => false
}
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = false
}
case object SinglePartition extends Partitioning {
@@ -104,7 +109,10 @@ case object SinglePartition extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case SinglePartition => true
+ case _ => false
+ }
}
case object BroadcastPartitioning extends Partitioning {
@@ -112,7 +120,10 @@ case object BroadcastPartitioning extends Partitioning {
override def satisfies(required: Distribution): Boolean = true
- override def keyExpressions: Seq[Expression] = Nil
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case BroadcastPartitioning => true
+ case _ => false
+ }
}
/**
@@ -127,7 +138,7 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
override def nullable: Boolean = false
override def dataType: DataType = IntegerType
- private[this] lazy val clusteringSet = expressions.toSet
+ lazy val clusteringSet = expressions.toSet
override def satisfies(required: Distribution): Boolean = required match {
case UnspecifiedDistribution => true
@@ -136,7 +147,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
- override def keyExpressions: Seq[Expression] = expressions
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case o: HashPartitioning =>
+ this.clusteringSet == o.clusteringSet && this.numPartitions == o.numPartitions
+ case _ => false
+ }
}
/**
@@ -170,5 +185,57 @@ case class RangePartitioning(ordering: Seq[SortOrder], numPartitions: Int)
case _ => false
}
- override def keyExpressions: Seq[Expression] = ordering.map(_.child)
+ override def guarantees(other: Partitioning): Boolean = other match {
+ case o: RangePartitioning => this == o
+ case _ => false
+ }
+}
+
+/**
+ * A collection of [[Partitioning]]s that can be used to describe the partitioning
+ * scheme of the output of a physical operator. It is usually used for an operator
+ * that has multiple children. In this case, a [[Partitioning]] in this collection
+ * describes how this operator's output is partitioned based on expressions from
+ * a child. For example, for a Join operator on two tables `A` and `B`
+ * with a join condition `A.key1 = B.key2`, assuming we use HashPartitioning schema,
+ * there are two [[Partitioning]]s can be used to describe how the output of
+ * this Join operator is partitioned, which are `HashPartitioning(A.key1)` and
+ * `HashPartitioning(B.key2)`. It is also worth noting that `partitionings`
+ * in this collection do not need to be equivalent, which is useful for
+ * Outer Join operators.
+ */
+case class PartitioningCollection(partitionings: Seq[Partitioning])
+ extends Expression with Partitioning with Unevaluable {
+
+ require(
+ partitionings.map(_.numPartitions).distinct.length == 1,
+ s"PartitioningCollection requires all of its partitionings have the same numPartitions.")
+
+ override def children: Seq[Expression] = partitionings.collect {
+ case expr: Expression => expr
+ }
+
+ override def nullable: Boolean = false
+
+ override def dataType: DataType = IntegerType
+
+ override val numPartitions = partitionings.map(_.numPartitions).distinct.head
+
+ /**
+ * Returns true if any `partitioning` of this collection satisfies the given
+ * [[Distribution]].
+ */
+ override def satisfies(required: Distribution): Boolean =
+ partitionings.exists(_.satisfies(required))
+
+ /**
+ * Returns true if any `partitioning` of this collection guarantees
+ * the given [[Partitioning]].
+ */
+ override def guarantees(other: Partitioning): Boolean =
+ partitionings.exists(_.guarantees(other))
+
+ override def toString: String = {
+ partitionings.map(_.toString).mkString("(", " or ", ")")
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
index c046dbf4dc..827f7ce692 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/DistributionSuite.scala
@@ -42,7 +42,7 @@ class DistributionSuite extends SparkFunSuite {
}
}
- test("HashPartitioning is the output partitioning") {
+ test("HashPartitioning (with nullSafe = true) is the output partitioning") {
// Cases which do not need an exchange between two data properties.
checkSatisfied(
HashPartitioning(Seq('a, 'b, 'c), 10),