diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala | 27 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala | 37 |
2 files changed, 58 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 127797c097..6c75a7a502 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -63,15 +63,30 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) - val childrenVars = childrenGen.zip(children).map { - case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) - } + val (classChildrenVars, initClassChildrenVars) = childrenGen.zip(children).map { + case (childGen, child) => + // SPARK-18125: The children vars are local variables. If the result expression uses + // splitExpression, those variables cannot be accessed so compilation fails. + // To fix it, we use class variables to hold those local variables. + val classChildVarName = ctx.freshName("classChildVar") + val classChildVarIsNull = ctx.freshName("classChildVarIsNull") + ctx.addMutableState(ctx.javaType(child.dataType), classChildVarName, "") + ctx.addMutableState("boolean", classChildVarIsNull, "") + + val classChildVar = + LambdaVariable(classChildVarName, classChildVarIsNull, child.dataType) + + val initCode = s"${classChildVar.value} = ${childGen.value};\n" + + s"${classChildVar.isNull} = ${childGen.isNull};" + + (classChildVar, initCode) + }.unzip val resultGen = result.transform { - case b: BoundReference => childrenVars(b.ordinal) + case b: BoundReference => classChildrenVars(b.ordinal) }.genCode(ctx) - ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code, - isNull = resultGen.isNull, value = resultGen.value) + ExprCode(code = childrenGen.map(_.code).mkString("\n") + initClassChildrenVars.mkString("\n") + + resultGen.code, isNull = resultGen.isNull, value = resultGen.value) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 6fa7b04877..a8dd422aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -923,6 +923,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext { .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() }) } + test("SPARK-18125: Spark generated code causes CompileException") { + val data = Array( + Route("a", "b", 1), + Route("a", "b", 2), + Route("a", "c", 2), + Route("a", "d", 10), + Route("b", "a", 1), + Route("b", "a", 5), + Route("b", "c", 6)) + val ds = sparkContext.parallelize(data).toDF.as[Route] + + val grped = ds.map(r => GroupedRoutes(r.src, r.dest, Seq(r))) + .groupByKey(r => (r.src, r.dest)) + .reduceGroups { (g1: GroupedRoutes, g2: GroupedRoutes) => + GroupedRoutes(g1.src, g1.dest, g1.routes ++ g2.routes) + }.map(_._2) + + val expected = Seq( + GroupedRoutes("a", "d", Seq(Route("a", "d", 10))), + GroupedRoutes("b", "c", Seq(Route("b", "c", 6))), + GroupedRoutes("a", "b", Seq(Route("a", "b", 1), Route("a", "b", 2))), + GroupedRoutes("b", "a", Seq(Route("b", "a", 1), Route("b", "a", 5))), + GroupedRoutes("a", "c", Seq(Route("a", "c", 2))) + ) + + implicit def ordering[GroupedRoutes]: Ordering[GroupedRoutes] = new Ordering[GroupedRoutes] { + override def compare(x: GroupedRoutes, y: GroupedRoutes): Int = { + x.toString.compareTo(y.toString) + } + } + + checkDatasetUnorderly(grped, expected: _*) + } + test("SPARK-18189: Fix serialization issue in KeyValueGroupedDataset") { val resultValue = 12345 val keyValueGrouped = Seq((1, 2), (3, 4)).toDS().groupByKey(_._1) @@ -1071,3 +1105,6 @@ object DatasetTransform { ds.map(_ + 1) } } + +case class Route(src: String, dest: String, cost: Int) +case class GroupedRoutes(src: String, dest: String, routes: Seq[Route]) |