diff options
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 13 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala | 54 |
2 files changed, 61 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 434c033c49..abbd8facd3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case d @ DeserializeToObject(_, _, s: SerializeFromObject) if d.outputObjectType == s.inputObjectType => - // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. - val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) - Project(objAttr :: Nil, s.child) - + // A workaround for SPARK-14803. Remove this after it is fixed. + if (d.outputObjectType.isInstanceOf[ObjectType] && + d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) { + s.child + } else { + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) + } case a @ AppendColumns(_, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjectType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 4a1bdb0b8a..84339f439a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.Encoder +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.{Encoder, Row} import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types._ object CatalystSerde { def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { @@ -29,13 +30,26 @@ object CatalystSerde { DeserializeToObject(deserializer, generateObjAttr[T], child) } + def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = { + val deserializer = UnresolvedDeserializer(encoder.deserializer) + DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child) + } + def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { SerializeFromObject(encoderFor[T].namedExpressions, child) } + def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = { + SerializeFromObject(encoder.namedExpressions, child) + } + def generateObjAttr[T : Encoder]: Attribute = { AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)() } + + def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = { + AttributeReference("obj", encoder.deserializer.dataType, nullable = false)() + } } /** @@ -106,6 +120,42 @@ case class MapPartitions( outputObjAttr: Attribute, child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer +object MapPartitionsInR { + def apply( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType, + encoder: ExpressionEncoder[Row], + child: LogicalPlan): LogicalPlan = { + val deserialized = CatalystSerde.deserialize(child, encoder) + val mapped = MapPartitionsInR( + func, + packageNames, + broadcastVars, + encoder.schema, + schema, + CatalystSerde.generateObjAttrForRow(RowEncoder(schema)), + deserialized) + CatalystSerde.serialize(mapped, RowEncoder(schema)) + } +} + +/** + * A relation produced by applying a serialized R function `func` to each partition of the `child`. + * + */ +case class MapPartitionsInR( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + inputSchema: StructType, + outputSchema: StructType, + outputObjAttr: Attribute, + child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer { + override lazy val schema = outputSchema +} + object MapElements { def apply[T : Encoder, U : Encoder]( func: AnyRef, |