diff options
Diffstat (limited to 'sql/catalyst')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 13 | ||||
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala | 17 |
2 files changed, 25 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 75130007b9..e34a478818 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -214,7 +214,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId) Project(objAttr :: Nil, s.child) - case a @ AppendColumns(_, _, _, s: SerializeFromObject) + case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject) if a.deserializer.dataType == s.inputObjAttr.dataType => AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child) @@ -223,7 +223,7 @@ object EliminateSerialization extends Rule[LogicalPlan] { // deserialization in condition, and push it down through `SerializeFromObject`. // e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization, // but `ds.map(...).as[AnotherType].filter(...)` can not be optimized. - case f @ TypedFilter(_, _, s: SerializeFromObject) + case f @ TypedFilter(_, _, _, _, s: SerializeFromObject) if f.deserializer.dataType == s.inputObjAttr.dataType => s.copy(child = f.withObjectProducerChild(s.child)) @@ -1703,9 +1703,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic */ object CombineTypedFilters extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child)) + case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child)) if t1.deserializer.dataType == t2.deserializer.dataType => - TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child) + TypedFilter( + combineFilterFunction(t2.func, t1.func), + t1.argumentClass, + t1.argumentSchema, + t1.deserializer, + child) } private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index e1890edcbb..fefe5a3953 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -155,6 +155,8 @@ object MapElements { val deserialized = CatalystSerde.deserialize[T](child) val mapped = MapElements( func, + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, CatalystSerde.generateObjAttr[U], deserialized) CatalystSerde.serialize[U](mapped) @@ -166,12 +168,19 @@ object MapElements { */ case class MapElements( func: AnyRef, + argumentClass: Class[_], + argumentSchema: StructType, outputObjAttr: Attribute, child: LogicalPlan) extends ObjectConsumer with ObjectProducer object TypedFilter { def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = { - TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child) + TypedFilter( + func, + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, + UnresolvedDeserializer(encoderFor[T].deserializer), + child) } } @@ -186,6 +195,8 @@ object TypedFilter { */ case class TypedFilter( func: AnyRef, + argumentClass: Class[_], + argumentSchema: StructType, deserializer: Expression, child: LogicalPlan) extends UnaryNode { @@ -213,6 +224,8 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], + implicitly[Encoder[T]].clsTag.runtimeClass, + implicitly[Encoder[T]].schema, UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) @@ -228,6 +241,8 @@ object AppendColumns { */ case class AppendColumns( func: Any => Any, + argumentClass: Class[_], + argumentSchema: StructType, deserializer: Expression, serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode { |