diff options
author | Xiao Li <gatorsmile@gmail.com> | 2017-02-22 17:26:56 -0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2017-02-22 17:26:56 -0800 |
commit | dc005ed53c87216efff50268009217ba26e34a10 (patch) | |
tree | a60d01a64d761433eee31b8e60e07f9a1085b336 /sql/catalyst/src | |
parent | 4661d30b988bf773ab45a15b143efb2908d33743 (diff) | |
download | spark-dc005ed53c87216efff50268009217ba26e34a10.tar.gz spark-dc005ed53c87216efff50268009217ba26e34a10.tar.bz2 spark-dc005ed53c87216efff50268009217ba26e34a10.zip |
[SPARK-19658][SQL] Set NumPartitions of RepartitionByExpression In Parser
### What changes were proposed in this pull request?
Currently, if `NumPartitions` is not set in RepartitionByExpression, we will set it using `spark.sql.shuffle.partitions` during Planner. However, this is not following the general resolution process. This PR is to set it in `Parser` and then `Optimizer` can use the value for plan optimization.
### How was this patch tested?
Added a test case.
Author: Xiao Li <gatorsmile@gmail.com>
Closes #16988 from gatorsmile/resolveRepartition.
Diffstat (limited to 'sql/catalyst/src')
6 files changed, 25 insertions, 21 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3c53132339..c062e4e84b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -373,8 +373,8 @@ package object dsl { def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) - def distribute(exprs: Expression*)(n: Int = -1): LogicalPlan = - RepartitionByExpression(exprs, logicalPlan, numPartitions = if (n < 0) None else Some(n)) + def distribute(exprs: Expression*)(n: Int): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan, numPartitions = n) def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0c13e3e93a..af846a09a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -578,7 +578,7 @@ object CollapseRepartition extends Rule[LogicalPlan] { RepartitionByExpression(exprs, child, numPartitions) // Case 3 case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = Some(numPartitions)) + r.copy(numPartitions = numPartitions) // Case 3 case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => RepartitionByExpression(exprs, child, numPartitions) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 08a6dd136b..926a37b363 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -242,20 +242,20 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Sort(sort.asScala.map(visitSortItem), global = false, query) } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // DISTRIBUTE BY ... - RepartitionByExpression(expressionList(distributeBy), query) + withRepartitionByExpression(ctx, expressionList(distributeBy), query) } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { // SORT BY ... DISTRIBUTE BY ... Sort( sort.asScala.map(visitSortItem), global = false, - RepartitionByExpression(expressionList(distributeBy), query)) + withRepartitionByExpression(ctx, expressionList(distributeBy), query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { // CLUSTER BY ... val expressions = expressionList(clusterBy) Sort( expressions.map(SortOrder(_, Ascending)), global = false, - RepartitionByExpression(expressions, query)) + withRepartitionByExpression(ctx, expressions, query)) } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { // [EMPTY] query @@ -274,6 +274,16 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** + * Create a clause for DISTRIBUTE BY. + */ + protected def withRepartitionByExpression( + ctx: QueryOrganizationContext, + expressions: Seq[Expression], + query: LogicalPlan): LogicalPlan = { + throw new ParseException("DISTRIBUTE BY is not supported", ctx) + } + + /** * Create a logical plan using a query specification. */ override def visitQuerySpecification( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index af57632516..d17d12cd83 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -844,18 +844,13 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) * information about the number of partitions during execution. Used when a specific ordering or * distribution is expected by the consumer of the query result. Use [[Repartition]] for RDD-like * `coalesce` and `repartition`. - * If `numPartitions` is not specified, the number of partitions will be the number set by - * `spark.sql.shuffle.partitions`. */ case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Option[Int] = None) extends UnaryNode { + numPartitions: Int) extends UnaryNode { - numPartitions match { - case Some(n) => require(n > 0, s"Number of partitions ($n) must be positive.") - case None => // Ok - } + require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows override def output: Seq[Attribute] = child.output diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 786e0f49b4..01737e0a17 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -21,11 +21,12 @@ import java.util.TimeZone import org.scalatest.ShouldMatchers -import org.apache.spark.sql.catalyst.{SimpleCatalystConf, TableIdentifier} +import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.{Cross, Inner} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.Cross import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types._ @@ -192,12 +193,13 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { } test("pull out nondeterministic expressions from RepartitionByExpression") { - val plan = RepartitionByExpression(Seq(Rand(33)), testRelation) + val plan = RepartitionByExpression(Seq(Rand(33)), testRelation, numPartitions = 10) val projected = Alias(Rand(33), "_nondeterministic")() val expected = Project(testRelation.output, RepartitionByExpression(Seq(projected.toAttribute), - Project(testRelation.output :+ projected, testRelation))) + Project(testRelation.output :+ projected, testRelation), + numPartitions = 10)) checkAnalysis(plan, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 2c14252426..67d5d2202b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -152,10 +152,7 @@ class PlanParserSuite extends PlanTest { val orderSortDistrClusterClauses = Seq( ("", basePlan), (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)()), - (" distribute by a sort by b", basePlan.distribute('a)().sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b)().sortBy('a.asc, 'b.asc)) + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)) ) orderSortDistrClusterClauses.foreach { |