aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala82
1 files changed, 43 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 072445af4f..8bce404735 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -315,25 +315,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
protected def transformChildren(
rule: PartialFunction[BaseType, BaseType],
nextOperation: (BaseType, PartialFunction[BaseType, BaseType]) => BaseType): BaseType = {
- var changed = false
- val newArgs = mapProductIterator {
- case arg: TreeNode[_] if containsChild(arg) =>
- val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- newChild
- } else {
- arg
- }
- case Some(arg: TreeNode[_]) if containsChild(arg) =>
- val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
- if (!(newChild fastEquals arg)) {
- changed = true
- Some(newChild)
- } else {
- Some(arg)
- }
- case m: Map[_, _] => m.mapValues {
+ if (children.nonEmpty) {
+ var changed = false
+ val newArgs = mapProductIterator {
case arg: TreeNode[_] if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
@@ -342,33 +326,53 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
} else {
arg
}
- case other => other
- }.view.force // `mapValues` is lazy and we need to force it to materialize
- case d: DataType => d // Avoid unpacking Structs
- case args: Traversable[_] => args.map {
- case arg: TreeNode[_] if containsChild(arg) =>
+ case Some(arg: TreeNode[_]) if containsChild(arg) =>
val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
if (!(newChild fastEquals arg)) {
changed = true
- newChild
+ Some(newChild)
} else {
- arg
+ Some(arg)
}
- case tuple @ (arg1: TreeNode[_], arg2: TreeNode[_]) =>
- val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
- val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
- if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
- changed = true
- (newChild1, newChild2)
- } else {
- tuple
- }
- case other => other
+ case m: Map[_, _] => m.mapValues {
+ case arg: TreeNode[_] if containsChild(arg) =>
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case other => other
+ }.view.force // `mapValues` is lazy and we need to force it to materialize
+ case d: DataType => d // Avoid unpacking Structs
+ case args: Traversable[_] => args.map {
+ case arg: TreeNode[_] if containsChild(arg) =>
+ val newChild = nextOperation(arg.asInstanceOf[BaseType], rule)
+ if (!(newChild fastEquals arg)) {
+ changed = true
+ newChild
+ } else {
+ arg
+ }
+ case tuple@(arg1: TreeNode[_], arg2: TreeNode[_]) =>
+ val newChild1 = nextOperation(arg1.asInstanceOf[BaseType], rule)
+ val newChild2 = nextOperation(arg2.asInstanceOf[BaseType], rule)
+ if (!(newChild1 fastEquals arg1) || !(newChild2 fastEquals arg2)) {
+ changed = true
+ (newChild1, newChild2)
+ } else {
+ tuple
+ }
+ case other => other
+ }
+ case nonChild: AnyRef => nonChild
+ case null => null
}
- case nonChild: AnyRef => nonChild
- case null => null
+ if (changed) makeCopy(newArgs) else this
+ } else {
+ this
}
- if (changed) makeCopy(newArgs) else this
}
/**