diff options
Diffstat (limited to 'sql/catalyst')
5 files changed, 131 insertions, 138 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6591559426..0e2fd43983 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1672,9 +1672,9 @@ object CleanupAliases extends Rule[LogicalPlan] { // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. - case o: ObjectOperator => o - case d: DeserializeToObject => d - case s: SerializeFromObject => s + case o: ObjectConsumer => o + case o: ObjectProducer => o + case a: AppendColumns => a case other => var stop = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 958966328b..085e95f542 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -245,6 +245,10 @@ package object dsl { def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + /** Creates a new AttributeReference of object type */ + def obj(cls: Class[_]): AttributeReference = + AttributeReference(s, ObjectType(cls), nullable = true)() + /** Create a function. */ def function(exprs: Expression*): UnresolvedFunction = UnresolvedFunction(s, exprs, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b806b725a8..0a5232b2d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -153,29 +153,16 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { - // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => - val childWithoutSerialization = child.withObjectOutput - m.copy( - deserializer = childWithoutSerialization.output.head, - child = childWithoutSerialization) - - case m @ MapElements(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => - val childWithoutSerialization = child.withObjectOutput - m.copy( - deserializer = childWithoutSerialization.output.head, - child = childWithoutSerialization) - - case d @ DeserializeToObject(_, s: SerializeFromObject) + case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) Project(objAttr :: Nil, s.child) + + case a @ AppendColumns(_, _, _, s: SerializeFromObject) + if a.deserializer.dataType == s.inputObjectType => + AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) } } @@ -366,9 +353,9 @@ object ColumnPruning extends Rule[LogicalPlan] { } a.copy(child = Expand(newProjects, newOutput, grandChild)) - // Prunes the unused columns from child of MapPartitions - case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty => - mp.copy(child = prunedChild(child, mp.references)) + // Prunes the unused columns from child of `DeserializeToObject` + case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty => + d.copy(child = prunedChild(child, d.references)) // Prunes the unused columns from child of Aggregate/Expand/Generate case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty => @@ -1453,7 +1440,7 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] { s } else { val newCondition = condition transform { - case a: Attribute if a == d.output.head => d.deserializer.child + case a: Attribute if a == d.output.head => d.deserializer } Filter(newCondition, d.child) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 6df46189b6..4a1bdb0b8a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,126 +21,111 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, StructType} object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) - DeserializeToObject(Alias(deserializer, "obj")(), child) + DeserializeToObject(deserializer, generateObjAttr[T], child) } def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } + + def generateObjAttr[T : Encoder]: Attribute = { + AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() + } } /** - * Takes the input row from child and turns it into object using the given deserializer expression. - * The output of this operator is a single-field safe row containing the deserialized object. + * A trait for logical operators that produces domain objects as output. + * The output of this operator is a single-field safe row containing the produced object. */ -case class DeserializeToObject( - deserializer: Alias, - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = deserializer.toAttribute :: Nil +trait ObjectProducer extends LogicalPlan { + // The attribute that reference to the single object field this operator outputs. + protected def outputObjAttr: Attribute + + override def output: Seq[Attribute] = outputObjAttr :: Nil + + override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr) - def outputObjectType: DataType = deserializer.dataType + def outputObjectType: DataType = outputObjAttr.dataType } /** - * Takes the input object from child and turns in into unsafe row using the given serializer - * expression. The output of its child must be a single-field row containing the input object. + * A trait for logical operators that consumes domain objects as input. + * The output of its child must be a single-field row containing the input object. */ -case class SerializeFromObject( - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode { - override def output: Seq[Attribute] = serializer.map(_.toAttribute) +trait ObjectConsumer extends UnaryNode { + assert(child.output.length == 1) + + // This operator always need all columns of its child, even it doesn't reference to. + override def references: AttributeSet = child.outputSet def inputObjectType: DataType = child.output.head.dataType } /** - * A trait for logical operators that apply user defined functions to domain objects. + * Takes the input row from child and turns it into object using the given deserializer expression. */ -trait ObjectOperator extends LogicalPlan { +case class DeserializeToObject( + deserializer: Expression, + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer - /** The serializer that is used to produce the output of this operator. */ - def serializer: Seq[NamedExpression] +/** + * Takes the input object from child and turns it into unsafe row using the given serializer + * expression. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectConsumer { override def output: Seq[Attribute] = serializer.map(_.toAttribute) - - /** - * The object type that is produced by the user defined function. Note that the return type here - * is the same whether or not the operator is output serialized data. - */ - def outputObject: NamedExpression = - Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")() - - /** - * Returns a copy of this operator that will produce an object instead of an encoded row. - * Used in the optimizer when transforming plans to remove unneeded serialization. - */ - def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) { - this - } else { - withNewSerializer(outputObject :: Nil) - } - - /** Returns a copy of this operator with a different serializer. */ - def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy { - productIterator.map { - case c if c == serializer => newSerializer - case other: AnyRef => other - }.toArray - } } object MapPartitions { def apply[T : Encoder, U : Encoder]( func: Iterator[T] => Iterator[U], - child: LogicalPlan): MapPartitions = { - MapPartitions( + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + val mapped = MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, - child) + CatalystSerde.generateObjAttr[U], + deserialized) + CatalystSerde.serialize[U](mapped) } } /** * A relation produced by applying `func` to each partition of the `child`. - * - * @param deserializer used to extract the input to `func` from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapPartitions( func: Iterator[Any] => Iterator[Any], - deserializer: Expression, - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer object MapElements { def apply[T : Encoder, U : Encoder]( func: AnyRef, - child: LogicalPlan): MapElements = { - MapElements( + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize[T](child) + val mapped = MapElements( func, - UnresolvedDeserializer(encoderFor[T].deserializer), - encoderFor[U].namedExpressions, - child) + CatalystSerde.generateObjAttr[U], + deserialized) + CatalystSerde.serialize[U](mapped) } } /** * A relation produced by applying `func` to each element of the `child`. - * - * @param deserializer used to extract the input to `func` from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapElements( func: AnyRef, - deserializer: Expression, - serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -156,7 +141,7 @@ object AppendColumns { } /** - * A relation produced by applying `func` to each partition of the `child`, concatenating the + * A relation produced by applying `func` to each element of the `child`, concatenating the * resulting columns at the end of the input row. * * @param deserializer used to extract the input to `func` from an input row. @@ -166,28 +151,41 @@ case class AppendColumns( func: Any => Any, deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) } +/** + * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly. + */ +case class AppendColumnsWithObject( + func: Any => Any, + childSerializer: Seq[NamedExpression], + newColumnsSerializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectConsumer { + + override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute) +} + /** Factory for constructing new `MapGroups` nodes. */ object MapGroups { def apply[K : Encoder, T : Encoder, U : Encoder]( func: (K, Iterator[T]) => TraversableOnce[U], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan): MapGroups = { - new MapGroups( + child: LogicalPlan): LogicalPlan = { + val mapped = new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), - encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, + CatalystSerde.generateObjAttr[U], child) + CatalystSerde.serialize[U](mapped) } } @@ -198,43 +196,43 @@ object MapGroups { * * @param keyDeserializer used to extract the key object for each group. * @param valueDeserializer used to extract the items in the iterator from an input row. - * @param serializer use to serialize the output of `func`. */ case class MapGroups( func: (Any, Iterator[Any]) => TraversableOnce[Any], keyDeserializer: Expression, valueDeserializer: Expression, - serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan) extends UnaryNode with ObjectOperator + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectProducer /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { - def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder]( - func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], + def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder]( + func: (K, Iterator[L], Iterator[R]) => TraversableOnce[OUT], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan): CoGroup = { + right: LogicalPlan): LogicalPlan = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) - CoGroup( + val cogrouped = CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to // resolve the `keyDeserializer` based on either of them, here we pick the left one. - UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), - UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), - UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), - encoderFor[Result].namedExpressions, + UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr), leftGroup, rightGroup, leftAttr, rightAttr, + CatalystSerde.generateObjAttr[OUT], left, right) + CatalystSerde.serialize[OUT](cogrouped) } } @@ -247,10 +245,10 @@ case class CoGroup( keyDeserializer: Expression, leftDeserializer: Expression, rightDeserializer: Expression, - serializer: Seq[NamedExpression], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], + outputObjAttr: Attribute, left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectOperator + right: LogicalPlan) extends BinaryNode with ObjectProducer diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala index 9177737560..3c033ddc37 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala @@ -22,8 +22,7 @@ import scala.reflect.runtime.universe.TypeTag import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder -import org.apache.spark.sql.catalyst.expressions.NewInstance -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions} +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.rules.RuleExecutor @@ -37,40 +36,45 @@ class EliminateSerializationSuite extends PlanTest { } implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() - private val func = identity[Iterator[(Int, Int)]] _ - private val func2 = identity[Iterator[OtherTuple]] _ + implicit private def intEncoder = ExpressionEncoder[Int]() - def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = { - val newInstances = plan.flatMap(_.expressions.collect { - case n: NewInstance => n - }) + test("back to back serialization") { + val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze + val optimized = Optimize.execute(plan) + val expected = input.select('obj.as("obj")).analyze + comparePlans(optimized, expected) + } - if (newInstances.size != count) { - fail( - s""" - |Wrong number of object creations in plan: ${newInstances.size} != $count - |$plan - """.stripMargin) - } + test("back to back serialization with object change") { + val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze + val optimized = Optimize.execute(plan) + comparePlans(optimized, plan) } - test("back to back MapPartitions") { - val input = LocalRelation('_1.int, '_2.int) - val plan = - MapPartitions(func, - MapPartitions(func, input)) + test("back to back serialization in AppendColumns") { + val input = LocalRelation('obj.obj(classOf[(Int, Int)])) + val func = (item: (Int, Int)) => item._1 + val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze + + val optimized = Optimize.execute(plan) + + val expected = AppendColumnsWithObject( + func.asInstanceOf[Any => Any], + productEncoder[(Int, Int)].namedExpressions, + intEncoder.namedExpressions, + input).analyze - val optimized = Optimize.execute(plan.analyze) - assertObjectCreations(1, optimized) + comparePlans(optimized, expected) } - test("back to back with object change") { - val input = LocalRelation('_1.int, '_2.int) - val plan = - MapPartitions(func, - MapPartitions(func2, input)) + test("back to back serialization in AppendColumns with object change") { + val input = LocalRelation('obj.obj(classOf[OtherTuple])) + val func = (item: (Int, Int)) => item._1 + val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze - val optimized = Optimize.execute(plan.analyze) - assertObjectCreations(2, optimized) + val optimized = Optimize.execute(plan) + comparePlans(optimized, plan) } } |