aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala623
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala290
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala1201
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala29
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala54
6 files changed, 3 insertions, 2196 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index 2a3d3a173c..b4a7c05ee0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -154,7 +154,7 @@ trait CheckAnalysis extends PredicateHelper {
}
}
- // Skip subquery aliases added by the Analyzer and the SQLBuilder.
+ // Skip subquery aliases added by the Analyzer.
// For projects, do the necessary mapping and skip to its child.
def cleanQuery(p: LogicalPlan): LogicalPlan = p match {
case s: SubqueryAlias => cleanQuery(s.child)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
deleted file mode 100644
index d5a8566d07..0000000000
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ /dev/null
@@ -1,623 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import java.util.concurrent.atomic.AtomicLong
-
-import scala.collection.mutable.Map
-import scala.util.control.NonFatal
-
-import org.apache.spark.internal.Logging
-import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.catalyst.catalog.CatalogRelation
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.optimizer.{CollapseProject, CombineUnions}
-import org.apache.spark.sql.catalyst.plans.logical._
-import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor}
-import org.apache.spark.sql.catalyst.util.quoteIdentifier
-import org.apache.spark.sql.execution.datasources.LogicalRelation
-import org.apache.spark.sql.types.{ByteType, DataType, IntegerType, NullType}
-
-/**
- * A builder class used to convert a resolved logical plan into a SQL query string. Note that not
- * all resolved logical plan are convertible. They either don't have corresponding SQL
- * representations (e.g. logical plans that operate on local Scala collections), or are simply not
- * supported by this builder (yet).
- */
-class SQLBuilder private (
- logicalPlan: LogicalPlan,
- nextSubqueryId: AtomicLong,
- nextGenAttrId: AtomicLong,
- exprIdMap: Map[Long, Long]) extends Logging {
- require(logicalPlan.resolved,
- "SQLBuilder only supports resolved logical query plans. Current plan:\n" + logicalPlan)
-
- def this(logicalPlan: LogicalPlan) =
- this(logicalPlan, new AtomicLong(0), new AtomicLong(0), Map.empty[Long, Long])
-
- def this(df: Dataset[_]) = this(df.queryExecution.analyzed)
-
- private def newSubqueryName(): String = s"gen_subquery_${nextSubqueryId.getAndIncrement()}"
- private def normalizedName(n: NamedExpression): String = synchronized {
- "gen_attr_" + exprIdMap.getOrElseUpdate(n.exprId.id, nextGenAttrId.getAndIncrement())
- }
-
- def toSQL: String = {
- val canonicalizedPlan = Canonicalizer.execute(logicalPlan)
- val outputNames = logicalPlan.output.map(_.name)
- val qualifiers = logicalPlan.output.flatMap(_.qualifier).distinct
-
- // Keep the qualifier information by using it as sub-query name, if there is only one qualifier
- // present.
- val finalName = if (qualifiers.length == 1) {
- qualifiers.head
- } else {
- newSubqueryName()
- }
-
- // Canonicalizer will remove all naming information, we should add it back by adding an extra
- // Project and alias the outputs.
- val aliasedOutput = canonicalizedPlan.output.zip(outputNames).map {
- case (attr, name) => Alias(attr.withQualifier(None), name)()
- }
- val finalPlan = Project(aliasedOutput, SubqueryAlias(finalName, canonicalizedPlan, None))
-
- try {
- val replaced = finalPlan.transformAllExpressions {
- case s: SubqueryExpression =>
- val query = new SQLBuilder(s.plan, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
- val sql = s match {
- case _: ListQuery => query
- case _: Exists => s"EXISTS($query)"
- case _ => s"($query)"
- }
- SubqueryHolder(sql)
- case e: NonSQLExpression =>
- throw new UnsupportedOperationException(
- s"Expression $e doesn't have a SQL representation"
- )
- case e => e
- }
-
- val generatedSQL = toSQL(replaced)
- logDebug(
- s"""Built SQL query string successfully from given logical plan:
- |
- |# Original logical plan:
- |${logicalPlan.treeString}
- |# Canonicalized logical plan:
- |${replaced.treeString}
- |# Generated SQL:
- |$generatedSQL
- """.stripMargin)
- generatedSQL
- } catch { case NonFatal(e) =>
- logDebug(
- s"""Failed to build SQL query string from given logical plan:
- |
- |# Original logical plan:
- |${logicalPlan.treeString}
- |# Canonicalized logical plan:
- |${canonicalizedPlan.treeString}
- """.stripMargin)
- throw e
- }
- }
-
- private def toSQL(node: LogicalPlan): String = node match {
- case Distinct(p: Project) =>
- projectToSQL(p, isDistinct = true)
-
- case p: Project =>
- projectToSQL(p, isDistinct = false)
-
- case a @ Aggregate(_, _, e @ Expand(_, _, p: Project)) if isGroupingSet(a, e, p) =>
- groupingSetToSQL(a, e, p)
-
- case p: Aggregate =>
- aggregateToSQL(p)
-
- case w: Window =>
- windowToSQL(w)
-
- case g: Generate =>
- generateToSQL(g)
-
- // This prevents a pattern of `((...) AS gen_subquery_0 LIMIT 1)` which does not work.
- // For example, `SELECT * FROM (SELECT id FROM tbl TABLESAMPLE (2 ROWS))` makes this plan.
- case Limit(limitExpr, child: SubqueryAlias) =>
- s"${toSQL(child)} LIMIT ${limitExpr.sql}"
-
- case Limit(limitExpr, child) =>
- s"(${toSQL(child)} LIMIT ${limitExpr.sql})"
-
- case Filter(condition, child) =>
- val whereOrHaving = child match {
- case _: Aggregate => "HAVING"
- case _ => "WHERE"
- }
- build(toSQL(child), whereOrHaving, condition.sql)
-
- case p @ Distinct(u: Union) if u.children.length > 1 =>
- val childrenSql = u.children.map(c => s"(${toSQL(c)})")
- childrenSql.mkString(" UNION DISTINCT ")
-
- case p: Union if p.children.length > 1 =>
- val childrenSql = p.children.map(c => s"(${toSQL(c)})")
- childrenSql.mkString(" UNION ALL ")
-
- case p: Intersect =>
- build("(" + toSQL(p.left), ") INTERSECT (", toSQL(p.right) + ")")
-
- case p: Except =>
- build("(" + toSQL(p.left), ") EXCEPT (", toSQL(p.right) + ")")
-
- case p: SubqueryAlias => build("(" + toSQL(p.child) + ")", "AS", p.alias)
-
- case p: Join =>
- build(
- toSQL(p.left),
- p.joinType.sql,
- "JOIN",
- toSQL(p.right),
- p.condition.map(" ON " + _.sql).getOrElse(""))
-
- case SQLTable(database, table, _, sample) =>
- val qualifiedName = s"${quoteIdentifier(database)}.${quoteIdentifier(table)}"
- sample.map { case (lowerBound, upperBound) =>
- val fraction = math.min(100, math.max(0, (upperBound - lowerBound) * 100))
- qualifiedName + " TABLESAMPLE(" + fraction + " PERCENT)"
- }.getOrElse(qualifiedName)
-
- case relation: CatalogRelation =>
- val m = relation.catalogTable
- val qualifiedName = s"${quoteIdentifier(m.database)}.${quoteIdentifier(m.identifier.table)}"
- qualifiedName
-
- case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _))
- if orders.map(_.child) == partitionExprs =>
- build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", "))
-
- case p: Sort =>
- build(
- toSQL(p.child),
- if (p.global) "ORDER BY" else "SORT BY",
- p.order.map(_.sql).mkString(", ")
- )
-
- case p: RepartitionByExpression =>
- build(
- toSQL(p.child),
- "DISTRIBUTE BY",
- p.partitionExpressions.map(_.sql).mkString(", ")
- )
-
- case p: ScriptTransformation =>
- scriptTransformationToSQL(p)
-
- case p: LocalRelation =>
- p.toSQL(newSubqueryName())
-
- case p: Range =>
- p.toSQL()
-
- case OneRowRelation =>
- ""
-
- case p: View =>
- toSQL(p.child)
-
- case _ =>
- throw new UnsupportedOperationException(s"unsupported plan $node")
- }
-
- /**
- * Turns a bunch of string segments into a single string and separate each segment by a space.
- * The segments are trimmed so only a single space appears in the separation.
- * For example, `build("a", " b ", " c")` becomes "a b c".
- */
- private def build(segments: String*): String =
- segments.map(_.trim).filter(_.nonEmpty).mkString(" ")
-
- private def projectToSQL(plan: Project, isDistinct: Boolean): String = {
- build(
- "SELECT",
- if (isDistinct) "DISTINCT" else "",
- plan.projectList.map(_.sql).mkString(", "),
- if (plan.child == OneRowRelation) "" else "FROM",
- toSQL(plan.child)
- )
- }
-
- private def scriptTransformationToSQL(plan: ScriptTransformation): String = {
- val inputRowFormatSQL = plan.ioschema.inputRowFormatSQL.getOrElse(
- throw new UnsupportedOperationException(
- s"unsupported row format ${plan.ioschema.inputRowFormat}"))
- val outputRowFormatSQL = plan.ioschema.outputRowFormatSQL.getOrElse(
- throw new UnsupportedOperationException(
- s"unsupported row format ${plan.ioschema.outputRowFormat}"))
-
- val outputSchema = plan.output.map { attr =>
- s"${attr.sql} ${attr.dataType.simpleString}"
- }.mkString(", ")
-
- build(
- "SELECT TRANSFORM",
- "(" + plan.input.map(_.sql).mkString(", ") + ")",
- inputRowFormatSQL,
- s"USING \'${plan.script}\'",
- "AS (" + outputSchema + ")",
- outputRowFormatSQL,
- if (plan.child == OneRowRelation) "" else "FROM",
- toSQL(plan.child)
- )
- }
-
- private def aggregateToSQL(plan: Aggregate): String = {
- val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ")
- build(
- "SELECT",
- plan.aggregateExpressions.map(_.sql).mkString(", "),
- if (plan.child == OneRowRelation) "" else "FROM",
- toSQL(plan.child),
- if (groupingSQL.isEmpty) "" else "GROUP BY",
- groupingSQL
- )
- }
-
- private def generateToSQL(g: Generate): String = {
- val columnAliases = g.generatorOutput.map(_.sql).mkString(", ")
-
- val childSQL = if (g.child == OneRowRelation) {
- // This only happens when we put UDTF in project list and there is no FROM clause. Because we
- // always generate LATERAL VIEW for `Generate`, here we use a trick to put a dummy sub-query
- // after FROM clause, so that we can generate a valid LATERAL VIEW SQL string.
- // For example, if the original SQL is: "SELECT EXPLODE(ARRAY(1, 2))", we will convert in to
- // LATERAL VIEW format, and generate:
- // SELECT col FROM (SELECT 1) sub_q0 LATERAL VIEW EXPLODE(ARRAY(1, 2)) sub_q1 AS col
- s"(SELECT 1) ${newSubqueryName()}"
- } else {
- toSQL(g.child)
- }
-
- // The final SQL string for Generate contains 7 parts:
- // 1. the SQL of child, can be a table or sub-query
- // 2. the LATERAL VIEW keyword
- // 3. an optional OUTER keyword
- // 4. the SQL of generator, e.g. EXPLODE(array_col)
- // 5. the table alias for output columns of generator.
- // 6. the AS keyword
- // 7. the column alias, can be more than one, e.g. AS key, value
- // A concrete example: "tbl LATERAL VIEW EXPLODE(map_col) sub_q AS key, value", and the builder
- // will put it in FROM clause later.
- build(
- childSQL,
- "LATERAL VIEW",
- if (g.outer) "OUTER" else "",
- g.generator.sql,
- newSubqueryName(),
- "AS",
- columnAliases
- )
- }
-
- private def sameOutput(output1: Seq[Attribute], output2: Seq[Attribute]): Boolean =
- output1.size == output2.size &&
- output1.zip(output2).forall(pair => pair._1.semanticEquals(pair._2))
-
- private def isGroupingSet(a: Aggregate, e: Expand, p: Project): Boolean = {
- assert(a.child == e && e.child == p)
- a.groupingExpressions.forall(_.isInstanceOf[Attribute]) && sameOutput(
- e.output.drop(p.child.output.length),
- a.groupingExpressions.map(_.asInstanceOf[Attribute]))
- }
-
- private def groupingSetToSQL(agg: Aggregate, expand: Expand, project: Project): String = {
- assert(agg.groupingExpressions.length > 1)
-
- // The last column of Expand is always grouping ID
- val gid = expand.output.last
-
- val numOriginalOutput = project.child.output.length
- // Assumption: Aggregate's groupingExpressions is composed of
- // 1) the grouping attributes
- // 2) gid, which is always the last one
- val groupByAttributes = agg.groupingExpressions.dropRight(1).map(_.asInstanceOf[Attribute])
- // Assumption: Project's projectList is composed of
- // 1) the original output (Project's child.output),
- // 2) the aliased group by expressions.
- val expandedAttributes = project.output.drop(numOriginalOutput)
- val groupByExprs = project.projectList.drop(numOriginalOutput).map(_.asInstanceOf[Alias].child)
- val groupingSQL = groupByExprs.map(_.sql).mkString(", ")
-
- // a map from group by attributes to the original group by expressions.
- val groupByAttrMap = AttributeMap(groupByAttributes.zip(groupByExprs))
- // a map from expanded attributes to the original group by expressions.
- val expandedAttrMap = AttributeMap(expandedAttributes.zip(groupByExprs))
-
- val groupingSet: Seq[Seq[Expression]] = expand.projections.map { project =>
- // Assumption: expand.projections is composed of
- // 1) the original output (Project's child.output),
- // 2) expanded attributes(or null literal)
- // 3) gid, which is always the last one in each project in Expand
- project.drop(numOriginalOutput).dropRight(1).collect {
- case attr: Attribute if expandedAttrMap.contains(attr) => expandedAttrMap(attr)
- }
- }
- val groupingSetSQL = "GROUPING SETS(" +
- groupingSet.map(e => s"(${e.map(_.sql).mkString(", ")})").mkString(", ") + ")"
-
- val aggExprs = agg.aggregateExpressions.map { case aggExpr =>
- val originalAggExpr = aggExpr.transformDown {
- // grouping_id() is converted to VirtualColumn.groupingIdName by Analyzer. Revert it back.
- case ar: AttributeReference if ar == gid => GroupingID(Nil)
- case ar: AttributeReference if groupByAttrMap.contains(ar) => groupByAttrMap(ar)
- case a @ Cast(BitwiseAnd(
- ShiftRight(ar: AttributeReference, Literal(value: Any, IntegerType)),
- Literal(1, IntegerType)), ByteType, _) if ar == gid =>
- // for converting an expression to its original SQL format grouping(col)
- val idx = groupByExprs.length - 1 - value.asInstanceOf[Int]
- groupByExprs.lift(idx).map(Grouping).getOrElse(a)
- }
-
- originalAggExpr match {
- // Ancestor operators may reference the output of this grouping set, and we use exprId to
- // generate a unique name for each attribute, so we should make sure the transformed
- // aggregate expression won't change the output, i.e. exprId and alias name should remain
- // the same.
- case ne: NamedExpression if ne.exprId == aggExpr.exprId => ne
- case e => Alias(e, normalizedName(aggExpr))(exprId = aggExpr.exprId)
- }
- }
-
- build(
- "SELECT",
- aggExprs.map(_.sql).mkString(", "),
- if (agg.child == OneRowRelation) "" else "FROM",
- toSQL(project.child),
- "GROUP BY",
- groupingSQL,
- groupingSetSQL
- )
- }
-
- private def windowToSQL(w: Window): String = {
- build(
- "SELECT",
- (w.child.output ++ w.windowExpressions).map(_.sql).mkString(", "),
- if (w.child == OneRowRelation) "" else "FROM",
- toSQL(w.child)
- )
- }
-
- object Canonicalizer extends RuleExecutor[LogicalPlan] {
- override protected def batches: Seq[Batch] = Seq(
- Batch("Prepare", FixedPoint(100),
- // The `WidenSetOperationTypes` analysis rule may introduce extra `Project`s over
- // `Aggregate`s to perform type casting. This rule merges these `Project`s into
- // `Aggregate`s.
- CollapseProject,
- // Parser is unable to parse the following query:
- // SELECT `u_1`.`id`
- // FROM (((SELECT `t0`.`id` FROM `default`.`t0`)
- // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`))
- // UNION ALL (SELECT `t0`.`id` FROM `default`.`t0`)) AS u_1
- // This rule combine adjacent Unions together so we can generate flat UNION ALL SQL string.
- CombineUnions),
- Batch("Recover Scoping Info", Once,
- // A logical plan is allowed to have same-name outputs with different qualifiers(e.g. the
- // `Join` operator). However, this kind of plan can't be put under a sub query as we will
- // erase and assign a new qualifier to all outputs and make it impossible to distinguish
- // same-name outputs. This rule renames all attributes, to guarantee different
- // attributes(with different exprId) always have different names. It also removes all
- // qualifiers, as attributes have unique names now and we don't need qualifiers to resolve
- // ambiguity.
- NormalizedAttribute,
- // Our analyzer will add one or more sub-queries above table relation, this rule removes
- // these sub-queries so that next rule can combine adjacent table relation and sample to
- // SQLTable.
- RemoveSubqueriesAboveSQLTable,
- // Finds the table relations and wrap them with `SQLTable`s. If there are any `Sample`
- // operators on top of a table relation, merge the sample information into `SQLTable` of
- // that table relation, as we can only convert table sample to standard SQL string.
- ResolveSQLTable,
- // Insert sub queries on top of operators that need to appear after FROM clause.
- AddSubquery,
- // Reconstruct subquery expressions.
- ConstructSubqueryExpressions
- )
- )
-
- object NormalizedAttribute extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.transformAllExpressions {
- case a: AttributeReference =>
- AttributeReference(normalizedName(a), a.dataType)(exprId = a.exprId, qualifier = None)
- case a: Alias =>
- Alias(a.child, normalizedName(a))(exprId = a.exprId, qualifier = None)
- }
- }
-
- object RemoveSubqueriesAboveSQLTable extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
- case SubqueryAlias(_, t @ ExtractSQLTable(_), _) => t
- }
- }
-
- object ResolveSQLTable extends Rule[LogicalPlan] {
- override def apply(plan: LogicalPlan): LogicalPlan = plan.transformDown {
- case Sample(lowerBound, upperBound, _, _, ExtractSQLTable(table)) =>
- aliasColumns(table.withSample(lowerBound, upperBound))
- case ExtractSQLTable(table) =>
- aliasColumns(table)
- }
-
- /**
- * Aliases the table columns to the generated attribute names, as we use exprId to generate
- * unique name for each attribute when normalize attributes, and we can't reference table
- * columns with their real names.
- */
- private def aliasColumns(table: SQLTable): LogicalPlan = {
- val aliasedOutput = table.output.map { attr =>
- Alias(attr, normalizedName(attr))(exprId = attr.exprId)
- }
- addSubquery(Project(aliasedOutput, table))
- }
- }
-
- object AddSubquery extends Rule[LogicalPlan] {
- override def apply(tree: LogicalPlan): LogicalPlan = tree transformUp {
- // This branch handles aggregate functions within HAVING clauses. For example:
- //
- // SELECT key FROM src GROUP BY key HAVING max(value) > "val_255"
- //
- // This kind of query results in query plans of the following form because of analysis rule
- // `ResolveAggregateFunctions`:
- //
- // Project ...
- // +- Filter ...
- // +- Aggregate ...
- // +- MetastoreRelation default, src, None
- case p @ Project(_, f @ Filter(_, _: Aggregate)) => p.copy(child = addSubquery(f))
-
- case w @ Window(_, _, _, f @ Filter(_, _: Aggregate)) => w.copy(child = addSubquery(f))
-
- case p: Project => p.copy(child = addSubqueryIfNeeded(p.child))
-
- // We will generate "SELECT ... FROM ..." for Window operator, so its child operator should
- // be able to put in the FROM clause, or we wrap it with a subquery.
- case w: Window => w.copy(child = addSubqueryIfNeeded(w.child))
-
- case j: Join => j.copy(
- left = addSubqueryIfNeeded(j.left),
- right = addSubqueryIfNeeded(j.right))
-
- // A special case for Generate. When we put UDTF in project list, followed by WHERE, e.g.
- // SELECT EXPLODE(arr) FROM tbl WHERE id > 1, the Filter operator will be under Generate
- // operator and we need to add a sub-query between them, as it's not allowed to have a WHERE
- // before LATERAL VIEW, e.g. "... FROM tbl WHERE id > 2 EXPLODE(arr) ..." is illegal.
- case g @ Generate(_, _, _, _, _, f: Filter) =>
- // Add an extra `Project` to make sure we can generate legal SQL string for sub-query,
- // for example, Subquery -> Filter -> Table will generate "(tbl WHERE ...) AS name", which
- // misses the SELECT part.
- val proj = Project(f.output, f)
- g.copy(child = addSubquery(proj))
- }
- }
-
- object ConstructSubqueryExpressions extends Rule[LogicalPlan] {
- def apply(tree: LogicalPlan): LogicalPlan = tree transformAllExpressions {
- case ScalarSubquery(query, conditions, exprId) if conditions.nonEmpty =>
- def rewriteAggregate(a: Aggregate): Aggregate = {
- val filter = Filter(conditions.reduce(And), addSubqueryIfNeeded(a.child))
- Aggregate(Nil, a.aggregateExpressions.take(1), filter)
- }
- val cleaned = query match {
- case Project(_, child) => child
- case child => child
- }
- val rewrite = cleaned match {
- case a: Aggregate =>
- rewriteAggregate(a)
- case Filter(c, a: Aggregate) =>
- Filter(c, rewriteAggregate(a))
- }
- ScalarSubquery(rewrite, Seq.empty, exprId)
-
- case PredicateSubquery(query, conditions, false, exprId) =>
- val subquery = addSubqueryIfNeeded(query)
- val plan = if (conditions.isEmpty) {
- subquery
- } else {
- Project(Seq(Alias(Literal(1), "1")()),
- Filter(conditions.reduce(And), subquery))
- }
- Exists(plan, exprId)
-
- case PredicateSubquery(query, conditions, true, exprId) =>
- val (in, correlated) = conditions.partition(_.isInstanceOf[EqualTo])
- val (outer, inner) = in.zipWithIndex.map {
- case (EqualTo(l, r), i) if query.outputSet.intersect(r.references).nonEmpty =>
- (l, Alias(r, s"_c$i")())
- case (EqualTo(r, l), i) =>
- (l, Alias(r, s"_c$i")())
- }.unzip
- val wrapped = addSubqueryIfNeeded(query)
- val filtered = if (correlated.nonEmpty) {
- Filter(conditions.reduce(And), wrapped)
- } else {
- wrapped
- }
- val value = outer match {
- case Seq(expr) => expr
- case exprs => CreateStruct(exprs)
- }
- In(value, Seq(ListQuery(Project(inner, filtered), exprId)))
- }
- }
-
- private def addSubquery(plan: LogicalPlan): SubqueryAlias = {
- SubqueryAlias(newSubqueryName(), plan, None)
- }
-
- private def addSubqueryIfNeeded(plan: LogicalPlan): LogicalPlan = plan match {
- case _: SubqueryAlias => plan
- case _: Filter => plan
- case _: Join => plan
- case _: LocalLimit => plan
- case _: GlobalLimit => plan
- case _: SQLTable => plan
- case _: Generate => plan
- case OneRowRelation => plan
- case _ => addSubquery(plan)
- }
- }
-
- case class SQLTable(
- database: String,
- table: String,
- output: Seq[Attribute],
- sample: Option[(Double, Double)] = None) extends LeafNode {
- def withSample(lowerBound: Double, upperBound: Double): SQLTable =
- this.copy(sample = Some(lowerBound -> upperBound))
- }
-
- object ExtractSQLTable {
- def unapply(plan: LogicalPlan): Option[SQLTable] = plan match {
- case l @ LogicalRelation(_, _, Some(catalogTable))
- if catalogTable.identifier.database.isDefined =>
- Some(SQLTable(
- catalogTable.identifier.database.get,
- catalogTable.identifier.table,
- l.output.map(_.withQualifier(None))))
-
- case relation: CatalogRelation =>
- val m = relation.catalogTable
- Some(SQLTable(m.database, m.identifier.table, relation.output.map(_.withQualifier(None))))
-
- case _ => None
- }
- }
-
- /**
- * A place holder for generated SQL for subquery expression.
- */
- case class SubqueryHolder(override val sql: String) extends LeafExpression with Unevaluable {
- override def dataType: DataType = NullType
- override def nullable: Boolean = true
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
deleted file mode 100644
index 1daa6f7822..0000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/ExpressionToSQLSuite.scala
+++ /dev/null
@@ -1,290 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import scala.util.control.NonFatal
-
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SQLTestUtils
-
-class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
- import testImplicits._
-
- protected override def beforeAll(): Unit = {
- super.beforeAll()
- sql("DROP TABLE IF EXISTS t0")
- sql("DROP TABLE IF EXISTS t1")
- sql("DROP TABLE IF EXISTS t2")
-
- val bytes = Array[Byte](1, 2, 3, 4)
- Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0")
-
- spark
- .range(10)
- .select('id as 'key, concat(lit("val_"), 'id) as 'value)
- .write
- .saveAsTable("t1")
-
- spark.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
- }
-
- override protected def afterAll(): Unit = {
- try {
- sql("DROP TABLE IF EXISTS t0")
- sql("DROP TABLE IF EXISTS t1")
- sql("DROP TABLE IF EXISTS t2")
- } finally {
- super.afterAll()
- }
- }
-
- private def checkSqlGeneration(hiveQl: String): Unit = {
- val df = sql(hiveQl)
-
- val convertedSQL = try new SQLBuilder(df).toSQL catch {
- case NonFatal(e) =>
- fail(
- s"""Cannot convert the following HiveQL query plan back to SQL query string:
- |
- |# Original HiveQL query string:
- |$hiveQl
- |
- |# Resolved query plan:
- |${df.queryExecution.analyzed.treeString}
- """.stripMargin)
- }
-
- try {
- checkAnswer(sql(convertedSQL), df)
- } catch { case cause: Throwable =>
- fail(
- s"""Failed to execute converted SQL string or got wrong answer:
- |
- |# Converted SQL query string:
- |$convertedSQL
- |
- |# Original HiveQL query string:
- |$hiveQl
- |
- |# Resolved query plan:
- |${df.queryExecution.analyzed.treeString}
- """.stripMargin,
- cause)
- }
- }
-
- test("misc non-aggregate functions") {
- checkSqlGeneration("SELECT abs(15), abs(-15)")
- checkSqlGeneration("SELECT array(1,2,3)")
- checkSqlGeneration("SELECT coalesce(null, 1, 2)")
- checkSqlGeneration("SELECT explode(array(1,2,3))")
- checkSqlGeneration("SELECT explode_outer(array())")
- checkSqlGeneration("SELECT greatest(1,null,3)")
- checkSqlGeneration("SELECT if(1==2, 'yes', 'no')")
- checkSqlGeneration("SELECT isnan(15), isnan('invalid')")
- checkSqlGeneration("SELECT isnull(null), isnull('a')")
- checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
- checkSqlGeneration("SELECT least(1,null,3)")
- checkSqlGeneration("SELECT map(1, 'a', 2, 'b')")
- checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
- checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
- checkSqlGeneration("SELECT posexplode_outer(array())")
- checkSqlGeneration("SELECT inline_outer(array(struct('a', 1)))")
- checkSqlGeneration("SELECT rand(1)")
- checkSqlGeneration("SELECT randn(3)")
- checkSqlGeneration("SELECT struct(1,2,3)")
- }
-
- test("math functions") {
- checkSqlGeneration("SELECT acos(-1)")
- checkSqlGeneration("SELECT asin(-1)")
- checkSqlGeneration("SELECT atan(1)")
- checkSqlGeneration("SELECT atan2(1, 1)")
- checkSqlGeneration("SELECT bin(10)")
- checkSqlGeneration("SELECT cbrt(1000.0)")
- checkSqlGeneration("SELECT ceil(2.333)")
- checkSqlGeneration("SELECT ceiling(2.333)")
- checkSqlGeneration("SELECT cos(1.0)")
- checkSqlGeneration("SELECT cosh(1.0)")
- checkSqlGeneration("SELECT conv(15, 10, 16)")
- checkSqlGeneration("SELECT degrees(pi())")
- checkSqlGeneration("SELECT e()")
- checkSqlGeneration("SELECT exp(1.0)")
- checkSqlGeneration("SELECT expm1(1.0)")
- checkSqlGeneration("SELECT floor(-2.333)")
- checkSqlGeneration("SELECT factorial(5)")
- checkSqlGeneration("SELECT hex(10)")
- checkSqlGeneration("SELECT hypot(3, 4)")
- checkSqlGeneration("SELECT log(10.0)")
- checkSqlGeneration("SELECT log10(1000.0)")
- checkSqlGeneration("SELECT log1p(0.0)")
- checkSqlGeneration("SELECT log2(8.0)")
- checkSqlGeneration("SELECT ln(10.0)")
- checkSqlGeneration("SELECT negative(-1)")
- checkSqlGeneration("SELECT pi()")
- checkSqlGeneration("SELECT pmod(3, 2)")
- checkSqlGeneration("SELECT positive(3)")
- checkSqlGeneration("SELECT pow(2, 3)")
- checkSqlGeneration("SELECT power(2, 3)")
- checkSqlGeneration("SELECT radians(180.0)")
- checkSqlGeneration("SELECT rint(1.63)")
- checkSqlGeneration("SELECT round(31.415, -1)")
- checkSqlGeneration("SELECT shiftleft(2, 3)")
- checkSqlGeneration("SELECT shiftright(16, 3)")
- checkSqlGeneration("SELECT shiftrightunsigned(16, 3)")
- checkSqlGeneration("SELECT sign(-2.63)")
- checkSqlGeneration("SELECT signum(-2.63)")
- checkSqlGeneration("SELECT sin(1.0)")
- checkSqlGeneration("SELECT sinh(1.0)")
- checkSqlGeneration("SELECT sqrt(100.0)")
- checkSqlGeneration("SELECT tan(1.0)")
- checkSqlGeneration("SELECT tanh(1.0)")
- }
-
- test("aggregate functions") {
- checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT percentile_approx(value, 0.25) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT percentile_approx(value, array(0.25, 0.75)) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT percentile_approx(value, 0.25, 100) FROM t1 GROUP BY key")
- checkSqlGeneration(
- "SELECT percentile_approx(value, array(0.25, 0.75), 100) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT covar_pop(value, key) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT covar_samp(value, key) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT first(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT first_value(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT kurtosis(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT last(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT last_value(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT percentile(value, 0.25) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT percentile(value, array(0.25, 0.75)) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT stddev_samp(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT sum(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT variance(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT var_pop(value) FROM t1 GROUP BY key")
- checkSqlGeneration("SELECT var_samp(value) FROM t1 GROUP BY key")
- }
-
- test("string functions") {
- checkSqlGeneration("SELECT ascii('SparkSql')")
- checkSqlGeneration("SELECT base64(a) FROM t0")
- checkSqlGeneration("SELECT concat('This ', 'is ', 'a ', 'test')")
- checkSqlGeneration("SELECT concat_ws(' ', 'This', 'is', 'a', 'test')")
- checkSqlGeneration("SELECT decode(a, 'UTF-8') FROM t0")
- checkSqlGeneration("SELECT encode('SparkSql', 'UTF-8')")
- checkSqlGeneration("SELECT find_in_set('ab', 'abc,b,ab,c,def')")
- checkSqlGeneration("SELECT format_number(1234567.890, 2)")
- checkSqlGeneration("SELECT format_string('aa%d%s',123, 'cc')")
- checkSqlGeneration("SELECT get_json_object('{\"a\":\"bc\"}','$.a')")
- checkSqlGeneration("SELECT initcap('This is a test')")
- checkSqlGeneration("SELECT instr('This is a test', 'is')")
- checkSqlGeneration("SELECT lcase('SparkSql')")
- checkSqlGeneration("SELECT length('This is a test')")
- checkSqlGeneration("SELECT levenshtein('This is a test', 'Another test')")
- checkSqlGeneration("SELECT lower('SparkSql')")
- checkSqlGeneration("SELECT locate('is', 'This is a test', 3)")
- checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')")
- checkSqlGeneration("SELECT ltrim(' SparkSql ')")
- checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')")
- checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')")
- checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)")
- checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')")
- checkSqlGeneration("SELECT repeat('SparkSql', 3)")
- checkSqlGeneration("SELECT reverse('SparkSql')")
- checkSqlGeneration("SELECT rpad('SparkSql', 16, ' is Cool')")
- checkSqlGeneration("SELECT rtrim(' SparkSql ')")
- checkSqlGeneration("SELECT soundex('SparkSql')")
- checkSqlGeneration("SELECT space(2)")
- checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')")
- checkSqlGeneration("SELECT space(2)")
- checkSqlGeneration("SELECT substr('This is a test', 1)")
- checkSqlGeneration("SELECT substring('This is a test', 1)")
- checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)")
- checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')")
- checkSqlGeneration("SELECT trim(' SparkSql ')")
- checkSqlGeneration("SELECT ucase('SparkSql')")
- checkSqlGeneration("SELECT unbase64('SparkSql')")
- checkSqlGeneration("SELECT unhex(41)")
- checkSqlGeneration("SELECT upper('SparkSql')")
- }
-
- test("datetime functions") {
- checkSqlGeneration("SELECT add_months('2001-03-31', 1)")
- checkSqlGeneration("SELECT count(current_date())")
- checkSqlGeneration("SELECT count(current_timestamp())")
- checkSqlGeneration("SELECT datediff('2001-01-02', '2001-01-01')")
- checkSqlGeneration("SELECT date_add('2001-01-02', 1)")
- checkSqlGeneration("SELECT date_format('2001-05-02', 'yyyy-dd')")
- checkSqlGeneration("SELECT date_sub('2001-01-02', 1)")
- checkSqlGeneration("SELECT day('2001-05-02')")
- checkSqlGeneration("SELECT dayofyear('2001-05-02')")
- checkSqlGeneration("SELECT dayofmonth('2001-05-02')")
- checkSqlGeneration("SELECT from_unixtime(1000, 'yyyy-MM-dd HH:mm:ss')")
- checkSqlGeneration("SELECT from_utc_timestamp('2015-07-24 00:00:00', 'PST')")
- checkSqlGeneration("SELECT hour('11:35:55')")
- checkSqlGeneration("SELECT last_day('2001-01-01')")
- checkSqlGeneration("SELECT minute('11:35:55')")
- checkSqlGeneration("SELECT month('2001-05-02')")
- checkSqlGeneration("SELECT months_between('2001-10-30 10:30:00', '1996-10-30')")
- checkSqlGeneration("SELECT next_day('2001-05-02', 'TU')")
- checkSqlGeneration("SELECT count(now())")
- checkSqlGeneration("SELECT quarter('2001-05-02')")
- checkSqlGeneration("SELECT second('11:35:55')")
- checkSqlGeneration("SELECT to_timestamp('2001-10-30 10:30:00', 'yyyy-MM-dd HH:mm:ss')")
- checkSqlGeneration("SELECT to_date('2001-10-30 10:30:00')")
- checkSqlGeneration("SELECT to_unix_timestamp('2015-07-24 00:00:00', 'yyyy-MM-dd HH:mm:ss')")
- checkSqlGeneration("SELECT to_utc_timestamp('2015-07-24 00:00:00', 'PST')")
- checkSqlGeneration("SELECT trunc('2001-10-30 10:30:00', 'YEAR')")
- checkSqlGeneration("SELECT unix_timestamp('2001-10-30 10:30:00')")
- checkSqlGeneration("SELECT weekofyear('2001-05-02')")
- checkSqlGeneration("SELECT year('2001-05-02')")
-
- checkSqlGeneration("SELECT interval 3 years - 3 month 7 week 123 microseconds as i")
- }
-
- test("collection functions") {
- checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)")
- checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))")
- checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))")
- }
-
- test("misc functions") {
- checkSqlGeneration("SELECT crc32('Spark')")
- checkSqlGeneration("SELECT md5('Spark')")
- checkSqlGeneration("SELECT hash('Spark')")
- checkSqlGeneration("SELECT sha('Spark')")
- checkSqlGeneration("SELECT sha1('Spark')")
- checkSqlGeneration("SELECT sha2('Spark', 0)")
- checkSqlGeneration("SELECT spark_partition_id()")
- checkSqlGeneration("SELECT input_file_name()")
- checkSqlGeneration("SELECT monotonically_increasing_id()")
- }
-
- test("subquery") {
- checkSqlGeneration("SELECT 1 + (SELECT 2)")
- checkSqlGeneration("SELECT 1 + (SELECT 2 + (SELECT 3 as a))")
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
deleted file mode 100644
index fe171a6ee8..0000000000
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/LogicalPlanToSQLSuite.scala
+++ /dev/null
@@ -1,1201 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql.catalyst
-
-import java.nio.charset.StandardCharsets
-import java.nio.file.{Files, NoSuchFileException, Paths}
-
-import scala.io.Source
-import scala.util.control.NonFatal
-
-import org.apache.spark.TestUtils
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.parser.ParseException
-import org.apache.spark.sql.catalyst.plans.logical.LeafNode
-import org.apache.spark.sql.functions._
-import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.test.SQLTestUtils
-
-/**
- * A test suite for LogicalPlan-to-SQL conversion.
- *
- * Each query has a golden generated SQL file in test/resources/sqlgen. The test suite also has
- * built-in functionality to automatically generate these golden files.
- *
- * To re-generate golden files, run:
- * SPARK_GENERATE_GOLDEN_FILES=1 build/sbt "hive/test-only *LogicalPlanToSQLSuite"
- */
-class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
- import testImplicits._
-
- // Used for generating new query answer files by saving
- private val regenerateGoldenFiles: Boolean = System.getenv("SPARK_GENERATE_GOLDEN_FILES") == "1"
- private val goldenSQLPath = {
- // If regenerateGoldenFiles is true, we must be running this in SBT and we use hard-coded
- // relative path. Otherwise, we use classloader's getResource to find the location.
- if (regenerateGoldenFiles) {
- java.nio.file.Paths.get("src", "test", "resources", "sqlgen").toFile.getCanonicalPath
- } else {
- getTestResourcePath("sqlgen")
- }
- }
-
- protected override def beforeAll(): Unit = {
- super.beforeAll()
- (0 to 3).foreach { i =>
- sql(s"DROP TABLE IF EXISTS parquet_t$i")
- }
- sql("DROP TABLE IF EXISTS t0")
-
- spark.range(10).write.saveAsTable("parquet_t0")
- sql("CREATE TABLE t0 AS SELECT * FROM parquet_t0")
-
- spark
- .range(10)
- .select('id as 'key, concat(lit("val_"), 'id) as 'value)
- .write
- .saveAsTable("parquet_t1")
-
- spark
- .range(10)
- .select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd)
- .write
- .saveAsTable("parquet_t2")
-
- def createArray(id: Column): Column = {
- when(id % 3 === 0, lit(null)).otherwise(array('id, 'id + 1))
- }
-
- spark
- .range(10)
- .select(
- createArray('id).as("arr"),
- array(array('id), createArray('id)).as("arr2"),
- lit("""{"f1": "1", "f2": "2", "f3": 3}""").as("json"),
- 'id
- )
- .write
- .saveAsTable("parquet_t3")
- }
-
- override protected def afterAll(): Unit = {
- try {
- (0 to 3).foreach { i =>
- sql(s"DROP TABLE IF EXISTS parquet_t$i")
- }
- sql("DROP TABLE IF EXISTS t0")
- } finally {
- super.afterAll()
- }
- }
-
- /**
- * Compare the generated SQL with the expected answer string.
- */
- private def checkSQLStructure(originalSQL: String, convertedSQL: String, answerFile: String) = {
- if (answerFile != null) {
- val separator = "-" * 80
- if (regenerateGoldenFiles) {
- val path = Paths.get(s"$goldenSQLPath/$answerFile.sql")
- val header = "-- This file is automatically generated by LogicalPlanToSQLSuite."
- val answerText = s"$header\n${originalSQL.trim()}\n${separator}\n$convertedSQL\n"
- Files.write(path, answerText.getBytes(StandardCharsets.UTF_8))
- } else {
- val goldenFileName = s"sqlgen/$answerFile.sql"
- val resourceStream = getClass.getClassLoader.getResourceAsStream(goldenFileName)
- if (resourceStream == null) {
- throw new NoSuchFileException(goldenFileName)
- }
- val answerText = try {
- Source.fromInputStream(resourceStream).mkString
- } finally {
- resourceStream.close
- }
- val sqls = answerText.split(separator)
- assert(sqls.length == 2, "Golden sql files should have a separator.")
- val expectedSQL = sqls(1).trim()
- assert(convertedSQL == expectedSQL)
- }
- }
- }
-
- /**
- * 1. Checks if SQL parsing succeeds.
- * 2. Checks if SQL generation succeeds.
- * 3. Checks the generated SQL against golden files.
- * 4. Verifies the execution result stays the same.
- */
- private def checkSQL(sqlString: String, answerFile: String = null): Unit = {
- val df = sql(sqlString)
-
- val convertedSQL = try new SQLBuilder(df).toSQL catch {
- case NonFatal(e) =>
- fail(
- s"""Cannot convert the following SQL query plan back to SQL query string:
- |
- |# Original SQL query string:
- |$sqlString
- |
- |# Resolved query plan:
- |${df.queryExecution.analyzed.treeString}
- """.stripMargin, e)
- }
-
- checkSQLStructure(sqlString, convertedSQL, answerFile)
-
- try {
- checkAnswer(sql(convertedSQL), df)
- } catch { case cause: Throwable =>
- fail(
- s"""Failed to execute converted SQL string or got wrong answer:
- |
- |# Converted SQL query string:
- |$convertedSQL
- |
- |# Original SQL query string:
- |$sqlString
- |
- |# Resolved query plan:
- |${df.queryExecution.analyzed.treeString}
- """.stripMargin, cause)
- }
- }
-
- // When saving golden files, these tests should be ignored to prevent making files.
- if (!regenerateGoldenFiles) {
- test("Test should fail if the SQL query cannot be parsed") {
- val m = intercept[ParseException] {
- checkSQL("SELE", "NOT_A_FILE")
- }.getMessage
- assert(m.contains("mismatched input"))
- }
-
- test("Test should fail if the golden file cannot be found") {
- val m2 = intercept[NoSuchFileException] {
- checkSQL("SELECT 1", "NOT_A_FILE")
- }.getMessage
- assert(m2.contains("NOT_A_FILE"))
- }
-
- test("Test should fail if the SQL query cannot be regenerated") {
- case class Unsupported() extends LeafNode with MultiInstanceRelation {
- override def newInstance(): Unsupported = copy()
- override def output: Seq[Attribute] = Nil
- }
- Unsupported().createOrReplaceTempView("not_sql_gen_supported_table_so_far")
- sql("select * from not_sql_gen_supported_table_so_far")
- val m3 = intercept[org.scalatest.exceptions.TestFailedException] {
- checkSQL("select * from not_sql_gen_supported_table_so_far", "in")
- }.getMessage
- assert(m3.contains("Cannot convert the following SQL query plan back to SQL query string"))
- }
-
- test("Test should fail if the SQL query did not equal to the golden SQL") {
- val m4 = intercept[org.scalatest.exceptions.TestFailedException] {
- checkSQL("SELECT 1", "in")
- }.getMessage
- assert(m4.contains("did not equal"))
- }
- }
-
- test("range") {
- checkSQL("select * from range(100)", "range")
- checkSQL("select * from range(1, 100, 20, 10)", "range_with_splits")
- }
-
- test("in") {
- checkSQL("SELECT id FROM parquet_t0 WHERE id IN (1, 2, 3)", "in")
- }
-
- test("not in") {
- checkSQL("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)", "not_in")
- }
-
- test("not like") {
- checkSQL("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'", "not_like")
- }
-
- test("aggregate function in having clause") {
- checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0", "agg1")
- }
-
- test("aggregate function in order by clause") {
- checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY MAX(key)", "agg2")
- }
-
- // When there are multiple aggregate functions in ORDER BY clause, all of them are extracted into
- // Aggregate operator and aliased to the same name "aggOrder". This is OK for normal query
- // execution since these aliases have different expression ID. But this introduces name collision
- // when converting resolved plans back to SQL query strings as expression IDs are stripped.
- test("aggregate function in order by clause with multiple order keys") {
- checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key, MAX(key)", "agg3")
- }
-
- test("order by asc nulls last") {
- checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key nulls last, MAX(key)",
- "sort_asc_nulls_last")
- }
-
- test("order by desc nulls first") {
- checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key ORDER BY key desc nulls first," +
- "MAX(key)", "sort_desc_nulls_first")
- }
-
- test("type widening in union") {
- checkSQL("SELECT id FROM parquet_t0 UNION ALL SELECT CAST(id AS INT) AS id FROM parquet_t0",
- "type_widening")
- }
-
- test("union distinct") {
- checkSQL("SELECT * FROM t0 UNION SELECT * FROM t0", "union_distinct")
- }
-
- test("three-child union") {
- checkSQL(
- """
- |SELECT id FROM parquet_t0
- |UNION ALL SELECT id FROM parquet_t0
- |UNION ALL SELECT id FROM parquet_t0
- """.stripMargin,
- "three_child_union")
- }
-
- test("intersect") {
- checkSQL("SELECT * FROM t0 INTERSECT SELECT * FROM t0", "intersect")
- }
-
- test("except") {
- checkSQL("SELECT * FROM t0 EXCEPT SELECT * FROM t0", "except")
- }
-
- test("self join") {
- checkSQL("SELECT x.key FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key", "self_join")
- }
-
- test("self join with group by") {
- checkSQL(
- "SELECT x.key, COUNT(*) FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key group by x.key",
- "self_join_with_group_by")
- }
-
- test("case") {
- checkSQL("SELECT CASE WHEN id % 2 > 0 THEN 0 WHEN id % 2 = 0 THEN 1 END FROM parquet_t0",
- "case")
- }
-
- test("case with else") {
- checkSQL("SELECT CASE WHEN id % 2 > 0 THEN 0 ELSE 1 END FROM parquet_t0", "case_with_else")
- }
-
- test("case with key") {
- checkSQL("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' END FROM parquet_t0",
- "case_with_key")
- }
-
- test("case with key and else") {
- checkSQL("SELECT CASE id WHEN 0 THEN 'foo' WHEN 1 THEN 'bar' ELSE 'baz' END FROM parquet_t0",
- "case_with_key_and_else")
- }
-
- test("select distinct without aggregate functions") {
- checkSQL("SELECT DISTINCT id FROM parquet_t0", "select_distinct")
- }
-
- test("rollup/cube #1") {
- // Original logical plan:
- // Aggregate [(key#17L % cast(5 as bigint))#47L,grouping__id#46],
- // [(count(1),mode=Complete,isDistinct=false) AS cnt#43L,
- // (key#17L % cast(5 as bigint))#47L AS _c1#45L,
- // grouping__id#46 AS _c2#44]
- // +- Expand [List(key#17L, value#18, (key#17L % cast(5 as bigint))#47L, 0),
- // List(key#17L, value#18, null, 1)],
- // [key#17L,value#18,(key#17L % cast(5 as bigint))#47L,grouping__id#46]
- // +- Project [key#17L,
- // value#18,
- // (key#17L % cast(5 as bigint)) AS (key#17L % cast(5 as bigint))#47L]
- // +- Subquery t1
- // +- Relation[key#17L,value#18] ParquetRelation
- // Converted SQL:
- // SELECT count( 1) AS `cnt`,
- // (`t1`.`key` % CAST(5 AS BIGINT)),
- // grouping_id() AS `_c2`
- // FROM `default`.`t1`
- // GROUP BY (`t1`.`key` % CAST(5 AS BIGINT))
- // GROUPING SETS (((`t1`.`key` % CAST(5 AS BIGINT))), ())
- checkSQL(
- "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH ROLLUP",
- "rollup_cube_1_1")
-
- checkSQL(
- "SELECT count(*) as cnt, key%5, grouping_id() FROM parquet_t1 GROUP BY key % 5 WITH CUBE",
- "rollup_cube_1_2")
- }
-
- test("rollup/cube #2") {
- checkSQL("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH ROLLUP",
- "rollup_cube_2_1")
-
- checkSQL("SELECT key, value, count(value) FROM parquet_t1 GROUP BY key, value WITH CUBE",
- "rollup_cube_2_2")
- }
-
- test("rollup/cube #3") {
- checkSQL(
- "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH ROLLUP",
- "rollup_cube_3_1")
-
- checkSQL(
- "SELECT key, count(value), grouping_id() FROM parquet_t1 GROUP BY key, value WITH CUBE",
- "rollup_cube_3_2")
- }
-
- test("rollup/cube #4") {
- checkSQL(
- s"""
- |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
- |GROUP BY key % 5, key - 5 WITH ROLLUP
- """.stripMargin,
- "rollup_cube_4_1")
-
- checkSQL(
- s"""
- |SELECT count(*) as cnt, key % 5 as k1, key - 5 as k2, grouping_id() FROM parquet_t1
- |GROUP BY key % 5, key - 5 WITH CUBE
- """.stripMargin,
- "rollup_cube_4_2")
- }
-
- test("rollup/cube #5") {
- checkSQL(
- s"""
- |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
- |FROM (SELECT key, key%2, key - 5 FROM parquet_t1) t GROUP BY key%5, key-5
- |WITH ROLLUP
- """.stripMargin,
- "rollup_cube_5_1")
-
- checkSQL(
- s"""
- |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id(key % 5, key - 5) AS k3
- |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
- |WITH CUBE
- """.stripMargin,
- "rollup_cube_5_2")
- }
-
- test("rollup/cube #6") {
- checkSQL("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b",
- "rollup_cube_6_1")
-
- checkSQL("SELECT a, b, sum(c) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b",
- "rollup_cube_6_2")
-
- checkSQL("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY ROLLUP(a, b) ORDER BY a, b",
- "rollup_cube_6_3")
-
- checkSQL("SELECT a, b, sum(a) FROM parquet_t2 GROUP BY CUBE(a, b) ORDER BY a, b",
- "rollup_cube_6_4")
-
- checkSQL("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH ROLLUP",
- "rollup_cube_6_5")
-
- checkSQL("SELECT a + b, b, sum(a - b) FROM parquet_t2 GROUP BY a + b, b WITH CUBE",
- "rollup_cube_6_6")
- }
-
- test("rollup/cube #7") {
- checkSQL("SELECT a, b, grouping_id(a, b) FROM parquet_t2 GROUP BY cube(a, b)",
- "rollup_cube_7_1")
-
- checkSQL("SELECT a, b, grouping(b) FROM parquet_t2 GROUP BY cube(a, b)",
- "rollup_cube_7_2")
-
- checkSQL("SELECT a, b, grouping(a) FROM parquet_t2 GROUP BY cube(a, b)",
- "rollup_cube_7_3")
- }
-
- test("rollup/cube #8") {
- // grouping_id() is part of another expression
- checkSQL(
- s"""
- |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
- |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
- |WITH ROLLUP
- """.stripMargin,
- "rollup_cube_8_1")
-
- checkSQL(
- s"""
- |SELECT hkey AS k1, value - 5 AS k2, hash(grouping_id()) AS hgid
- |FROM (SELECT hash(key) as hkey, key as value FROM parquet_t1) t GROUP BY hkey, value-5
- |WITH CUBE
- """.stripMargin,
- "rollup_cube_8_2")
- }
-
- test("rollup/cube #9") {
- // self join is used as the child node of ROLLUP/CUBE with replaced quantifiers
- checkSQL(
- s"""
- |SELECT t.key - 5, cnt, SUM(cnt)
- |FROM (SELECT x.key, COUNT(*) as cnt
- |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
- |GROUP BY cnt, t.key - 5
- |WITH ROLLUP
- """.stripMargin,
- "rollup_cube_9_1")
-
- checkSQL(
- s"""
- |SELECT t.key - 5, cnt, SUM(cnt)
- |FROM (SELECT x.key, COUNT(*) as cnt
- |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key GROUP BY x.key) t
- |GROUP BY cnt, t.key - 5
- |WITH CUBE
- """.stripMargin,
- "rollup_cube_9_2")
- }
-
- test("grouping sets #1") {
- checkSQL(
- s"""
- |SELECT count(*) AS cnt, key % 5 AS k1, key - 5 AS k2, grouping_id() AS k3
- |FROM (SELECT key, key % 2, key - 5 FROM parquet_t1) t GROUP BY key % 5, key - 5
- |GROUPING SETS (key % 5, key - 5)
- """.stripMargin,
- "grouping_sets_1")
- }
-
- test("grouping sets #2") {
- checkSQL(
- "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a, b) ORDER BY a, b",
- "grouping_sets_2_1")
-
- checkSQL(
- "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (a) ORDER BY a, b",
- "grouping_sets_2_2")
-
- checkSQL(
- "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (b) ORDER BY a, b",
- "grouping_sets_2_3")
-
- checkSQL(
- "SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b GROUPING SETS (()) ORDER BY a, b",
- "grouping_sets_2_4")
-
- checkSQL(
- s"""
- |SELECT a, b, sum(c) FROM parquet_t2 GROUP BY a, b
- |GROUPING SETS ((), (a), (a, b)) ORDER BY a, b
- """.stripMargin,
- "grouping_sets_2_5")
- }
-
- test("cluster by") {
- checkSQL("SELECT id FROM parquet_t0 CLUSTER BY id", "cluster_by")
- }
-
- test("distribute by") {
- checkSQL("SELECT id FROM parquet_t0 DISTRIBUTE BY id", "distribute_by")
- }
-
- test("distribute by with sort by") {
- checkSQL("SELECT id FROM parquet_t0 DISTRIBUTE BY id SORT BY id",
- "distribute_by_with_sort_by")
- }
-
- test("SPARK-13720: sort by after having") {
- checkSQL("SELECT COUNT(value) FROM parquet_t1 GROUP BY key HAVING MAX(key) > 0 SORT BY key",
- "sort_by_after_having")
- }
-
- test("distinct aggregation") {
- checkSQL("SELECT COUNT(DISTINCT id) FROM parquet_t0", "distinct_aggregation")
- }
-
- test("TABLESAMPLE") {
- // Project [id#2L]
- // +- Sample 0.0, 1.0, false, ...
- // +- Subquery s
- // +- Subquery parquet_t0
- // +- Relation[id#2L] ParquetRelation
- checkSQL("SELECT s.id FROM parquet_t0 TABLESAMPLE(100 PERCENT) s", "tablesample_1")
-
- // Project [id#2L]
- // +- Sample 0.0, 1.0, false, ...
- // +- Subquery parquet_t0
- // +- Relation[id#2L] ParquetRelation
- checkSQL("SELECT * FROM parquet_t0 TABLESAMPLE(100 PERCENT)", "tablesample_2")
-
- // Project [id#21L]
- // +- Sample 0.0, 1.0, false, ...
- // +- MetastoreRelation default, t0, Some(s)
- checkSQL("SELECT s.id FROM t0 TABLESAMPLE(100 PERCENT) s", "tablesample_3")
-
- // Project [id#24L]
- // +- Sample 0.0, 1.0, false, ...
- // +- MetastoreRelation default, t0, None
- checkSQL("SELECT * FROM t0 TABLESAMPLE(100 PERCENT)", "tablesample_4")
-
- // When a sampling fraction is not 100%, the returned results are random.
- // Thus, added an always-false filter here to check if the generated plan can be successfully
- // executed.
- checkSQL("SELECT s.id FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) s WHERE 1=0", "tablesample_5")
- checkSQL("SELECT * FROM parquet_t0 TABLESAMPLE(0.1 PERCENT) WHERE 1=0", "tablesample_6")
- }
-
- test("multi-distinct columns") {
- checkSQL("SELECT a, COUNT(DISTINCT b), COUNT(DISTINCT c), SUM(d) FROM parquet_t2 GROUP BY a",
- "multi_distinct")
- }
-
- test("persisted data source relations") {
- Seq("orc", "json", "parquet").foreach { format =>
- val tableName = s"${format}_parquet_t0"
- withTable(tableName) {
- spark.range(10).write.format(format).saveAsTable(tableName)
- checkSQL(s"SELECT id FROM $tableName", s"data_source_$tableName")
- }
- }
- }
-
- test("script transformation - schemaless") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL("SELECT TRANSFORM (a, b, c, d) USING 'cat' FROM parquet_t2",
- "script_transformation_1")
- checkSQL("SELECT TRANSFORM (*) USING 'cat' FROM parquet_t2",
- "script_transformation_2")
- }
-
- test("script transformation - alias list") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL("SELECT TRANSFORM (a, b, c, d) USING 'cat' AS (d1, d2, d3, d4) FROM parquet_t2",
- "script_transformation_alias_list")
- }
-
- test("script transformation - alias list with type") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL(
- """FROM
- |(FROM parquet_t1 SELECT TRANSFORM(key, value) USING 'cat' AS (thing1 int, thing2 string)) t
- |SELECT thing1 + 1
- """.stripMargin,
- "script_transformation_alias_list_with_type")
- }
-
- test("script transformation - row format delimited clause with only one format property") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL(
- """SELECT TRANSFORM (key) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t'
- |USING 'cat' AS (tKey) ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t'
- |FROM parquet_t1
- """.stripMargin,
- "script_transformation_row_format_one")
- }
-
- test("script transformation - row format delimited clause with multiple format properties") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL(
- """SELECT TRANSFORM (key)
- |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t'
- |USING 'cat' AS (tKey)
- |ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t' LINES TERMINATED BY '\t'
- |FROM parquet_t1
- """.stripMargin,
- "script_transformation_row_format_multiple")
- }
-
- test("script transformation - row format serde clauses with SERDEPROPERTIES") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL(
- """SELECT TRANSFORM (key, value)
- |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
- |WITH SERDEPROPERTIES('field.delim' = '|')
- |USING 'cat' AS (tKey, tValue)
- |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
- |WITH SERDEPROPERTIES('field.delim' = '|')
- |FROM parquet_t1
- """.stripMargin,
- "script_transformation_row_format_serde")
- }
-
- test("script transformation - row format serde clauses without SERDEPROPERTIES") {
- assume(TestUtils.testCommandAvailable("/bin/bash"))
-
- checkSQL(
- """SELECT TRANSFORM (key, value)
- |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
- |USING 'cat' AS (tKey, tValue)
- |ROW FORMAT SERDE 'org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe'
- |FROM parquet_t1
- """.stripMargin,
- "script_transformation_row_format_without_serde")
- }
-
- test("plans with non-SQL expressions") {
- spark.udf.register("foo", (_: Int) * 2)
- intercept[UnsupportedOperationException](new SQLBuilder(sql("SELECT foo(id) FROM t0")).toSQL)
- }
-
- test("named expression in column names shouldn't be quoted") {
- def checkColumnNames(query: String, expectedColNames: String*): Unit = {
- checkSQL(query)
- assert(sql(query).columns === expectedColNames)
- }
-
- // Attributes
- checkColumnNames(
- """SELECT * FROM (
- | SELECT 1 AS a, 2 AS b, 3 AS `we``ird`
- |) s
- """.stripMargin,
- "a", "b", "we`ird"
- )
-
- checkColumnNames(
- """SELECT x.a, y.a, x.b, y.b
- |FROM (SELECT 1 AS a, 2 AS b) x
- |CROSS JOIN (SELECT 1 AS a, 2 AS b) y
- |ON x.a = y.a
- """.stripMargin,
- "a", "a", "b", "b"
- )
-
- // String literal
- checkColumnNames(
- "SELECT 'foo', '\"bar\\''",
- "foo", "\"bar\'"
- )
-
- // Numeric literals (should have CAST or suffixes in column names)
- checkColumnNames(
- "SELECT 1Y, 2S, 3, 4L, 5.1, 6.1D",
- "1", "2", "3", "4", "5.1", "6.1"
- )
-
- // Aliases
- checkColumnNames(
- "SELECT 1 AS a",
- "a"
- )
-
- // Complex type extractors
- checkColumnNames(
- """SELECT
- | a.f1, b[0].f1, b.f1, c["foo"], d[0]
- |FROM (
- | SELECT
- | NAMED_STRUCT("f1", 1, "f2", "foo") AS a,
- | ARRAY(NAMED_STRUCT("f1", 1, "f2", "foo")) AS b,
- | MAP("foo", 1) AS c,
- | ARRAY(1) AS d
- |) s
- """.stripMargin,
- "f1", "b[0].f1", "f1", "c[foo]", "d[0]"
- )
- }
-
- test("window basic") {
- checkSQL("SELECT MAX(value) OVER (PARTITION BY key % 3) FROM parquet_t1", "window_basic_1")
-
- checkSQL(
- """
- |SELECT key, value, ROUND(AVG(key) OVER (), 2)
- |FROM parquet_t1 ORDER BY key
- """.stripMargin,
- "window_basic_2")
-
- checkSQL(
- """
- |SELECT value, MAX(key + 1) OVER (PARTITION BY key % 5 ORDER BY key % 7) AS max
- |FROM parquet_t1
- """.stripMargin,
- "window_basic_3")
-
- checkSQL(
- """
- |SELECT key, value, ROUND(AVG(key) OVER (), 2)
- |FROM parquet_t1 ORDER BY key nulls last
- """.stripMargin,
- "window_basic_asc_nulls_last")
-
- checkSQL(
- """
- |SELECT key, value, ROUND(AVG(key) OVER (), 2)
- |FROM parquet_t1 ORDER BY key desc nulls first
- """.stripMargin,
- "window_basic_desc_nulls_first")
- }
-
- test("multiple window functions in one expression") {
- checkSQL(
- """
- |SELECT
- | MAX(key) OVER (ORDER BY key DESC, value) / MIN(key) OVER (PARTITION BY key % 3)
- |FROM parquet_t1
- """.stripMargin)
- }
-
- test("regular expressions and window functions in one expression") {
- checkSQL("SELECT MAX(key) OVER (PARTITION BY key % 3) + key FROM parquet_t1",
- "regular_expressions_and_window")
- }
-
- test("aggregate functions and window functions in one expression") {
- checkSQL("SELECT MAX(c) + COUNT(a) OVER () FROM parquet_t2 GROUP BY a, b",
- "aggregate_functions_and_window")
- }
-
- test("window with different window specification") {
- checkSQL(
- """
- |SELECT key, value,
- |DENSE_RANK() OVER (ORDER BY key, value) AS dr,
- |MAX(value) OVER (PARTITION BY key ORDER BY key ASC) AS max
- |FROM parquet_t1
- """.stripMargin)
- }
-
- test("window with the same window specification with aggregate + having") {
- checkSQL(
- """
- |SELECT key, value,
- |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key DESC) AS max
- |FROM parquet_t1 GROUP BY key, value HAVING key > 5
- """.stripMargin,
- "window_with_the_same_window_with_agg_having")
- }
-
- test("window with the same window specification with aggregate functions") {
- checkSQL(
- """
- |SELECT key, value,
- |MAX(value) OVER (PARTITION BY key % 5 ORDER BY key) AS max
- |FROM parquet_t1 GROUP BY key, value
- """.stripMargin,
- "window_with_the_same_window_with_agg_functions")
- }
-
- test("window with the same window specification with aggregate") {
- checkSQL(
- """
- |SELECT key, value,
- |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr,
- |COUNT(key)
- |FROM parquet_t1 GROUP BY key, value
- """.stripMargin,
- "window_with_the_same_window_with_agg")
- }
-
- test("window with the same window specification without aggregate and filter") {
- checkSQL(
- """
- |SELECT key, value,
- |DENSE_RANK() OVER (DISTRIBUTE BY key SORT BY key, value) AS dr,
- |COUNT(key) OVER(DISTRIBUTE BY key SORT BY key, value) AS ca
- |FROM parquet_t1
- """.stripMargin,
- "window_with_the_same_window_with_agg_filter")
- }
-
- test("window clause") {
- checkSQL(
- """
- |SELECT key, MAX(value) OVER w1 AS MAX, MIN(value) OVER w2 AS min
- |FROM parquet_t1
- |WINDOW w1 AS (PARTITION BY key % 5 ORDER BY key), w2 AS (PARTITION BY key % 6)
- """.stripMargin)
- }
-
- test("special window functions") {
- checkSQL(
- """
- |SELECT
- | RANK() OVER w,
- | PERCENT_RANK() OVER w,
- | DENSE_RANK() OVER w,
- | ROW_NUMBER() OVER w,
- | NTILE(10) OVER w,
- | CUME_DIST() OVER w,
- | LAG(key, 2) OVER w,
- | LEAD(key, 2) OVER w
- |FROM parquet_t1
- |WINDOW w AS (PARTITION BY key % 5 ORDER BY key)
- """.stripMargin)
- }
-
- test("window with join") {
- checkSQL(
- """
- |SELECT x.key, MAX(y.key) OVER (PARTITION BY x.key % 5 ORDER BY x.key)
- |FROM parquet_t1 x JOIN parquet_t1 y ON x.key = y.key
- """.stripMargin,
- "window_with_join")
- }
-
- test("join 2 tables and aggregate function in having clause") {
- checkSQL(
- """
- |SELECT COUNT(a.value), b.KEY, a.KEY
- |FROM parquet_t1 a CROSS JOIN parquet_t1 b
- |GROUP BY a.KEY, b.KEY
- |HAVING MAX(a.KEY) > 0
- """.stripMargin,
- "join_2_tables")
- }
-
- test("generator in project list without FROM clause") {
- checkSQL("SELECT EXPLODE(ARRAY(1,2,3))", "generator_without_from_1")
- checkSQL("SELECT EXPLODE(ARRAY(1,2,3)) AS val", "generator_without_from_2")
- }
-
- test("generator in project list with non-referenced table") {
- checkSQL("SELECT EXPLODE(ARRAY(1,2,3)) FROM t0", "generator_non_referenced_table_1")
- checkSQL("SELECT EXPLODE(ARRAY(1,2,3)) AS val FROM t0", "generator_non_referenced_table_2")
- }
-
- test("generator in project list with referenced table") {
- checkSQL("SELECT EXPLODE(arr) FROM parquet_t3", "generator_referenced_table_1")
- checkSQL("SELECT EXPLODE(arr) AS val FROM parquet_t3", "generator_referenced_table_2")
- }
-
- test("generator in project list with non-UDTF expressions") {
- checkSQL("SELECT EXPLODE(arr), id FROM parquet_t3", "generator_non_udtf_1")
- checkSQL("SELECT EXPLODE(arr) AS val, id as a FROM parquet_t3", "generator_non_udtf_2")
- }
-
- test("generator in lateral view") {
- checkSQL("SELECT val, id FROM parquet_t3 LATERAL VIEW EXPLODE(arr) exp AS val",
- "generator_in_lateral_view_1")
- checkSQL("SELECT val, id FROM parquet_t3 LATERAL VIEW OUTER EXPLODE(arr) exp AS val",
- "generator_in_lateral_view_2")
- }
-
- test("generator in lateral view with ambiguous names") {
- checkSQL(
- """
- |SELECT exp.id, parquet_t3.id
- |FROM parquet_t3
- |LATERAL VIEW EXPLODE(arr) exp AS id
- """.stripMargin,
- "generator_with_ambiguous_names_1")
-
- checkSQL(
- """
- |SELECT exp.id, parquet_t3.id
- |FROM parquet_t3
- |LATERAL VIEW OUTER EXPLODE(arr) exp AS id
- """.stripMargin,
- "generator_with_ambiguous_names_2")
- }
-
- test("use JSON_TUPLE as generator") {
- checkSQL(
- """
- |SELECT c0, c1, c2
- |FROM parquet_t3
- |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt
- """.stripMargin,
- "json_tuple_generator_1")
-
- checkSQL(
- """
- |SELECT a, b, c
- |FROM parquet_t3
- |LATERAL VIEW JSON_TUPLE(json, 'f1', 'f2', 'f3') jt AS a, b, c
- """.stripMargin,
- "json_tuple_generator_2")
- }
-
- test("nested generator in lateral view") {
- checkSQL(
- """
- |SELECT val, id
- |FROM parquet_t3
- |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array
- |LATERAL VIEW EXPLODE(nested_array) exp1 AS val
- """.stripMargin,
- "nested_generator_in_lateral_view_1")
-
- checkSQL(
- """
- |SELECT val, id
- |FROM parquet_t3
- |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array
- |LATERAL VIEW OUTER EXPLODE(nested_array) exp1 AS val
- """.stripMargin,
- "nested_generator_in_lateral_view_2")
- }
-
- test("generate with other operators") {
- checkSQL(
- """
- |SELECT EXPLODE(arr) AS val, id
- |FROM parquet_t3
- |WHERE id > 2
- |ORDER BY val, id
- |LIMIT 5
- """.stripMargin,
- "generate_with_other_1")
-
- checkSQL(
- """
- |SELECT val, id
- |FROM parquet_t3
- |LATERAL VIEW EXPLODE(arr2) exp1 AS nested_array
- |LATERAL VIEW EXPLODE(nested_array) exp1 AS val
- |WHERE val > 2
- |ORDER BY val, id
- |LIMIT 5
- """.stripMargin,
- "generate_with_other_2")
- }
-
- test("filter after subquery") {
- checkSQL("SELECT a FROM (SELECT key + 1 AS a FROM parquet_t1) t WHERE a > 5",
- "filter_after_subquery")
- }
-
- test("SPARK-14933 - select parquet table") {
- withTable("parquet_t") {
- sql("create table parquet_t stored as parquet as select 1 as c1, 'abc' as c2")
- checkSQL("select * from parquet_t", "select_parquet_table")
- }
- }
-
- test("predicate subquery") {
- withTable("t1") {
- withSQLConf(SQLConf.CROSS_JOINS_ENABLED.key -> "true") {
- sql("CREATE TABLE t1(a int)")
- checkSQL("select * from t1 b where exists (select * from t1 a)", "predicate_subquery")
- }
- }
- }
-
- test("broadcast join") {
- checkSQL(
- """
- |SELECT /*+ MAPJOIN(srcpart) */ subq.key1, z.value
- |FROM (SELECT x.key as key1, x.value as value1, y.key as key2, y.value as value2
- | FROM src1 x JOIN src y ON (x.key = y.key)) subq
- |JOIN srcpart z ON (subq.key1 = z.key and z.ds='2008-04-08' and z.hr=11)
- |ORDER BY subq.key1, z.value
- """.stripMargin,
- "broadcast_join_subquery")
- }
-
- test("subquery using single table") {
- checkSQL(
- """
- |SELECT a.k, a.c
- |FROM (SELECT b.key as k, count(1) as c
- | FROM src b
- | GROUP BY b.key) a
- |WHERE a.k >= 90
- """.stripMargin,
- "subq2")
- }
-
- test("correlated subqueries using EXISTS on where clause") {
- checkSQL(
- """
- |select *
- |from src b
- |where exists (select a.key
- | from src a
- | where b.value = a.value and a.key = b.key and a.value > 'val_9')
- """.stripMargin,
- "subquery_exists_1")
-
- checkSQL(
- """
- |select *
- |from (select *
- | from src b
- | where exists (select a.key
- | from src a
- | where b.value = a.value and a.key = b.key and a.value > 'val_9')) a
- """.stripMargin,
- "subquery_exists_2")
- }
-
- test("correlated subqueries using EXISTS on having clause") {
- checkSQL(
- """
- |select b.key, count(*)
- |from src b
- |group by b.key
- |having exists (select a.key
- | from src a
- | where a.key = b.key and a.value > 'val_9')
- """.stripMargin,
- "subquery_exists_having_1")
-
- checkSQL(
- """
- |select *
- |from (select b.key, count(*)
- | from src b
- | group by b.key
- | having exists (select a.key
- | from src a
- | where a.key = b.key and a.value > 'val_9')) a
- """.stripMargin,
- "subquery_exists_having_2")
-
- checkSQL(
- """
- |select b.key, min(b.value)
- |from src b
- |group by b.key
- |having exists (select a.key
- | from src a
- | where a.value > 'val_9' and a.value = min(b.value))
- """.stripMargin,
- "subquery_exists_having_3")
- }
-
- test("correlated subqueries using NOT EXISTS on where clause") {
- checkSQL(
- """
- |select *
- |from src b
- |where not exists (select a.key
- | from src a
- | where b.value = a.value and a.key = b.key and a.value > 'val_2')
- """.stripMargin,
- "subquery_not_exists_1")
-
- checkSQL(
- """
- |select *
- |from src b
- |where not exists (select a.key
- | from src a
- | where b.value = a.value and a.value > 'val_2')
- """.stripMargin,
- "subquery_not_exists_2")
- }
-
- test("correlated subqueries using NOT EXISTS on having clause") {
- checkSQL(
- """
- |select *
- |from src b
- |group by key, value
- |having not exists (select a.key
- | from src a
- | where b.value = a.value and a.key = b.key and a.value > 'val_12')
- """.stripMargin,
- "subquery_not_exists_having_1")
-
- checkSQL(
- """
- |select *
- |from src b
- |group by key, value
- |having not exists (select distinct a.key
- | from src a
- | where b.value = a.value and a.value > 'val_12')
- """.stripMargin,
- "subquery_not_exists_having_2")
- }
-
- test("subquery using IN on where clause") {
- checkSQL(
- """
- |SELECT key
- |FROM src
- |WHERE key in (SELECT max(key) FROM src)
- """.stripMargin,
- "subquery_in")
- }
-
- test("subquery using IN on having clause") {
- checkSQL(
- """
- |select key, count(*)
- |from src
- |group by key
- |having count(*) in (select count(*) from src s1 where s1.key = '90' group by s1.key)
- |order by key
- """.stripMargin,
- "subquery_in_having_1")
-
- checkSQL(
- """
- |select b.key, min(b.value)
- |from src b
- |group by b.key
- |having b.key in (select a.key
- | from src a
- | where a.value > 'val_9' and a.value = min(b.value))
- |order by b.key
- """.stripMargin,
- "subquery_in_having_2")
- }
-
- test("SPARK-14933 - select orc table") {
- withTable("orc_t") {
- sql("create table orc_t stored as orc as select 1 as c1, 'abc' as c2")
- checkSQL("select * from orc_t", "select_orc_table")
- }
- }
-
- test("inline tables") {
- checkSQL(
- """
- |select * from values ("one", 1), ("two", 2), ("three", null) as data(a, b) where b > 1
- """.stripMargin,
- "inline_tables")
- }
-
- test("SPARK-17750 - interval arithmetic") {
- withTable("dates") {
- sql("create table dates (ts timestamp)")
- checkSQL(
- """
- |select ts + interval 1 day, ts + interval 2 days,
- | ts - interval 1 day, ts - interval 2 days,
- | ts + interval '1' day, ts + interval '2' days,
- | ts - interval '1' day, ts - interval '2' days
- |from dates
- """.stripMargin,
- "interval_arithmetic"
- )
- }
- }
-
- test("SPARK-17982 - limit") {
- withTable("tbl") {
- sql("CREATE TABLE tbl(id INT, name STRING)")
- checkSQL(
- "SELECT * FROM (SELECT id FROM tbl LIMIT 2)",
- "limit"
- )
- }
- }
-}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
index 31755f56ec..157783abc8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/catalyst/SQLBuilderTest.scala
@@ -41,33 +41,4 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton {
""".stripMargin)
}
}
-
- protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = {
- val generatedSQL = try new SQLBuilder(plan).toSQL catch { case NonFatal(e) =>
- fail(
- s"""Cannot convert the following logical query plan to SQL:
- |
- |${plan.treeString}
- """.stripMargin)
- }
-
- try {
- assert(generatedSQL === expectedSQL)
- } catch {
- case cause: Throwable =>
- fail(
- s"""Wrong SQL generated for the following logical query plan:
- |
- |${plan.treeString}
- |
- |$cause
- """.stripMargin)
- }
-
- checkAnswer(spark.sql(generatedSQL), Dataset.ofRows(spark, plan))
- }
-
- protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = {
- checkSQL(df.queryExecution.analyzed, expectedSQL)
- }
}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
index 4772a264d6..f3151d52f2 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala
@@ -27,7 +27,6 @@ import org.scalatest.{BeforeAndAfterAll, GivenWhenThen}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.Dataset
-import org.apache.spark.sql.catalyst.SQLBuilder
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.util._
@@ -343,57 +342,8 @@ abstract class HiveComparisonTest
// Run w/ catalyst
val catalystResults = queryList.zip(hiveResults).map { case (queryString, hive) =>
- var query: TestHiveQueryExecution = null
- try {
- query = {
- val originalQuery = new TestHiveQueryExecution(
- queryString.replace("../../data", testDataPath))
- val containsCommands = originalQuery.analyzed.collectFirst {
- case _: Command => ()
- case _: InsertIntoTable => ()
- }.nonEmpty
-
- if (containsCommands) {
- originalQuery
- } else {
- val convertedSQL = try {
- new SQLBuilder(originalQuery.analyzed).toSQL
- } catch {
- case NonFatal(e) => fail(
- s"""Cannot convert the following HiveQL query plan back to SQL query string:
- |
- |# Original HiveQL query string:
- |$queryString
- |
- |# Resolved query plan:
- |${originalQuery.analyzed.treeString}
- """.stripMargin, e)
- }
-
- try {
- val queryExecution = new TestHiveQueryExecution(convertedSQL)
- // Trigger the analysis of this converted SQL query.
- queryExecution.analyzed
- queryExecution
- } catch {
- case NonFatal(e) => fail(
- s"""Failed to analyze the converted SQL string:
- |
- |# Original HiveQL query string:
- |$queryString
- |
- |# Resolved query plan:
- |${originalQuery.analyzed.treeString}
- |
- |# Converted SQL query string:
- |$convertedSQL
- """.stripMargin, e)
- }
- }
- }
-
- (query, prepareAnswer(query, query.hiveResultString()))
- } catch {
+ val query = new TestHiveQueryExecution(queryString.replace("../../data", testDataPath))
+ try { (query, prepareAnswer(query, query.hiveResultString())) } catch {
case e: Throwable =>
val errorMessage =
s"""