aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala17
2 files changed, 25 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 75130007b9..e34a478818 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -214,7 +214,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
val objAttr = Alias(s.inputObjAttr, s.inputObjAttr.name)(exprId = d.outputObjAttr.exprId)
Project(objAttr :: Nil, s.child)
- case a @ AppendColumns(_, _, _, s: SerializeFromObject)
+ case a @ AppendColumns(_, _, _, _, _, s: SerializeFromObject)
if a.deserializer.dataType == s.inputObjAttr.dataType =>
AppendColumnsWithObject(a.func, s.serializer, a.serializer, s.child)
@@ -223,7 +223,7 @@ object EliminateSerialization extends Rule[LogicalPlan] {
// deserialization in condition, and push it down through `SerializeFromObject`.
// e.g. `ds.map(...).filter(...)` can be optimized by this rule to save extra deserialization,
// but `ds.map(...).as[AnotherType].filter(...)` can not be optimized.
- case f @ TypedFilter(_, _, s: SerializeFromObject)
+ case f @ TypedFilter(_, _, _, _, s: SerializeFromObject)
if f.deserializer.dataType == s.inputObjAttr.dataType =>
s.copy(child = f.withObjectProducerChild(s.child))
@@ -1703,9 +1703,14 @@ case class GetCurrentDatabase(sessionCatalog: SessionCatalog) extends Rule[Logic
*/
object CombineTypedFilters extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case t1 @ TypedFilter(_, _, t2 @ TypedFilter(_, _, child))
+ case t1 @ TypedFilter(_, _, _, _, t2 @ TypedFilter(_, _, _, _, child))
if t1.deserializer.dataType == t2.deserializer.dataType =>
- TypedFilter(combineFilterFunction(t2.func, t1.func), t1.deserializer, child)
+ TypedFilter(
+ combineFilterFunction(t2.func, t1.func),
+ t1.argumentClass,
+ t1.argumentSchema,
+ t1.deserializer,
+ child)
}
private def combineFilterFunction(func1: AnyRef, func2: AnyRef): Any => Boolean = {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
index e1890edcbb..fefe5a3953 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala
@@ -155,6 +155,8 @@ object MapElements {
val deserialized = CatalystSerde.deserialize[T](child)
val mapped = MapElements(
func,
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
CatalystSerde.generateObjAttr[U],
deserialized)
CatalystSerde.serialize[U](mapped)
@@ -166,12 +168,19 @@ object MapElements {
*/
case class MapElements(
func: AnyRef,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
outputObjAttr: Attribute,
child: LogicalPlan) extends ObjectConsumer with ObjectProducer
object TypedFilter {
def apply[T : Encoder](func: AnyRef, child: LogicalPlan): TypedFilter = {
- TypedFilter(func, UnresolvedDeserializer(encoderFor[T].deserializer), child)
+ TypedFilter(
+ func,
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
+ UnresolvedDeserializer(encoderFor[T].deserializer),
+ child)
}
}
@@ -186,6 +195,8 @@ object TypedFilter {
*/
case class TypedFilter(
func: AnyRef,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
deserializer: Expression,
child: LogicalPlan) extends UnaryNode {
@@ -213,6 +224,8 @@ object AppendColumns {
child: LogicalPlan): AppendColumns = {
new AppendColumns(
func.asInstanceOf[Any => Any],
+ implicitly[Encoder[T]].clsTag.runtimeClass,
+ implicitly[Encoder[T]].schema,
UnresolvedDeserializer(encoderFor[T].deserializer),
encoderFor[U].namedExpressions,
child)
@@ -228,6 +241,8 @@ object AppendColumns {
*/
case class AppendColumns(
func: Any => Any,
+ argumentClass: Class[_],
+ argumentSchema: StructType,
deserializer: Expression,
serializer: Seq[NamedExpression],
child: LogicalPlan) extends UnaryNode {