diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-04-01 13:00:55 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2016-04-01 13:00:55 -0700 |
commit | df68beb85de59bb6d35b2a8a3b85dbc447798bf5 (patch) | |
tree | 56ac7f97f44da2392f739387aa2568faa07df053 /sql | |
parent | 381358fbe9afbe205299cbbea4c43148e2e69468 (diff) | |
download | spark-df68beb85de59bb6d35b2a8a3b85dbc447798bf5.tar.gz spark-df68beb85de59bb6d35b2a8a3b85dbc447798bf5.tar.bz2 spark-df68beb85de59bb6d35b2a8a3b85dbc447798bf5.zip |
[SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression
## What changes were proposed in this pull request?
JIRA: https://issues.apache.org/jira/browse/SPARK-13995
We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`.
For example:
val tr = LocalRelation('a.int, 'b.long)
val plan = tr.where('a.attr === 'b.attr).analyze
Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it.
Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions.
For example, consider the following constraints:
val df = Seq((1,2,3)).toDF("a", "b", "c")
df.where("a + b = c").queryExecution.analyzed.constraints
The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c).
## How was this patch tested?
Test is added into `ConstraintPropagationSuite`.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #11809 from viirya/constraint-cast.
Diffstat (limited to 'sql')
7 files changed, 134 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a965cc8d53..d842ffdc66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -112,7 +112,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1e9c971800..b388091538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryMinus(child: Expression) extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def sql: String = s"(-${child.sql})" } -case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryPositive(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects @ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class Add(left: Expression, right: Expression) extends BinaryArithmetic { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } -case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { +case class Subtract(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } -case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { +case class Multiply(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { +case class Divide(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a5b5758167..262582ca5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -97,7 +97,7 @@ trait NamedExpression extends Expression { } } -abstract class Attribute extends LeafExpression with NamedExpression { +abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { override def references: AttributeSet = AttributeSet(this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa7..23baa6f783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -92,4 +92,11 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) } } + + /** + * When an expression inherits this, meaning the expression is null intolerant (i.e. any null + * input will result in null output). We will use this information during constructing IsNotNull + * constraints. + */ + trait NullIntolerant } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e23ad5596b..4eb33258ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,7 +90,7 @@ trait PredicateHelper { case class Not(child: Expression) - extends UnaryExpression with Predicate with ImplicitCastInputTypes { + extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { override def toString: String = s"NOT $child" @@ -402,7 +402,8 @@ private[sql] object Equality { } -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = AnyDataType @@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { +case class LessThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class LessThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d31164fe94..22a4461e66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * returns a constraint of the form `isNotNull(a)` */ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { - var isNotNullConstraints = Set.empty[Expression] - - // First, we propagate constraints if the condition consists of equality and ranges. For all - // other cases, we return an empty set of constraints - constraints.foreach { - case EqualTo(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case Not(EqualTo(l, r)) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case _ => // No inference - } + // First, we propagate constraints from the null intolerant expressions. + var isNotNullConstraints: Set[Expression] = + constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output @@ -73,6 +57,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT } /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant | IsNotNull(_: NullIntolerant) => + expr.children.flatMap(scanNullIntolerantExpr) + case _ => Seq.empty[Attribute] + } + + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an * additional constraint of the form `b = 5` diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index e5063599a3..5cbb889f8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "b"))))) } + test("infer constraints on cast") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr === 'b.attr && + 'c.attr + 100 > 'd.attr && + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) + } + + test("infer isnotnull constraints from compound expressions") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr + 'b.attr === 'c.attr && + IsNotNull( + Cast( + Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === + Cast(resolveColumn(tr, "c"), LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) === + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) < + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, + ExpressionSet(Seq( + (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - + (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > + Cast(resolveColumn(tr, "e") * 1000, LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. + verifyConstraints( + tr.where('a.attr === 'c.attr && + IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, + ExpressionSet(Seq( + resolveColumn(tr, "a") === resolveColumn(tr, "c"), + IsNotNull(IsNotNull(resolveColumn(tr, "b"))), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } + test("infer IsNotNull constraints from non-nullable attributes") { val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(), AttributeReference("c", StringType, nullable = false)()) |