aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala16
2 files changed, 35 insertions, 2 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 48d70099b6..688c77d3ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1597,7 +1597,19 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
*/
object EmbedSerializerInFilter extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
+ case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject))
+ // SPARK-15632: Conceptually, filter operator should never introduce schema change. This
+ // optimization rule also relies on this assumption. However, Dataset typed filter operator
+ // does introduce schema changes in some cases. Thus, we only enable this optimization when
+ //
+ // 1. either input and output schemata are exactly the same, or
+ // 2. both input and output schemata are single-field schema and share the same type.
+ //
+ // The 2nd case is included because encoders for primitive types always have only a single
+ // field with hard-coded field name "value".
+ // TODO Cleans this up after fixing SPARK-15632.
+ if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) =>
+
val numObjects = condition.collect {
case a: Attribute if a == d.output.head => a
}.length
@@ -1622,6 +1634,13 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] {
Project(objAttrs, filter)
}
}
+
+ def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = {
+ (lhs, rhs) match {
+ case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType
+ case _ => false
+ }
+ }
}
/**
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index e395007999..8fc4dc9f17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -706,7 +706,7 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
val dataset = Seq(1, 2, 3).toDS()
dataset.createOrReplaceTempView("tempView")
- // Overrrides the existing temporary view with same name
+ // Overrides the existing temporary view with same name
// No exception should be thrown here.
dataset.createOrReplaceTempView("tempView")
@@ -769,6 +769,20 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
checkShowString(ds, expected)
}
+
+ test(
+ "SPARK-15112: EmbedDeserializerInFilter should not optimize plan fragment that changes schema"
+ ) {
+ val ds = Seq(1 -> "foo", 2 -> "bar").toDF("b", "a").as[ClassData]
+
+ assertResult(Seq(ClassData("foo", 1), ClassData("bar", 2))) {
+ ds.collect().toSeq
+ }
+
+ assertResult(Seq(ClassData("bar", 2))) {
+ ds.filter(_.b > 1).collect().toSeq
+ }
+ }
}
case class Generic[T](id: T, value: Double)