aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDongjoon Hyun <dongjoon@apache.org>2016-04-19 21:38:15 +0800
committerWenchen Fan <wenchen@databricks.com>2016-04-19 21:38:15 +0800
commit3d46d796a3a2b60b37dc318652eded5e992be1e5 (patch)
treea8157c731edb6f061fa163ef995bd4a839fccfdc
parent74fe235ab5ed169fb30d9d2c04077b90d1bf1b95 (diff)
downloadspark-3d46d796a3a2b60b37dc318652eded5e992be1e5.tar.gz
spark-3d46d796a3a2b60b37dc318652eded5e992be1e5.tar.bz2
spark-3d46d796a3a2b60b37dc318652eded5e992be1e5.zip
[SPARK-14577][SQL] Add spark.sql.codegen.maxCaseBranches config option
## What changes were proposed in this pull request? We currently disable codegen for `CaseWhen` if the number of branches is greater than 20 (in CaseWhen.MAX_NUM_CASES_FOR_CODEGEN). It would be better if this value is a non-public config defined in SQLConf. ## How was this patch tested? Pass the Jenkins tests (including a new testcase `Support spark.sql.codegen.maxCaseBranches option`) Author: Dongjoon Hyun <dongjoon@apache.org> Closes #12353 from dongjoon-hyun/SPARK-14577.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala86
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala14
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala102
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala8
6 files changed, 180 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
index abba866821..0efe3c4d45 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala
@@ -29,6 +29,7 @@ trait CatalystConf {
def groupByOrdinal: Boolean
def optimizerMaxIterations: Int
+ def maxCaseBranchesForCodegen: Int
/**
* Returns the [[Resolver]] for the current configuration, which can be used to determine if two
@@ -45,6 +46,7 @@ case class SimpleCatalystConf(
caseSensitiveAnalysis: Boolean,
orderByOrdinal: Boolean = true,
groupByOrdinal: Boolean = true,
- optimizerMaxIterations: Int = 100)
+ optimizerMaxIterations: Int = 100,
+ maxCaseBranchesForCodegen: Int = 20)
extends CatalystConf {
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index 336649c0fd..e97e08947a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
}
/**
- * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
- * When a = true, returns b; when c = true, returns d; else returns e.
+ * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen.
*
* @param branches seq of (branch condition, branch value)
* @param elseValue optional value for the else branch
*/
-// scalastyle:off line.size.limit
-@ExpressionDescription(
- usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
-// scalastyle:on line.size.limit
-case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
- extends Expression with CodegenFallback {
+abstract class CaseWhenBase(
+ branches: Seq[(Expression, Expression)],
+ elseValue: Option[Expression])
+ extends Expression with Serializable {
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
@@ -142,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
}
}
- def shouldCodegen: Boolean = {
- branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
+ override def toString: String = {
+ val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
+ val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
+ "CASE" + cases + elseCase + " END"
+ }
+
+ override def sql: String = {
+ val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
+ val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
+ "CASE" + cases + elseCase + " END"
+ }
+}
+
+
+/**
+ * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END".
+ * When a = true, returns b; when c = true, returns d; else returns e.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
+ */
+// scalastyle:off line.size.limit
+@ExpressionDescription(
+ usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.")
+// scalastyle:on line.size.limit
+case class CaseWhen(
+ val branches: Seq[(Expression, Expression)],
+ val elseValue: Option[Expression] = None)
+ extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable {
+
+ override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ super[CodegenFallback].doGenCode(ctx, ev)
+ }
+
+ def toCodegen(): CaseWhenCodegen = {
+ CaseWhenCodegen(branches, elseValue)
}
+}
+
+/**
+ * CaseWhen expression used when code generation condition is satisfied.
+ * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen.
+ *
+ * @param branches seq of (branch condition, branch value)
+ * @param elseValue optional value for the else branch
+ */
+case class CaseWhenCodegen(
+ val branches: Seq[(Expression, Expression)],
+ val elseValue: Option[Expression] = None)
+ extends CaseWhenBase(branches, elseValue) with Serializable {
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
- if (!shouldCodegen) {
- // Fallback to interpreted mode if there are too many branches, as it may reach the
- // 64K limit (limit on bytecode size for a single function).
- return super[CodegenFallback].doGenCode(ctx, ev)
- }
// Generate code that looks like:
//
// condA = ...
@@ -202,26 +241,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$generatedCode""")
}
-
- override def toString: String = {
- val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString
- val elseCase = elseValue.map(" ELSE " + _).getOrElse("")
- "CASE" + cases + elseCase + " END"
- }
-
- override def sql: String = {
- val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString
- val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("")
- "CASE" + cases + elseCase + " END"
- }
}
/** Factory methods for CaseWhen. */
object CaseWhen {
-
- // The maximum number of switches supported with codegen.
- val MAX_NUM_CASES_FOR_CODEGEN = 20
-
def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}
@@ -242,7 +265,6 @@ object CaseWhen {
}
}
-
/**
* Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END".
* When a = b, returns c; when a = d, returns e; else returns f.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index c46bdfb2b5..b806b725a8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -104,7 +104,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("Subquery", Once,
- OptimizeSubqueries) :: Nil
+ OptimizeSubqueries) ::
+ Batch("OptimizeCodegen", Once,
+ OptimizeCodegen(conf)) :: Nil
}
/**
@@ -864,6 +866,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper {
}
/**
+ * Optimizes expressions by replacing according to CodeGen configuration.
+ */
+case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen =>
+ e.toCodegen()
+ }
+}
+
+/**
* Combines all adjacent [[Union]] operators into a single [[Union]].
*/
object CombineUnions extends Rule[LogicalPlan] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
new file mode 100644
index 0000000000..4385b0e019
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.SimpleCatalystConf
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.Literal._
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.rules._
+
+
+class OptimizeCodegenSuite extends PlanTest {
+
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil
+ }
+
+ protected def assertEquivalent(e1: Expression, e2: Expression): Unit = {
+ val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze
+ val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze)
+ comparePlans(actual, correctAnswer)
+ }
+
+ test("Codegen only when the number of branches is small.") {
+ assertEquivalent(
+ CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+ CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen())
+
+ assertEquivalent(
+ CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)),
+ CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)))
+ }
+
+ test("Nested CaseWhen Codegen.") {
+ assertEquivalent(
+ CaseWhen(
+ Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))),
+ CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))),
+ CaseWhen(
+ Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))),
+ CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen())
+ }
+
+ test("Multiple CaseWhen in one operator.") {
+ val plan = OneRowRelation
+ .select(
+ CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+ CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
+ CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
+ CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze
+ val correctAnswer = OneRowRelation
+ .select(
+ CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
+ CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
+ CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)),
+ CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze
+ val optimized = Optimize.execute(plan)
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("Multiple CaseWhen in different operators") {
+ val plan = OneRowRelation
+ .select(
+ CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)),
+ CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)),
+ CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+ .where(
+ LessThan(
+ CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)),
+ CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+ ).analyze
+ val correctAnswer = OneRowRelation
+ .select(
+ CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(),
+ CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(),
+ CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+ .where(
+ LessThan(
+ CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(),
+ CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)))
+ ).analyze
+ val optimized = Optimize.execute(plan)
+ comparePlans(optimized, correctAnswer)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 29b66e3dee..46eaede5e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -429,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
- case e: CaseWhen => e.shouldCodegen
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 7f206bdb9b..4ae8278a9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -402,6 +402,12 @@ object SQLConf {
.intConf
.createWithDefault(200)
+ val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches")
+ .internal()
+ .doc("The maximum number of switches supported with codegen.")
+ .intConf
+ .createWithDefault(20)
+
val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes")
.doc("The maximum number of bytes to pack into a single partition when reading files.")
.longConf
@@ -529,6 +535,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging {
def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS)
+ def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES)
+
def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED)
def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW)