aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-10 18:22:47 -0700
committerReynold Xin <rxin@databricks.com>2015-06-10 18:22:47 -0700
commit4e42842e82e058d54329bd66185d8a7e77ab335a (patch)
treef2641ba3bedf08241cd71932fc0c86c7dbd61770
parent6a47114bc297f0bce874e425feb1c24a5c26cef0 (diff)
downloadspark-4e42842e82e058d54329bd66185d8a7e77ab335a.tar.gz
spark-4e42842e82e058d54329bd66185d8a7e77ab335a.tar.bz2
spark-4e42842e82e058d54329bd66185d8a7e77ab335a.zip
[SPARK-8164] transformExpressions should support nested expression sequence
Currently we only support `Seq[Expression]`, we should handle cases like `Seq[Seq[Expression]]` so that we can remove the unnecessary `GroupExpression`. Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6706 from cloud-fan/clean and squashes the following commits: 60a1193 [Wenchen Fan] support nested expression sequence and remove GroupExpression
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala22
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala4
6 files changed, 30 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c4f12cfe87..cbd8def4f1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -172,8 +172,8 @@ class Analyzer(
* expressions which equal GroupBy expressions with Literal(null), if those expressions
* are not set for this grouping set (according to the bit mask).
*/
- private[this] def expand(g: GroupingSets): Seq[GroupExpression] = {
- val result = new scala.collection.mutable.ArrayBuffer[GroupExpression]
+ private[this] def expand(g: GroupingSets): Seq[Seq[Expression]] = {
+ val result = new scala.collection.mutable.ArrayBuffer[Seq[Expression]]
g.bitmasks.foreach { bitmask =>
// get the non selected grouping attributes according to the bit mask
@@ -194,7 +194,7 @@ class Analyzer(
Literal.create(bitmask, IntegerType)
})
- result += GroupExpression(substitution)
+ result += substitution
}
result.toSeq
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index a05794f1db..63dd5f9854 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -239,18 +239,6 @@ abstract class UnaryExpression extends Expression with trees.UnaryNode[Expressio
}
}
-// TODO Semantically we probably not need GroupExpression
-// All we need is holding the Seq[Expression], and ONLY used in doing the
-// expressions transformation correctly. Probably will be removed since it's
-// not like a real expressions.
-case class GroupExpression(children: Seq[Expression]) extends Expression {
- self: Product =>
- override def eval(input: Row): Any = throw new UnsupportedOperationException
- override def nullable: Boolean = false
- override def foldable: Boolean = false
- override def dataType: DataType = throw new UnsupportedOperationException
-}
-
/**
* Expressions that require a specific `DataType` as input should implement this trait
* so that the proper type conversions can be performed in the analyzer.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index eff5c61644..2f545bb432 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -81,17 +81,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}
}
- val newArgs = productIterator.map {
+ def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpressionDown(e)
case Some(e: Expression) => Some(transformExpressionDown(e))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
- case seq: Traversable[_] => seq.map {
- case e: Expression => transformExpressionDown(e)
- case other => other
- }
+ case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
- }.toArray
+ }
+
+ val newArgs = productIterator.map(recursiveTransform).toArray
if (changed) makeCopy(newArgs) else this
}
@@ -114,17 +113,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
}
}
- val newArgs = productIterator.map {
+ def recursiveTransform(arg: Any): AnyRef = arg match {
case e: Expression => transformExpressionUp(e)
case Some(e: Expression) => Some(transformExpressionUp(e))
case m: Map[_, _] => m
case d: DataType => d // Avoid unpacking Structs
- case seq: Traversable[_] => seq.map {
- case e: Expression => transformExpressionUp(e)
- case other => other
- }
+ case seq: Traversable[_] => seq.map(recursiveTransform)
case other: AnyRef => other
- }.toArray
+ }
+
+ val newArgs = productIterator.map(recursiveTransform).toArray
if (changed) makeCopy(newArgs) else this
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index e77e5c27b6..963c782091 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -226,7 +226,7 @@ case class Window(
* @param child Child operator
*/
case class Expand(
- projections: Seq[GroupExpression],
+ projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def statistics: Statistics = {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 67db3d5e6d..8ec79c3d4d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -31,6 +31,11 @@ case class Dummy(optKey: Option[Expression]) extends Expression {
override def eval(input: Row): Any = null.asInstanceOf[Any]
}
+case class ComplexPlan(exprs: Seq[Seq[Expression]])
+ extends org.apache.spark.sql.catalyst.plans.logical.LeafNode {
+ override def output: Seq[Attribute] = Nil
+}
+
class TreeNodeSuite extends SparkFunSuite {
test("top node changed") {
val after = Literal(1) transform { case Literal(1, _) => Literal(2) }
@@ -220,4 +225,13 @@ class TreeNodeSuite extends SparkFunSuite {
assert(expected === actual)
}
}
+
+ test("transformExpressions on nested expression sequence") {
+ val plan = ComplexPlan(Seq(Seq(Literal(1)), Seq(Literal(2))))
+ val actual = plan.transformExpressions {
+ case Literal(value, _) => Literal(value.toString)
+ }
+ val expected = ComplexPlan(Seq(Seq(Literal("1")), Seq(Literal("2"))))
+ assert(expected === actual)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
index f16ca36909..4b601c1192 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Expand.scala
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{UnknownPartitioning, Partit
*/
@DeveloperApi
case class Expand(
- projections: Seq[GroupExpression],
+ projections: Seq[Seq[Expression]],
output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
@@ -49,7 +49,7 @@ case class Expand(
// workers via closure. However we can't assume the Projection
// is serializable because of the code gen, so we have to
// create the projections within each of the partition processing.
- val groups = projections.map(ee => newProjection(ee.children, child.output)).toArray
+ val groups = projections.map(ee => newProjection(ee, child.output)).toArray
new Iterator[Row] {
private[this] var result: Row = _