aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala20
3 files changed, 37 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
index 8d8cc152ff..607c7c877c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala
@@ -69,8 +69,17 @@ class EquivalentExpressions {
*/
def addExprTree(root: Expression, ignoreLeaf: Boolean = true): Unit = {
val skip = root.isInstanceOf[LeafExpression] && ignoreLeaf
- // the children of CodegenFallback will not be used to generate code (call eval() instead)
- if (!skip && !addExpr(root) && !root.isInstanceOf[CodegenFallback]) {
+ // There are some special expressions that we should not recurse into children.
+ // 1. CodegenFallback: it's children will not be used to generate code (call eval() instead)
+ // 2. ReferenceToExpressions: it's kind of an explicit sub-expression elimination.
+ val shouldRecurse = root match {
+ // TODO: some expressions implements `CodegenFallback` but can still do codegen,
+ // e.g. `CaseWhen`, we should support them.
+ case _: CodegenFallback => false
+ case _: ReferenceToExpressions => false
+ case _ => true
+ }
+ if (!skip && !addExpr(root) && shouldRecurse) {
root.children.foreach(addExprTree(_, ignoreLeaf))
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 38ac13b208..d29c27c14b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -110,13 +110,17 @@ class CodegenContext {
}
def declareMutableStates(): String = {
- mutableStates.map { case (javaType, variableName, _) =>
+ // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
+ // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
+ mutableStates.distinct.map { case (javaType, variableName, _) =>
s"private $javaType $variableName;"
}.mkString("\n")
}
def initMutableStates(): String = {
- mutableStates.map(_._3).mkString("\n")
+ // It's possible that we add same mutable state twice, e.g. the `mergeExpressions` in
+ // `TypedAggregateExpression`, we should call `distinct` here to remove the duplicated ones.
+ mutableStates.distinct.map(_._3).mkString("\n")
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 0d84a594f7..6eae3ed7ad 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import scala.language.postfixOps
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.expressions.scala.typed
import org.apache.spark.sql.functions._
@@ -72,6 +73,16 @@ object NameAgg extends Aggregator[AggData, String, String] {
}
+object SeqAgg extends Aggregator[AggData, Seq[Int], Seq[Int]] {
+ def zero: Seq[Int] = Nil
+ def reduce(b: Seq[Int], a: AggData): Seq[Int] = a.a +: b
+ def merge(b1: Seq[Int], b2: Seq[Int]): Seq[Int] = b1 ++ b2
+ def finish(r: Seq[Int]): Seq[Int] = r
+ override def bufferEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+ override def outputEncoder: Encoder[Seq[Int]] = ExpressionEncoder()
+}
+
+
class ParameterizedTypeSum[IN, OUT : Numeric : Encoder](f: IN => OUT)
extends Aggregator[IN, OUT, OUT] {
@@ -212,4 +223,13 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
val df = Seq(1 -> "a", 2 -> "b", 3 -> "b").toDF("i", "j")
checkAnswer(df.groupBy($"j").agg(RowAgg.toColumn), Row("a", 1) :: Row("b", 5) :: Nil)
}
+
+ test("SPARK-14675: ClassFormatError when use Seq as Aggregator buffer type") {
+ val ds = Seq(AggData(1, "a"), AggData(2, "a")).toDS()
+
+ checkDataset(
+ ds.groupByKey(_.b).agg(SeqAgg.toColumn),
+ "a" -> Seq(1, 2)
+ )
+ }
}