aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-30 08:08:15 -0700
committerDavies Liu <davies@databricks.com>2015-06-30 08:08:15 -0700
commit865a834e51ac3074811a11fd99a36d942f7f7de8 (patch)
treedb8186903d6581f71b63552c0b8c2eb66ce42caa /sql
parent08fab4843845136358f3a7251e8d90135126b419 (diff)
downloadspark-865a834e51ac3074811a11fd99a36d942f7f7de8.tar.gz
spark-865a834e51ac3074811a11fd99a36d942f7f7de8.tar.bz2
spark-865a834e51ac3074811a11fd99a36d942f7f7de8.zip
[SPARK-8723] [SQL] improve divide and remainder code gen
We can avoid execution of both left and right expression by null and zero check. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7111 from cloud-fan/cg and squashes the following commits: d6b12ef [Wenchen Fan] improve divide and remainder code gen
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala54
1 files changed, 36 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index ae765c1653..5363b35568 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -216,23 +216,32 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
- val test = if (left.dataType.isInstanceOf[DecimalType]) {
+ val isZero = if (dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitive} == 0"
}
- val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol "
- val javaType = ctx.javaType(left.dataType)
- eval1.code + eval2.code +
- s"""
+ val javaType = ctx.javaType(dataType)
+ val divide = if (dataType.isInstanceOf[DecimalType]) {
+ s"${eval1.primitive}.$decimalMethod(${eval2.primitive})"
+ } else {
+ s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})"
+ }
+ s"""
+ ${eval2.code}
boolean ${ev.isNull} = false;
- ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
- if (${eval1.isNull} || ${eval2.isNull} || $test) {
+ $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive}));
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = $divide;
+ }
}
- """
+ """
}
}
@@ -273,23 +282,32 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val eval1 = left.gen(ctx)
val eval2 = right.gen(ctx)
- val test = if (left.dataType.isInstanceOf[DecimalType]) {
+ val isZero = if (dataType.isInstanceOf[DecimalType]) {
s"${eval2.primitive}.isZero()"
} else {
s"${eval2.primitive} == 0"
}
- val method = if (left.dataType.isInstanceOf[DecimalType]) s".$decimalMethod" else s" $symbol "
- val javaType = ctx.javaType(left.dataType)
- eval1.code + eval2.code +
- s"""
+ val javaType = ctx.javaType(dataType)
+ val remainder = if (dataType.isInstanceOf[DecimalType]) {
+ s"${eval1.primitive}.$decimalMethod(${eval2.primitive})"
+ } else {
+ s"($javaType)(${eval1.primitive} $symbol ${eval2.primitive})"
+ }
+ s"""
+ ${eval2.code}
boolean ${ev.isNull} = false;
- ${ctx.javaType(left.dataType)} ${ev.primitive} = ${ctx.defaultValue(left.dataType)};
- if (${eval1.isNull} || ${eval2.isNull} || $test) {
+ $javaType ${ev.primitive} = ${ctx.defaultValue(javaType)};
+ if (${eval2.isNull} || $isZero) {
${ev.isNull} = true;
} else {
- ${ev.primitive} = ($javaType) (${eval1.primitive}$method(${eval2.primitive}));
+ ${eval1.code}
+ if (${eval1.isNull}) {
+ ${ev.isNull} = true;
+ } else {
+ ${ev.primitive} = $remainder;
+ }
}
- """
+ """
}
}