aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-12-15 18:29:19 -0800
committerAndrew Or <andrew@databricks.com>2015-12-15 18:29:19 -0800
commita89e8b6122ee5a1517fbcf405b1686619db56696 (patch)
tree1af3ad686d5944f08ed4e384209b05ef484ad039 /sql/catalyst
parentc5b6b398d5e368626e589feede80355fb74c2bd8 (diff)
downloadspark-a89e8b6122ee5a1517fbcf405b1686619db56696.tar.gz
spark-a89e8b6122ee5a1517fbcf405b1686619db56696.tar.bz2
spark-a89e8b6122ee5a1517fbcf405b1686619db56696.zip
[SPARK-10477][SQL] using DSL in ColumnPruningSuite to improve readability
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #8645 from cloud-fan/test.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala7
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala41
2 files changed, 27 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index af594c25c5..e50971173c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -275,13 +275,14 @@ package object dsl {
def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan)
- // TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
- alias: Option[String] = None): LogicalPlan =
- Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
+ alias: Option[String] = None,
+ outputNames: Seq[String] = Nil): LogicalPlan =
+ Generate(generator, join = join, outer = outer, alias,
+ outputNames.map(UnresolvedAttribute(_)), logicalPlan)
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index 4a1e7ceaf3..9bf61ae091 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
import org.apache.spark.sql.catalyst.expressions.Explode
import org.apache.spark.sql.catalyst.plans.PlanTest
-import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
@@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest {
test("Column pruning for Generate when Generate.join = false") {
val input = LocalRelation('a.int, 'b.array(StringType))
- val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze
+ val query = input.generate(Explode('b), join = false).analyze
+
val optimized = Optimize.execute(query)
- val correctAnswer =
- Generate(Explode('b), false, false, None, 's.string :: Nil,
- Project('b.attr :: Nil, input)).analyze
+ val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze
comparePlans(optimized, correctAnswer)
}
@@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest {
val input = LocalRelation('a.int, 'b.int, 'c.array(StringType))
val query =
- Project(Seq('a, 's),
- Generate(Explode('c), true, false, None, 's.string :: Nil,
- input)).analyze
+ input
+ .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
+ .select('a, 'explode)
+ .analyze
+
val optimized = Optimize.execute(query)
val correctAnswer =
- Project(Seq('a, 's),
- Generate(Explode('c), true, false, None, 's.string :: Nil,
- Project(Seq('a, 'c),
- input))).analyze
+ input
+ .select('a, 'c)
+ .generate(Explode('c), join = true, outputNames = "explode" :: Nil)
+ .select('a, 'explode)
+ .analyze
comparePlans(optimized, correctAnswer)
}
@@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest {
val input = LocalRelation('b.array(StringType))
val query =
- Project(('s + 1).as("s+1") :: Nil,
- Generate(Explode('b), true, false, None, 's.string :: Nil,
- input)).analyze
+ input
+ .generate(Explode('b), join = true, outputNames = "explode" :: Nil)
+ .select(('explode + 1).as("result"))
+ .analyze
+
val optimized = Optimize.execute(query)
val correctAnswer =
- Project(('s + 1).as("s+1") :: Nil,
- Generate(Explode('b), false, false, None, 's.string :: Nil,
- input)).analyze
+ input
+ .generate(Explode('b), join = false, outputNames = "explode" :: Nil)
+ .select(('explode + 1).as("result"))
+ .analyze
comparePlans(optimized, correctAnswer)
}