aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-07 17:23:34 -0700
committerYin Huai <yhuai@databricks.com>2016-04-07 17:23:34 -0700
commit49fb237081bbca0d811aa48aa06f4728fea62781 (patch)
tree674f92f37feb32ef62c2905169d9c435ad337500 /sql
parentae1db91d158d1ae62a0ab7ea74467679ca050101 (diff)
downloadspark-49fb237081bbca0d811aa48aa06f4728fea62781.tar.gz
spark-49fb237081bbca0d811aa48aa06f4728fea62781.tar.bz2
spark-49fb237081bbca0d811aa48aa06f4728fea62781.zip
[SPARK-14270][SQL] whole stage codegen support for typed filter
## What changes were proposed in this pull request? We implement typed filter by `MapPartitions`, which doesn't work well with whole stage codegen. This PR use `Filter` to implement typed filter and we can get the whole stage codegen support for free. This PR also introduced `DeserializeToObject` and `SerializeFromObject`, to seperate serialization logic from object operator, so that it's eaiser to write optimization rules for adjacent object operators. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12061 from cloud-fan/whole-stage-codegen.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala19
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala39
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala74
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala76
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala21
11 files changed, 342 insertions, 15 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 7bcba421fd..3555a6d7fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1670,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] {
// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
case o: ObjectOperator => o
+ case d: DeserializeToObject => d
+ case s: SerializeFromObject => s
case other =>
var stop = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 105947028d..1e7296664b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp}
import scala.language.implicitConversions
+import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -166,6 +167,14 @@ package object dsl {
case target => UnresolvedStar(Option(target))
}
+ def callFunction[T, U](
+ func: T => U,
+ returnType: DataType,
+ argument: Expression): Expression = {
+ val function = Literal.create(func, ObjectType(classOf[T => U]))
+ Invoke(function, "apply", returnType, argument :: Nil)
+ }
+
implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name }
// TODO more implicit class for literal?
implicit class DslString(val s: String) extends ImplicitOperators {
@@ -270,6 +279,16 @@ package object dsl {
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
+ def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
+ val deserialized = logicalPlan.deserialize[T]
+ val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
+ Filter(condition, deserialized).serialize[T]
+ }
+
+ def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
+
+ def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan)
+
def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan)
def join(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index f581810c26..619514e8aa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] {
EliminateSerialization) ::
Batch("Decimal Optimizations", FixedPoint(100),
DecimalAggregates) ::
+ Batch("Typed Filter Optimization", FixedPoint(100),
+ EmbedSerializerInFilter) ::
Batch("LocalRelation", FixedPoint(100),
ConvertToLocalRelation) ::
Batch("Subquery", Once,
@@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] {
child = childWithoutSerialization)
case m @ MapElements(_, deserializer, _, child: ObjectOperator)
- if !deserializer.isInstanceOf[Attribute] &&
- deserializer.dataType == child.outputObject.dataType =>
+ if !deserializer.isInstanceOf[Attribute] &&
+ deserializer.dataType == child.outputObject.dataType =>
val childWithoutSerialization = child.withObjectOutput
m.copy(
deserializer = childWithoutSerialization.output.head,
child = childWithoutSerialization)
+
+ case d @ DeserializeToObject(_, s: SerializeFromObject)
+ if d.outputObjectType == s.inputObjectType =>
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
}
}
@@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] {
}
}
}
+
+/**
+ * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
+ * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
+ * the deserializer in filter condition to save the extra serialization at last.
+ */
+object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) =>
+ val numObjects = condition.collect {
+ case a: Attribute if a == d.output.head => a
+ }.length
+
+ if (numObjects > 1) {
+ // If the filter condition references the object more than one times, we should not embed
+ // deserializer in it as the deserialization will happen many times and slow down the
+ // execution.
+ // TODO: we can still embed it if we can make sure subexpression elimination works here.
+ s
+ } else {
+ val newCondition = condition transform {
+ case a: Attribute if a == d.output.head => d.deserializer.child
+ }
+ Filter(newCondition, d.child)
+ }
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index ec33a538a9..6df46189b6 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -21,7 +21,42 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{ObjectType, StructType}
+import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+
+object CatalystSerde {
+ def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
+ val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
+ DeserializeToObject(Alias(deserializer, "obj")(), child)
+ }
+
+ def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
+ SerializeFromObject(encoderFor[T].namedExpressions, child)
+ }
+}
+
+/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+ deserializer: Alias,
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ def outputObjectType: DataType = deserializer.dataType
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression. The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ def inputObjectType: DataType = child.output.head.dataType
+}
/**
* A trait for logical operators that apply user defined functions to domain objects.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
new file mode 100644
index 0000000000..1fae64e3bc
--- /dev/null
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
@@ -0,0 +1,74 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.optimizer
+
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.dsl.plans._
+import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.plans.PlanTest
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.rules.RuleExecutor
+import org.apache.spark.sql.types.BooleanType
+
+class TypedFilterOptimizationSuite extends PlanTest {
+ object Optimize extends RuleExecutor[LogicalPlan] {
+ val batches =
+ Batch("EliminateSerialization", FixedPoint(50),
+ EliminateSerialization) ::
+ Batch("EmbedSerializerInFilter", FixedPoint(50),
+ EmbedSerializerInFilter) :: Nil
+ }
+
+ implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
+
+ test("back to back filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f1 = (i: (Int, Int)) => i._1 > 0
+ val f2 = (i: (Int, Int)) => i._2 > 0
+
+ val query = input.filter(f1).filter(f2).analyze
+
+ val optimized = Optimize.execute(query)
+
+ val expected = input.deserialize[(Int, Int)]
+ .where(callFunction(f1, BooleanType, 'obj))
+ .select('obj.as("obj"))
+ .where(callFunction(f2, BooleanType, 'obj))
+ .serialize[(Int, Int)].analyze
+
+ comparePlans(optimized, expected)
+ }
+
+ test("embed deserializer in filter condition if there is only one filter") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: (Int, Int)) => i._1 > 0
+
+ val query = input.filter(f).analyze
+
+ val optimized = Optimize.execute(query)
+
+ val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
+ val condition = callFunction(f, BooleanType, deserializer)
+ val expected = input.where(condition).analyze
+
+ comparePlans(optimized, expected)
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2854d5f9da..2f6d8d109f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1879,7 +1879,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+ def filter(func: T => Boolean): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
+ val condition = Invoke(function, "apply", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
@@ -1890,7 +1896,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+ def filter(func: FilterFunction[T]): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
+ val condition = Invoke(function, "call", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index eee2b946e3..c15aaed365 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -346,6 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
+ case logical.DeserializeToObject(deserializer, child) =>
+ execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
+ case logical.SerializeFromObject(serializer, child) =>
+ execution.SerializeFromObject(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
case logical.MapElements(f, in, out, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index f48f3f09c7..d2ab18ef0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,6 +28,73 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+ deserializer: Alias,
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
+ override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(deserializer, child.output))
+ ctx.currentVars = input
+ val resultVars = bound.gen(ctx) :: Nil
+ consume(ctx, resultVars)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+ iter.map(projection)
+ }
+ }
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression. The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val bound = serializer.map { expr =>
+ ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
+ }
+ ctx.currentVars = input
+ val resultVars = bound.map(_.gen(ctx))
+ consume(ctx, resultVars)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val projection = UnsafeProjection.create(serializer)
+ iter.map(projection)
+ }
+ }
+}
+
+/**
* Helper functions for physical operators that work with user defined objects.
*/
trait ObjectOperator extends SparkPlan {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 6eb952445f..5f3dd906fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -28,16 +28,10 @@ object DatasetBenchmark {
case class Data(l: Long, s: String)
- def main(args: Array[String]): Unit = {
- val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
- val sqlContext = new SQLContext(sparkContext)
-
+ def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
import sqlContext.implicits._
- val numRows = 10000000
val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
- val numChains = 10
-
val benchmark = new Benchmark("back-to-back map", numRows)
val func = (d: Data) => Data(d.l + 1, d.s)
@@ -61,7 +55,7 @@ object DatasetBenchmark {
res.queryExecution.toRdd.foreach(_ => Unit)
}
- val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
@@ -72,6 +66,63 @@ object DatasetBenchmark {
res.foreach(_ => Unit)
}
+ benchmark
+ }
+
+ def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
+ import sqlContext.implicits._
+
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val benchmark = new Benchmark("back-to-back filter", numRows)
+
+ val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
+ val funcs = 0.until(numChains).map { i =>
+ (d: Data) => func(d, i)
+ }
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
+ var i = 0
+ while (i < numChains) {
+ res = res.filter(funcs(i))
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.filter($"l" % (100L + i) === 0L)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = rdd.filter(funcs(i))
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
+ def main(args: Array[String]): Unit = {
+ val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+ val sqlContext = new SQLContext(sparkContext)
+
+ val numRows = 10000000
+ val numChains = 10
+
+ val benchmark = backToBackMap(sqlContext, numRows, numChains)
+ val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
+
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
@@ -82,5 +133,14 @@ object DatasetBenchmark {
RDD 216 / 237 46.3 21.6 4.2X
*/
benchmark.run()
+
+ /*
+ back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Dataset 585 / 628 17.1 58.5 1.0X
+ DataFrame 62 / 80 160.7 6.2 9.4X
+ RDD 205 / 220 48.7 20.5 2.8X
+ */
+ benchmark2.run()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 48a077d0e5..826862835a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.MemoryPlan
+import org.apache.spark.sql.types.ObjectType
abstract class QueryTest extends PlanTest {
@@ -204,6 +205,7 @@ abstract class QueryTest extends PlanTest {
case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
+ case Literal(_, _: ObjectType) => return
}
// bypass hive tests before we fix all corner cases in hive module.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f73ca887f1..4474cfcf6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.{Encoders, Row}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -82,4 +81,22 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
assert(ds.collect() === 0.until(10).map(_.toString).toArray)
}
+
+ test("typed filter should be included in WholeStageCodegen") {
+ val ds = sqlContext.range(10).filter(_ % 2 == 0)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined)
+ assert(ds.collect() === Array(0, 2, 4, 6, 8))
+ }
+
+ test("back-to-back typed filter should be included in WholeStageCodegen") {
+ val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
+ assert(ds.collect() === Array(0, 6))
+ }
}