aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-31 22:43:03 -0800
committerDavies Liu <davies.liu@gmail.com>2016-01-31 22:43:03 -0800
commitc1da4d421ab78772ffa52ad46e5bdfb4e5268f47 (patch)
treeda4be6e0c6142592d845daa5bc0928b7000c3769
parent5a8b978fabb60aa178274f86432c63680c8b351a (diff)
downloadspark-c1da4d421ab78772ffa52ad46e5bdfb4e5268f47.tar.gz
spark-c1da4d421ab78772ffa52ad46e5bdfb4e5268f47.tar.bz2
spark-c1da4d421ab78772ffa52ad46e5bdfb4e5268f47.zip
[SPARK-13093] [SQL] improve null check in nullSafeCodeGen for unary, binary and ternary expression
The current implementation is sub-optimal: * If an expression is always nullable, e.g. `Unhex`, we can still remove null check for children if they are not nullable. * If an expression has some non-nullable children, we can still remove null check for these children and keep null check for others. This PR improves this by making the null check elimination more fine-grained. Author: Wenchen Fan <wenchen@databricks.com> Closes #10987 from cloud-fan/null-check.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala104
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala32
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala16
3 files changed, 85 insertions, 67 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index db17ba7c84..353fb92581 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -320,7 +320,7 @@ abstract class UnaryExpression extends Expression {
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
- * null, and if not not null, use `f` to generate the expression.
+ * null, and if not null, use `f` to generate the expression.
*
* As an example, the following does a boolean inversion (i.e. NOT).
* {{{
@@ -340,7 +340,7 @@ abstract class UnaryExpression extends Expression {
/**
* Called by unary expressions to generate a code block that returns null if its parent returns
- * null, and if not not null, use `f` to generate the expression.
+ * null, and if not null, use `f` to generate the expression.
*
* @param f function that accepts the non-null evaluation result name of child and returns Java
* code to compute the output.
@@ -349,20 +349,23 @@ abstract class UnaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: String => String): String = {
- val eval = child.gen(ctx)
+ val childGen = child.gen(ctx)
+ val resultCode = f(childGen.value)
+
if (nullable) {
- eval.code + s"""
- boolean ${ev.isNull} = ${eval.isNull};
+ val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode)
+ s"""
+ ${childGen.code}
+ boolean ${ev.isNull} = ${childGen.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${eval.isNull}) {
- ${f(eval.value)}
- }
+ $nullSafeEval
"""
} else {
ev.isNull = "false"
- eval.code + s"""
+ s"""
+ ${childGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- ${f(eval.value)}
+ $resultCode
"""
}
}
@@ -440,29 +443,31 @@ abstract class BinaryExpression extends Expression {
ctx: CodegenContext,
ev: ExprCode,
f: (String, String) => String): String = {
- val eval1 = left.gen(ctx)
- val eval2 = right.gen(ctx)
- val resultCode = f(eval1.value, eval2.value)
+ val leftGen = left.gen(ctx)
+ val rightGen = right.gen(ctx)
+ val resultCode = f(leftGen.value, rightGen.value)
+
if (nullable) {
+ val nullSafeEval =
+ leftGen.code + ctx.nullSafeExec(left.nullable, leftGen.isNull) {
+ rightGen.code + ctx.nullSafeExec(right.nullable, rightGen.isNull) {
+ s"""
+ ${ev.isNull} = false; // resultCode could change nullability.
+ $resultCode
+ """
+ }
+ }
+
s"""
- ${eval1.code}
- boolean ${ev.isNull} = ${eval1.isNull};
+ boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${eval2.code}
- if (!${eval2.isNull}) {
- $resultCode
- } else {
- ${ev.isNull} = true;
- }
- }
+ $nullSafeEval
"""
-
} else {
ev.isNull = "false"
s"""
- ${eval1.code}
- ${eval2.code}
+ ${leftGen.code}
+ ${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
@@ -527,7 +532,7 @@ abstract class TernaryExpression extends Expression {
/**
* Default behavior of evaluation according to the default nullability of TernaryExpression.
- * If subclass of BinaryExpression override nullable, probably should also override this.
+ * If subclass of TernaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
val exprs = children
@@ -553,11 +558,11 @@ abstract class TernaryExpression extends Expression {
sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
/**
- * Short hand for generating binary evaluation code.
+ * Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
- * @param f accepts two variable names and returns Java code to compute the output.
+ * @param f accepts three variable names and returns Java code to compute the output.
*/
protected def defineCodeGen(
ctx: CodegenContext,
@@ -569,41 +574,46 @@ abstract class TernaryExpression extends Expression {
}
/**
- * Short hand for generating binary evaluation code.
+ * Short hand for generating ternary evaluation code.
* If either of the sub-expressions is null, the result of this computation
* is assumed to be null.
*
- * @param f function that accepts the 2 non-null evaluation result names of children
+ * @param f function that accepts the 3 non-null evaluation result names of children
* and returns Java code to compute the output.
*/
protected def nullSafeCodeGen(
ctx: CodegenContext,
ev: ExprCode,
f: (String, String, String) => String): String = {
- val evals = children.map(_.gen(ctx))
- val resultCode = f(evals(0).value, evals(1).value, evals(2).value)
+ val leftGen = children(0).gen(ctx)
+ val midGen = children(1).gen(ctx)
+ val rightGen = children(2).gen(ctx)
+ val resultCode = f(leftGen.value, midGen.value, rightGen.value)
+
if (nullable) {
+ val nullSafeEval =
+ leftGen.code + ctx.nullSafeExec(children(0).nullable, leftGen.isNull) {
+ midGen.code + ctx.nullSafeExec(children(1).nullable, midGen.isNull) {
+ rightGen.code + ctx.nullSafeExec(children(2).nullable, rightGen.isNull) {
+ s"""
+ ${ev.isNull} = false; // resultCode could change nullability.
+ $resultCode
+ """
+ }
+ }
+ }
+
s"""
- ${evals(0).code}
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${evals(0).isNull}) {
- ${evals(1).code}
- if (!${evals(1).isNull}) {
- ${evals(2).code}
- if (!${evals(2).isNull}) {
- ${ev.isNull} = false; // resultCode could change nullability
- $resultCode
- }
- }
- }
+ $nullSafeEval
"""
} else {
ev.isNull = "false"
s"""
- ${evals(0).code}
- ${evals(1).code}
- ${evals(2).code}
+ ${leftGen.code}
+ ${midGen.code}
+ ${rightGen.code}
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
$resultCode
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 21f9198073..a30aba1617 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -402,18 +402,38 @@ class CodegenContext {
}
/**
- * Generates code for greater of two expressions.
- *
- * @param dataType data type of the expressions
- * @param c1 name of the variable of expression 1's output
- * @param c2 name of the variable of expression 2's output
- */
+ * Generates code for greater of two expressions.
+ *
+ * @param dataType data type of the expressions
+ * @param c1 name of the variable of expression 1's output
+ * @param c2 name of the variable of expression 2's output
+ */
def genGreater(dataType: DataType, c1: String, c2: String): String = javaType(dataType) match {
case JAVA_BYTE | JAVA_SHORT | JAVA_INT | JAVA_LONG => s"$c1 > $c2"
case _ => s"(${genComp(dataType, c1, c2)}) > 0"
}
/**
+ * Generates code to do null safe execution, i.e. only execute the code when the input is not
+ * null by adding null check if necessary.
+ *
+ * @param nullable used to decide whether we should add null check or not.
+ * @param isNull the code to check if the input is null.
+ * @param execute the code that should only be executed when the input is not null.
+ */
+ def nullSafeExec(nullable: Boolean, isNull: String)(execute: String): String = {
+ if (nullable) {
+ s"""
+ if (!$isNull) {
+ $execute
+ }
+ """
+ } else {
+ "\n" + execute
+ }
+ }
+
+ /**
* List of java data types that have special accessors and setters in [[InternalRow]].
*/
val primitiveTypes =
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 8480c3f9a1..36e1fa1176 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -327,7 +327,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
ev.isNull = "false"
val childrenHash = children.map { child =>
val childGen = child.gen(ctx)
- childGen.code + generateNullCheck(child.nullable, childGen.isNull) {
+ childGen.code + ctx.nullSafeExec(child.nullable, childGen.isNull) {
computeHash(childGen.value, child.dataType, ev.value, ctx)
}
}.mkString("\n")
@@ -338,18 +338,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
"""
}
- private def generateNullCheck(nullable: Boolean, isNull: String)(execution: String): String = {
- if (nullable) {
- s"""
- if (!$isNull) {
- $execution
- }
- """
- } else {
- "\n" + execution
- }
- }
-
private def nullSafeElementHash(
input: String,
index: String,
@@ -359,7 +347,7 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
ctx: CodegenContext): String = {
val element = ctx.freshName("element")
- generateNullCheck(nullable, s"$input.isNullAt($index)") {
+ ctx.nullSafeExec(nullable, s"$input.isNullAt($index)") {
s"""
final ${ctx.javaType(elementType)} $element = ${ctx.getValue(input, elementType, index)};
${computeHash(element, elementType, result, ctx)}