aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala13
2 files changed, 26 insertions, 0 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index aa15f4a823..b53c0b5bec 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -727,6 +727,19 @@ object FoldablePropagation extends Rule[LogicalPlan] {
case j @ Join(_, _, LeftOuter | RightOuter | FullOuter, _) =>
stop = true
j
+
+ // These 3 operators take attributes as constructor parameters, and these attributes
+ // can't be replaced by alias.
+ case m: MapGroups =>
+ stop = true
+ m
+ case f: FlatMapGroupsInR =>
+ stop = true
+ f
+ case c: CoGroup =>
+ stop = true
+ c
+
case p: LogicalPlan if !stop => p.transformExpressions {
case a: AttributeReference if foldableMap.contains(a) =>
foldableMap(a)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 88fb1472b6..8ce6ea66b6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -878,6 +878,19 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val ds = spark.createDataset(data)(enc)
checkDataset(ds, (("a", "b"), "c"), (null, "d"))
}
+
+ test("SPARK-16995: flat mapping on Dataset containing a column created with lit/expr") {
+ val df = Seq("1").toDF("a")
+
+ import df.sparkSession.implicits._
+
+ checkDataset(
+ df.withColumn("b", lit(0)).as[ClassData]
+ .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
+ checkDataset(
+ df.withColumn("b", expr("0")).as[ClassData]
+ .groupByKey(_.a).flatMapGroups { case (x, iter) => List[Int]() })
+ }
}
case class Generic[T](id: T, value: Double)