From a89e8b6122ee5a1517fbcf405b1686619db56696 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 15 Dec 2015 18:29:19 -0800 Subject: [SPARK-10477][SQL] using DSL in ColumnPruningSuite to improve readability Author: Wenchen Fan Closes #8645 from cloud-fan/test. --- .../apache/spark/sql/catalyst/dsl/package.scala | 7 ++-- .../catalyst/optimizer/ColumnPruningSuite.scala | 41 ++++++++++++---------- 2 files changed, 27 insertions(+), 21 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index af594c25c5..e50971173c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -275,13 +275,14 @@ package object dsl { def unionAll(otherPlan: LogicalPlan): LogicalPlan = Union(logicalPlan, otherPlan) - // TODO specify the output column names def generate( generator: Generator, join: Boolean = false, outer: Boolean = false, - alias: Option[String] = None): LogicalPlan = - Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan) + alias: Option[String] = None, + outputNames: Seq[String] = Nil): LogicalPlan = + Generate(generator, join = join, outer = outer, alias, + outputNames.map(UnresolvedAttribute(_)), logicalPlan) def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan = InsertIntoTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 4a1e7ceaf3..9bf61ae091 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.expressions.Explode import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, LocalRelation, Generate, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ @@ -35,12 +35,11 @@ class ColumnPruningSuite extends PlanTest { test("Column pruning for Generate when Generate.join = false") { val input = LocalRelation('a.int, 'b.array(StringType)) - val query = Generate(Explode('b), false, false, None, 's.string :: Nil, input).analyze + val query = input.generate(Explode('b), join = false).analyze + val optimized = Optimize.execute(query) - val correctAnswer = - Generate(Explode('b), false, false, None, 's.string :: Nil, - Project('b.attr :: Nil, input)).analyze + val correctAnswer = input.select('b).generate(Explode('b), join = false).analyze comparePlans(optimized, correctAnswer) } @@ -49,16 +48,19 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('a.int, 'b.int, 'c.array(StringType)) val query = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(Seq('a, 's), - Generate(Explode('c), true, false, None, 's.string :: Nil, - Project(Seq('a, 'c), - input))).analyze + input + .select('a, 'c) + .generate(Explode('c), join = true, outputNames = "explode" :: Nil) + .select('a, 'explode) + .analyze comparePlans(optimized, correctAnswer) } @@ -67,15 +69,18 @@ class ColumnPruningSuite extends PlanTest { val input = LocalRelation('b.array(StringType)) val query = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), true, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = true, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze + val optimized = Optimize.execute(query) val correctAnswer = - Project(('s + 1).as("s+1") :: Nil, - Generate(Explode('b), false, false, None, 's.string :: Nil, - input)).analyze + input + .generate(Explode('b), join = false, outputNames = "explode" :: Nil) + .select(('explode + 1).as("result")) + .analyze comparePlans(optimized, correctAnswer) } -- cgit v1.2.3