diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-06-30 08:15:08 +0800 |
---|---|---|
committer | Cheng Lian <lian@databricks.com> | 2016-06-30 08:15:08 +0800 |
commit | d063898bebaaf4ec2aad24c3ac70aabdbf97a190 (patch) | |
tree | e10e1ed9765961338730237002e1b78c3f1f184b /sql/catalyst/src/main | |
parent | 2eaabfa4142d4050be2b45fd277ff5c7fa430581 (diff) | |
download | spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.gz spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.bz2 spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.zip |
[SPARK-16134][SQL] optimizer rules for typed filter
## What changes were proposed in this pull request?
This PR adds 3 optimizer rules for typed filter:
1. push typed filter down through `SerializeFromObject` and eliminate the deserialization in filter condition.
2. pull typed filter up through `SerializeFromObject` and eliminate the deserialization in filter condition.
3. combine adjacent typed filters and share the deserialized object among all the condition expressions.
This PR also adds `TypedFilter` logical plan, to separate it from normal filter, so that the concept is more clear and it's easier to write optimizer rules.
## How was this patch tested?
`TypedFilterOptimizationSuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #13846 from cloud-fan/filter.
Diffstat (limited to 'sql/catalyst/src/main')
4 files changed, 91 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 2ca990d19a..84c9cc8c8e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -293,11 +293,7 @@ package object dsl { def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) - def filter[T : Encoder](func: T => Boolean): LogicalPlan = { - val deserialized = logicalPlan.deserialize[T] - val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) - Filter(condition, deserialized).serialize[T] - } + def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan) def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 502d791c6e..127797c097 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) var maxOrdinal = -1 result foreach { case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal + case _ => } if (maxOrdinal > children.length) { return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 9bc8cea377..842d6bc26f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -21,6 +21,7 @@ import scala.annotation.tailrec import scala.collection.immutable.HashSet import scala.collection.mutable.ArrayBuffer +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog} @@ -110,8 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) Batch("Decimal Optimizations", fixedPoint, DecimalAggregates) :: Batch("Typed Filter Optimization", fixedPoint, - EmbedSerializerInFilter, - RemoveAliasOnlyProject) :: + CombineTypedFilters) :: Batch("LocalRelation", fixedPoint, ConvertToLocalRelation) :: Batch("OptimizeCodegen", Once, @@ -206,15 +206,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] { object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) - if d.outputObjectType == s.inputObjectType => + if d.outputObjAttr.dataType == s.inputObjAttr.dataType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. // We will remove it later in RemoveAliasOnlyProject rule. - val objAttr = - Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId) + val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId) Project(objAttr :: Nil, s.child) + case a @ AppendColumns(_, _, _, s: SerializeFromObject) - if a.deserializer.dataType == s.inputObjectType => + if a.deserializer.dataType == s.inputObjAttr.dataType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) + + // If there is a `SerializeFromObject` under typed filter and its input object type is same with + // the typed filter's deserializer, we can convert typed filter to normal filter without + // deserialization in condition, and push it down through `SerializeFromObject`. + // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization, + // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized. + case f @ TypedFilter(_, _, s: SerializeFromObject) + if f.deserializer.dataType == s.inputObjAttr.dataType => + s.copy(child = f.withObjectProducerChild(s.child)) + + // If there is a `DeserializeToObject` upon typed filter and its output object type is same with + // the typed filter's deserializer, we can convert typed filter to normal filter without + // deserialization in condition, and pull it up through `DeserializeToObject`. + // e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization, + // but `ds.filter(...).as[AnotherType].map(...)` can not be optimized. + case d @ DeserializeToObject(_, _, f: TypedFilter) + if d.outputObjAttr.dataType == f.deserializer.dataType => + f.withObjectProducerChild(d.copy(child = f.child)) } } @@ -1645,54 +1663,30 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic } /** - * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a - * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed - * the deserializer in filter condition to save the extra serialization at last. + * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one, + * mering the filter functions into one conjunctive function. */ -object EmbedSerializerInFilter extends Rule[LogicalPlan] { +object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) - // SPARK-15632: Conceptually, filter operator should never introduce schema change. This - // optimization rule also relies on this assumption. However, Dataset typed filter operator - // does introduce schema changes in some cases. Thus, we only enable this optimization when - // - // 1. either input and output schemata are exactly the same, or - // 2. both input and output schemata are single-field schema and share the same type. - // - // The 2nd case is included because encoders for primitive types always have only a single - // field with hard-coded field name "value". - // TODO Cleans this up after fixing SPARK-15632. - if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) => - - val numObjects = condition.collect { - case a: Attribute if a == d.output.head => a - }.length - - if (numObjects > 1) { - // If the filter condition references the object more than one times, we should not embed - // deserializer in it as the deserialization will happen many times and slow down the - // execution. - // TODO: we can still embed it if we can make sure subexpression elimination works here. - s - } else { - val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer - } - val filter = Filter(newCondition, d.child) - - // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`. - // We will remove it later in RemoveAliasOnlyProject rule. - val objAttrs = filter.output.zip(s.output).map { case (fout, sout) => - Alias(fout, fout.name)(exprId = sout.exprId) - } - Project(objAttrs, filter) - } - } - - def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = { - (lhs, rhs) match { - case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType - case _ => false + case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child)) + if t1.deserializer.dataType == t2.deserializer.dataType => + TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child) + } + + private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = { + (func1, func2) match { + case (f1: FilterFunction[_], f2: FilterFunction[_]) => + input => f1.asInstanceOf[FilterFunction[Any]].call(input) && + f2.asInstanceOf[FilterFunction[Any]].call(input) + case (f1: FilterFunction[_], f2) => + input => f1.asInstanceOf[FilterFunction[Any]].call(input) && + f2.asInstanceOf[Any => Boolean](input) + case (f1, f2: FilterFunction[_]) => + input => f1.asInstanceOf[Any => Boolean].apply(input) && + f2.asInstanceOf[FilterFunction[Any]].call(input) + case (f1, f2) => + input => f1.asInstanceOf[Any => Boolean].apply(input) && + f2.asInstanceOf[Any => Boolean].apply(input) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 7beeeb4f04..e1890edcbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,11 +17,15 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.language.existentials + +import org.apache.spark.api.java.function.FilterFunction import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.objects.Invoke import org.apache.spark.sql.types._ object CatalystSerde { @@ -45,13 +49,11 @@ object CatalystSerde { */ trait ObjectProducer extends LogicalPlan { // The attribute that reference to the single object field this operator outputs. - protected def outputObjAttr: Attribute + def outputObjAttr: Attribute override def output: Seq[Attribute] = outputObjAttr :: Nil override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) - - def outputObjectType: DataType = outputObjAttr.dataType } /** @@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode { // This operator always need all columns of its child, even it doesn't reference to. override def references: AttributeSet = child.outputSet - def inputObjectType: DataType = child.output.head.dataType + def inputObjAttr: Attribute = child.output.head } /** @@ -167,6 +169,43 @@ case class MapElements( outputObjAttr: Attribute, child: LogicalPlan) extends ObjectConsumer with ObjectProducer +object TypedFilter { + def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { + TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child` and filter them by the + * resulting boolean value. + * + * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding + * the input row to object and apply the given function with decoded object. However we need the + * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write + * optimizer rules. + */ +case class TypedFilter( + func: AnyRef, + deserializer: Expression, + child: LogicalPlan) extends UnaryNode { + + override def output: Seq[Attribute] = child.output + + def withObjectProducerChild(obj: LogicalPlan): Filter = { + assert(obj.output.length == 1) + Filter(typedCondition(obj.output.head), obj) + } + + def typedCondition(input: Expression): Expression = { + val (funcClass, methodName) = func match { + case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call" + case _ => classOf[Any => Boolean] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + Invoke(funcObj, methodName, BooleanType, input :: Nil) + } +} + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( |