aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-05-18 21:43:07 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-18 21:43:07 -0700
commit661c21049b62ebfaf788dcbc31d62a09e206265b (patch)
tree79528af3875b1485ef4e368430f81b40ad634d63 /sql
parent5c9117a3ed373461529f9f9306668ed4149c63fb (diff)
downloadspark-661c21049b62ebfaf788dcbc31d62a09e206265b.tar.gz
spark-661c21049b62ebfaf788dcbc31d62a09e206265b.tar.bz2
spark-661c21049b62ebfaf788dcbc31d62a09e206265b.zip
[SPARK-15381] [SQL] physical object operator should define reference correctly
## What changes were proposed in this pull request? Whole Stage Codegen depends on `SparkPlan.reference` to do some optimization. For physical object operators, they should be consistent with their logical version and set the `reference` correctly. ## How was this patch tested? new test in DatasetSuite Author: Wenchen Fan <wenchen@databricks.com> Closes #13167 from cloud-fan/bug.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala94
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala5
4 files changed, 64 insertions, 47 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 84339f439a..98ce5dd2ef 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -94,7 +94,7 @@ case class DeserializeToObject(
*/
case class SerializeFromObject(
serializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectConsumer {
+ child: LogicalPlan) extends ObjectConsumer {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
}
@@ -118,7 +118,7 @@ object MapPartitions {
case class MapPartitions(
func: Iterator[Any] => Iterator[Any],
outputObjAttr: Attribute,
- child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
+ child: LogicalPlan) extends ObjectConsumer with ObjectProducer
object MapPartitionsInR {
def apply(
@@ -152,7 +152,7 @@ case class MapPartitionsInR(
inputSchema: StructType,
outputSchema: StructType,
outputObjAttr: Attribute,
- child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
+ child: LogicalPlan) extends ObjectConsumer with ObjectProducer {
override lazy val schema = outputSchema
}
@@ -175,7 +175,7 @@ object MapElements {
case class MapElements(
func: AnyRef,
outputObjAttr: Attribute,
- child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
+ child: LogicalPlan) extends ObjectConsumer with ObjectProducer
/** Factory for constructing new `AppendColumn` nodes. */
object AppendColumns {
@@ -215,7 +215,7 @@ case class AppendColumnsWithObject(
func: Any => Any,
childSerializer: Seq[NamedExpression],
newColumnsSerializer: Seq[NamedExpression],
- child: LogicalPlan) extends UnaryNode with ObjectConsumer {
+ child: LogicalPlan) extends ObjectConsumer {
override def output: Seq[Attribute] = (childSerializer ++ newColumnsSerializer).map(_.toAttribute)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index faf359f548..5cfb6d5363 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -303,7 +303,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
"logical except operator should have been replaced by anti-join in the optimizer")
case logical.DeserializeToObject(deserializer, objAttr, child) =>
- execution.DeserializeToObject(deserializer, objAttr, planLater(child)) :: Nil
+ execution.DeserializeToObjectExec(deserializer, objAttr, planLater(child)) :: Nil
case logical.SerializeFromObject(serializer, child) =>
execution.SerializeFromObjectExec(serializer, planLater(child)) :: Nil
case logical.MapPartitions(f, objAttr, child) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
index 3ff991392d..5fced940b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala
@@ -28,17 +28,41 @@ import org.apache.spark.sql.catalyst.expressions.objects.Invoke
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.types.{DataType, ObjectType}
+
+/**
+ * Physical version of `ObjectProducer`.
+ */
+trait ObjectProducerExec extends SparkPlan {
+ // The attribute that reference to the single object field this operator outputs.
+ protected def outputObjAttr: Attribute
+
+ override def output: Seq[Attribute] = outputObjAttr :: Nil
+
+ override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+
+ def outputObjectType: DataType = outputObjAttr.dataType
+}
+
+/**
+ * Physical version of `ObjectConsumer`.
+ */
+trait ObjectConsumerExec extends UnaryExecNode {
+ assert(child.output.length == 1)
+
+ // This operator always need all columns of its child, even it doesn't reference to.
+ override def references: AttributeSet = child.outputSet
+
+ def inputObjectType: DataType = child.output.head.dataType
+}
+
/**
* Takes the input row from child and turns it into object using the given deserializer expression.
* The output of this operator is a single-field safe row containing the deserialized object.
*/
-case class DeserializeToObject(
+case class DeserializeToObjectExec(
deserializer: Expression,
outputObjAttr: Attribute,
- child: SparkPlan) extends UnaryExecNode with CodegenSupport {
-
- override def output: Seq[Attribute] = outputObjAttr :: Nil
- override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+ child: SparkPlan) extends UnaryExecNode with ObjectProducerExec with CodegenSupport {
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -70,7 +94,7 @@ case class DeserializeToObject(
*/
case class SerializeFromObjectExec(
serializer: Seq[NamedExpression],
- child: SparkPlan) extends UnaryExecNode with CodegenSupport {
+ child: SparkPlan) extends ObjectConsumerExec with CodegenSupport {
override def output: Seq[Attribute] = serializer.map(_.toAttribute)
@@ -102,7 +126,7 @@ case class SerializeFromObjectExec(
/**
* Helper functions for physical operators that work with user defined objects.
*/
-trait ObjectOperator extends SparkPlan {
+object ObjectOperator {
def deserializeRowToObject(
deserializer: Expression,
inputSchema: Seq[Attribute]): InternalRow => Any = {
@@ -141,15 +165,12 @@ case class MapPartitionsExec(
func: Iterator[Any] => Iterator[Any],
outputObjAttr: Attribute,
child: SparkPlan)
- extends UnaryExecNode with ObjectOperator {
-
- override def output: Seq[Attribute] = outputObjAttr :: Nil
- override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+ extends ObjectConsumerExec with ObjectProducerExec {
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = unwrapObjectFromRow(child.output.head.dataType)
- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+ val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
func(iter.map(getObject)).map(outputObject)
}
}
@@ -166,10 +187,7 @@ case class MapElementsExec(
func: AnyRef,
outputObjAttr: Attribute,
child: SparkPlan)
- extends UnaryExecNode with ObjectOperator with CodegenSupport {
-
- override def output: Seq[Attribute] = outputObjAttr :: Nil
- override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+ extends ObjectConsumerExec with ObjectProducerExec with CodegenSupport {
override def inputRDDs(): Seq[RDD[InternalRow]] = {
child.asInstanceOf[CodegenSupport].inputRDDs()
@@ -202,8 +220,8 @@ case class MapElementsExec(
}
child.execute().mapPartitionsInternal { iter =>
- val getObject = unwrapObjectFromRow(child.output.head.dataType)
- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+ val getObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
iter.map(row => outputObject(callFunc(getObject(row))))
}
}
@@ -218,7 +236,7 @@ case class AppendColumnsExec(
func: Any => Any,
deserializer: Expression,
serializer: Seq[NamedExpression],
- child: SparkPlan) extends UnaryExecNode with ObjectOperator {
+ child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = child.output ++ serializer.map(_.toAttribute)
@@ -226,9 +244,9 @@ case class AppendColumnsExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getObject = deserializeRowToObject(deserializer, child.output)
+ val getObject = ObjectOperator.deserializeRowToObject(deserializer, child.output)
val combiner = GenerateUnsafeRowJoiner.create(child.schema, newColumnSchema)
- val outputObject = serializeObjectToRow(serializer)
+ val outputObject = ObjectOperator.serializeObjectToRow(serializer)
iter.map { row =>
val newColumns = outputObject(func(getObject(row)))
@@ -246,7 +264,7 @@ case class AppendColumnsWithObjectExec(
func: Any => Any,
inputSerializer: Seq[NamedExpression],
newColumnsSerializer: Seq[NamedExpression],
- child: SparkPlan) extends UnaryExecNode with ObjectOperator {
+ child: SparkPlan) extends ObjectConsumerExec {
override def output: Seq[Attribute] = (inputSerializer ++ newColumnsSerializer).map(_.toAttribute)
@@ -255,9 +273,9 @@ case class AppendColumnsWithObjectExec(
override protected def doExecute(): RDD[InternalRow] = {
child.execute().mapPartitionsInternal { iter =>
- val getChildObject = unwrapObjectFromRow(child.output.head.dataType)
- val outputChildObject = serializeObjectToRow(inputSerializer)
- val outputNewColumnOjb = serializeObjectToRow(newColumnsSerializer)
+ val getChildObject = ObjectOperator.unwrapObjectFromRow(child.output.head.dataType)
+ val outputChildObject = ObjectOperator.serializeObjectToRow(inputSerializer)
+ val outputNewColumnOjb = ObjectOperator.serializeObjectToRow(newColumnsSerializer)
val combiner = GenerateUnsafeRowJoiner.create(inputSchema, newColumnSchema)
iter.map { row =>
@@ -280,10 +298,7 @@ case class MapGroupsExec(
groupingAttributes: Seq[Attribute],
dataAttributes: Seq[Attribute],
outputObjAttr: Attribute,
- child: SparkPlan) extends UnaryExecNode with ObjectOperator {
-
- override def output: Seq[Attribute] = outputObjAttr :: Nil
- override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+ child: SparkPlan) extends UnaryExecNode with ObjectProducerExec {
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(groupingAttributes) :: Nil
@@ -295,9 +310,9 @@ case class MapGroupsExec(
child.execute().mapPartitionsInternal { iter =>
val grouped = GroupedIterator(iter, groupingAttributes, child.output)
- val getKey = deserializeRowToObject(keyDeserializer, groupingAttributes)
- val getValue = deserializeRowToObject(valueDeserializer, dataAttributes)
- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+ val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, groupingAttributes)
+ val getValue = ObjectOperator.deserializeRowToObject(valueDeserializer, dataAttributes)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
grouped.flatMap { case (key, rowIter) =>
val result = func(
@@ -325,10 +340,7 @@ case class CoGroupExec(
rightAttr: Seq[Attribute],
outputObjAttr: Attribute,
left: SparkPlan,
- right: SparkPlan) extends BinaryExecNode with ObjectOperator {
-
- override def output: Seq[Attribute] = outputObjAttr :: Nil
- override def producedAttributes: AttributeSet = AttributeSet(outputObjAttr)
+ right: SparkPlan) extends BinaryExecNode with ObjectProducerExec {
override def requiredChildDistribution: Seq[Distribution] =
ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
@@ -341,10 +353,10 @@ case class CoGroupExec(
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
- val getKey = deserializeRowToObject(keyDeserializer, leftGroup)
- val getLeft = deserializeRowToObject(leftDeserializer, leftAttr)
- val getRight = deserializeRowToObject(rightDeserializer, rightAttr)
- val outputObject = wrapObjectToRow(outputObjAttr.dataType)
+ val getKey = ObjectOperator.deserializeRowToObject(keyDeserializer, leftGroup)
+ val getLeft = ObjectOperator.deserializeRowToObject(leftDeserializer, leftAttr)
+ val getRight = ObjectOperator.deserializeRowToObject(rightDeserializer, rightAttr)
+ val outputObject = ObjectOperator.wrapObjectToRow(outputObjAttr.dataType)
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 1935e41185..52e706285c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -711,6 +711,11 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
assert(e.message.contains("already exists"))
dataset.sparkSession.catalog.dropTempView("tempView")
}
+
+ test("SPARK-15381: physical object operator should define `reference` correctly") {
+ val df = Seq(1 -> 2).toDF("a", "b")
+ checkAnswer(df.map(row => row)(RowEncoder(df.schema)).select("b", "a"), Row(2, 1))
+ }
}
case class Generic[T](id: T, value: Double)