diff options
author | Sean Zhong <seanzhong@databricks.com> | 2016-06-06 22:40:21 -0700 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-06-06 22:40:21 -0700 |
commit | 0e0904a2fce3c4447c24f1752307b6d01ffbd0ad (patch) | |
tree | 75861951e2866cfb4788cce9b6ac0c4e3c7e4dbd | |
parent | c409e23abd128dad33557025f1e824ef47e6222f (diff) | |
download | spark-0e0904a2fce3c4447c24f1752307b6d01ffbd0ad.tar.gz spark-0e0904a2fce3c4447c24f1752307b6d01ffbd0ad.tar.bz2 spark-0e0904a2fce3c4447c24f1752307b6d01ffbd0ad.zip |
[SPARK-15632][SQL] Typed Filter should NOT change the Dataset schema
## What changes were proposed in this pull request?
This PR makes sure the typed Filter doesn't change the Dataset schema.
**Before the change:**
```
scala> val df = spark.range(0,9)
scala> df.schema
res12: org.apache.spark.sql.types.StructType = StructType(StructField(id,LongType,false))
scala> val afterFilter = df.filter(_=>true)
scala> afterFilter.schema // !!! schema is CHANGED!!! Column name is changed from id to value, nullable is changed from false to true.
res13: org.apache.spark.sql.types.StructType = StructType(StructField(value,LongType,true))
```
SerializeFromObject and DeserializeToObject are inserted to wrap the Filter, and these two can possibly change the schema of Dataset.
**After the change:**
```
scala> afterFilter.schema // schema is NOT changed.
res47: org.apache.spark.sql.types.StructType = StructType(StructField(id,LongType,false))
```
## How was this patch tested?
Unit test.
Author: Sean Zhong <seanzhong@databricks.com>
Closes #13529 from clockfly/spark-15632.
5 files changed, 31 insertions, 10 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 289c16aef4..63d87bfb6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -57,7 +57,9 @@ class TypedFilterOptimizationSuite extends PlanTest { comparePlans(optimized, expected) } - test("embed deserializer in filter condition if there is only one filter") { + // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules + // for typed filters. + ignore("embed deserializer in typed filter condition if there is only one filter") { val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 96c871d034..6cbc27d91c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1944,11 +1944,11 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: T => Boolean): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) val function = Literal.create(func, ObjectType(classOf[T => Boolean])) - val condition = Invoke(function, "apply", BooleanType, deserialized.output) - val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil) + val filter = Filter(condition, logicalPlan) + withTypedPlan(filter) } /** @@ -1961,11 +1961,11 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: FilterFunction[T]): Dataset[T] = { - val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) - val condition = Invoke(function, "call", BooleanType, deserialized.output) - val filter = Filter(condition, deserialized) - withTypedPlan(CatalystSerde.serialize[T](filter)) + val condition = Invoke(function, "call", BooleanType, deserializer :: Nil) + val filter = Filter(condition, logicalPlan) + withTypedPlan(filter) } /** diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 8354a5bdac..37577accfd 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -92,6 +92,19 @@ public class JavaDatasetSuite implements Serializable { Assert.assertFalse(iter.hasNext()); } + // SPARK-15632: typed filter should preserve the underlying logical schema + @Test + public void testTypedFilterPreservingSchema() { + Dataset<Long> ds = spark.range(10); + Dataset<Long> ds2 = ds.filter(new FilterFunction<Long>() { + @Override + public boolean call(Long value) throws Exception { + return value > 3; + } + }); + Assert.assertEquals(ds.schema(), ds2.schema()); + } + @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index bf2b0a2c7c..11b52bdead 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -225,6 +225,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext { "b") } + test("SPARK-15632: typed filter should preserve the underlying logical schema") { + val ds = spark.range(10) + val ds2 = ds.filter(_ > 3) + assert(ds.schema.equals(ds2.schema)) + } + test("foreach") { val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS() val acc = sparkContext.longAccumulator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 68f0ee864f..f26e5e7b69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -97,7 +97,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { val plan = ds.queryExecution.executedPlan assert(plan.find(p => p.isInstanceOf[WholeStageCodegenExec] && - p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined) + p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined) assert(ds.collect() === Array(0, 6)) } |