aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-06-30 08:15:08 +0800
committerCheng Lian <lian@databricks.com>2016-06-30 08:15:08 +0800
commitd063898bebaaf4ec2aad24c3ac70aabdbf97a190 (patch)
treee10e1ed9765961338730237002e1b78c3f1f184b
parent2eaabfa4142d4050be2b45fd277ff5c7fa430581 (diff)
downloadspark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.gz
spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.tar.bz2
spark-d063898bebaaf4ec2aad24c3ac70aabdbf97a190.zip
[SPARK-16134][SQL] optimizer rules for typed filter
## What changes were proposed in this pull request? This PR adds 3 optimizer rules for typed filter: 1. push typed filter down through `SerializeFromObject` and eliminate the deserialization in filter condition. 2. pull typed filter up through `SerializeFromObject` and eliminate the deserialization in filter condition. 3. combine adjacent typed filters and share the deserialized object among all the condition expressions. This PR also adds `TypedFilter` logical plan, to separate it from normal filter, so that the concept is more clear and it's easier to write optimizer rules. ## How was this patch tested? `TypedFilterOptimizationSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #13846 from cloud-fan/filter.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala98
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala47
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala86
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala1
8 files changed, 162 insertions, 91 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 2ca990d19a..84c9cc8c8e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -293,11 +293,7 @@ package object dsl {
def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan)
- def filter[T : Encoder](func: T => Boolean): LogicalPlan = {
- val deserialized = logicalPlan.deserialize[T]
- val condition = expressions.callFunction(func, BooleanType, deserialized.output.head)
- Filter(condition, deserialized).serialize[T]
- }
+ def filter[T : Encoder](func: T => Boolean): LogicalPlan = TypedFilter(func, logicalPlan)
def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
index 502d791c6e..127797c097 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala
@@ -45,6 +45,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression])
var maxOrdinal = -1
result foreach {
case b: BoundReference if b.ordinal > maxOrdinal => maxOrdinal = b.ordinal
+ case _ =>
}
if (maxOrdinal > children.length) {
return TypeCheckFailure(s"The result expression need $maxOrdinal input expressions, but " +
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 9bc8cea377..842d6bc26f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -21,6 +21,7 @@ import scala.annotation.tailrec
import scala.collection.immutable.HashSet
import scala.collection.mutable.ArrayBuffer
+import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf}
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.catalog.{InMemoryCatalog, SessionCatalog}
@@ -110,8 +111,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
Batch("Decimal Optimizations", fixedPoint,
DecimalAggregates) ::
Batch("Typed Filter Optimization", fixedPoint,
- EmbedSerializerInFilter,
- RemoveAliasOnlyProject) ::
+ CombineTypedFilters) ::
Batch("LocalRelation", fixedPoint,
ConvertToLocalRelation) ::
Batch("OptimizeCodegen", Once,
@@ -206,15 +206,33 @@ object RemoveAliasOnlyProject extends Rule[LogicalPlan] {
object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
- if d.outputObjectType == s.inputObjectType =>
+ if d.outputObjAttr.dataType == s.inputObjAttr.dataType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
// We will remove it later in RemoveAliasOnlyProject rule.
- val objAttr =
- Alias(s.child.output.head, s.child.output.head.name)(exprId = d.output.head.exprId)
+ val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)
+
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
- if a.deserializer.dataType == s.inputObjectType =>
+ if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
+
+ // If there is a `SerializeFromObject` under typed filter and its input object type is same with
+ // the typed filter's deserializer, we can convert typed filter to normal filter without
+ // deserialization in condition, and push it down through `SerializeFromObject`.
+ // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
+ // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
+ case f @ TypedFilter(_, _, s: SerializeFromObject)
+ if f.deserializer.dataType == s.inputObjAttr.dataType =>
+ s.copy(child = f.withObjectProducerChild(s.child))
+
+ // If there is a `DeserializeToObject` upon typed filter and its output object type is same with
+ // the typed filter's deserializer, we can convert typed filter to normal filter without
+ // deserialization in condition, and pull it up through `DeserializeToObject`.
+ // e.g. `ds.filter(...).map(...)` can be optimized by this rule to save extra deserialization,
+ // but `ds.filter(...).as[AnotherType].map(...)` can not be optimized.
+ case d @ DeserializeToObject(_, _, f: TypedFilter)
+ if d.outputObjAttr.dataType == f.deserializer.dataType =>
+ f.withObjectProducerChild(d.copy(child = f.child))
}
}
@@ -1645,54 +1663,30 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
}
/**
- * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a
- * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed
- * the deserializer in filter condition to save the extra serialization at last.
+ * Combines two adjacent [[TypedFilter]]s, which operate on same type object in condition, into one,
+ * mering the filter functions into one conjunctive function.
*/
-object EmbedSerializerInFilter extends Rule[LogicalPlan] {
+object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject))
- // SPARK-15632: Conceptually, filter operator should never introduce schema change. This
- // optimization rule also relies on this assumption. However, Dataset typed filter operator
- // does introduce schema changes in some cases. Thus, we only enable this optimization when
- //
- // 1. either input and output schemata are exactly the same, or
- // 2. both input and output schemata are single-field schema and share the same type.
- //
- // The 2nd case is included because encoders for primitive types always have only a single
- // field with hard-coded field name "value".
- // TODO Cleans this up after fixing SPARK-15632.
- if s.schema == d.child.schema || samePrimitiveType(s.schema, d.child.schema) =>
-
- val numObjects = condition.collect {
- case a: Attribute if a == d.output.head => a
- }.length
-
- if (numObjects > 1) {
- // If the filter condition references the object more than one times, we should not embed
- // deserializer in it as the deserialization will happen many times and slow down the
- // execution.
- // TODO: we can still embed it if we can make sure subexpression elimination works here.
- s
- } else {
- val newCondition = condition transform {
- case a: Attribute if a == d.output.head => d.deserializer
- }
- val filter = Filter(newCondition, d.child)
-
- // Adds an extra Project here, to preserve the output expr id of `SerializeFromObject`.
- // We will remove it later in RemoveAliasOnlyProject rule.
- val objAttrs = filter.output.zip(s.output).map { case (fout, sout) =>
- Alias(fout, fout.name)(exprId = sout.exprId)
- }
- Project(objAttrs, filter)
- }
- }
-
- def samePrimitiveType(lhs: StructType, rhs: StructType): Boolean = {
- (lhs, rhs) match {
- case (StructType(Array(f1)), StructType(Array(f2))) => f1.dataType == f2.dataType
- case _ => false
+ case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
+ if t1.deserializer.dataType == t2.deserializer.dataType =>
+ TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
+ }
+
+ private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
+ (func1, func2) match {
+ case (f1: FilterFunction[_], f2: FilterFunction[_]) =>
+ input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
+ f2.asInstanceOf[FilterFunction[Any]].call(input)
+ case (f1: FilterFunction[_], f2) =>
+ input => f1.asInstanceOf[FilterFunction[Any]].call(input) &&
+ f2.asInstanceOf[Any => Boolean](input)
+ case (f1, f2: FilterFunction[_]) =>
+ input => f1.asInstanceOf[Any => Boolean].apply(input) &&
+ f2.asInstanceOf[FilterFunction[Any]].call(input)
+ case (f1, f2) =>
+ input => f1.asInstanceOf[Any => Boolean].apply(input) &&
+ f2.asInstanceOf[Any => Boolean].apply(input)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 7beeeb4f04..e1890edcbb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -17,11 +17,15 @@
package org.apache.spark.sql.catalyst.plans.logical
+import scala.language.existentials
+
+import org.apache.spark.api.java.function.FilterFunction
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.types._
object CatalystSerde {
@@ -45,13 +49,11 @@ object CatalystSerde {
*/
trait ObjectProducer extends LogicalPlan {
// The attribute that reference to the single object field this operator outputs.
- protected def outputObjAttr: Attribute
+ def outputObjAttr: Attribute
override def output: Seq[Attribute] = outputObjAttr :: Nil
override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
-
- def outputObjectType: DataType = outputObjAttr.dataType
}
/**
@@ -64,7 +66,7 @@ trait ObjectConsumer extends UnaryNode {
// This operator always need all columns of its child, even it doesn't reference to.
override def references: AttributeSet = child.outputSet
- def inputObjectType: DataType = child.output.head.dataType
+ def inputObjAttr: Attribute = child.output.head
}
/**
@@ -167,6 +169,43 @@ case class MapElements(
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
+object TypedFilter {
+ def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
+ TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
+ }
+}
+
+/**
+ * A relation produced by applying `func` to each element of the `child` and filter them by the
+ * resulting boolean value.
+ *
+ * This is logically equal to a normal [[Filter]] operator whose condition expression is decoding
+ * the input row to object and apply the given function with decoded object. However we need the
+ * encapsulation of [[TypedFilter]] to make the concept more clear and make it easier to write
+ * optimizer rules.
+ */
+case class TypedFilter(
+ func: AnyRef,
+ deserializer: Expression,
+ child: LogicalPlan) extends UnaryNode {
+
+ override def output: Seq[Attribute] = child.output
+
+ def withObjectProducerChild(obj: LogicalPlan): Filter = {
+ assert(obj.output.length == 1)
+ Filter(typedCondition(obj.output.head), obj)
+ }
+
+ def typedCondition(input: Expression): Expression = {
+ val (funcClass, methodName) = func match {
+ case m: FilterFunction[_] => classOf[FilterFunction[_]] -> "call"
+ case _ => classOf[Any => Boolean] -> "apply"
+ }
+ val funcObj = Literal.create(func, ObjectType(funcClass))
+ Invoke(funcObj, methodName, BooleanType, input :: Nil)
+ }
+}
+
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
def apply[T : Encoder, U : Encoder](
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
index 63d87bfb6d..56f096f3ec 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.sql.catalyst.optimizer
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
+import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, TypedFilter}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.BooleanType
@@ -33,44 +32,91 @@ class TypedFilterOptimizationSuite extends PlanTest {
val batches =
Batch("EliminateSerialization", FixedPoint(50),
EliminateSerialization) ::
- Batch("EmbedSerializerInFilter", FixedPoint(50),
- EmbedSerializerInFilter) :: Nil
+ Batch("CombineTypedFilters", FixedPoint(50),
+ CombineTypedFilters) :: Nil
}
implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
- test("back to back filter") {
+ test("filter after serialize with the same object type") {
val input = LocalRelation('_1.int, '_2.int)
- val f1 = (i: (Int, Int)) => i._1 > 0
- val f2 = (i: (Int, Int)) => i._2 > 0
+ val f = (i: (Int, Int)) => i._1 > 0
- val query = input.filter(f1).filter(f2).analyze
+ val query = input
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)]
+ .filter(f).analyze
val optimized = Optimize.execute(query)
- val expected = input.deserialize[(Int, Int)]
- .where(callFunction(f1, BooleanType, 'obj))
- .select('obj.as("obj"))
- .where(callFunction(f2, BooleanType, 'obj))
+ val expected = input
+ .deserialize[(Int, Int)]
+ .where(callFunction(f, BooleanType, 'obj))
.serialize[(Int, Int)].analyze
comparePlans(optimized, expected)
}
- // TODO: Remove this after we completely fix SPARK-15632 by adding optimization rules
- // for typed filters.
- ignore("embed deserializer in typed filter condition if there is only one filter") {
+ test("filter after serialize with different object types") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: OtherTuple) => i._1 > 0
+
+ val query = input
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)]
+ .filter(f).analyze
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
+
+ test("filter before deserialize with the same object type") {
val input = LocalRelation('_1.int, '_2.int)
val f = (i: (Int, Int)) => i._1 > 0
- val query = input.filter(f).analyze
+ val query = input
+ .filter(f)
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)].analyze
val optimized = Optimize.execute(query)
- val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer)
- val condition = callFunction(f, BooleanType, deserializer)
- val expected = input.where(condition).select('_1.as("_1"), '_2.as("_2")).analyze
+ val expected = input
+ .deserialize[(Int, Int)]
+ .where(callFunction(f, BooleanType, 'obj))
+ .serialize[(Int, Int)].analyze
comparePlans(optimized, expected)
}
+
+ test("filter before deserialize with different object types") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f = (i: OtherTuple) => i._1 > 0
+
+ val query = input
+ .filter(f)
+ .deserialize[(Int, Int)]
+ .serialize[(Int, Int)].analyze
+ val optimized = Optimize.execute(query)
+ comparePlans(optimized, query)
+ }
+
+ test("back to back filter with the same object type") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f1 = (i: (Int, Int)) => i._1 > 0
+ val f2 = (i: (Int, Int)) => i._2 > 0
+
+ val query = input.filter(f1).filter(f2).analyze
+ val optimized = Optimize.execute(query)
+ assert(optimized.collect { case t: TypedFilter => t }.length == 1)
+ }
+
+ test("back to back filter with different object types") {
+ val input = LocalRelation('_1.int, '_2.int)
+ val f1 = (i: (Int, Int)) => i._1 > 0
+ val f2 = (i: OtherTuple) => i._2 > 0
+
+ val query = input.filter(f1).filter(f2).analyze
+ val optimized = Optimize.execute(query)
+ assert(optimized.collect { case t: TypedFilter => t }.length == 2)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a6581eb563..e64669a19c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1997,11 +1997,7 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: T => Boolean): Dataset[T] = {
- val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
- val function = Literal.create(func, ObjectType(classOf[T => Boolean]))
- val condition = Invoke(function, "apply", BooleanType, deserializer :: Nil)
- val filter = Filter(condition, logicalPlan)
- withTypedPlan(filter)
+ withTypedPlan(TypedFilter(func, logicalPlan))
}
/**
@@ -2014,11 +2010,7 @@ class Dataset[T] private[sql](
*/
@Experimental
def filter(func: FilterFunction[T]): Dataset[T] = {
- val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
- val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]]))
- val condition = Invoke(function, "call", BooleanType, deserializer :: Nil)
- val filter = Filter(condition, logicalPlan)
- withTypedPlan(filter)
+ withTypedPlan(TypedFilter(func, logicalPlan))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index b619d4edc3..5e643ea75a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -385,6 +385,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.ProjectExec(projectList, planLater(child)) :: Nil
case logical.Filter(condition, child) =>
execution.FilterExec(condition, planLater(child)) :: Nil
+ case f: logical.TypedFilter =>
+ execution.FilterExec(f.typedCondition(f.deserializer), planLater(f.child)) :: Nil
case e @ logical.Expand(_, _, child) =>
execution.ExpandExec(e.projections, e.output, planLater(child)) :: Nil
case logical.Window(windowExprs, partitionSpec, orderSpec, child) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index b15f38c2a7..ab505139a8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -238,6 +238,7 @@ abstract class QueryTest extends PlanTest {
case _: ObjectConsumer => return
case _: ObjectProducer => return
case _: AppendColumns => return
+ case _: TypedFilter => return
case _: LogicalRelation => return
case p if p.getClass.getSimpleName == "MetastoreRelation" => return
case _: MemoryPlan => return