diff options
Diffstat (limited to 'sql')
3 files changed, 147 insertions, 103 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 1af5437647..271ef33090 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -185,8 +185,7 @@ case class Alias(child: Expression, name: String)( override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".") - val aliasName = if (isGenerated) s"$name#${exprId.id}" else s"$name" - s"${child.sql} AS $qualifiersString${quoteIdentifier(aliasName)}" + s"${child.sql} AS $qualifiersString${quoteIdentifier(name)}" } } @@ -302,8 +301,7 @@ case class AttributeReference( override def sql: String = { val qualifiersString = if (qualifiers.isEmpty) "" else qualifiers.map(quoteIdentifier).mkString("", ".", ".") - val attrRefName = if (isGenerated) s"$name#${exprId.id}" else s"$name" - s"$qualifiersString${quoteIdentifier(attrRefName)}" + s"$qualifiersString${quoteIdentifier(name)}" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 760335bba5..3bc8e9a5e0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -24,6 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ @@ -54,8 +55,26 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) + val outputNames = logicalPlan.output.map(_.name) + val qualifiers = logicalPlan.output.flatMap(_.qualifiers).distinct + + // Keep the qualifier information by using it as sub-query name, if there is only one qualifier + // present. + val finalName = if (qualifiers.length == 1) { + qualifiers.head + } else { + SQLBuilder.newSubqueryName + } + + // Canonicalizer will remove all naming information, we should add it back by adding an extra + // Project and alias the outputs. + val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map { + case (attr, name) => Alias(attr.withQualifiers(Nil), name)() + } + val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan)) + try { - val replaced = canonicalizedPlan.transformAllExpressions { + val replaced = finalPlan.transformAllExpressions { case e: SubqueryExpression => SubqueryHolder(new SQLBuilder(e.query, sqlContext).toSQL) case e: NonSQLExpression => @@ -109,23 +128,6 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case Limit(limitExpr, child) => s"${toSQL(child)} LIMIT ${limitExpr.sql}" - case p: Sample if p.isTableSample => - val fraction = math.min(100, math.max(0, (p.upperBound - p.lowerBound) * 100)) - p.child match { - case m: MetastoreRelation => - val aliasName = m.alias.getOrElse("") - build( - s"`${m.databaseName}`.`${m.tableName}`", - "TABLESAMPLE(" + fraction + " PERCENT)", - aliasName) - case s: SubqueryAlias => - val aliasName = if (s.child.isInstanceOf[SubqueryAlias]) s.alias else "" - val plan = if (s.child.isInstanceOf[SubqueryAlias]) s.child else s - build(toSQL(plan), "TABLESAMPLE(" + fraction + " PERCENT)", aliasName) - case _ => - build(toSQL(p.child), "TABLESAMPLE(" + fraction + " PERCENT)") - } - case Filter(condition, child) => val whereOrHaving = child match { case _: Aggregate => "HAVING" @@ -147,18 +149,7 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi case p: Except => build("(" + toSQL(p.left), ") EXCEPT (", toSQL(p.right) + ")") - case p: SubqueryAlias => - p.child match { - // Persisted data source relation - case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => - s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" - // Parentheses is not used for persisted data source relations - // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 - case SubqueryAlias(_, _: LogicalRelation | _: MetastoreRelation) => - build(toSQL(p.child), "AS", p.alias) - case _ => - build("(" + toSQL(p.child) + ")", "AS", p.alias) - } + case p: SubqueryAlias => build("(" + toSQL(p.child) + ")", "AS", p.alias) case p: Join => build( @@ -168,11 +159,12 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi toSQL(p.right), p.condition.map(" ON " + _.sql).getOrElse("")) - case p: MetastoreRelation => - build( - s"${quoteIdentifier(p.databaseName)}.${quoteIdentifier(p.tableName)}", - p.alias.map(a => s" AS ${quoteIdentifier(a)}").getOrElse("") - ) + case SQLTable(database, table, _, sample) => + val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}" + sample.map { case (lowerBound, upperBound) => + val fraction = math.min(100, math.max(0, (upperBound - lowerBound) * 100)) + qualifiedName + " TABLESAMPLE(" + fraction + " PERCENT)" + }.getOrElse(qualifiedName) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => @@ -274,8 +266,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val groupingSetSQL = "GROUPING SETS(" + groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")" - val aggExprs = agg.aggregateExpressions.map { case expr => - expr.transformDown { + val aggExprs = agg.aggregateExpressions.map { case aggExpr => + val originalAggExpr = aggExpr.transformDown { // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back. case ar: AttributeReference if ar == gid => GroupingID(Nil) case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar) @@ -286,6 +278,15 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi val idx = groupByExprs.length - 1 - value.asInstanceOf[Int] groupByExprs.lift(idx).map(Grouping).getOrElse(a) } + + originalAggExpr match { + // Ancestor operators may reference the output of this grouping set, and we use exprId to + // generate a unique name for each attribute, so we should make sure the transformed + // aggregate expression won't change the output, i.e. exprId and alias name should remain + // the same. + case ne: NamedExpression if ne.exprId == aggExpr.exprId => ne + case e => Alias(e, normalizedName(aggExpr))(exprId = aggExpr.exprId) + } } build( @@ -308,6 +309,8 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi ) } + private def normalizedName(n: NamedExpression): String = "gen_attr_" + n.exprId.id + object Canonicalizer extends RuleExecutor[LogicalPlan] { override protected def batches: Seq[Batch] = Seq( Batch("Collapse Project", FixedPoint(100), @@ -316,31 +319,55 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // `Aggregate`s. CollapseProject), Batch("Recover Scoping Info", Once, - // Used to handle other auxiliary `Project`s added by analyzer (e.g. - // `ResolveAggregateFunctions` rule) - AddSubquery, - // Previous rule will add extra sub-queries, this rule is used to re-propagate and update - // the qualifiers bottom up, e.g.: - // - // Sort - // ordering = t1.a - // Project - // projectList = [t1.a, t1.b] - // Subquery gen_subquery - // child ... - // - // will be transformed to: - // - // Sort - // ordering = gen_subquery.a - // Project - // projectList = [gen_subquery.a, gen_subquery.b] - // Subquery gen_subquery - // child ... - UpdateQualifiers + // Remove all sub queries, as we will insert new ones when it's necessary. + EliminateSubqueryAliases, + // A logical plan is allowed to have same-name outputs with different qualifiers(e.g. the + // `Join` operator). However, this kind of plan can't be put under a sub query as we will + // erase and assign a new qualifier to all outputs and make it impossible to distinguish + // same-name outputs. This rule renames all attributes, to guarantee different + // attributes(with different exprId) always have different names. It also removes all + // qualifiers, as attributes have unique names now and we don't need qualifiers to resolve + // ambiguity. + NormalizedAttribute, + // Finds the table relations and wrap them with `SQLTable`s. If there are any `Sample` + // operators on top of a table relation, merge the sample information into `SQLTable` of + // that table relation, as we can only convert table sample to standard SQL string. + ResolveSQLTable, + // Insert sub queries on top of operators that need to appear after FROM clause. + AddSubquery ) ) + object NormalizedAttribute extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions { + case a: AttributeReference => + AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifiers = Nil) + case a: Alias => + Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifiers = Nil) + } + } + + object ResolveSQLTable extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown { + case Sample(lowerBound, upperBound, _, _, ExtractSQLTable(table)) => + aliasColumns(table.withSample(lowerBound, upperBound)) + case ExtractSQLTable(table) => + aliasColumns(table) + } + + /** + * Aliases the table columns to the generated attribute names, as we use exprId to generate + * unique name for each attribute when normalize attributes, and we can't reference table + * columns with their real names. + */ + private def aliasColumns(table: SQLTable): LogicalPlan = { + val aliasedOutput = table.output.map { attr => + Alias(attr, normalizedName(attr))(exprId = attr.exprId) + } + addSubquery(Project(aliasedOutput, table)) + } + } + object AddSubquery extends Rule[LogicalPlan] { override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { // This branch handles aggregate functions within HAVING clauses. For example: @@ -354,55 +381,56 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi // +- Filter ... // +- Aggregate ... // +- MetastoreRelation default, src, None - case plan @ Project(_, Filter(_, _: Aggregate)) => wrapChildWithSubquery(plan) + case p @ Project(_, f @ Filter(_, _: Aggregate)) => p.copy(child = addSubquery(f)) - case w @ Window(_, _, _, Filter(_, _: Aggregate)) => wrapChildWithSubquery(w) + case w @ Window(_, _, _, f @ Filter(_, _: Aggregate)) => w.copy(child = addSubquery(f)) - case plan @ Project(_, - _: SubqueryAlias - | _: Filter - | _: Join - | _: MetastoreRelation - | OneRowRelation - | _: LocalLimit - | _: GlobalLimit - | _: Sample - ) => plan - - case plan: Project => wrapChildWithSubquery(plan) + case p: Project => p.copy(child = addSubqueryIfNeeded(p.child)) // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should // be able to put in the FROM clause, or we wrap it with a subquery. - case w @ Window(_, _, _, - _: SubqueryAlias - | _: Filter - | _: Join - | _: MetastoreRelation - | OneRowRelation - | _: LocalLimit - | _: GlobalLimit - | _: Sample - ) => w - - case w: Window => wrapChildWithSubquery(w) - } + case w: Window => w.copy(child = addSubqueryIfNeeded(w.child)) - private def wrapChildWithSubquery(plan: UnaryNode): LogicalPlan = { - val newChild = SubqueryAlias(SQLBuilder.newSubqueryName, plan.child) - plan.withNewChildren(Seq(newChild)) + case j: Join => j.copy( + left = addSubqueryIfNeeded(j.left), + right = addSubqueryIfNeeded(j.right)) } } - object UpdateQualifiers extends Rule[LogicalPlan] { - override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp { - case plan => - val inputAttributes = plan.children.flatMap(_.output) - plan transformExpressions { - case a: AttributeReference if !plan.producedAttributes.contains(a) => - val qualifier = inputAttributes.find(_ semanticEquals a).map(_.qualifiers) - a.withQualifiers(qualifier.getOrElse(Nil)) - } - } + private def addSubquery(plan: LogicalPlan): SubqueryAlias = { + SubqueryAlias(SQLBuilder.newSubqueryName, plan) + } + + private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match { + case _: SubqueryAlias => plan + case _: Filter => plan + case _: Join => plan + case _: LocalLimit => plan + case _: GlobalLimit => plan + case _: SQLTable => plan + case OneRowRelation => plan + case _ => addSubquery(plan) + } + } + + case class SQLTable( + database: String, + table: String, + output: Seq[Attribute], + sample: Option[(Double, Double)] = None) extends LeafNode { + def withSample(lowerBound: Double, upperBound: Double): SQLTable = + this.copy(sample = Some(lowerBound -> upperBound)) + } + + object ExtractSQLTable { + def unapply(plan: LogicalPlan): Option[SQLTable] = plan match { + case l @ LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => + Some(SQLTable(database, table, l.output.map(_.withQualifiers(Nil)))) + + case m: MetastoreRelation => + Some(SQLTable(m.databaseName, m.tableName, m.output.map(_.withQualifiers(Nil)))) + + case _ => None } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 198652b355..f02ecb48d5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -550,4 +550,22 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { |WINDOW w AS (PARTITION BY key % 5 ORDER BY key) """.stripMargin) } + + test("window with join") { + checkHiveQl( + """ + |SELECT x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key) + |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key + """.stripMargin) + } + + test("join 2 tables and aggregate function in having clause") { + checkHiveQl( + """ + |SELECT COUNT(a.value), b.KEY, a.KEY + |FROM parquet_t1 a, parquet_t1 b + |GROUP BY a.KEY, b.KEY + |HAVING MAX(a.KEY) > 0 + """.stripMargin) + } } |