aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala37
2 files changed, 58 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 127797c097..6c75a7a502 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -63,15 +63,30 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childrenGen = children.map(_.genCode(ctx))
- val childrenVars = childrenGen.zip(children).map {
- case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType)
- }
+ val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map {
+ case (childGen, child) =>
+ // SPARK-18125: The children vars are local variables. If the result expression uses
+ // splitExpression, those variables cannot be accessed so compilation fails.
+ // To fix it, we use class variables to hold those local variables.
+ val classChildVarName = ctx.freshName("classChildVar")
+ val classChildVarIsNull = ctx.freshName("classChildVarIsNull")
+ ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "")
+ ctx.addMutableState("boolean", classChildVarIsNull, "")
+
+ val classChildVar =
+ LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType)
+
+ val initCode = s"${classChildVar.value} = ${childGen.value};\n" +
+ s"${classChildVar.isNull} = ${childGen.isNull};"
+
+ (classChildVar, initCode)
+ }.unzip
val resultGen = result.transform {
- case b: BoundReference => childrenVars(b.ordinal)
+ case b: BoundReference => classChildrenVars(b.ordinal)
}.genCode(ctx)
- ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code,
- isNull = resultGen.isNull, value = resultGen.value)
+ ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") +
+ resultGen.code, isNull = resultGen.isNull, value = resultGen.value)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 6fa7b04877..a8dd422aa0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -923,6 +923,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
.groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
}
+ test("SPARK-18125: Spark generated code causes CompileException") {
+ val data = Array(
+ Route("a", "b", 1),
+ Route("a", "b", 2),
+ Route("a", "c", 2),
+ Route("a", "d", 10),
+ Route("b", "a", 1),
+ Route("b", "a", 5),
+ Route("b", "c", 6))
+ val ds = sparkContext.parallelize(data).toDF.as[Route]
+
+ val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r)))
+ .groupByKey(r => (r.src, r.dest))
+ .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) =>
+ GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes)
+ }.map(_._2)
+
+ val expected = Seq(
+ GroupedRoutes("a", "d", Seq(Route("a", "d", 10))),
+ GroupedRoutes("b", "c", Seq(Route("b", "c", 6))),
+ GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))),
+ GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))),
+ GroupedRoutes("a", "c", Seq(Route("a", "c", 2)))
+ )
+
+ implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] {
+ override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = {
+ x.toString.compareTo(y.toString)
+ }
+ }
+
+ checkDatasetUnorderly(grped, expected: _*)
+ }
+
test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") {
val resultValue = 12345
val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1)
@@ -1071,3 +1105,6 @@ object DatasetTransform {
ds.map(_ + 1)
}
}
+
+case class Route(src: String, dest: String, cost: Int)
+case class GroupedRoutes(src: String, dest: String, routes: Seq[Route])