aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-04-01 13:00:55 -0700
committerMichael Armbrust <michael@databricks.com>2016-04-01 13:00:55 -0700
commitdf68beb85de59bb6d35b2a8a3b85dbc447798bf5 (patch)
tree56ac7f97f44da2392f739387aa2568faa07df053
parent381358fbe9afbe205299cbbea4c43148e2e69468 (diff)
downloadspark-df68beb85de59bb6d35b2a8a3b85dbc447798bf5.tar.gz
spark-df68beb85de59bb6d35b2a8a3b85dbc447798bf5.tar.bz2
spark-df68beb85de59bb6d35b2a8a3b85dbc447798bf5.zip
[SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression
## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13995 We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`. For example: val tr = LocalRelation('a.int, 'b.long) val plan = tr.where('a.attr === 'b.attr).analyze Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it. Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions. For example, consider the following constraints: val df = Seq((1,2,3)).toDF("a", "b", "c") df.where("a + b = c").queryExecution.analyzed.constraints The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c). ## How was this patch tested? Test is added into `ConstraintPropagationSuite`. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #11809 from viirya/constraint-cast.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala25
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala7
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala17
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala33
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala85
7 files changed, 134 insertions, 37 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a965cc8d53..d842ffdc66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -112,7 +112,7 @@ object Cast {
}
/** Cast the child expression to the target data type. */
-case class Cast(child: Expression, dataType: DataType) extends UnaryExpression {
+case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
override def toString: String = s"cast($child as ${dataType.simpleString})"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1e9c971800..b388091538 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -24,7 +24,8 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval
-case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class UnaryMinus(child: Expression) extends UnaryExpression
+ with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp
override def sql: String = s"(-${child.sql})"
}
-case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class UnaryPositive(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def prettyName: String = "positive"
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)
@@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects
@ExpressionDescription(
usage = "_FUNC_(expr) - Returns the absolute value of the numeric value",
extended = "> SELECT _FUNC_('-1');\n1")
-case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes {
+case class Abs(child: Expression)
+ extends UnaryExpression with ExpectsInputTypes with NullIntolerant {
override def inputTypes: Seq[AbstractDataType] = Seq(NumericType)
@@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic {
def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}
-case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic {
}
}
-case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Subtract(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.NumericAndInterval
@@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti
}
}
-case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Multiply(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti
protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2)
}
-case class Divide(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Divide(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
}
}
-case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Remainder(left: Expression, right: Expression)
+ extends BinaryArithmetic with NullIntolerant {
override def inputType: AbstractDataType = NumericType
@@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression)
override def symbol: String = "min"
}
-case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic {
+case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant {
override def toString: String = s"pmod($left, $right)"
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index a5b5758167..262582ca5d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -97,7 +97,7 @@ trait NamedExpression extends Expression {
}
}
-abstract class Attribute extends LeafExpression with NamedExpression {
+abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant {
override def references: AttributeSet = AttributeSet(this)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
index f1fa13daa7..23baa6f783 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala
@@ -92,4 +92,11 @@ package object expressions {
StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable)))
}
}
+
+ /**
+ * When an expression inherits this, meaning the expression is null intolerant (i.e. any null
+ * input will result in null output). We will use this information during constructing IsNotNull
+ * constraints.
+ */
+ trait NullIntolerant
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index e23ad5596b..4eb33258ac 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -90,7 +90,7 @@ trait PredicateHelper {
case class Not(child: Expression)
- extends UnaryExpression with Predicate with ImplicitCastInputTypes {
+ extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant {
override def toString: String = s"NOT $child"
@@ -402,7 +402,8 @@ private[sql] object Equality {
}
-case class EqualTo(left: Expression, right: Expression) extends BinaryComparison {
+case class EqualTo(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = AnyDataType
@@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp
}
-case class LessThan(left: Expression, right: Expression) extends BinaryComparison {
+case class LessThan(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso
}
-case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+case class LessThanOrEqual(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo
}
-case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison {
+case class GreaterThan(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
@@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar
}
-case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison {
+case class GreaterThanOrEqual(left: Expression, right: Expression)
+ extends BinaryComparison with NullIntolerant {
override def inputType: AbstractDataType = TypeCollection.Ordered
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d31164fe94..22a4461e66 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* returns a constraint of the form `isNotNull(a)`
*/
private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = {
- var isNotNullConstraints = Set.empty[Expression]
-
- // First, we propagate constraints if the condition consists of equality and ranges. For all
- // other cases, we return an empty set of constraints
- constraints.foreach {
- case EqualTo(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case GreaterThan(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case GreaterThanOrEqual(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case LessThan(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case LessThanOrEqual(l, r) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case Not(EqualTo(l, r)) =>
- isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r))
- case _ => // No inference
- }
+ // First, we propagate constraints from the null intolerant expressions.
+ var isNotNullConstraints: Set[Expression] =
+ constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_))
// Second, we infer additional constraints from non-nullable attributes that are part of the
// operator's output
@@ -73,6 +57,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
}
/**
+ * Recursively explores the expressions which are null intolerant and returns all attributes
+ * in these expressions.
+ */
+ private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match {
+ case a: Attribute => Seq(a)
+ case _: NullIntolerant | IsNotNull(_: NullIntolerant) =>
+ expr.children.flatMap(scanNullIntolerantExpr)
+ case _ => Seq.empty[Attribute]
+ }
+
+ /**
* Infers an additional set of constraints from a given set of equality constraints.
* For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an
* additional constraint of the form `b = 5`
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index e5063599a3..5cbb889f8e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
class ConstraintPropagationSuite extends SparkFunSuite {
@@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "b")))))
}
+ test("infer constraints on cast") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr === 'b.attr &&
+ 'c.attr + 100 > 'd.attr &&
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
+ Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
+ }
+
+ test("infer isnotnull constraints from compound expressions") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr === 'c.attr &&
+ IsNotNull(
+ Cast(
+ Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) ===
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) <
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
+ ExpressionSet(Seq(
+ (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
+ (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
+ Cast(resolveColumn(tr, "e") * 1000, LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
+ verifyConstraints(
+ tr.where('a.attr === 'c.attr &&
+ IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
+ ExpressionSet(Seq(
+ resolveColumn(tr, "a") === resolveColumn(tr, "c"),
+ IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "c")))))
+ }
+
test("infer IsNotNull constraints from non-nullable attributes") {
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
AttributeReference("c", StringType, nullable = false)())