aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala54
2 files changed, 61 insertions, 6 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 434c033c49..abbd8facd3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -159,10 +159,15 @@ object EliminateSerialization extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case d @ DeserializeToObject(_, _, s: SerializeFromObject)
if d.outputObjectType == s.inputObjectType =>
- // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
- val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
- Project(objAttr :: Nil, s.child)
-
+ // A workaround for SPARK-14803. Remove this after it is fixed.
+ if (d.outputObjectType.isInstanceOf[ObjectType] &&
+ d.outputObjectType.asInstanceOf[ObjectType].cls == classOf[org.apache.spark.sql.Row]) {
+ s.child
+ } else {
+ // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`.
+ val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId)
+ Project(objAttr :: Nil, s.child)
+ }
case a @ AppendColumns(_, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjectType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index 4a1bdb0b8a..84339f439a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.Encoder
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.{Encoder, Row}
import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{DataType, StructType}
+import org.apache.spark.sql.types._
object CatalystSerde {
def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = {
@@ -29,13 +30,26 @@ object CatalystSerde {
DeserializeToObject(deserializer, generateObjAttr[T], child)
}
+ def deserialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): DeserializeToObject = {
+ val deserializer = UnresolvedDeserializer(encoder.deserializer)
+ DeserializeToObject(deserializer, generateObjAttrForRow(encoder), child)
+ }
+
def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = {
SerializeFromObject(encoderFor[T].namedExpressions, child)
}
+ def serialize(child: LogicalPlan, encoder: ExpressionEncoder[Row]): SerializeFromObject = {
+ SerializeFromObject(encoder.namedExpressions, child)
+ }
+
def generateObjAttr[T : Encoder]: Attribute = {
AttributeReference("obj", encoderFor[T].deserializer.dataType, nullable = false)()
}
+
+ def generateObjAttrForRow(encoder: ExpressionEncoder[Row]): Attribute = {
+ AttributeReference("obj", encoder.deserializer.dataType, nullable = false)()
+ }
}
/**
@@ -106,6 +120,42 @@ case class MapPartitions(
outputObjAttr: Attribute,
child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer
+object MapPartitionsInR {
+ def apply(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ schema: StructType,
+ encoder: ExpressionEncoder[Row],
+ child: LogicalPlan): LogicalPlan = {
+ val deserialized = CatalystSerde.deserialize(child, encoder)
+ val mapped = MapPartitionsInR(
+ func,
+ packageNames,
+ broadcastVars,
+ encoder.schema,
+ schema,
+ CatalystSerde.generateObjAttrForRow(RowEncoder(schema)),
+ deserialized)
+ CatalystSerde.serialize(mapped, RowEncoder(schema))
+ }
+}
+
+/**
+ * A relation produced by applying a serialized R function `func` to each partition of the `child`.
+ *
+ */
+case class MapPartitionsInR(
+ func: Array[Byte],
+ packageNames: Array[Byte],
+ broadcastVars: Array[Broadcast[Object]],
+ inputSchema: StructType,
+ outputSchema: StructType,
+ outputObjAttr: Attribute,
+ child: LogicalPlan) extends UnaryNode with ObjectConsumer with ObjectProducer {
+ override lazy val schema = outputSchema
+}
+
object MapElements {
def apply[T : Encoder, U : Encoder](
func: AnyRef,