From 3d46d796a3a2b60b37dc318652eded5e992be1e5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 19 Apr 2016 21:38:15 +0800 Subject: [SPARK-14577][SQL] Add spark.sql.codegen.maxCaseBranches config option ## What changes were proposed in this pull request? We currently disable codegen for `CaseWhen` if the number of branches is greater than 20 (in CaseWhen.MAX_NUM_CASES_FOR_CODEGEN). It would be better if this value is a non-public config defined in SQLConf. ## How was this patch tested? Pass the Jenkins tests (including a new testcase `Support spark.sql.codegen.maxCaseBranches option`) Author: Dongjoon Hyun Closes #12353 from dongjoon-hyun/SPARK-14577. --- .../apache/spark/sql/catalyst/CatalystConf.scala | 4 +- .../expressions/conditionalExpressions.scala | 86 ++++++++++------- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 ++- .../catalyst/optimizer/OptimizeCodegenSuite.scala | 102 +++++++++++++++++++++ .../spark/sql/execution/WholeStageCodegen.scala | 1 - .../org/apache/spark/sql/internal/SQLConf.scala | 8 ++ 6 files changed, 180 insertions(+), 35 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index abba866821..0efe3c4d45 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -29,6 +29,7 @@ trait CatalystConf { def groupByOrdinal: Boolean def optimizerMaxIterations: Int + def maxCaseBranchesForCodegen: Int /** * Returns the [[Resolver]] for the current configuration, which can be used to determine if two @@ -45,6 +46,7 @@ case class SimpleCatalystConf( caseSensitiveAnalysis: Boolean, orderByOrdinal: Boolean = true, groupByOrdinal: Boolean = true, - optimizerMaxIterations: Int = 100) + optimizerMaxIterations: Int = 100, + maxCaseBranchesForCodegen: Int = 20) extends CatalystConf { } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 336649c0fd..e97e08947a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -81,18 +81,15 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } /** - * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". - * When a = true, returns b; when c = true, returns d; else returns e. + * Abstract parent class for common logic in CaseWhen and CaseWhenCodegen. * * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ -// scalastyle:off line.size.limit -@ExpressionDescription( - usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") -// scalastyle:on line.size.limit -case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) - extends Expression with CodegenFallback { +abstract class CaseWhenBase( + branches: Seq[(Expression, Expression)], + elseValue: Option[Expression]) + extends Expression with Serializable { override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue @@ -142,16 +139,58 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E } } - def shouldCodegen: Boolean = { - branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN + override def toString: String = { + val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString + val elseCase = elseValue.map(" ELSE " + _).getOrElse("") + "CASE" + cases + elseCase + " END" + } + + override def sql: String = { + val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString + val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") + "CASE" + cases + elseCase + " END" + } +} + + +/** + * Case statements of the form "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END". + * When a = true, returns b; when c = true, returns d; else returns e. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch + */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") +// scalastyle:on line.size.limit +case class CaseWhen( + val branches: Seq[(Expression, Expression)], + val elseValue: Option[Expression] = None) + extends CaseWhenBase(branches, elseValue) with CodegenFallback with Serializable { + + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + super[CodegenFallback].doGenCode(ctx, ev) + } + + def toCodegen(): CaseWhenCodegen = { + CaseWhenCodegen(branches, elseValue) } +} + +/** + * CaseWhen expression used when code generation condition is satisfied. + * OptimizeCodegen optimizer replaces CaseWhen into CaseWhenCodegen. + * + * @param branches seq of (branch condition, branch value) + * @param elseValue optional value for the else branch + */ +case class CaseWhenCodegen( + val branches: Seq[(Expression, Expression)], + val elseValue: Option[Expression] = None) + extends CaseWhenBase(branches, elseValue) with Serializable { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - if (!shouldCodegen) { - // Fallback to interpreted mode if there are too many branches, as it may reach the - // 64K limit (limit on bytecode size for a single function). - return super[CodegenFallback].doGenCode(ctx, ev) - } // Generate code that looks like: // // condA = ... @@ -202,26 +241,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $generatedCode""") } - - override def toString: String = { - val cases = branches.map { case (c, v) => s" WHEN $c THEN $v" }.mkString - val elseCase = elseValue.map(" ELSE " + _).getOrElse("") - "CASE" + cases + elseCase + " END" - } - - override def sql: String = { - val cases = branches.map { case (c, v) => s" WHEN ${c.sql} THEN ${v.sql}" }.mkString - val elseCase = elseValue.map(" ELSE " + _.sql).getOrElse("") - "CASE" + cases + elseCase + " END" - } } /** Factory methods for CaseWhen. */ object CaseWhen { - - // The maximum number of switches supported with codegen. - val MAX_NUM_CASES_FOR_CODEGEN = 20 - def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = { CaseWhen(branches, Option(elseValue)) } @@ -242,7 +265,6 @@ object CaseWhen { } } - /** * Case statements of the form "CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END". * When a = b, returns c; when a = d, returns e; else returns f. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c46bdfb2b5..b806b725a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -104,7 +104,9 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("Subquery", Once, - OptimizeSubqueries) :: Nil + OptimizeSubqueries) :: + Batch("OptimizeCodegen", Once, + OptimizeCodegen(conf)) :: Nil } /** @@ -863,6 +865,16 @@ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Optimizes expressions by replacing according to CodeGen configuration. + */ +case class OptimizeCodegen(conf: CatalystConf) extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + case e @ CaseWhen(branches, _) if branches.size < conf.maxCaseBranchesForCodegen => + e.toCodegen() + } +} + /** * Combines all adjacent [[Union]] operators into a single [[Union]]. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala new file mode 100644 index 0000000000..4385b0e019 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizeCodegenSuite.scala @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + + +class OptimizeCodegenSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("OptimizeCodegen", Once, OptimizeCodegen(SimpleCatalystConf(true))) :: Nil + } + + protected def assertEquivalent(e1: Expression, e2: Expression): Unit = { + val correctAnswer = Project(Alias(e2, "out")() :: Nil, OneRowRelation).analyze + val actual = Optimize.execute(Project(Alias(e1, "out")() :: Nil, OneRowRelation).analyze) + comparePlans(actual, correctAnswer) + } + + test("Codegen only when the number of branches is small.") { + assertEquivalent( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen()) + + assertEquivalent( + CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2)), + CaseWhen(List.fill(100)(TrueLiteral, Literal(1)), Literal(2))) + } + + test("Nested CaseWhen Codegen.") { + assertEquivalent( + CaseWhen( + Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), Literal(3))), + CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5))), + CaseWhen( + Seq((CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), Literal(3))), + CaseWhen(Seq((TrueLiteral, Literal(4))), Literal(5)).toCodegen()).toCodegen()) + } + + test("Multiple CaseWhen in one operator.") { + val plan = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6))).analyze + val correctAnswer = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0)), + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen()).analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, correctAnswer) + } + + test("Multiple CaseWhen in different operators") { + val plan = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + .where( + LessThan( + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + ).analyze + val correctAnswer = OneRowRelation + .select( + CaseWhen(Seq((TrueLiteral, Literal(1))), Literal(2)).toCodegen(), + CaseWhen(Seq((FalseLiteral, Literal(3))), Literal(4)).toCodegen(), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + .where( + LessThan( + CaseWhen(Seq((TrueLiteral, Literal(5))), Literal(6)).toCodegen(), + CaseWhen(List.fill(20)((TrueLiteral, Literal(0))), Literal(0))) + ).analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 29b66e3dee..46eaede5e7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -429,7 +429,6 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { case e: LeafExpression => true - case e: CaseWhen => e.shouldCodegen // CodegenFallback requires the input to be an InternalRow case e: CodegenFallback => false case _ => true diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7f206bdb9b..4ae8278a9d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -402,6 +402,12 @@ object SQLConf { .intConf .createWithDefault(200) + val MAX_CASES_BRANCHES = SQLConfigBuilder("spark.sql.codegen.maxCaseBranches") + .internal() + .doc("The maximum number of switches supported with codegen.") + .intConf + .createWithDefault(20) + val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -529,6 +535,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def maxCaseBranchesForCodegen: Int = getConf(MAX_CASES_BRANCHES) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) -- cgit v1.2.3