diff options
author | Takuya UESHIN <ueshin@happy-camper.st> | 2016-05-19 22:55:44 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-05-19 22:55:44 -0700 |
commit | d5e1c5acde95158db38448526c8afad4a6d21dc2 (patch) | |
tree | e49c0106823ca961e29571e1f7f168c4dd964da0 | |
parent | e384c7fbb94cef3c18e8fa8d06159b76b88b5167 (diff) | |
download | spark-d5e1c5acde95158db38448526c8afad4a6d21dc2.tar.gz spark-d5e1c5acde95158db38448526c8afad4a6d21dc2.tar.bz2 spark-d5e1c5acde95158db38448526c8afad4a6d21dc2.zip |
[SPARK-15313][SQL] EmbedSerializerInFilter rule should keep exprIds of output of surrounded SerializeFromObject.
## What changes were proposed in this pull request?
The following code:
```
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
ds.filter(_._1 == "b").select(expr("_1").as[String]).foreach(println(_))
```
throws an Exception:
```
org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: _1#420
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87)
...
Cause: java.lang.RuntimeException: Couldn't find _1#420 in [_1#416,_2#417]
at scala.sys.package$.error(package.scala:27)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:94)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88)
at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87)
...
```
This is because `EmbedSerializerInFilter` rule drops the `exprId`s of output of surrounded `SerializeFromObject`.
The analyzed and optimized plans of the above example are as follows:
```
== Analyzed Logical Plan ==
_1: string
Project [_1#420]
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, scala.Tuple2]._1, true) AS _1#420,input[0, scala.Tuple2]._2 AS _2#421]
+- Filter <function1>.apply
+- DeserializeToObject newInstance(class scala.Tuple2), obj#419: scala.Tuple2
+- LocalRelation [_1#416,_2#417], [[0,1800000001,1,61],[0,1800000001,2,62],[0,1800000001,3,63]]
== Optimized Logical Plan ==
!Project [_1#420]
+- Filter <function1>.apply
+- LocalRelation [_1#416,_2#417], [[0,1800000001,1,61],[0,1800000001,2,62],[0,1800000001,3,63]]
```
This PR fixes `EmbedSerializerInFilter` rule to keep `exprId`s of output of surrounded `SerializeFromObject`.
The plans after this patch are as follows:
```
== Analyzed Logical Plan ==
_1: string
Project [_1#420]
+- SerializeFromObject [staticinvoke(class org.apache.spark.unsafe.types.UTF8String, StringType, fromString, input[0, scala.Tuple2]._1, true) AS _1#420,input[0, scala.Tuple2]._2 AS _2#421]
+- Filter <function1>.apply
+- DeserializeToObject newInstance(class scala.Tuple2), obj#419: scala.Tuple2
+- LocalRelation [_1#416,_2#417], [[0,1800000001,1,61],[0,1800000001,2,62],[0,1800000001,3,63]]
== Optimized Logical Plan ==
Project [_1#416]
+- Filter <function1>.apply
+- LocalRelation [_1#416,_2#417], [[0,1800000001,1,61],[0,1800000001,2,62],[0,1800000001,3,63]]
```
## How was this patch tested?
Existing tests and I added a test to check if `filter and then select` works.
Author: Takuya UESHIN <ueshin@happy-camper.st>
Closes #13096 from ueshin/issues/SPARK-15313.
3 files changed, 18 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 6825b65e2b..a6fb34cbfb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -109,7 +109,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, - EmbedSerializerInFilter) :: + EmbedSerializerInFilter, + RemoveAliasOnlyProject) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("OptimizeCodegen", Once, @@ -1611,7 +1612,14 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { val newCondition = condition transform { case a: Attribute if a == d.output.head => d.deserializer } - Filter(newCondition, d.child) + val filter = Filter(newCondition, d.child) + + // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. + // We will remove it later in RemoveAliasOnlyProject rule. + val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => + Alias(fout, fout.name)(exprId = sout.exprId) + } + Project(objAttrs, filter) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 1fae64e3bc..289c16aef4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -67,7 +67,7 @@ class TypedFilterOptimizationSuite extends PlanTest { val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) val condition = callFunction(f, BooleanType, deserializer) - val expected = input.where(condition).analyze + val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze comparePlans(optimized, expected) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 52e706285c..0ffbd6db12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -205,6 +205,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext { ("b", 2)) } + test("filter and then select") { + val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() + checkDataset( + ds.filter(_._1 == "b").select(expr("_1").as[String]), + ("b")) + } + test("foreach") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.longAccumulator |