aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-30 08:15:08 +0800
committerCheng Lian <lian@databricks.com>2016-06-30 08:15:08 +0800
commitd063898bebaaf4ec2aad24c3ac70aabdbf97a190 (patch)
treee10e1ed9765961338730237002e1b78c3f1f184b /sql/catalyst/src/main
parent2eaabfa4142d4050be2b45fd277ff5c7fa430581 (diff)
downloadspark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.gz
spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.bz2
spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.zip
[SPARK-16134][SQL] optimizer rules for typed filter
## What changes were proposed in this pull request? This PR adds 3 optimizer rules for typed filter: 1. push typed filter down through `SerializeFromObject` and eliminate the deserialization in filter condition. 2. pull typed filter up through `SerializeFromObject` and eliminate the deserialization in filter condition. 3. combine adjacent typed filters and share the deserialized object among all the condition expressions. This PR also adds `TypedFilter` logical plan, to separate it from normal filter, so that the concept is more clear and it's easier to write optimizer rules. ## How was this patch tested? `TypedFilterOptimizationSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13846 from cloud-fan/filter.
Diffstat (limited to 'sql/catalyst/src/main')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala98
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala47
4 files changed, 91 insertions, 61 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 2ca990d19a..84c9cc8c8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -293,11 +293,7 @@ package object dsl {
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
- def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
- val deserialized = logicalPlan.deserialize[T]
- val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
- Filter(condition, deserialized).serialize[T]
- }
+ def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan)
def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 502d791c6e..127797c097 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
var maxOrdinal = -1
result foreach {
case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
+ case _ =>
}
if (maxOrdinal > children.length) {
return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9bc8cea377..842d6bc26f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
@@ -110,8 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
- EmbedSerializerInFilter,
- RemoveAliasOnlyProject) ::
+ CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("OptimizeCodegen", Once,
@@ -206,15 +206,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
- if d.outputObjectType == s.inputObjectType =>
+ if d.outputObjAttr.dataType == s.inputObjAttr.dataType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
- val objAttr =
- Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId)
+ val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)
+
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
- if a.deserializer.dataType == s.inputObjectType =>
+ if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
+
+ // If there is a `SerializeFromObject` under typed filter and its input object type is same with
+ // the typed filter's deserializer, we can convert typed filter to normal filter without
+ // deserialization in condition, and push it down through `SerializeFromObject`.
+ // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
+ // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
+ case f @ TypedFilter(_, _, s: SerializeFromObject)
+ if f.deserializer.dataType == s.inputObjAttr.dataType =>
+ s.copy(child = f.withObjectProducerChild(s.child))
+
+ // If there is a `DeserializeToObject` upon typed filter and its output object type is same with
+ // the typed filter's deserializer, we can convert typed filter to normal filter without
+ // deserialization in condition, and pull it up through `DeserializeToObject`.
+ // e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization,
+ // but `ds.filter(...).as[AnotherType].map(...)` can not be optimized.
+ case d @ DeserializeToObject(_, _, f: TypedFilter)
+ if d.outputObjAttr.dataType == f.deserializer.dataType =>
+ f.withObjectProducerChild(d.copy(child = f.child))
}
}
@@ -1645,54 +1663,30 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
}
/**
- * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
- * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
- * the deserializer in filter condition to save the extra serialization at last.
+ * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one,
+ * mering the filter functions into one conjunctive function.
*/
-object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject))
- // SPARK-15632: Conceptually, filter operator should never introduce schema change. This
- // optimization rule also relies on this assumption. However, Dataset typed filter operator
- // does introduce schema changes in some cases. Thus, we only enable this optimization when
- //
- // 1. either input and output schemata are exactly the same, or
- // 2. both input and output schemata are single-field schema and share the same type.
- //
- // The 2nd case is included because encoders for primitive types always have only a single
- // field with hard-coded field name "value".
- // TODO Cleans this up after fixing SPARK-15632.
- if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) =>
-
- val numObjects = condition.collect {
- case a: Attribute if a == d.output.head => a
- }.length
-
- if (numObjects > 1) {
- // If the filter condition references the object more than one times, we should not embed
- // deserializer in it as the deserialization will happen many times and slow down the
- // execution.
- // TODO: we can still embed it if we can make sure subexpression elimination works here.
- s
- } else {
- val newCondition = condition transform {
- case a: Attribute if a == d.output.head => d.deserializer
- }
- val filter = Filter(newCondition, d.child)
-
- // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`.
- // We will remove it later in RemoveAliasOnlyProject rule.
- val objAttrs = filter.output.zip(s.output).map { case (fout, sout) =>
- Alias(fout, fout.name)(exprId = sout.exprId)
- }
- Project(objAttrs, filter)
- }
- }
-
- def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = {
- (lhs, rhs) match {
- case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType
- case _ => false
+ case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
+ if t1.deserializer.dataType == t2.deserializer.dataType =>
+ TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
+ }
+
+ private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
+ (func1, func2) match {
+ case (f1: FilterFunction[_], f2: FilterFunction[_]) =>
+ input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
+ f2.asInstanceOf[FilterFunction[Any]].call(input)
+ case (f1: FilterFunction[_], f2) =>
+ input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
+ f2.asInstanceOf[Any => Boolean](input)
+ case (f1, f2: FilterFunction[_]) =>
+ input => f1.asInstanceOf[Any => Boolean].apply(input) &&
+ f2.asInstanceOf[FilterFunction[Any]].call(input)
+ case (f1, f2) =>
+ input => f1.asInstanceOf[Any => Boolean].apply(input) &&
+ f2.asInstanceOf[Any => Boolean].apply(input)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 7beeeb4f04..e1890edcbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -17,11 +17,15 @@
package org.apache.spark.sql.catalyst.plans.logical
+import scala.language.existentials
+
+import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types._
object CatalystSerde {
@@ -45,13 +49,11 @@ object CatalystSerde {
*/
trait ObjectProducer extends LogicalPlan {
// The attribute that reference to the single object field this operator outputs.
- protected def outputObjAttr: Attribute
+ def outputObjAttr: Attribute
override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
-
- def outputObjectType: DataType = outputObjAttr.dataType
}
/**
@@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet
- def inputObjectType: DataType = child.output.head.dataType
+ def inputObjAttr: Attribute = child.output.head
}
/**
@@ -167,6 +169,43 @@ case class MapElements(
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
+object TypedFilter {
+ def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
+ TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each element of the `child` and filter them by the
+ * resulting boolean value.
+ *
+ * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding
+ * the input row to object and apply the given function with decoded object. However we need the
+ * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write
+ * optimizer rules.
+ */
+case class TypedFilter(
+ func: AnyRef,
+ deserializer: Expression,
+ child: LogicalPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+
+ def withObjectProducerChild(obj: LogicalPlan): Filter = {
+ assert(obj.output.length == 1)
+ Filter(typedCondition(obj.output.head), obj)
+ }
+
+ def typedCondition(input: Expression): Expression = {
+ val (funcClass, methodName) = func match {
+ case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
+ case _ => classOf[Any => Boolean] -> "apply"
+ }
+ val funcObj = Literal.create(func, ObjectType(funcClass))
+ Invoke(funcObj, methodName, BooleanType, input :: Nil)
+ }
+}
+
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](