From 49fb237081bbca0d811aa48aa06f4728fea62781 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2016 17:23:34 -0700 Subject: [SPARK-14270][SQL] whole stage codegen support for typed filter ## What changes were proposed in this pull request? We implement typed filter by `MapPartitions`, which doesn't work well with whole stage codegen. This PR use `Filter` to implement typed filter and we can get the whole stage codegen support for free. This PR also introduced `DeserializeToObject` and `SerializeFromObject`, to seperate serialization logic from object operator, so that it's eaiser to write optimization rules for adjacent object operators. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12061 from cloud-fan/whole-stage-codegen. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 + .../apache/spark/sql/catalyst/dsl/package.scala | 19 ++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 39 ++++++++++- .../spark/sql/catalyst/plans/logical/object.scala | 37 ++++++++++- .../optimizer/TypedFilterOptimizationSuite.scala | 74 +++++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 16 ++++- .../spark/sql/execution/SparkStrategies.scala | 4 ++ .../org/apache/spark/sql/execution/objects.scala | 67 +++++++++++++++++++ .../org/apache/spark/sql/DatasetBenchmark.scala | 76 +++++++++++++++++++--- .../scala/org/apache/spark/sql/QueryTest.scala | 2 + .../sql/execution/WholeStageCodegenSuite.scala | 21 +++++- 11 files changed, 342 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7bcba421fd..3555a6d7fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1670,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] { // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. case o: ObjectOperator => o + case d: DeserializeToObject => d + case s: SerializeFromObject => s case other => var stop = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 105947028d..1e7296664b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -166,6 +167,14 @@ package object dsl { case target => UnresolvedStar(Option(target)) } + def callFunction[T, U]( + func: T => U, + returnType: DataType, + argument: Expression): Expression = { + val function = Literal.create(func, ObjectType(classOf[T => U])) + Invoke(function, "apply", returnType, argument :: Nil) + } + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { @@ -270,6 +279,16 @@ package object dsl { def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def filter[T : Encoder](func: T => Boolean): LogicalPlan = { + val deserialized = logicalPlan.deserialize[T] + val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) + Filter(condition, deserialized).serialize[T] + } + + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) + + def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f581810c26..619514e8aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { EliminateSerialization) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: + Batch("Typed Filter Optimization", FixedPoint(100), + EmbedSerializerInFilter) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Batch("Subquery", Once, @@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] { child = childWithoutSerialization) case m @ MapElements(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => val childWithoutSerialization = child.withObjectOutput m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case d @ DeserializeToObject(_, s: SerializeFromObject) + if d.outputObjectType == s.inputObjectType => + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) } } @@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a + * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed + * the deserializer in filter condition to save the extra serialization at last. + */ +object EmbedSerializerInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) => + val numObjects = condition.collect { + case a: Attribute if a == d.output.head => a + }.length + + if (numObjects > 1) { + // If the filter condition references the object more than one times, we should not embed + // deserializer in it as the deserialization will happen many times and slow down the + // execution. + // TODO: we can still embed it if we can make sure subexpression elimination works here. + s + } else { + val newCondition = condition transform { + case a: Attribute if a == d.output.head => d.deserializer.child + } + Filter(newCondition, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index ec33a538a9..6df46189b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,7 +21,42 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} + +object CatalystSerde { + def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) + DeserializeToObject(Alias(deserializer, "obj")(), child) + } + + def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { + SerializeFromObject(encoderFor[T].namedExpressions, child) + } +} + +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + def outputObjectType: DataType = deserializer.dataType +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + def inputObjectType: DataType = child.output.head.dataType +} /** * A trait for logical operators that apply user defined functions to domain objects. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala new file mode 100644 index 0000000000..1fae64e3bc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType + +class TypedFilterOptimizationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("EliminateSerialization", FixedPoint(50), + EliminateSerialization) :: + Batch("EmbedSerializerInFilter", FixedPoint(50), + EmbedSerializerInFilter) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + + test("back to back filter") { + val input = LocalRelation('_1.int, '_2.int) + val f1 = (i: (Int, Int)) => i._1 > 0 + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = input.filter(f1).filter(f2).analyze + + val optimized = Optimize.execute(query) + + val expected = input.deserialize[(Int, Int)] + .where(callFunction(f1, BooleanType, 'obj)) + .select('obj.as("obj")) + .where(callFunction(f2, BooleanType, 'obj)) + .serialize[(Int, Int)].analyze + + comparePlans(optimized, expected) + } + + test("embed deserializer in filter condition if there is only one filter") { + val input = LocalRelation('_1.int, '_2.int) + val f = (i: (Int, Int)) => i._1 > 0 + + val query = input.filter(f).analyze + + val optimized = Optimize.execute(query) + + val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) + val condition = callFunction(f, BooleanType, deserializer) + val expected = input.where(condition).analyze + + comparePlans(optimized, expected) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2854d5f9da..2f6d8d109f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1879,7 +1879,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + def filter(func: T => Boolean): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[T => Boolean])) + val condition = Invoke(function, "apply", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: @@ -1890,7 +1896,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) + def filter(func: FilterFunction[T]): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) + val condition = Invoke(function, "call", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eee2b946e3..c15aaed365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,6 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.DeserializeToObject(deserializer, child) => + execution.DeserializeToObject(deserializer, planLater(child)) :: Nil + case logical.SerializeFromObject(serializer, child) => + execution.SerializeFromObject(serializer, planLater(child)) :: Nil case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil case logical.MapElements(f, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index f48f3f09c7..d2ab18ef0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -27,6 +27,73 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(deserializer, child.output)) + ctx.currentVars = input + val resultVars = bound.gen(ctx) :: Nil + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + iter.map(projection) + } + } +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = serializer.map { expr => + ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + } + ctx.currentVars = input + val resultVars = bound.map(_.gen(ctx)) + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = UnsafeProjection.create(serializer) + iter.map(projection) + } + } +} + /** * Helper functions for physical operators that work with user defined objects. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 6eb952445f..5f3dd906fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -28,16 +28,10 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) - + def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { import sqlContext.implicits._ - val numRows = 10000000 val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val numChains = 10 - val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) @@ -61,7 +55,7 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,6 +66,63 @@ object DatasetBenchmark { res.foreach(_ => Unit) } + benchmark + } + + def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("back-to-back filter", numRows) + + val func = (d: Data, i: Int) => d.l % (100L + i) == 0L + val funcs = 0.until(numChains).map { i => + (d: Data) => func(d, i) + } + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.filter(funcs(i)) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % (100L + i) === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.filter(funcs(i)) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark + } + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + val numRows = 10000000 + val numChains = 10 + + val benchmark = backToBackMap(sqlContext, numRows, numChains) + val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) + /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz @@ -82,5 +133,14 @@ object DatasetBenchmark { RDD 216 / 237 46.3 21.6 4.2X */ benchmark.run() + + /* + back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Dataset 585 / 628 17.1 58.5 1.0X + DataFrame 62 / 80 160.7 6.2 9.4X + RDD 205 / 220 48.7 20.5 2.8X + */ + benchmark2.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 48a077d0e5..826862835a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { @@ -204,6 +205,7 @@ abstract class QueryTest extends PlanTest { case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return + case Literal(_, _: ObjectType) => return } // bypass hive tests before we fix all corner cases in hive module. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f73ca887f1..4474cfcf6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} @@ -82,4 +81,22 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } + + test("typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined) + assert(ds.collect() === Array(0, 2, 4, 6, 8)) + } + + test("back-to-back typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) + assert(ds.collect() === Array(0, 6)) + } } -- cgit v1.2.3