aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-06-06 22:40:21 -0700
committerCheng Lian <lian@databricks.com>2016-06-06 22:40:21 -0700
commit0e0904a2fce3c4447c24f1752307b6d01ffbd0ad (patch)
tree75861951e2866cfb4788cce9b6ac0c4e3c7e4dbd
parentc409e23abd128dad33557025f1e824ef47e6222f (diff)
downloadspark-0e0904a2fce3c4447c24f1752307b6d01ffbd0ad.tar.gz
spark-0e0904a2fce3c4447c24f1752307b6d01ffbd0ad.tar.bz2
spark-0e0904a2fce3c4447c24f1752307b6d01ffbd0ad.zip
[SPARK-15632][SQL] Typed Filter should NOT change the Dataset schema
## What changes were proposed in this pull request? This PR makes sure the typed Filter doesn't change the Dataset schema. **Before the change:** ``` scala> val df = spark.range(0,9) scala> df.schema res12: org.apache.spark.sql.types.StructType = StructType(StructField(id,LongType,false)) scala> val afterFilter = df.filter(_=>true) scala> afterFilter.schema // !!! schema is CHANGED!!! Column name is changed from id to value, nullable is changed from false to true. res13: org.apache.spark.sql.types.StructType = StructType(StructField(value,LongType,true)) ``` SerializeFromObject and DeserializeToObject are inserted to wrap the Filter, and these two can possibly change the schema of Dataset. **After the change:** ``` scala> afterFilter.schema // schema is NOT changed. res47: org.apache.spark.sql.types.StructType = StructType(StructField(id,LongType,false)) ``` ## How was this patch tested? Unit test. Author: Sean Zhong <seanzhong@databricks.com> Closes #13529 from clockfly/spark-15632.
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala16
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala2
5 files changed, 31 insertions, 10 deletions
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
index 289c16aef4..63d87bfb6d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
@@ -57,7 +57,9 @@ class TypedFilterOptimizationSuite extends PlanTest {
comparePlans(optimized, expected)
}
- test("embed deserializer in filter condition if there is only one filter") {
+ // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules
+ // for typed filters.
+ ignore("embed deserializer in typed filter condition if there is only one filter") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: (Int, Int)) => i._1 > 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 96c871d034..6cbc27d91c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1944,11 +1944,11 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: T => Boolean): Dataset[T] = {
- val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
- val condition = Invoke(function, "apply", BooleanType, deserialized.output)
- val filter = Filter(condition, deserialized)
- withTypedPlan(CatalystSerde.serialize[T](filter))
+ val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil)
+ val filter = Filter(condition, logicalPlan)
+ withTypedPlan(filter)
}
/**
@@ -1961,11 +1961,11 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: FilterFunction[T]): Dataset[T] = {
- val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
- val condition = Invoke(function, "call", BooleanType, deserialized.output)
- val filter = Filter(condition, deserialized)
- withTypedPlan(CatalystSerde.serialize[T](filter))
+ val condition = Invoke(function, "call", BooleanType, deserializer :: Nil)
+ val filter = Filter(condition, logicalPlan)
+ withTypedPlan(filter)
}
/**
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 8354a5bdac..37577accfd 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -92,6 +92,19 @@ public class JavaDatasetSuite implements Serializable {
Assert.assertFalse(iter.hasNext());
}
+ // SPARK-15632: typed filter should preserve the underlying logical schema
+ @Test
+ public void testTypedFilterPreservingSchema() {
+ Dataset<Long> ds = spark.range(10);
+ Dataset<Long> ds2 = ds.filter(new FilterFunction<Long>() {
+ @Override
+ public boolean call(Long value) throws Exception {
+ return value > 3;
+ }
+ });
+ Assert.assertEquals(ds.schema(), ds2.schema());
+ }
+
@Test
public void testCommonOperation() {
List<String> data = Arrays.asList("hello", "world");
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index bf2b0a2c7c..11b52bdead 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -225,6 +225,12 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
"b")
}
+ test("SPARK-15632: typed filter should preserve the underlying logical schema") {
+ val ds = spark.range(10)
+ val ds2 = ds.filter(_ > 3)
+ assert(ds.schema.equals(ds2.schema))
+ }
+
test("foreach") {
val ds = Seq(("a", 1), ("b", 2), ("c", 3)).toDS()
val acc = sparkContext.longAccumulator
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 68f0ee864f..f26e5e7b69 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -97,7 +97,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = ds.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegenExec] &&
- p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[SerializeFromObjectExec]).isDefined)
+ p.asInstanceOf[WholeStageCodegenExec].child.isInstanceOf[FilterExec]).isDefined)
assert(ds.collect() === Array(0, 6))
}