aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-07-20 23:33:07 -0700
committerReynold Xin <rxin@databricks.com>2015-07-20 23:33:07 -0700
commit1ddd0f2f1688560f88470e312b72af04364e2d49 (patch)
tree14c4a5f5d58d3f20b606930283cb4bc8b46d0d73
parent228ab65a4eeef8a42eb4713edf72b50590f63176 (diff)
downloadspark-1ddd0f2f1688560f88470e312b72af04364e2d49.tar.gz
spark-1ddd0f2f1688560f88470e312b72af04364e2d49.tar.bz2
spark-1ddd0f2f1688560f88470e312b72af04364e2d49.zip
[SPARK-9161][SQL] codegen FormatNumber
Jira https://issues.apache.org/jira/browse/SPARK-9161 Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7545 from tarekauel/SPARK-9161 and squashes the following commits: 21425c8 [Tarek Auel] [SPARK-9161][SQL] codegen FormatNumber
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala68
1 files changed, 54 insertions, 14 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 438215e8e6..92fefe1585 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -902,22 +902,15 @@ case class FormatNumber(x: Expression, d: Expression)
@transient
private val numberFormat: DecimalFormat = new DecimalFormat("")
- override def eval(input: InternalRow): Any = {
- val xObject = x.eval(input)
- if (xObject == null) {
+ override protected def nullSafeEval(xObject: Any, dObject: Any): Any = {
+ val dValue = dObject.asInstanceOf[Int]
+ if (dValue < 0) {
return null
}
- val dObject = d.eval(input)
-
- if (dObject == null || dObject.asInstanceOf[Int] < 0) {
- return null
- }
- val dValue = dObject.asInstanceOf[Int]
-
if (dValue != lastDValue) {
// construct a new DecimalFormat only if a new dValue
- pattern.delete(0, pattern.length())
+ pattern.delete(0, pattern.length)
pattern.append("#,###,###,###,###,###,##0")
// decimal place
@@ -930,9 +923,10 @@ case class FormatNumber(x: Expression, d: Expression)
pattern.append("0")
}
}
- val dFormat = new DecimalFormat(pattern.toString())
- lastDValue = dValue;
- numberFormat.applyPattern(dFormat.toPattern())
+ val dFormat = new DecimalFormat(pattern.toString)
+ lastDValue = dValue
+
+ numberFormat.applyPattern(dFormat.toPattern)
}
x.dataType match {
@@ -947,6 +941,52 @@ case class FormatNumber(x: Expression, d: Expression)
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ nullSafeCodeGen(ctx, ev, (num, d) => {
+
+ def typeHelper(p: String): String = {
+ x.dataType match {
+ case _ : DecimalType => s"""$p.toJavaBigDecimal()"""
+ case _ => s"$p"
+ }
+ }
+
+ val sb = classOf[StringBuffer].getName
+ val df = classOf[DecimalFormat].getName
+ val lastDValue = ctx.freshName("lastDValue")
+ val pattern = ctx.freshName("pattern")
+ val numberFormat = ctx.freshName("numberFormat")
+ val i = ctx.freshName("i")
+ val dFormat = ctx.freshName("dFormat")
+ ctx.addMutableState("int", lastDValue, s"$lastDValue = -100;")
+ ctx.addMutableState(sb, pattern, s"$pattern = new $sb();")
+ ctx.addMutableState(df, numberFormat, s"""$numberFormat = new $df("");""")
+
+ s"""
+ if ($d >= 0) {
+ $pattern.delete(0, $pattern.length());
+ if ($d != $lastDValue) {
+ $pattern.append("#,###,###,###,###,###,##0");
+
+ if ($d > 0) {
+ $pattern.append(".");
+ for (int $i = 0; $i < $d; $i++) {
+ $pattern.append("0");
+ }
+ }
+ $df $dFormat = new $df($pattern.toString());
+ $lastDValue = $d;
+ $numberFormat.applyPattern($dFormat.toPattern());
+ ${ev.primitive} = UTF8String.fromString($numberFormat.format(${typeHelper(num)}));
+ }
+ } else {
+ ${ev.primitive} = null;
+ ${ev.isNull} = true;
+ }
+ """
+ })
+ }
+
override def prettyName: String = "format_number"
}