diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-06-30 08:15:08 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-06-30 08:15:08 +0800 |
commit | d063898bebaaf4ec2aad24c3ac70aabdbf97a190 (patch) | |
tree | e10e1ed9765961338730237002e1b78c3f1f184b /sql | |
parent | 2eaabfa4142d4050be2b45fd277ff5c7fa430581 (diff) | |
download | spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.gz spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.bz2 spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.zip |
[SPARK-16134][SQL] optimizer rules for typed filter
## What changes were proposed in this pull request?
This PR adds 3 optimizer rules for typed filter:
1. push typed filter down through `SerializeFromObject` and eliminate the deserialization in filter condition.
2. pull typed filter up through `SerializeFromObject` and eliminate the deserialization in filter condition.
3. combine adjacent typed filters and share the deserialized object among all the condition expressions.
This PR also adds `TypedFilter` logical plan, to separate it from normal filter, so that the concept is more clear and it's easier to write optimizer rules.
## How was this patch tested?
`TypedFilterOptimizationSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #13846 from cloud-fan/filter.
Diffstat (limited to 'sql')
8 files changed, 162 insertions, 91 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2ca990d19a..84c9cc8c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -293,11 +293,7 @@ package object dsl { def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def filter[T : Encoder](func: T => Boolean): LogicalPlan = { - val deserialized = logicalPlan.deserialize[T] - val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) - Filter(condition, deserialized).serialize[T] - } + def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 502d791c6e..127797c097 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) var maxOrdinal = -1 result foreach { case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal + case _ => } if (maxOrdinal > children.length) { return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9bc8cea377..842d6bc26f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.immutable.HashSet import scala.collection.mutable.ArrayBuffer +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} @@ -110,8 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, - EmbedSerializerInFilter, - RemoveAliasOnlyProject) :: + CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("OptimizeCodegen", Once, @@ -206,15 +206,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) - if d.outputObjectType == s.inputObjectType => + if d.outputObjAttr.dataType == s.inputObjAttr.dataType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. // We will remove it later in RemoveAliasOnlyProject rule. - val objAttr = - Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) + val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId) Project(objAttr :: Nil, s.child) + case a @ AppendColumns(_, _, _, s: SerializeFromObject) - if a.deserializer.dataType == s.inputObjectType => + if a.deserializer.dataType == s.inputObjAttr.dataType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) + + // If there is a `SerializeFromObject` under typed filter and its input object type is same with + // the typed filter's deserializer, we can convert typed filter to normal filter without + // deserialization in condition, and push it down through `SerializeFromObject`. + // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization, + // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized. + case f @ TypedFilter(_, _, s: SerializeFromObject) + if f.deserializer.dataType == s.inputObjAttr.dataType => + s.copy(child = f.withObjectProducerChild(s.child)) + + // If there is a `DeserializeToObject` upon typed filter and its output object type is same with + // the typed filter's deserializer, we can convert typed filter to normal filter without + // deserialization in condition, and pull it up through `DeserializeToObject`. + // e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization, + // but `ds.filter(...).as[AnotherType].map(...)` can not be optimized. + case d @ DeserializeToObject(_, _, f: TypedFilter) + if d.outputObjAttr.dataType == f.deserializer.dataType => + f.withObjectProducerChild(d.copy(child = f.child)) } } @@ -1645,54 +1663,30 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic } /** - * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a - * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed - * the deserializer in filter condition to save the extra serialization at last. + * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one, + * mering the filter functions into one conjunctive function. */ -object EmbedSerializerInFilter extends Rule[LogicalPlan] { +object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) - // SPARK-15632: Conceptually, filter operator should never introduce schema change. This - // optimization rule also relies on this assumption. However, Dataset typed filter operator - // does introduce schema changes in some cases. Thus, we only enable this optimization when - // - // 1. either input and output schemata are exactly the same, or - // 2. both input and output schemata are single-field schema and share the same type. - // - // The 2nd case is included because encoders for primitive types always have only a single - // field with hard-coded field name "value". - // TODO Cleans this up after fixing SPARK-15632. - if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) => - - val numObjects = condition.collect { - case a: Attribute if a == d.output.head => a - }.length - - if (numObjects > 1) { - // If the filter condition references the object more than one times, we should not embed - // deserializer in it as the deserialization will happen many times and slow down the - // execution. - // TODO: we can still embed it if we can make sure subexpression elimination works here. - s - } else { - val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer - } - val filter = Filter(newCondition, d.child) - - // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. - // We will remove it later in RemoveAliasOnlyProject rule. - val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => - Alias(fout, fout.name)(exprId = sout.exprId) - } - Project(objAttrs, filter) - } - } - - def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { - (lhs, rhs) match { - case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType - case _ => false + case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child)) + if t1.deserializer.dataType == t2.deserializer.dataType => + TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child) + } + + private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = { + (func1, func2) match { + case (f1: FilterFunction[_], f2: FilterFunction[_]) => + input => f1.asInstanceOf[FilterFunction[Any]].call(input) && + f2.asInstanceOf[FilterFunction[Any]].call(input) + case (f1: FilterFunction[_], f2) => + input => f1.asInstanceOf[FilterFunction[Any]].call(input) && + f2.asInstanceOf[Any => Boolean](input) + case (f1, f2: FilterFunction[_]) => + input => f1.asInstanceOf[Any => Boolean].apply(input) && + f2.asInstanceOf[FilterFunction[Any]].call(input) + case (f1, f2) => + input => f1.asInstanceOf[Any => Boolean].apply(input) && + f2.asInstanceOf[Any => Boolean].apply(input) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 7beeeb4f04..e1890edcbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.language.existentials + +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.types._ object CatalystSerde { @@ -45,13 +49,11 @@ object CatalystSerde { */ trait ObjectProducer extends LogicalPlan { // The attribute that reference to the single object field this operator outputs. - protected def outputObjAttr: Attribute + def outputObjAttr: Attribute override def output: Seq[Attribute] = outputObjAttr :: Nil override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) - - def outputObjectType: DataType = outputObjAttr.dataType } /** @@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode { // This operator always need all columns of its child, even it doesn't reference to. override def references: AttributeSet = child.outputSet - def inputObjectType: DataType = child.output.head.dataType + def inputObjAttr: Attribute = child.output.head } /** @@ -167,6 +169,43 @@ case class MapElements( outputObjAttr: Attribute, child: LogicalPlan) extends ObjectConsumer with ObjectProducer +object TypedFilter { + def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { + TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child` and filter them by the + * resulting boolean value. + * + * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding + * the input row to object and apply the given function with decoded object. However we need the + * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write + * optimizer rules. + */ +case class TypedFilter( + func: AnyRef, + deserializer: Expression, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + def withObjectProducerChild(obj: LogicalPlan): Filter = { + assert(obj.output.length == 1) + Filter(typedCondition(obj.output.head), obj) + } + + def typedCondition(input: Expression): Expression = { + val (funcClass, methodName) = func match { + case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" + case _ => classOf[Any => Boolean] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + Invoke(funcObj, methodName, BooleanType, input :: Nil) + } +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala index 63d87bfb6d..56f096f3ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag -import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, TypedFilter} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.types.BooleanType @@ -33,44 +32,91 @@ class TypedFilterOptimizationSuite extends PlanTest { val batches = Batch("EliminateSerialization", FixedPoint(50), EliminateSerialization) :: - Batch("EmbedSerializerInFilter", FixedPoint(50), - EmbedSerializerInFilter) :: Nil + Batch("CombineTypedFilters", FixedPoint(50), + CombineTypedFilters) :: Nil } implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() - test("back to back filter") { + test("filter after serialize with the same object type") { val input = LocalRelation('_1.int, '_2.int) - val f1 = (i: (Int, Int)) => i._1 > 0 - val f2 = (i: (Int, Int)) => i._2 > 0 + val f = (i: (Int, Int)) => i._1 > 0 - val query = input.filter(f1).filter(f2).analyze + val query = input + .deserialize[(Int, Int)] + .serialize[(Int, Int)] + .filter(f).analyze val optimized = Optimize.execute(query) - val expected = input.deserialize[(Int, Int)] - .where(callFunction(f1, BooleanType, 'obj)) - .select('obj.as("obj")) - .where(callFunction(f2, BooleanType, 'obj)) + val expected = input + .deserialize[(Int, Int)] + .where(callFunction(f, BooleanType, 'obj)) .serialize[(Int, Int)].analyze comparePlans(optimized, expected) } - // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules - // for typed filters. - ignore("embed deserializer in typed filter condition if there is only one filter") { + test("filter after serialize with different object types") { + val input = LocalRelation('_1.int, '_2.int) + val f = (i: OtherTuple) => i._1 > 0 + + val query = input + .deserialize[(Int, Int)] + .serialize[(Int, Int)] + .filter(f).analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + + test("filter before deserialize with the same object type") { val input = LocalRelation('_1.int, '_2.int) val f = (i: (Int, Int)) => i._1 > 0 - val query = input.filter(f).analyze + val query = input + .filter(f) + .deserialize[(Int, Int)] + .serialize[(Int, Int)].analyze val optimized = Optimize.execute(query) - val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) - val condition = callFunction(f, BooleanType, deserializer) - val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze + val expected = input + .deserialize[(Int, Int)] + .where(callFunction(f, BooleanType, 'obj)) + .serialize[(Int, Int)].analyze comparePlans(optimized, expected) } + + test("filter before deserialize with different object types") { + val input = LocalRelation('_1.int, '_2.int) + val f = (i: OtherTuple) => i._1 > 0 + + val query = input + .filter(f) + .deserialize[(Int, Int)] + .serialize[(Int, Int)].analyze + val optimized = Optimize.execute(query) + comparePlans(optimized, query) + } + + test("back to back filter with the same object type") { + val input = LocalRelation('_1.int, '_2.int) + val f1 = (i: (Int, Int)) => i._1 > 0 + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = input.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 1) + } + + test("back to back filter with different object types") { + val input = LocalRelation('_1.int, '_2.int) + val f1 = (i: (Int, Int)) => i._1 > 0 + val f2 = (i: OtherTuple) => i._2 > 0 + + val query = input.filter(f1).filter(f2).analyze + val optimized = Optimize.execute(query) + assert(optimized.collect { case t: TypedFilter => t }.length == 2) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a6581eb563..e64669a19c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1997,11 +1997,7 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: T => Boolean): Dataset[T] = { - val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) - val function = Literal.create(func, ObjectType(classOf[T => Boolean])) - val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil) - val filter = Filter(condition, logicalPlan) - withTypedPlan(filter) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** @@ -2014,11 +2010,7 @@ class Dataset[T] private[sql]( */ @Experimental def filter(func: FilterFunction[T]): Dataset[T] = { - val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) - val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) - val condition = Invoke(function, "call", BooleanType, deserializer :: Nil) - val filter = Filter(condition, logicalPlan) - withTypedPlan(filter) + withTypedPlan(TypedFilter(func, logicalPlan)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b619d4edc3..5e643ea75a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -385,6 +385,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.ProjectExec(projectList, planLater(child)) :: Nil case logical.Filter(condition, child) => execution.FilterExec(condition, planLater(child)) :: Nil + case f: logical.TypedFilter => + execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil case e @ logical.Expand(_, _, child) => execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil case logical.Window(windowExprs, partitionSpec, orderSpec, child) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index b15f38c2a7..ab505139a8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -238,6 +238,7 @@ abstract class QueryTest extends PlanTest { case _: ObjectConsumer => return case _: ObjectProducer => return case _: AppendColumns => return + case _: TypedFilter => return case _: LogicalRelation => return case p if p.getClass.getSimpleName == "MetastoreRelation" => return case _: MemoryPlan => return |