aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala31
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala166
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala62
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala160
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala2
11 files changed, 254 insertions, 217 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index 6591559426..0e2fd43983 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -1672,9 +1672,9 @@ object CleanupAliases extends Rule[LogicalPlan] {
// Operators that operate on objects should only have expressions from encoders, which should
// never have extra aliases.
- case o: ObjectOperator => o
- case d: DeserializeToObject => d
- case s: SerializeFromObject => s
+ case o: ObjectConsumer => o
+ case o: ObjectProducer => o
+ case a: AppendColumns => a
case other =>
var stop = false
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 958966328b..085e95f542 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -245,6 +245,10 @@ package object dsl {
def struct(attrs: AttributeReference*): AttributeReference =
struct(StructType.fromAttributes(attrs))
+ /** Creates a new AttributeReference of object type */
+ def obj(cls: Class[_]): AttributeReference =
+ AttributeReference(s, ObjectType(cls), nullable = true)()
+
/** Create a function. */
def function(exprs: Expression*): UnresolvedFunction =
UnresolvedFunction(s, exprs, isDistinct = false)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b806b725a8..0a5232b2d4 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -153,29 +153,16 @@ object SamplePushDown extends Rule[LogicalPlan] {
* representation of data item. For example back to back map operations.
*/
object EliminateSerialization extends Rule[LogicalPlan] {
- // TODO: find a more general way to do this optimization.
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case m @ MapPartitions(_, deserializer, _, child: ObjectOperator)
- if !deserializer.isInstanceOf[Attribute] &&
- deserializer.dataType == child.outputObject.dataType =>
- val childWithoutSerialization = child.withObjectOutput
- m.copy(
- deserializer = childWithoutSerialization.output.head,
- child = childWithoutSerialization)
-
- case m @ MapElements(_, deserializer, _, child: ObjectOperator)
- if !deserializer.isInstanceOf[Attribute] &&
- deserializer.dataType == child.outputObject.dataType =>
- val childWithoutSerialization = child.withObjectOutput
- m.copy(
- deserializer = childWithoutSerialization.output.head,
- child = childWithoutSerialization)
-
- case d @ DeserializeToObject(_, s: SerializeFromObject)
+ case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
// Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
Project(objAttr :: Nil, s.child)
+
+ case a @ AppendColumns(_, _, _, s: SerializeFromObject)
+ if a.deserializer.dataType == s.inputObjectType =>
+ AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
}
}
@@ -366,9 +353,9 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
a.copy(child = Expand(newProjects, newOutput, grandChild))
- // Prunes the unused columns from child of MapPartitions
- case mp @ MapPartitions(_, _, _, child) if (child.outputSet -- mp.references).nonEmpty =>
- mp.copy(child = prunedChild(child, mp.references))
+ // Prunes the unused columns from child of `DeserializeToObject`
+ case d @ DeserializeToObject(_, _, child) if (child.outputSet -- d.references).nonEmpty =>
+ d.copy(child = prunedChild(child, d.references))
// Prunes the unused columns from child of Aggregate/Expand/Generate
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
@@ -1453,7 +1440,7 @@ object EmbedSerializerInFilter extends Rule[LogicalPlan] {
s
} else {
val newCondition = condition transform {
- case a: Attribute if a == d.output.head => d.deserializer.child
+ case a: Attribute if a == d.output.head => d.deserializer
}
Filter(newCondition, d.child)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 6df46189b6..4a1bdb0b8a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -21,126 +21,111 @@ import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{DataType, ObjectType, StructType}
+import org.apache.spark.sql.types.{DataType, StructType}
object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer)
- DeserializeToObject(Alias(deserializer, "obj")(), child)
+ DeserializeToObject(deserializer, generateObjAttr[T], child)
}
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
+
+ def generateObjAttr[T : Encoder]: Attribute = {
+ AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
+ }
}
/**
- * Takes the input row from child and turns it into object using the given deserializer expression.
- * The output of this operator is a single-field safe row containing the deserialized object.
+ * A trait for logical operators that produces domain objects as output.
+ * The output of this operator is a single-field safe row containing the produced object.
*/
-case class DeserializeToObject(
- deserializer: Alias,
- child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+trait ObjectProducer extends LogicalPlan {
+ // The attribute that reference to the single object field this operator outputs.
+ protected def outputObjAttr: Attribute
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
- def outputObjectType: DataType = deserializer.dataType
+ def outputObjectType: DataType = outputObjAttr.dataType
}
/**
- * Takes the input object from child and turns in into unsafe row using the given serializer
- * expression. The output of its child must be a single-field row containing the input object.
+ * A trait for logical operators that consumes domain objects as input.
+ * The output of its child must be a single-field row containing the input object.
*/
-case class SerializeFromObject(
- serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+trait ObjectConsumer extends UnaryNode {
+ assert(child.output.length == 1)
+
+ // This operator always need all columns of its child, even it doesn't reference to.
+ override def references: AttributeSet = child.outputSet
def inputObjectType: DataType = child.output.head.dataType
}
/**
- * A trait for logical operators that apply user defined functions to domain objects.
+ * Takes the input row from child and turns it into object using the given deserializer expression.
*/
-trait ObjectOperator extends LogicalPlan {
+case class DeserializeToObject(
+ deserializer: Expression,
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectProducer
- /** The serializer that is used to produce the output of this operator. */
- def serializer: Seq[NamedExpression]
+/**
+ * Takes the input object from child and turns it into unsafe row using the given serializer
+ * expression.
+ */
+case class SerializeFromObject(
+ serializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
-
- /**
- * The object type that is produced by the user defined function. Note that the return type here
- * is the same whether or not the operator is output serialized data.
- */
- def outputObject: NamedExpression =
- Alias(serializer.head.collect { case b: BoundReference => b }.head, "obj")()
-
- /**
- * Returns a copy of this operator that will produce an object instead of an encoded row.
- * Used in the optimizer when transforming plans to remove unneeded serialization.
- */
- def withObjectOutput: LogicalPlan = if (output.head.dataType.isInstanceOf[ObjectType]) {
- this
- } else {
- withNewSerializer(outputObject :: Nil)
- }
-
- /** Returns a copy of this operator with a different serializer. */
- def withNewSerializer(newSerializer: Seq[NamedExpression]): LogicalPlan = makeCopy {
- productIterator.map {
- case c if c == serializer => newSerializer
- case other: AnyRef => other
- }.toArray
- }
}
object MapPartitions {
def apply[T : Encoder, U : Encoder](
func: Iterator[T] => Iterator[U],
- child: LogicalPlan): MapPartitions = {
- MapPartitions(
+ child: LogicalPlan): LogicalPlan = {
+ val deserialized = CatalystSerde.deserialize[T](child)
+ val mapped = MapPartitions(
func.asInstanceOf[Iterator[Any] => Iterator[Any]],
- UnresolvedDeserializer(encoderFor[T].deserializer),
- encoderFor[U].namedExpressions,
- child)
+ CatalystSerde.generateObjAttr[U],
+ deserialized)
+ CatalystSerde.serialize[U](mapped)
}
}
/**
* A relation produced by applying `func` to each partition of the `child`.
- *
- * @param deserializer used to extract the input to `func` from an input row.
- * @param serializer use to serialize the output of `func`.
*/
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
- deserializer: Expression,
- serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectOperator
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,
- child: LogicalPlan): MapElements = {
- MapElements(
+ child: LogicalPlan): LogicalPlan = {
+ val deserialized = CatalystSerde.deserialize[T](child)
+ val mapped = MapElements(
func,
- UnresolvedDeserializer(encoderFor[T].deserializer),
- encoderFor[U].namedExpressions,
- child)
+ CatalystSerde.generateObjAttr[U],
+ deserialized)
+ CatalystSerde.serialize[U](mapped)
}
}
/**
* A relation produced by applying `func` to each element of the `child`.
- *
- * @param deserializer used to extract the input to `func` from an input row.
- * @param serializer use to serialize the output of `func`.
*/
case class MapElements(
func: AnyRef,
- deserializer: Expression,
- serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectOperator
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
@@ -156,7 +141,7 @@ object AppendColumns {
}
/**
- * A relation produced by applying `func` to each partition of the `child`, concatenating the
+ * A relation produced by applying `func` to each element of the `child`, concatenating the
* resulting columns at the end of the input row.
*
* @param deserializer used to extract the input to `func` from an input row.
@@ -166,28 +151,41 @@ case class AppendColumns(
func: Any => Any,
deserializer: Expression,
serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectOperator {
+ child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
def newColumns: Seq[Attribute] = serializer.map(_.toAttribute)
}
+/**
+ * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly.
+ */
+case class AppendColumnsWithObject(
+ func: Any => Any,
+ childSerializer: Seq[NamedExpression],
+ newColumnsSerializer: Seq[NamedExpression],
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer {
+
+ override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute)
+}
+
/** Factory for constructing new `MapGroups` nodes. */
object MapGroups {
def apply[K : Encoder, T : Encoder, U : Encoder](
func: (K, Iterator[T]) => TraversableOnce[U],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
- child: LogicalPlan): MapGroups = {
- new MapGroups(
+ child: LogicalPlan): LogicalPlan = {
+ val mapped = new MapGroups(
func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]],
UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes),
UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes),
- encoderFor[U].namedExpressions,
groupingAttributes,
dataAttributes,
+ CatalystSerde.generateObjAttr[U],
child)
+ CatalystSerde.serialize[U](mapped)
}
}
@@ -198,43 +196,43 @@ object MapGroups {
*
* @param keyDeserializer used to extract the key object for each group.
* @param valueDeserializer used to extract the items in the iterator from an input row.
- * @param serializer use to serialize the output of `func`.
*/
case class MapGroups(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
- serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
- child: LogicalPlan) extends UnaryNode with ObjectOperator
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectProducer
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
- def apply[Key : Encoder, Left : Encoder, Right : Encoder, Result : Encoder](
- func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+ def apply[K : Encoder, L : Encoder, R : Encoder, OUT : Encoder](
+ func: (K, Iterator[L], Iterator[R]) => TraversableOnce[OUT],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
left: LogicalPlan,
- right: LogicalPlan): CoGroup = {
+ right: LogicalPlan): LogicalPlan = {
require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup))
- CoGroup(
+ val cogrouped = CoGroup(
func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]],
// The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to
// resolve the `keyDeserializer` based on either of them, here we pick the left one.
- UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup),
- UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr),
- UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr),
- encoderFor[Result].namedExpressions,
+ UnresolvedDeserializer(encoderFor[K].deserializer, leftGroup),
+ UnresolvedDeserializer(encoderFor[L].deserializer, leftAttr),
+ UnresolvedDeserializer(encoderFor[R].deserializer, rightAttr),
leftGroup,
rightGroup,
leftAttr,
rightAttr,
+ CatalystSerde.generateObjAttr[OUT],
left,
right)
+ CatalystSerde.serialize[OUT](cogrouped)
}
}
@@ -247,10 +245,10 @@ case class CoGroup(
keyDeserializer: Expression,
leftDeserializer: Expression,
rightDeserializer: Expression,
- serializer: Seq[NamedExpression],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
+ outputObjAttr: Attribute,
left: LogicalPlan,
- right: LogicalPlan) extends BinaryNode with ObjectOperator
+ right: LogicalPlan) extends BinaryNode with ObjectProducer
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala
index 9177737560..3c033ddc37 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSerializationSuite.scala
@@ -22,8 +22,7 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions.NewInstance
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, MapPartitions}
+import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -37,40 +36,45 @@ class EliminateSerializationSuite extends PlanTest {
}
implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]()
- private val func = identity[Iterator[(Int, Int)]] _
- private val func2 = identity[Iterator[OtherTuple]] _
+ implicit private def intEncoder = ExpressionEncoder[Int]()
- def assertObjectCreations(count: Int, plan: LogicalPlan): Unit = {
- val newInstances = plan.flatMap(_.expressions.collect {
- case n: NewInstance => n
- })
+ test("back to back serialization") {
+ val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
+ val plan = input.serialize[(Int, Int)].deserialize[(Int, Int)].analyze
+ val optimized = Optimize.execute(plan)
+ val expected = input.select('obj.as("obj")).analyze
+ comparePlans(optimized, expected)
+ }
- if (newInstances.size != count) {
- fail(
- s"""
- |Wrong number of object creations in plan: ${newInstances.size} != $count
- |$plan
- """.stripMargin)
- }
+ test("back to back serialization with object change") {
+ val input = LocalRelation('obj.obj(classOf[OtherTuple]))
+ val plan = input.serialize[OtherTuple].deserialize[(Int, Int)].analyze
+ val optimized = Optimize.execute(plan)
+ comparePlans(optimized, plan)
}
- test("back to back MapPartitions") {
- val input = LocalRelation('_1.int, '_2.int)
- val plan =
- MapPartitions(func,
- MapPartitions(func, input))
+ test("back to back serialization in AppendColumns") {
+ val input = LocalRelation('obj.obj(classOf[(Int, Int)]))
+ val func = (item: (Int, Int)) => item._1
+ val plan = AppendColumns(func, input.serialize[(Int, Int)]).analyze
+
+ val optimized = Optimize.execute(plan)
+
+ val expected = AppendColumnsWithObject(
+ func.asInstanceOf[Any => Any],
+ productEncoder[(Int, Int)].namedExpressions,
+ intEncoder.namedExpressions,
+ input).analyze
- val optimized = Optimize.execute(plan.analyze)
- assertObjectCreations(1, optimized)
+ comparePlans(optimized, expected)
}
- test("back to back with object change") {
- val input = LocalRelation('_1.int, '_2.int)
- val plan =
- MapPartitions(func,
- MapPartitions(func2, input))
+ test("back to back serialization in AppendColumns with object change") {
+ val input = LocalRelation('obj.obj(classOf[OtherTuple]))
+ val func = (item: (Int, Int)) => item._1
+ val plan = AppendColumns(func, input.serialize[OtherTuple]).analyze
- val optimized = Optimize.execute(plan.analyze)
- assertObjectCreations(2, optimized)
+ val optimized = Optimize.execute(plan)
+ comparePlans(optimized, plan)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1a09d70fb9..3c708cbf29 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2251,16 +2251,16 @@ class Dataset[T] private[sql](
def unpersist(): this.type = unpersist(blocking = false)
/**
- * Represents the content of the [[Dataset]] as an [[RDD]] of [[Row]]s. Note that the RDD is
- * memoized. Once called, it won't change even if you change any query planning related Spark SQL
- * configurations (e.g. `spark.sql.shuffle.partitions`).
+ * Represents the content of the [[Dataset]] as an [[RDD]] of [[T]].
*
* @group rdd
* @since 1.6.0
*/
lazy val rdd: RDD[T] = {
- queryExecution.toRdd.mapPartitions { rows =>
- rows.map(boundTEncoder.fromRow)
+ val objectType = unresolvedTEncoder.deserializer.dataType
+ val deserialized = CatalystSerde.deserialize[T](logicalPlan)
+ sqlContext.executePlan(deserialized).toRdd.mapPartitions { rows =>
+ rows.map(_.get(0, objectType).asInstanceOf[T])
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c15aaed365..a4b0fa59db 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -346,21 +346,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
throw new IllegalStateException(
"logical intersect operator should have been replaced by semi-join in the optimizer")
- case logical.DeserializeToObject(deserializer, child) =>
- execution.DeserializeToObject(deserializer, planLater(child)) :: Nil
+ case logical.DeserializeToObject(deserializer, objAttr, child) =>
+ execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
case logical.SerializeFromObject(serializer, child) =>
execution.SerializeFromObject(serializer, planLater(child)) :: Nil
- case logical.MapPartitions(f, in, out, child) =>
- execution.MapPartitions(f, in, out, planLater(child)) :: Nil
- case logical.MapElements(f, in, out, child) =>
- execution.MapElements(f, in, out, planLater(child)) :: Nil
+ case logical.MapPartitions(f, objAttr, child) =>
+ execution.MapPartitions(f, objAttr, planLater(child)) :: Nil
+ case logical.MapElements(f, objAttr, child) =>
+ execution.MapElements(f, objAttr, planLater(child)) :: Nil
case logical.AppendColumns(f, in, out, child) =>
execution.AppendColumns(f, in, out, planLater(child)) :: Nil
- case logical.MapGroups(f, key, in, out, grouping, data, child) =>
- execution.MapGroups(f, key, in, out, grouping, data, planLater(child)) :: Nil
- case logical.CoGroup(f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr, left, right) =>
+ case logical.AppendColumnsWithObject(f, childSer, newSer, child) =>
+ execution.AppendColumnsWithObject(f, childSer, newSer, planLater(child)) :: Nil
+ case logical.MapGroups(f, key, value, grouping, data, objAttr, child) =>
+ execution.MapGroups(f, key, value, grouping, data, objAttr, planLater(child)) :: Nil
+ case logical.CoGroup(f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr, left, right) =>
execution.CoGroup(
- f, keyObj, lObj, rObj, out, lGroup, rGroup, lAttr, rAttr,
+ f, key, lObj, rObj, lGroup, rGroup, lAttr, rAttr, oAttr,
planLater(left), planLater(right)) :: Nil
case logical.Repartition(numPartitions, shuffle, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 46eaede5e7..23b2eabd0c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -473,6 +473,10 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] {
* Inserts a WholeStageCodegen on top of those that support codegen.
*/
private def insertWholeStageCodegen(plan: SparkPlan): SparkPlan = plan match {
+ // For operators that will output domain object, do not insert WholeStageCodegen for it as
+ // domain object can not be written into unsafe row.
+ case plan if plan.output.length == 1 && plan.output.head.dataType.isInstanceOf[ObjectType] =>
+ plan.withNewChildren(plan.children.map(insertWholeStageCodegen))
case plan: CodegenSupport if supportCodegen(plan) =>
WholeStageCodegen(insertInputAdapter(plan))
case other =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index e7261fc512..7c8bc7fed8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -25,16 +25,19 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.types.ObjectType
+import org.apache.spark.sql.types.{DataType, ObjectType}
/**
* Takes the input row from child and turns it into object using the given deserializer expression.
* The output of this operator is a single-field safe row containing the deserialized object.
*/
case class DeserializeToObject(
- deserializer: Alias,
+ deserializer: Expression,
+ outputObjAttr: Attribute,
child: SparkPlan) extends UnaryNode with CodegenSupport {
- override def output: Seq[Attribute] = deserializer.toAttribute :: Nil
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -67,6 +70,7 @@ case class DeserializeToObject(
case class SerializeFromObject(
serializer: Seq[NamedExpression],
child: SparkPlan) extends UnaryNode with CodegenSupport {
+
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
override def inputRDDs(): Seq[RDD[InternalRow]] = {
@@ -98,60 +102,71 @@ case class SerializeFromObject(
* Helper functions for physical operators that work with user defined objects.
*/
trait ObjectOperator extends SparkPlan {
- def generateToObject(objExpr: Expression, inputSchema: Seq[Attribute]): InternalRow => Any = {
- val objectProjection = GenerateSafeProjection.generate(objExpr :: Nil, inputSchema)
- (i: InternalRow) => objectProjection(i).get(0, objExpr.dataType)
+ def deserializeRowToObject(
+ deserializer: Expression,
+ inputSchema: Seq[Attribute]): InternalRow => Any = {
+ val proj = GenerateSafeProjection.generate(deserializer :: Nil, inputSchema)
+ (i: InternalRow) => proj(i).get(0, deserializer.dataType)
}
- def generateToRow(serializer: Seq[Expression]): Any => InternalRow = {
- val outputProjection = if (serializer.head.dataType.isInstanceOf[ObjectType]) {
- GenerateSafeProjection.generate(serializer)
- } else {
- GenerateUnsafeProjection.generate(serializer)
+ def serializeObjectToRow(serializer: Seq[Expression]): Any => UnsafeRow = {
+ val proj = GenerateUnsafeProjection.generate(serializer)
+ val objType = serializer.head.collect { case b: BoundReference => b.dataType }.head
+ val objRow = new SpecificMutableRow(objType :: Nil)
+ (o: Any) => {
+ objRow(0) = o
+ proj(objRow)
}
- val inputType = serializer.head.collect { case b: BoundReference => b.dataType }.head
- val outputRow = new SpecificMutableRow(inputType :: Nil)
+ }
+
+ def wrapObjectToRow(objType: DataType): Any => InternalRow = {
+ val outputRow = new SpecificMutableRow(objType :: Nil)
(o: Any) => {
outputRow(0) = o
- outputProjection(outputRow)
+ outputRow
}
}
+
+ def unwrapObjectFromRow(objType: DataType): InternalRow => Any = {
+ (i: InternalRow) => i.get(0, objType)
+ }
}
/**
- * Applies the given function to each input row and encodes the result.
+ * Applies the given function to input object iterator.
+ * The output of its child must be a single-field row containing the input object.
*/
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
- deserializer: Expression,
- serializer: Seq[NamedExpression],
+ outputObjAttr: Attribute,
child: SparkPlan) extends UnaryNode with ObjectOperator {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(deserializer, child.output)
- val outputObject = generateToRow(serializer)
+ val getObject = unwrapObjectFromRow(child.output.head.dataType)
+ val outputObject = wrapObjectToRow(outputObjAttr.dataType)
func(iter.map(getObject)).map(outputObject)
}
}
}
/**
- * Applies the given function to each input row and encodes the result.
+ * Applies the given function to each input object.
+ * The output of its child must be a single-field row containing the input object.
*
- * Note that, each serializer expression needs the result object which is returned by the given
- * function, as input. This operator uses some tricks to make sure we only calculate the result
- * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with
- * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of
- * a project while explain.
+ * This operator is kind of a safe version of [[Project]], as it's output is custom object, we need
+ * to use safe row to contain it.
*/
case class MapElements(
func: AnyRef,
- deserializer: Expression,
- serializer: Seq[NamedExpression],
+ outputObjAttr: Attribute,
child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -167,23 +182,14 @@ case class MapElements(
case _ => classOf[Any => Any] -> "apply"
}
val funcObj = Literal.create(func, ObjectType(funcClass))
- val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType
- val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer))
+ val callFunc = Invoke(funcObj, methodName, outputObjAttr.dataType, child.output)
val bound = ExpressionCanonicalizer.execute(
BindReferences.bindReference(callFunc, child.output))
ctx.currentVars = input
- val evaluated = bound.genCode(ctx)
-
- val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType)
- val outputFields = serializer.map(_ transform {
- case _: BoundReference => resultObj
- })
- val resultVars = outputFields.map(_.genCode(ctx))
- s"""
- ${evaluated.code}
- ${consume(ctx, resultVars)}
- """
+ val resultVars = bound.genCode(ctx) :: Nil
+
+ consume(ctx, resultVars)
}
override protected def doExecute(): RDD[InternalRow] = {
@@ -191,9 +197,10 @@ case class MapElements(
case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i)
case _ => func.asInstanceOf[Any => Any]
}
+
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(deserializer, child.output)
- val outputObject = generateToRow(serializer)
+ val getObject = unwrapObjectFromRow(child.output.head.dataType)
+ val outputObject = wrapObjectToRow(outputObjAttr.dataType)
iter.map(row => outputObject(callFunc(getObject(row))))
}
}
@@ -216,15 +223,43 @@ case class AppendColumns(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = generateToObject(deserializer, child.output)
+ val getObject = deserializeRowToObject(deserializer, child.output)
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
- val outputObject = generateToRow(serializer)
+ val outputObject = serializeObjectToRow(serializer)
iter.map { row =>
val newColumns = outputObject(func(getObject(row)))
+ combiner.join(row.asInstanceOf[UnsafeRow], newColumns): InternalRow
+ }
+ }
+ }
+}
+
+/**
+ * An optimized version of [[AppendColumns]], that can be executed on deserialized object directly.
+ */
+case class AppendColumnsWithObject(
+ func: Any => Any,
+ inputSerializer: Seq[NamedExpression],
+ newColumnsSerializer: Seq[NamedExpression],
+ child: SparkPlan) extends UnaryNode with ObjectOperator {
+
+ override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)
- // This operates on the assumption that we always serialize the result...
- combiner.join(row.asInstanceOf[UnsafeRow], newColumns.asInstanceOf[UnsafeRow]): InternalRow
+ private def inputSchema = inputSerializer.map(_.toAttribute).toStructType
+ private def newColumnSchema = newColumnsSerializer.map(_.toAttribute).toStructType
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ child.execute().mapPartitionsInternal { iter =>
+ val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
+ val outputChildObject = serializeObjectToRow(inputSerializer)
+ val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
+ val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema)
+
+ iter.map { row =>
+ val childObj = getChildObject(row)
+ val newColumns = outputNewColumnOjb(func(childObj))
+ combiner.join(outputChildObject(childObj), newColumns): InternalRow
}
}
}
@@ -232,19 +267,19 @@ case class AppendColumns(
/**
* Groups the input rows together and calls the function with each group and an iterator containing
- * all elements in the group. The result of this function is encoded and flattened before
- * being output.
+ * all elements in the group. The result of this function is flattened before being output.
*/
case class MapGroups(
func: (Any, Iterator[Any]) => TraversableOnce[Any],
keyDeserializer: Expression,
valueDeserializer: Expression,
- serializer: Seq[NamedExpression],
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
+ outputObjAttr: Attribute,
child: SparkPlan) extends UnaryNode with ObjectOperator {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
@@ -256,9 +291,9 @@ case class MapGroups(
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
- val getKey = generateToObject(keyDeserializer, groupingAttributes)
- val getValue = generateToObject(valueDeserializer, dataAttributes)
- val outputObject = generateToRow(serializer)
+ val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
+ val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
+ val outputObject = wrapObjectToRow(outputObjAttr.dataType)
grouped.flatMap { case (key, rowIter) =>
val result = func(
@@ -273,22 +308,23 @@ case class MapGroups(
/**
* Co-groups the data from left and right children, and calls the function with each group and 2
* iterators containing all elements in the group from left and right side.
- * The result of this function is encoded and flattened before being output.
+ * The result of this function is flattened before being output.
*/
case class CoGroup(
func: (Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any],
keyDeserializer: Expression,
leftDeserializer: Expression,
rightDeserializer: Expression,
- serializer: Seq[NamedExpression],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
leftAttr: Seq[Attribute],
rightAttr: Seq[Attribute],
+ outputObjAttr: Attribute,
left: SparkPlan,
right: SparkPlan) extends BinaryNode with ObjectOperator {
- override def output: Seq[Attribute] = serializer.map(_.toAttribute)
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
@@ -301,10 +337,10 @@ case class CoGroup(
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
- val getKey = generateToObject(keyDeserializer, leftGroup)
- val getLeft = generateToObject(leftDeserializer, leftAttr)
- val getRight = generateToObject(rightDeserializer, rightAttr)
- val outputObject = generateToRow(serializer)
+ val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
+ val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
+ val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
+ val outputObject = wrapObjectToRow(outputObjAttr.dataType)
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 23a0ce215f..2dca792c83 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -201,7 +201,9 @@ abstract class QueryTest extends PlanTest {
val logicalPlan = df.queryExecution.analyzed
// bypass some cases that we can't handle currently.
logicalPlan.transform {
- case _: ObjectOperator => return
+ case _: ObjectConsumer => return
+ case _: ObjectProducer => return
+ case _: AppendColumns => return
case _: LogicalRelation => return
case _: MemoryPlan => return
}.transformAllExpressions {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 8efd9de29e..d7cf1dc6aa 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -79,7 +79,7 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
val plan = ds.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
- p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined)
+ p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined)
assert(ds.collect() === 0.until(10).map(_.toString).toArray)
}