aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-05-31 10:08:00 -0700
committerWenchen Fan <wenchen@databricks.com>2016-05-31 10:08:00 -0700
commitd67c82e4b647dacd0b789d72c9eaf4dc7d329dbd (patch)
treefa744ff350d34fe476dc03137e1714b2bfc808ac /sql
parent2bfc4f15214a870b3e067f06f37eb506b0070a1f (diff)
downloadspark-d67c82e4b647dacd0b789d72c9eaf4dc7d329dbd.tar.gz
spark-d67c82e4b647dacd0b789d72c9eaf4dc7d329dbd.tar.bz2
spark-d67c82e4b647dacd0b789d72c9eaf4dc7d329dbd.zip
[SPARK-15647][SQL] Fix Boundary Cases in OptimizeCodegen Rule
#### What changes were proposed in this pull request? The following condition in the Optimizer rule `OptimizeCodegen` is not right. ```Scala branches.size < conf.maxCaseBranchesForCodegen ``` - The number of branches in case when clause should be `branches.size + elseBranch.size`. - `maxCaseBranchesForCodegen` is the maximum boundary for enabling codegen. Thus, we should use `<=` instead of `<`. This PR is to fix this boundary case and also add missing test cases for verifying the conf `MAX_CASES_BRANCHES`. #### How was this patch tested? Added test cases in `SQLConfSuite` Author: gatorsmile <gatorsmile@gmail.com> Closes #13392 from gatorsmile/maxCaseWhen.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala29
2 files changed, 35 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 688c77d3ca..93762ad1b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -937,8 +937,12 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
*/
case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
- e.toCodegen()
+ case e: CaseWhen if canCodegen(e) => e.toCodegen()
+ }
+
+ private def canCodegen(e: CaseWhen): Boolean = {
+ val numBranches = e.branches.size + e.elseValue.size
+ numBranches <= conf.maxCaseBranchesForCodegen
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
index 3d4fc75e83..2cd3f475b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/internal/SQLConfSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.internal
import org.apache.spark.sql.{QueryTest, Row, SparkSession, SQLContext}
+import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.test.{SharedSQLContext, TestSQLContext}
class SQLConfSuite extends QueryTest with SharedSQLContext {
@@ -219,4 +220,32 @@ class SQLConfSuite extends QueryTest with SharedSQLContext {
}
}
+ test("MAX_CASES_BRANCHES") {
+ withTable("tab1") {
+ spark.range(10).write.saveAsTable("tab1")
+ val sql_one_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 END FROM tab1"
+ val sql_two_branch_caseWhen = "SELECT CASE WHEN id = 1 THEN 1 ELSE 0 END FROM tab1"
+
+ withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "0") {
+ assert(!sql(sql_one_branch_caseWhen)
+ .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ assert(!sql(sql_two_branch_caseWhen)
+ .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ }
+
+ withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "1") {
+ assert(sql(sql_one_branch_caseWhen)
+ .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ assert(!sql(sql_two_branch_caseWhen)
+ .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ }
+
+ withSQLConf(SQLConf.MAX_CASES_BRANCHES.key -> "2") {
+ assert(sql(sql_one_branch_caseWhen)
+ .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ assert(sql(sql_two_branch_caseWhen)
+ .queryExecution.executedPlan.isInstanceOf[WholeStageCodegenExec])
+ }
+ }
+ }
}