aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-03-24 12:28:01 -0700
committerMichael Armbrust <michael@databricks.com>2015-03-24 12:28:01 -0700
commit3fa3d121dfec60f9768d3859e8450ee482b2d4e8 (patch)
tree9d1fb6719e07af1a2b23640f09257a2b08205358 /sql
parent26c6ce3d2947df5a294b1ad4a22fae5d31d06c19 (diff)
downloadspark-3fa3d121dfec60f9768d3859e8450ee482b2d4e8.tar.gz
spark-3fa3d121dfec60f9768d3859e8450ee482b2d4e8.tar.bz2
spark-3fa3d121dfec60f9768d3859e8450ee482b2d4e8.zip
[SPARK-6054][SQL] Fix transformations of TreeNodes that hold StructTypes
Due to a recent change that made `StructType` a `Seq` we started inadvertently turning `StructType`s into generic `Traversable` when attempting nested tree transformations. In this PR we explicitly avoid descending into `DataType`s to avoid this bug. Author: Michael Armbrust <michael@databricks.com> Closes #5157 from marmbrus/udfFix and squashes the following commits: 26f7087 [Michael Armbrust] Fix transformations of TreeNodes that hold StructTypes
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala6
3 files changed, 25 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 48191f3119..bd9291e9ba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -85,6 +85,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionDown(e)
case other => other
@@ -117,6 +118,7 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case seq: Traversable[_] => seq.map {
case e: Expression => transformExpressionUp(e)
case other => other
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index f84ffe4e17..0ae9f6b296 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors._
+import org.apache.spark.sql.types.DataType
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
@@ -220,6 +221,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformDown(rule)
@@ -276,6 +278,7 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
Some(arg)
}
case m: Map[_,_] => m
+ case d: DataType => d // Avoid unpacking Structs
case args: Traversable[_] => args.map {
case arg: TreeNode[_] if children contains arg =>
val newChild = arg.asInstanceOf[BaseType].transformUp(rule)
@@ -307,10 +310,15 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
* @param newArgs the new product arguments.
*/
def makeCopy(newArgs: Array[AnyRef]): this.type = attachTree(this, "makeCopy") {
+ val defaultCtor =
+ getClass.getConstructors
+ .find(_.getParameterTypes.size != 0)
+ .headOption
+ .getOrElse(sys.error(s"No valid constructor for $nodeName"))
+
try {
CurrentOrigin.withOrigin(origin) {
// Skip no-arg constructors that are just there for kryo.
- val defaultCtor = getClass.getConstructors.find(_.getParameterTypes.size != 0).head
if (otherCopyArgs.isEmpty) {
defaultCtor.newInstance(newArgs: _*).asInstanceOf[this.type]
} else {
@@ -320,8 +328,14 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
} catch {
case e: java.lang.IllegalArgumentException =>
throw new TreeNodeException(
- this, s"Failed to copy node. Is otherCopyArgs specified correctly for $nodeName? "
- + s"Exception message: ${e.getMessage}.")
+ this,
+ s"""
+ |Failed to copy node.
+ |Is otherCopyArgs specified correctly for $nodeName.
+ |Exception message: ${e.getMessage}
+ |ctor: $defaultCtor?
+ |args: ${newArgs.mkString(", ")}
+ """.stripMargin)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index be105c6e83..d615542ab5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -50,4 +50,10 @@ class UDFSuite extends QueryTest {
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
+
+ test("udf that is transformed") {
+ udf.register("makeStruct", (x: Int, y: Int) => (x, y))
+ // 1 + 1 is constant folded causing a transformation.
+ assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+ }
}