aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala')
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala126
1 files changed, 125 insertions, 1 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
index e5063599a3..81cc6b123c 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.types.{IntegerType, StringType}
+import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType}
class ConstraintPropagationSuite extends SparkFunSuite {
@@ -88,6 +88,33 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(aliasedRelation.analyze, "a")))))
}
+ test("propagating constraints in expand") {
+ val tr = LocalRelation('a.int, 'b.int, 'c.int)
+
+ assert(tr.analyze.constraints.isEmpty)
+
+ // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation
+ // by creating notNullRelation.
+ val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2)
+ verifyConstraints(notNullRelation.analyze.constraints,
+ ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "c")),
+ resolveColumn(notNullRelation.analyze, "a") < 5,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "a")),
+ resolveColumn(notNullRelation.analyze, "b") > 2,
+ IsNotNull(resolveColumn(notNullRelation.analyze, "b")))))
+
+ val expand = Expand(
+ Seq(
+ Seq('c, Literal.create(null, StringType), 1),
+ Seq('c, 'a, 2)),
+ Seq('c, 'a, 'gid.int),
+ Project(Seq('a, 'c),
+ notNullRelation))
+ verifyConstraints(expand.analyze.constraints,
+ ExpressionSet(Seq.empty[Expression]))
+ }
+
test("propagating constraints in aliases") {
val tr = LocalRelation('a.int, 'b.string, 'c.int)
@@ -121,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite {
.analyze.constraints,
ExpressionSet(Seq(resolveColumn(tr1, "a") > 10,
IsNotNull(resolveColumn(tr1, "a")))))
+
+ val a = resolveColumn(tr1, "a")
+ verifyConstraints(tr1
+ .where('a.attr > 10)
+ .union(tr2.where('d.attr > 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a))))
+
+ val b = resolveColumn(tr1, "b")
+ verifyConstraints(tr1
+ .where('a.attr > 10 && 'b.attr < 10)
+ .union(tr2.where('d.attr > 11 && 'e.attr < 11))
+ .analyze.constraints,
+ ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b))))
}
test("propagating constraints in intersect") {
@@ -219,6 +260,89 @@ class ConstraintPropagationSuite extends SparkFunSuite {
IsNotNull(resolveColumn(tr, "b")))))
}
+ test("infer constraints on cast") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr === 'b.attr &&
+ 'c.attr + 100 > 'd.attr &&
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"),
+ Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType)))))
+ }
+
+ test("infer isnotnull constraints from compound expressions") {
+ val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int)
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr === 'c.attr &&
+ IsNotNull(
+ Cast(
+ Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "e")),
+ IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType)))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) ===
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) ===
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints,
+ ExpressionSet(Seq(
+ Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >=
+ Cast(resolveColumn(tr, "c"), LongType),
+ Cast(resolveColumn(tr, "d"), DoubleType) /
+ Cast(Cast(10, LongType), DoubleType) <
+ Cast(resolveColumn(tr, "e"), DoubleType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ verifyConstraints(
+ tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints,
+ ExpressionSet(Seq(
+ (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) -
+ (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) >
+ Cast(resolveColumn(tr, "e") * 1000, LongType),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "b")),
+ IsNotNull(resolveColumn(tr, "c")),
+ IsNotNull(resolveColumn(tr, "d")),
+ IsNotNull(resolveColumn(tr, "e")))))
+
+ // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null.
+ verifyConstraints(
+ tr.where('a.attr === 'c.attr &&
+ IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints,
+ ExpressionSet(Seq(
+ resolveColumn(tr, "a") === resolveColumn(tr, "c"),
+ IsNotNull(IsNotNull(resolveColumn(tr, "b"))),
+ IsNotNull(resolveColumn(tr, "a")),
+ IsNotNull(resolveColumn(tr, "c")))))
+ }
+
test("infer IsNotNull constraints from non-nullable attributes") {
val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(),
AttributeReference("c", StringType, nullable = false)())