diff options
2 files changed, 90 insertions, 13 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
index afc190e697..bacedec1ae 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala
@@ -64,19 +64,75 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi
val trueEval = trueValue.genCode(ctx)
val falseEval = falseValue.genCode(ctx)
- ev.copy(code = s"""
- ${condEval.code}
- boolean ${ev.isNull} = false;
- ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${condEval.isNull} && ${condEval.value}) {
- ${trueEval.code}
- ${ev.isNull} = ${trueEval.isNull};
- ${ev.value} = ${trueEval.value};
- } else {
- ${falseEval.code}
- ${ev.isNull} = ${falseEval.isNull};
- ${ev.value} = ${falseEval.value};
- }""")
+ // place generated code of condition, true value and false value in separate methods if
+ // their code combined is large
+ val combinedLength = condEval.code.length + trueEval.code.length + falseEval.code.length
+ val generatedCode = if (combinedLength > 1024 &&
+ // Split these expressions only if they are created from a row object
+ (ctx.INPUT_ROW != null && ctx.currentVars == null)) {
+ val (condFuncName, condGlobalIsNull, condGlobalValue) =
+ createAndAddFunction(ctx, condEval, predicate.dataType, "evalIfCondExpr")
+ val (trueFuncName, trueGlobalIsNull, trueGlobalValue) =
+ createAndAddFunction(ctx, trueEval, trueValue.dataType, "evalIfTrueExpr")
+ val (falseFuncName, falseGlobalIsNull, falseGlobalValue) =
+ createAndAddFunction(ctx, falseEval, falseValue.dataType, "evalIfFalseExpr")
+ s"""
+ $condFuncName(${ctx.INPUT_ROW});
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!$condGlobalIsNull && $condGlobalValue) {
+ $trueFuncName(${ctx.INPUT_ROW});
+ ${ev.isNull} = $trueGlobalIsNull;
+ ${ev.value} = $trueGlobalValue;
+ } else {
+ $falseFuncName(${ctx.INPUT_ROW});
+ ${ev.isNull} = $falseGlobalIsNull;
+ ${ev.value} = $falseGlobalValue;
+ }
+ """
+ }
+ else {
+ s"""
+ ${condEval.code}
+ boolean ${ev.isNull} = false;
+ ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+ if (!${condEval.isNull} && ${condEval.value}) {
+ ${trueEval.code}
+ ${ev.isNull} = ${trueEval.isNull};
+ ${ev.value} = ${trueEval.value};
+ } else {
+ ${falseEval.code}
+ ${ev.isNull} = ${falseEval.isNull};
+ ${ev.value} = ${falseEval.value};
+ }
+ """
+ }
+ ev.copy(code = generatedCode)
+ }
+ private def createAndAddFunction(
+ ctx: CodegenContext,
+ ev: ExprCode,
+ dataType: DataType,
+ baseFuncName: String): (String, String, String) = {
+ val globalIsNull = ctx.freshName("isNull")
+ ctx.addMutableState("boolean", globalIsNull, s"$globalIsNull = false;")
+ val globalValue = ctx.freshName("value")
+ ctx.addMutableState(ctx.javaType(dataType), globalValue,
+ s"$globalValue = ${ctx.defaultValue(dataType)};")
+ val funcName = ctx.freshName(baseFuncName)
+ val funcBody =
+ s"""
+ |private void $funcName(InternalRow ${ctx.INPUT_ROW}) {
+ | ${ev.code.trim}
+ | $globalIsNull = ${ev.isNull};
+ | $globalValue = ${ev.value};
+ |}
+ """.stripMargin
+ ctx.addNewFunction(funcName, funcBody)
+ (funcName, globalIsNull, globalValue)
override def toString: String = s"if ($predicate) $trueValue else $falseValue"
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 0cb201e4da..0f4b4b5bc8 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -97,6 +97,27 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
assert(actual(0) == cases)
+ test("SPARK-18091: split large if expressions into blocks due to JVM code size limit") {
+ val inStr = "StringForTesting"
+ val row = create_row(inStr)
+ val inputStrAttr = 'a.string.at(0)
+ var strExpr: Expression = inputStrAttr
+ for (_ <- 1 to 13) {
+ strExpr = If(EqualTo(Decode(Encode(strExpr, "utf-8"), "utf-8"), inputStrAttr),
+ strExpr, strExpr)
+ }
+ val expressions = Seq(strExpr)
+ val plan = GenerateUnsafeProjection.generate(expressions, true)
+ val actual = plan(row).toSeq(expressions.map(_.dataType))
+ val expected = Seq(UTF8String.fromString(inStr))
+ if (!checkResult(actual, expected)) {
+ fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
+ }
+ }
test("SPARK-14793: split wide array creation into blocks due to JVM code size limit") {
val length = 5000
val expressions = Seq(CreateArray(List.fill(length)(EqualTo(Literal(1), Literal(1)))))