aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala25
1 files changed, 21 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 09c200fa83..d4fc9e4da9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan {
})
}
+ private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = {
+ val common = a.intersect(b)
+ // The constraint with only one reference could be easily inferred as predicate
+ // Grouping the constraints by it's references so we can combine the constraints with same
+ // reference together
+ val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head)
+ // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2)
+ val others = (othera.keySet intersect otherb.keySet).map { attr =>
+ Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And))
+ }
+ common ++ others
+ }
+
override protected def validConstraints: Set[Expression] = {
children
.map(child => rewriteConstraints(children.head.output, child.output, child.constraints))
- .reduce(_ intersect _)
+ .reduce(merge(_, _))
}
}
@@ -252,7 +266,7 @@ case class Join(
override def output: Seq[Attribute] = {
joinType match {
- case LeftSemi =>
+ case LeftExistence(_) =>
left.output
case LeftOuter =>
left.output ++ right.output.map(_.withNullability(true))
@@ -276,7 +290,7 @@ case class Join(
.union(splitConjunctivePredicates(condition.get).toSet)
case Inner =>
left.constraints.union(right.constraints)
- case LeftSemi =>
+ case LeftExistence(_) =>
left.constraints
case LeftOuter =>
left.constraints
@@ -519,7 +533,6 @@ case class Expand(
projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
-
override def references: AttributeSet =
AttributeSet(projections.flatten.flatMap(_.references))
@@ -527,6 +540,10 @@ case class Expand(
val sizeInBytes = super.statistics.sizeInBytes * projections.length
Statistics(sizeInBytes = sizeInBytes)
}
+
+ // This operator can reuse attributes (for example making them null when doing a roll up) so
+ // the contraints of the child may no longer be valid.
+ override protected def validConstraints: Set[Expression] = Set.empty[Expression]
}
/**