aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-07-07 19:45:43 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-07 19:45:43 +0800
commita04cab8f17fcac05f86d2c472558ab98923f91e3 (patch)
tree7f2de5eceae3de1f071281d2512ea21361a3c5a4 /sql/catalyst
parent6343f66557434ce889a25a7889d76d0d24188ced (diff)
downloadspark-a04cab8f17fcac05f86d2c472558ab98923f91e3.tar.gz
spark-a04cab8f17fcac05f86d2c472558ab98923f91e3.tar.bz2
spark-a04cab8f17fcac05f86d2c472558ab98923f91e3.zip
[SPARK-16174][SQL] Improve `OptimizeIn` optimizer to remove literal repetitions
## What changes were proposed in this pull request? This PR improves `OptimizeIn` optimizer to remove the literal repetitions from SQL `IN` predicates. This optimizer prevents user mistakes and also can optimize some queries like [TPCDS-36](https://github.com/apache/spark/blob/master/sql/core/src/test/resources/tpcds/q36.sql#L19). **Before** ```scala scala> sql("select state from (select explode(array('CA','TN')) state) where state in ('TN','TN','TN','TN','TN','TN','TN')").explain == Physical Plan == *Filter state#6 IN (TN,TN,TN,TN,TN,TN,TN) +- Generate explode([CA,TN]), false, false, [state#6] +- Scan OneRowRelation[] ``` **After** ```scala scala> sql("select state from (select explode(array('CA','TN')) state) where state in ('TN','TN','TN','TN','TN','TN','TN')").explain == Physical Plan == *Filter state#6 IN (TN) +- Generate explode([CA,TN]), false, false, [state#6] +- Scan OneRowRelation[] ``` ## How was this patch tested? Pass the Jenkins tests (including a new testcase). Author: Dongjoon Hyun <dongjoon@apache.org> Closes #13876 from dongjoon-hyun/SPARK-16174.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala20
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala24
3 files changed, 39 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
index a3b098afe5..734bacf727 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala
@@ -132,6 +132,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate
}
override def children: Seq[Expression] = value +: list
+ lazy val inSetConvertible = list.forall(_.isInstanceOf[Literal])
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9ee1735069..03d15eabdd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -820,16 +820,24 @@ object ConstantFolding extends Rule[LogicalPlan] {
}
/**
- * Replaces [[In (value, seq[Literal])]] with optimized version[[InSet (value, HashSet[Literal])]]
- * which is much faster
+ * Optimize IN predicates:
+ * 1. Removes literal repetitions.
+ * 2. Replaces [[In (value, seq[Literal])]] with optimized version
+ * [[InSet (value, HashSet[Literal])]] which is much faster.
*/
case class OptimizeIn(conf: CatalystConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case q: LogicalPlan => q transformExpressionsDown {
- case In(v, list) if !list.exists(!_.isInstanceOf[Literal]) &&
- list.size > conf.optimizerInSetConversionThreshold =>
- val hSet = list.map(e => e.eval(EmptyRow))
- InSet(v, HashSet() ++ hSet)
+ case expr @ In(v, list) if expr.inSetConvertible =>
+ val newList = ExpressionSet(list).toSeq
+ if (newList.size > conf.optimizerInSetConversionThreshold) {
+ val hSet = newList.map(e => e.eval(EmptyRow))
+ InSet(v, HashSet() ++ hSet)
+ } else if (newList.size < list.size) {
+ expr.copy(list = newList)
+ } else { // newList.length == list.length
+ expr
+ }
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
index f1a4ea8280..0877207728 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeInSuite.scala
@@ -42,6 +42,30 @@ class OptimizeInSuite extends PlanTest {
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ test("OptimizedIn test: Remove deterministic repetitions") {
+ val originalQuery =
+ testRelation
+ .where(In(UnresolvedAttribute("a"),
+ Seq(Literal(1), Literal(1), Literal(2), Literal(2), Literal(1), Literal(2))))
+ .where(In(UnresolvedAttribute("b"),
+ Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"),
+ Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0),
+ Rand(0), Rand(0))))
+ .analyze
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .where(In(UnresolvedAttribute("a"), Seq(Literal(1), Literal(2))))
+ .where(In(UnresolvedAttribute("b"),
+ Seq(UnresolvedAttribute("a"), UnresolvedAttribute("a"),
+ Round(UnresolvedAttribute("a"), 0), Round(UnresolvedAttribute("a"), 0),
+ Rand(0), Rand(0))))
+ .analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
test("OptimizedIn test: In clause not optimized to InSet when less than 10 items") {
val originalQuery =
testRelation