aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2017-01-09 14:25:38 -0800
committerJosh Rosen <joshrosen@databricks.com>2017-01-09 14:25:38 -0800
commitfaabe69cc081145f43f9c68db1a7a8c5c39684fb (patch)
tree4b064365607e2bcb8b0d8d5160e65a7dce0c45ce
parent15c2bd01b03b1a07f10779f68118cd28f2c62c9a (diff)
downloadspark-faabe69cc081145f43f9c68db1a7a8c5c39684fb.tar.gz
spark-faabe69cc081145f43f9c68db1a7a8c5c39684fb.tar.bz2
spark-faabe69cc081145f43f9c68db1a7a8c5c39684fb.zip
[SPARK-18952] Regex strings not properly escaped in codegen for aggregations
## What changes were proposed in this pull request? If I use the function regexp_extract, and then in my regex string, use `\`, i.e. escape character, this fails codegen, because the `\` character is not properly escaped when codegen'd. Example stack trace: ``` /* 059 */ private int maxSteps = 2; /* 060 */ private int numRows = 0; /* 061 */ private org.apache.spark.sql.types.StructType keySchema = new org.apache.spark.sql.types.StructType().add("date_format(window#325.start, yyyy-MM-dd HH:mm)", org.apache.spark.sql.types.DataTypes.StringType) /* 062 */ .add("regexp_extract(source#310.description, ([a-zA-Z]+)\[.*, 1)", org.apache.spark.sql.types.DataTypes.StringType); /* 063 */ private org.apache.spark.sql.types.StructType valueSchema = new org.apache.spark.sql.types.StructType().add("sum", org.apache.spark.sql.types.DataTypes.LongType); /* 064 */ private Object emptyVBase; ... org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 62, Column 58: Invalid escape sequence at org.codehaus.janino.Scanner.scanLiteralCharacter(Scanner.java:918) at org.codehaus.janino.Scanner.produce(Scanner.java:604) at org.codehaus.janino.Parser.peekRead(Parser.java:3239) at org.codehaus.janino.Parser.parseArguments(Parser.java:3055) at org.codehaus.janino.Parser.parseSelector(Parser.java:2914) at org.codehaus.janino.Parser.parseUnaryExpression(Parser.java:2617) at org.codehaus.janino.Parser.parseMultiplicativeExpression(Parser.java:2573) at org.codehaus.janino.Parser.parseAdditiveExpression(Parser.java:2552) ``` In the codegend expression, the literal should use `\\` instead of `\` A similar problem was solved here: https://github.com/apache/spark/pull/15156. ## How was this patch tested? Regression test in `DataFrameAggregationSuite` Author: Burak Yavuz <brkyvz@gmail.com> Closes #16361 from brkyvz/reg-break.
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala9
3 files changed, 23 insertions, 10 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
index a77e178546..9316ebcdf1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/RowBasedHashMapGenerator.scala
@@ -43,28 +43,30 @@ class RowBasedHashMapGenerator(
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
groupingKeySchema, bufferSchema) {
- protected def initializeAggregateHashMap(): String = {
+ override protected def initializeAggregateHashMap(): String = {
val generatedKeySchema: String =
s"new org.apache.spark.sql.types.StructType()" +
groupingKeySchema.map { key =>
+ val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedValueSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
+ val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
index 7418df90b8..0c40417db0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala
@@ -48,28 +48,30 @@ class VectorizedHashMapGenerator(
extends HashMapGenerator (ctx, aggregateExpressions, generatedClassName,
groupingKeySchema, bufferSchema) {
- protected def initializeAggregateHashMap(): String = {
+ override protected def initializeAggregateHashMap(): String = {
val generatedSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
(groupingKeySchema ++ bufferSchema).map { key =>
+ val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
val generatedAggBufferSchema: String =
s"new org.apache.spark.sql.types.StructType()" +
bufferSchema.map { key =>
+ val keyName = ctx.addReferenceMinorObj(key.name)
key.dataType match {
case d: DecimalType =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.createDecimalType(
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.createDecimalType(
|${d.precision}, ${d.scale}))""".stripMargin
case _ =>
- s""".add("${key.name}", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
+ s""".add("$keyName", org.apache.spark.sql.types.DataTypes.${key.dataType})"""
}
}.mkString("\n").concat(";")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index 645175900f..7853b22fec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -97,6 +97,15 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
)
}
+ test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") {
+ val df = Seq(("some[thing]", "random-string")).toDF("key", "val")
+
+ checkAnswer(
+ df.groupBy(regexp_extract('key, "([a-z]+)\\[", 1)).count(),
+ Row("some", 1) :: Nil
+ )
+ }
+
test("rollup") {
checkAnswer(
courseSales.rollup("course", "year").sum("earnings"),