diff options
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala | 25 |
1 files changed, 21 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09c200fa83..d4fc9e4da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + override protected def validConstraints: Set[Expression] = { children .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) - .reduce(_ intersect _) + .reduce(merge(_, _)) } } @@ -252,7 +266,7 @@ case class Join( override def output: Seq[Attribute] = { joinType match { - case LeftSemi => + case LeftExistence(_) => left.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -276,7 +290,7 @@ case class Join( .union(splitConjunctivePredicates(condition.get).toSet) case Inner => left.constraints.union(right.constraints) - case LeftSemi => + case LeftExistence(_) => left.constraints case LeftOuter => left.constraints @@ -519,7 +533,6 @@ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) @@ -527,6 +540,10 @@ case class Expand( val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + // This operator can reuse attributes (for example making them null when doing a roll up) so + // the contraints of the child may no longer be valid. + override protected def validConstraints: Set[Expression] = Set.empty[Expression] } /** |