aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTarek Auel <tarek.auel@googlemail.com>2015-07-21 15:47:40 -0700
committerMichael Armbrust <michael@databricks.com>2015-07-21 15:47:40 -0700
commitd4c7a7a3642a74ad40093c96c4bf45a62a470605 (patch)
tree386e4a13c97095fba35e3be45f8105d74ccc6701
parentc07838b5a9cdf96c0f49055ea1c397e0f0e915d2 (diff)
downloadspark-d4c7a7a3642a74ad40093c96c4bf45a62a470605.tar.gz
spark-d4c7a7a3642a74ad40093c96c4bf45a62a470605.tar.bz2
spark-d4c7a7a3642a74ad40093c96c4bf45a62a470605.zip
[SPARK-9154] [SQL] codegen StringFormat
Jira: https://issues.apache.org/jira/browse/SPARK-9154 fixes bug of #7546 marmbrus I can't reopen the other PR, because I didn't closed it. Can you trigger Jenkins? Author: Tarek Auel <tarek.auel@googlemail.com> Closes #7571 from tarekauel/SPARK-9154 and squashes the following commits: dcae272 [Tarek Auel] [SPARK-9154][SQL] build fix 1487602 [Tarek Auel] Merge remote-tracking branch 'upstream/master' into SPARK-9154 f512c5f [Tarek Auel] [SPARK-9154][SQL] build fix a943d3e [Tarek Auel] [SPARK-9154] implicit input cast, added tests for null, support for null primitives 10b4de8 [Tarek Auel] [SPARK-9154][SQL] codegen removed fallback trait cd8322b [Tarek Auel] [SPARK-9154][SQL] codegen string format 086caba [Tarek Auel] [SPARK-9154][SQL] codegen string format
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala42
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala18
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala10
4 files changed, 70 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index fe57d17f1e..1f18a6e9ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -526,7 +526,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression)
/**
* Returns the input formatted according do printf-style format strings
*/
-case class StringFormat(children: Expression*) extends Expression with CodegenFallback {
+case class StringFormat(children: Expression*) extends Expression with ImplicitCastInputTypes {
require(children.nonEmpty, "printf() should take at least 1 argument")
@@ -536,6 +536,10 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
private def format: Expression = children(0)
private def args: Seq[Expression] = children.tail
+ override def inputTypes: Seq[AbstractDataType] =
+ StringType :: List.fill(children.size - 1)(AnyDataType)
+
+
override def eval(input: InternalRow): Any = {
val pattern = format.eval(input)
if (pattern == null) {
@@ -551,6 +555,42 @@ case class StringFormat(children: Expression*) extends Expression with CodegenFa
}
}
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val pattern = children.head.gen(ctx)
+
+ val argListGen = children.tail.map(x => (x.dataType, x.gen(ctx)))
+ val argListCode = argListGen.map(_._2.code + "\n")
+
+ val argListString = argListGen.foldLeft("")((s, v) => {
+ val nullSafeString =
+ if (ctx.boxedType(v._1) != ctx.javaType(v._1)) {
+ // Java primitives get boxed in order to allow null values.
+ s"(${v._2.isNull}) ? (${ctx.boxedType(v._1)}) null : " +
+ s"new ${ctx.boxedType(v._1)}(${v._2.primitive})"
+ } else {
+ s"(${v._2.isNull}) ? null : ${v._2.primitive}"
+ }
+ s + "," + nullSafeString
+ })
+
+ val form = ctx.freshName("formatter")
+ val formatter = classOf[java.util.Formatter].getName
+ val sb = ctx.freshName("sb")
+ val stringBuffer = classOf[StringBuffer].getName
+ s"""
+ ${pattern.code}
+ boolean ${ev.isNull} = ${pattern.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${ev.isNull}) {
+ ${argListCode.mkString}
+ $stringBuffer $sb = new $stringBuffer();
+ $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US);
+ $form.format(${pattern.primitive}.toString() $argListString);
+ ${ev.primitive} = UTF8String.fromString($sb.toString());
+ }
+ """
+ }
+
override def prettyName: String = "printf"
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 96c540ab36..3c2d88731b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -351,18 +351,16 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
}
test("FORMAT") {
- val f = 'f.string.at(0)
- val d1 = 'd.int.at(1)
- val s1 = 's.int.at(2)
-
- val row1 = create_row("aa%d%s", 12, "cc")
- val row2 = create_row(null, 12, "cc")
- checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+ checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
checkEvaluation(StringFormat(Literal("aa")), "aa", create_row(null))
- checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a", row1)
+ checkEvaluation(StringFormat(Literal("aa%d%s"), Literal(123), Literal("a")), "aa123a")
+ checkEvaluation(StringFormat(Literal("aa%d%s"), 12, "cc"), "aa12cc")
- checkEvaluation(StringFormat(f, d1, s1), "aa12cc", row1)
- checkEvaluation(StringFormat(f, d1, s1), null, row2)
+ checkEvaluation(StringFormat(Literal.create(null, StringType), 12, "cc"), null)
+ checkEvaluation(
+ StringFormat(Literal("aa%d%s"), Literal.create(null, IntegerType), "cc"), "aanullcc")
+ checkEvaluation(
+ StringFormat(Literal("aa%d%s"), 12, Literal.create(null, StringType)), "aa12null")
}
test("INSTR") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index d94d733582..e5ff8ae7e3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1743,6 +1743,17 @@ object functions {
/**
* Format strings in printf-style.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def formatString(format: Column, arguments: Column*): Column = {
+ StringFormat((format +: arguments).map(_.expr): _*)
+ }
+
+ /**
+ * Format strings in printf-style.
* NOTE: `format` is the string value of the formatter, not column name.
*
* @group string_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index d1f855903c..3702e73b4e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -132,6 +132,16 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(
df.selectExpr("printf(a, b, c)"),
Row("aa123cc"))
+
+ val df2 = Seq(("aa%d%s".getBytes, 123, "cc")).toDF("a", "b", "c")
+
+ checkAnswer(
+ df2.select(formatString($"a", $"b", $"c"), formatString("aa%d%s", "b", "c")),
+ Row("aa123cc", "aa123cc"))
+
+ checkAnswer(
+ df2.selectExpr("printf(a, b, c)"),
+ Row("aa123cc"))
}
test("string instr function") {