diff options
author | Michael Armbrust <michael@databricks.com> | 2014-06-15 11:28:34 +0200 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-06-15 11:28:55 +0200 |
commit | 868cf421ec59c731da2a02f94d62a22f96b60a01 (patch) | |
tree | 6cd6e37a21ebb722d89991d7a76ab411e9e0220c /sql | |
parent | 05d85c86ecbf75f7bb13efcf24b3af4e9e3ef612 (diff) | |
download | spark-868cf421ec59c731da2a02f94d62a22f96b60a01.tar.gz spark-868cf421ec59c731da2a02f94d62a22f96b60a01.tar.bz2 spark-868cf421ec59c731da2a02f94d62a22f96b60a01.zip |
[SQL] Support transforming TreeNodes with Option children.
Thanks goes to @marmbrus for his implementation.
Author: Michael Armbrust <michael@databricks.com>
Author: Zongheng Yang <zongheng.y@gmail.com>
Closes #1074 from concretevitamin/option-treenode and squashes the following commits:
ef27b85 [Zongheng Yang] Merge pull request #1 from marmbrus/pr/1074
73133c2 [Michael Armbrust] TreeNodes can't be inner classes.
ab78420 [Zongheng Yang] Add a test.
2ccb721 [Michael Armbrust] Add support for transformation of optional children.
(cherry picked from commit 269fc62b20ee5f9cd60a8f133c29f662d17071b1)
Signed-off-by: Michael Armbrust <michael@databricks.com>
Diffstat (limited to 'sql')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala | 19 | ||||
-rw-r--r-- | sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala | 27 |
2 files changed, 45 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 0369129393..cd04bdf02c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -187,6 +187,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformDown(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -231,6 +239,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } else { arg } + case Some(arg: TreeNode[_]) if children contains arg => + val newChild = arg.asInstanceOf[BaseType].transformUp(rule) + if (!(newChild fastEquals arg)) { + changed = true + Some(newChild) + } else { + Some(arg) + } case m: Map[_,_] => m case args: Traversable[_] => args.map { case arg: TreeNode[_] if children contains arg => @@ -273,7 +289,8 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] { } catch { case e: java.lang.IllegalArgumentException => throw new TreeNodeException( - this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName?") + this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? " + + s"Exception message: ${e.getMessage}.") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala index 1ddc41a731..6344874538 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala @@ -22,6 +22,17 @@ import scala.collection.mutable.ArrayBuffer import org.scalatest.FunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.{StringType, NullType} + +case class Dummy(optKey: Option[Expression]) extends Expression { + def children = optKey.toSeq + def references = Set.empty[Attribute] + def nullable = true + def dataType = NullType + override lazy val resolved = true + type EvaluatedType = Any + def eval(input: Row) = null.asInstanceOf[Any] +} class TreeNodeSuite extends FunSuite { test("top node changed") { @@ -75,4 +86,20 @@ class TreeNodeSuite extends FunSuite { assert(expected === actual) } + + test("transform works on nodes with Option children") { + val dummy1 = Dummy(Some(Literal("1", StringType))) + val dummy2 = Dummy(None) + val toZero: PartialFunction[Expression, Expression] = { case Literal(_, _) => Literal(0) } + + var actual = dummy1 transformDown toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy1 transformUp toZero + assert(actual === Dummy(Some(Literal(0)))) + + actual = dummy2 transform toZero + assert(actual === Dummy(None)) + } + } |