diff options
author | Tarek Auel <tarek.auel@googlemail.com> | 2015-07-20 09:35:45 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-20 09:35:45 -0700 |
commit | 5112b7f58b9b8031ff79b9184dafe12b71ba1f79 (patch) | |
tree | ef54633ab56f0d2ef145b5483658fd69d0e5cb03 /sql/catalyst | |
parent | d0b4e93f7e92ea59058cc457a5586a4d9a596d71 (diff) | |
download | spark-5112b7f58b9b8031ff79b9184dafe12b71ba1f79.tar.gz spark-5112b7f58b9b8031ff79b9184dafe12b71ba1f79.tar.bz2 spark-5112b7f58b9b8031ff79b9184dafe12b71ba1f79.zip |
[SPARK-9153][SQL] codegen StringLPad/StringRPad
Jira: https://issues.apache.org/jira/browse/SPARK-9153
Author: Tarek Auel <tarek.auel@googlemail.com>
Closes #7527 from tarekauel/SPARK-9153 and squashes the following commits:
3840c6b [Tarek Auel] [SPARK-9153] removed codegen fallback
92b6a5d [Tarek Auel] [SPARK-9153] codegen lpad/rpad
Diffstat (limited to 'sql/catalyst')
2 files changed, 58 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 5f8ac716f7..6608036f01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -401,7 +401,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -432,6 +432,31 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val lenGen = len.gen(ctx) + val strGen = str.gen(ctx) + val padGen = pad.gen(ctx) + + s""" + ${lenGen.code} + boolean ${ev.isNull} = ${lenGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${padGen.code} + if (!${padGen.isNull}) { + ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "lpad" } @@ -439,7 +464,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends Expression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil override def foldable: Boolean = children.forall(_.foldable) @@ -470,6 +495,31 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) } } + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val lenGen = len.gen(ctx) + val strGen = str.gen(ctx) + val padGen = pad.gen(ctx) + + s""" + ${lenGen.code} + boolean ${ev.isNull} = ${lenGen.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${ev.isNull}) { + ${strGen.code} + if (!${strGen.isNull}) { + ${padGen.code} + if (!${padGen.isNull}) { + ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive}); + } else { + ${ev.isNull} = true; + } + } else { + ${ev.isNull} = true; + } + } + """ + } + override def prettyName: String = "rpad" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 96f433be8b..d5731229df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -413,18 +413,24 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val row1 = create_row("hi", 5, "??") val row2 = create_row("hi", 1, "?") val row3 = create_row(null, 1, "?") + val row4 = create_row("hi", null, "?") + val row5 = create_row("hi", 1, null) checkEvaluation(StringLPad(Literal("hi"), Literal(5), Literal("??")), "???hi", row1) checkEvaluation(StringLPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) checkEvaluation(StringLPad(s1, s2, s3), "???hi", row1) checkEvaluation(StringLPad(s1, s2, s3), "h", row2) checkEvaluation(StringLPad(s1, s2, s3), null, row3) + checkEvaluation(StringLPad(s1, s2, s3), null, row4) + checkEvaluation(StringLPad(s1, s2, s3), null, row5) checkEvaluation(StringRPad(Literal("hi"), Literal(5), Literal("??")), "hi???", row1) checkEvaluation(StringRPad(Literal("hi"), Literal(1), Literal("??")), "h", row1) checkEvaluation(StringRPad(s1, s2, s3), "hi???", row1) checkEvaluation(StringRPad(s1, s2, s3), "h", row2) checkEvaluation(StringRPad(s1, s2, s3), null, row3) + checkEvaluation(StringRPad(s1, s2, s3), null, row4) + checkEvaluation(StringRPad(s1, s2, s3), null, row5) } test("REPEAT") { |