From c1da4d421ab78772ffa52ad46e5bdfb4e5268f47 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 31 Jan 2016 22:43:03 -0800 Subject: [SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression The current implementation is sub-optimal: * If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable. * If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others. This PR improves this by making the null check elimination more fine-grained. Author: Wenchen Fan Closes #10987 from cloud-fan/null-check. --- .../sql/catalyst/expressions/Expression.scala | 104 +++++++++++---------- .../expressions/codegen/CodeGenerator.scala | 32 +++++-- .../spark/sql/catalyst/expressions/misc.scala | 16 +--- 3 files changed, 85 insertions(+), 67 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index db17ba7c84..353fb92581 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * As an example, the following does a boolean inversion (i.e. NOT). * {{{ @@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression { /** * Called by unary expressions to generate a code block that returns null if its parent returns - * null, and if not not null, use `f` to generate the expression. + * null, and if not null, use `f` to generate the expression. * * @param f function that accepts the non-null evaluation result name of child and returns Java * code to compute the output. @@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: String => String): String = { - val eval = child.gen(ctx) + val childGen = child.gen(ctx) + val resultCode = f(childGen.value) + if (nullable) { - eval.code + s""" - boolean ${ev.isNull} = ${eval.isNull}; + val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) + s""" + ${childGen.code} + boolean ${ev.isNull} = ${childGen.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${eval.isNull}) { - ${f(eval.value)} - } + $nullSafeEval """ } else { ev.isNull = "false" - eval.code + s""" + s""" + ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${f(eval.value)} + $resultCode """ } } @@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression { ctx: CodegenContext, ev: ExprCode, f: (String, String) => String): String = { - val eval1 = left.gen(ctx) - val eval2 = right.gen(ctx) - val resultCode = f(eval1.value, eval2.value) + val leftGen = left.gen(ctx) + val rightGen = right.gen(ctx) + val resultCode = f(leftGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) { + rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + s""" - ${eval1.code} - boolean ${ev.isNull} = ${eval1.isNull}; + boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${eval2.code} - if (!${eval2.isNull}) { - $resultCode - } else { - ${ev.isNull} = true; - } - } + $nullSafeEval """ - } else { ev.isNull = "false" s""" - ${eval1.code} - ${eval2.code} + ${leftGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ @@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression { /** * Default behavior of evaluation according to the default nullability of TernaryExpression. - * If subclass of BinaryExpression override nullable, probably should also override this. + * If subclass of TernaryExpression override nullable, probably should also override this. */ override def eval(input: InternalRow): Any = { val exprs = children @@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression { sys.error(s"BinaryExpressions must override either eval or nullSafeEval") /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f accepts two variable names and returns Java code to compute the output. + * @param f accepts three variable names and returns Java code to compute the output. */ protected def defineCodeGen( ctx: CodegenContext, @@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression { } /** - * Short hand for generating binary evaluation code. + * Short hand for generating ternary evaluation code. * If either of the sub-expressions is null, the result of this computation * is assumed to be null. * - * @param f function that accepts the 2 non-null evaluation result names of children + * @param f function that accepts the 3 non-null evaluation result names of children * and returns Java code to compute the output. */ protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, f: (String, String, String) => String): String = { - val evals = children.map(_.gen(ctx)) - val resultCode = f(evals(0).value, evals(1).value, evals(2).value) + val leftGen = children(0).gen(ctx) + val midGen = children(1).gen(ctx) + val rightGen = children(2).gen(ctx) + val resultCode = f(leftGen.value, midGen.value, rightGen.value) + if (nullable) { + val nullSafeEval = + leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) { + midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) { + rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) { + s""" + ${ev.isNull} = false; // resultCode could change nullability. + $resultCode + """ + } + } + } + s""" - ${evals(0).code} boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${evals(0).isNull}) { - ${evals(1).code} - if (!${evals(1).isNull}) { - ${evals(2).code} - if (!${evals(2).isNull}) { - ${ev.isNull} = false; // resultCode could change nullability - $resultCode - } - } - } + $nullSafeEval """ } else { ev.isNull = "false" s""" - ${evals(0).code} - ${evals(1).code} - ${evals(2).code} + ${leftGen.code} + ${midGen.code} + ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $resultCode """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 21f9198073..a30aba1617 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -402,17 +402,37 @@ class CodegenContext { } /** - * Generates code for greater of two expressions. - * - * @param dataType data type of the expressions - * @param c1 name of the variable of expression 1's output - * @param c2 name of the variable of expression 2's output - */ + * Generates code for greater of two expressions. + * + * @param dataType data type of the expressions + * @param c1 name of the variable of expression 1's output + * @param c2 name of the variable of expression 2's output + */ def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match { case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2" case _ => s"(${genComp(dataType, c1, c2)}) > 0" } + /** + * Generates code to do null safe execution, i.e. only execute the code when the input is not + * null by adding null check if necessary. + * + * @param nullable used to decide whether we should add null check or not. + * @param isNull the code to check if the input is null. + * @param execute the code that should only be executed when the input is not null. + */ + def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = { + if (nullable) { + s""" + if (!$isNull) { + $execute + } + """ + } else { + "\n" + execute + } + } + /** * List of java data types that have special accessors and setters in [[InternalRow]]. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8480c3f9a1..36e1fa1176 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.gen(ctx) - childGen.code + generateNullCheck(child.nullable, childGen.isNull) { + childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) { computeHash(childGen.value, child.dataType, ev.value, ctx) } }.mkString("\n") @@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression """ } - private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = { - if (nullable) { - s""" - if (!$isNull) { - $execution - } - """ - } else { - "\n" + execution - } - } - private def nullSafeElementHash( input: String, index: String, @@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression ctx: CodegenContext): String = { val element = ctx.freshName("element") - generateNullCheck(nullable, s"$input.isNullAt($index)") { + ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") { s""" final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)}; ${computeHash(element, elementType, result, ctx)} -- cgit v1.2.3