aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-07 17:23:34 -0700
committerYin Huai <yhuai@databricks.com>2016-04-07 17:23:34 -0700
commit49fb237081bbca0d811aa48aa06f4728fea62781 (patch)
tree674f92f37feb32ef62c2905169d9c435ad337500 /sql/core
parentae1db91d158d1ae62a0ab7ea74467679ca050101 (diff)
downloadspark-49fb237081bbca0d811aa48aa06f4728fea62781.tar.gz
spark-49fb237081bbca0d811aa48aa06f4728fea62781.tar.bz2
spark-49fb237081bbca0d811aa48aa06f4728fea62781.zip
[SPARK-14270][SQL] whole stage codegen support for typed filter
## What changes were proposed in this pull request? We implement typed filter by `MapPartitions`, which doesn't work well with whole stage codegen. This PR use `Filter` to implement typed filter and we can get the whole stage codegen support for free. This PR also introduced `DeserializeToObject` and `SerializeFromObject`, to seperate serialization logic from object operator, so that it's eaiser to write optimization rules for adjacent object operators. ## How was this patch tested? existing tests. Author: Wenchen Fan <wenchen@databricks.com> Closes #12061 from cloud-fan/whole-stage-codegen.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala76
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala21
6 files changed, 174 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 2854d5f9da..2f6d8d109f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1879,7 +1879,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func))
+ def filter(func: T => Boolean): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
+ val condition = Invoke(function, "apply", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
@@ -1890,7 +1896,13 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
@Experimental
- def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t))
+ def filter(func: FilterFunction[T]): Dataset[T] = {
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
+ val condition = Invoke(function, "call", BooleanType, deserialized.output)
+ val filter = Filter(condition, deserialized)
+ withTypedPlan(CatalystSerde.serialize[T](filter))
+ }
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index eee2b946e3..c15aaed365 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -346,6 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
+ case logical.DeserializeToObject(deserializer, child) =>
+ execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
+ case logical.SerializeFromObject(serializer, child) =>
+ execution.SerializeFromObject(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, in, out, child) =>
execution.MapPartitions(f, in, out, planLater(child)) :: Nil
case logical.MapElements(f, in, out, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index f48f3f09c7..d2ab18ef0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,6 +28,73 @@ import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.ObjectType
/**
+ * Takes the input row from child and turns it into object using the given deserializer expression.
+ * The output of this operator is a single-field safe row containing the deserialized object.
+ */
+case class DeserializeToObject(
+ deserializer: Alias,
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
+ override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val bound = ExpressionCanonicalizer.execute(
+ BindReferences.bindReference(deserializer, child.output))
+ ctx.currentVars = input
+ val resultVars = bound.gen(ctx) :: Nil
+ consume(ctx, resultVars)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output)
+ iter.map(projection)
+ }
+ }
+}
+
+/**
+ * Takes the input object from child and turns in into unsafe row using the given serializer
+ * expression. The output of its child must be a single-field row containing the input object.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
+ override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val bound = serializer.map { expr =>
+ ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output))
+ }
+ ctx.currentVars = input
+ val resultVars = bound.map(_.gen(ctx))
+ consume(ctx, resultVars)
+ }
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val projection = UnsafeProjection.create(serializer)
+ iter.map(projection)
+ }
+ }
+}
+
+/**
* Helper functions for physical operators that work with user defined objects.
*/
trait ObjectOperator extends SparkPlan {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
index 6eb952445f..5f3dd906fe 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala
@@ -28,16 +28,10 @@ object DatasetBenchmark {
case class Data(l: Long, s: String)
- def main(args: Array[String]): Unit = {
- val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
- val sqlContext = new SQLContext(sparkContext)
-
+ def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
import sqlContext.implicits._
- val numRows = 10000000
val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
- val numChains = 10
-
val benchmark = new Benchmark("back-to-back map", numRows)
val func = (d: Data) => Data(d.l + 1, d.s)
@@ -61,7 +55,7 @@ object DatasetBenchmark {
res.queryExecution.toRdd.foreach(_ => Unit)
}
- val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
benchmark.addCase("RDD") { iter =>
var res = rdd
var i = 0
@@ -72,6 +66,63 @@ object DatasetBenchmark {
res.foreach(_ => Unit)
}
+ benchmark
+ }
+
+ def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = {
+ import sqlContext.implicits._
+
+ val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s"))
+ val benchmark = new Benchmark("back-to-back filter", numRows)
+
+ val func = (d: Data, i: Int) => d.l % (100L + i) == 0L
+ val funcs = 0.until(numChains).map { i =>
+ (d: Data) => func(d, i)
+ }
+ benchmark.addCase("Dataset") { iter =>
+ var res = df.as[Data]
+ var i = 0
+ while (i < numChains) {
+ res = res.filter(funcs(i))
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ benchmark.addCase("DataFrame") { iter =>
+ var res = df
+ var i = 0
+ while (i < numChains) {
+ res = res.filter($"l" % (100L + i) === 0L)
+ i += 1
+ }
+ res.queryExecution.toRdd.foreach(_ => Unit)
+ }
+
+ val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString))
+ benchmark.addCase("RDD") { iter =>
+ var res = rdd
+ var i = 0
+ while (i < numChains) {
+ res = rdd.filter(funcs(i))
+ i += 1
+ }
+ res.foreach(_ => Unit)
+ }
+
+ benchmark
+ }
+
+ def main(args: Array[String]): Unit = {
+ val sparkContext = new SparkContext("local[*]", "Dataset benchmark")
+ val sqlContext = new SQLContext(sparkContext)
+
+ val numRows = 10000000
+ val numChains = 10
+
+ val benchmark = backToBackMap(sqlContext, numRows, numChains)
+ val benchmark2 = backToBackFilter(sqlContext, numRows, numChains)
+
/*
Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4
Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz
@@ -82,5 +133,14 @@ object DatasetBenchmark {
RDD 216 / 237 46.3 21.6 4.2X
*/
benchmark.run()
+
+ /*
+ back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ Dataset 585 / 628 17.1 58.5 1.0X
+ DataFrame 62 / 80 160.7 6.2 9.4X
+ RDD 205 / 220 48.7 20.5 2.8X
+ */
+ benchmark2.run()
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 48a077d0e5..826862835a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.LogicalRDD
import org.apache.spark.sql.execution.columnar.InMemoryRelation
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.execution.streaming.MemoryPlan
+import org.apache.spark.sql.types.ObjectType
abstract class QueryTest extends PlanTest {
@@ -204,6 +205,7 @@ abstract class QueryTest extends PlanTest {
case _: MemoryPlan => return
}.transformAllExpressions {
case a: ImperativeAggregate => return
+ case Literal(_, _: ObjectType) => return
}
// bypass hive tests before we fix all corner cases in hive module.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index f73ca887f1..4474cfcf6e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -17,8 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.api.java.function.MapFunction
-import org.apache.spark.sql.{Encoders, Row}
+import org.apache.spark.sql.Row
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions.{avg, broadcast, col, max}
@@ -82,4 +81,22 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
assert(ds.collect() === 0.until(10).map(_.toString).toArray)
}
+
+ test("typed filter should be included in WholeStageCodegen") {
+ val ds = sqlContext.range(10).filter(_ % 2 == 0)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined)
+ assert(ds.collect() === Array(0, 2, 4, 6, 8))
+ }
+
+ test("back-to-back typed filter should be included in WholeStageCodegen") {
+ val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0)
+ val plan = ds.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
+ assert(ds.collect() === Array(0, 6))
+ }
}