aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala6
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala1
4 files changed, 38 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index c3e9fa33e6..5ceb36513f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -86,7 +86,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
* @param elseValue optional value for the else branch
*/
case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None)
- extends Expression {
+ extends Expression with CodegenFallback {
override def children: Seq[Expression] = branches.flatMap(b => b._1 :: b._2 :: Nil) ++ elseValue
@@ -136,7 +136,16 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
}
}
+ def shouldCodegen: Boolean = {
+ branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN
+ }
+
override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ if (!shouldCodegen) {
+ // Fallback to interpreted mode if there are too many branches, as it may reach the
+ // 64K limit (limit on bytecode size for a single function).
+ return super[CodegenFallback].genCode(ctx, ev)
+ }
// Generate code that looks like:
//
// condA = ...
@@ -205,6 +214,9 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E
/** Factory methods for CaseWhen. */
object CaseWhen {
+ // The maxium number of switches supported with codegen.
+ val MAX_NUM_CASES_FOR_CODEGEN = 20
+
def apply(branches: Seq[(Expression, Expression)], elseValue: Expression): CaseWhen = {
CaseWhen(branches, Option(elseValue))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index 37bfe98d3a..a76517a89c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -203,7 +203,7 @@ case class Literal protected (value: Any, dataType: DataType)
case FloatType =>
val v = value.asInstanceOf[Float]
if (v.isNaN || v.isInfinite) {
- super.genCode(ctx, ev)
+ super[CodegenFallback].genCode(ctx, ev)
} else {
ev.isNull = "false"
ev.value = s"${value}f"
@@ -212,7 +212,7 @@ case class Literal protected (value: Any, dataType: DataType)
case DoubleType =>
val v = value.asInstanceOf[Double]
if (v.isNaN || v.isInfinite) {
- super.genCode(ctx, ev)
+ super[CodegenFallback].genCode(ctx, ev)
} else {
ev.isNull = "false"
ev.value = s"${value}D"
@@ -232,7 +232,7 @@ case class Literal protected (value: Any, dataType: DataType)
""
// eval() version may be faster for non-primitive types
case other =>
- super.genCode(ctx, ev)
+ super[CodegenFallback].genCode(ctx, ev)
}
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index b5413fbe2b..260dfb3f42 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -58,6 +58,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
}
}
+ test("SPARK-13242: case-when expression with large number of branches (or cases)") {
+ val cases = 50
+ val clauses = 20
+
+ // Generate an individual case
+ def generateCase(n: Int): (Expression, Expression) = {
+ val condition = (1 to clauses)
+ .map(c => EqualTo(BoundReference(0, StringType, false), Literal(s"$c:$n")))
+ .reduceLeft[Expression]((l, r) => Or(l, r))
+ (condition, Literal(n))
+ }
+
+ val expression = CaseWhen((1 to cases).map(generateCase(_)))
+
+ val plan = GenerateMutableProjection.generate(Seq(expression))()
+ val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}")))
+ val actual = plan(input).toSeq(Seq(expression.dataType))
+
+ assert(actual(0) == cases)
+ }
+
test("test generated safe and unsafe projection") {
val schema = new StructType(Array(
StructField("a", StringType, true),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 45578d50bf..dd831e60cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -416,6 +416,7 @@ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Ru
private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
+ case e: CaseWhen => e.shouldCodegen
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true