diff options
author | Reynold Xin <rxin@databricks.com> | 2016-02-12 10:08:19 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-02-12 10:08:19 -0800 |
commit | c4d5ad80c8091c961646a82e85ecbc335b8ffe2d (patch) | |
tree | 593586addd182aef21ec026e16d7a4a33e132fe0 /sql | |
parent | 5b805df279d744543851f06e5a0d741354ef485b (diff) | |
download | spark-c4d5ad80c8091c961646a82e85ecbc335b8ffe2d.tar.gz spark-c4d5ad80c8091c961646a82e85ecbc335b8ffe2d.tar.bz2 spark-c4d5ad80c8091c961646a82e85ecbc335b8ffe2d.zip |
[SPARK-13282][SQL] LogicalPlan toSql should just return a String
Previously we were using Option[String] and None to indicate the case when Spark fails to generate SQL. It is easier to just use exceptions to propagate error cases, rather than having for comprehension everywhere. I also introduced a "build" function that simplifies string concatenation (i.e. no need to reason about whether we have an extra space or not).
Author: Reynold Xin <rxin@databricks.com>
Closes #11171 from rxin/SPARK-13282.
Diffstat (limited to 'sql')
6 files changed, 141 insertions, 156 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala index 4b75e60f8d..d7bae913f8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/SQLBuilder.scala @@ -19,10 +19,12 @@ package org.apache.spark.sql.hive import java.util.concurrent.atomic.AtomicLong +import scala.util.control.NonFatal + import org.apache.spark.Logging import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Attribute, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.{Rule, RuleExecutor} @@ -37,16 +39,10 @@ import org.apache.spark.sql.execution.datasources.LogicalRelation class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Logging { def this(df: DataFrame) = this(df.queryExecution.analyzed, df.sqlContext) - def toSQL: Option[String] = { + def toSQL: String = { val canonicalizedPlan = Canonicalizer.execute(logicalPlan) - val maybeSQL = try { - toSQL(canonicalizedPlan) - } catch { case cause: UnsupportedOperationException => - logInfo(s"Failed to build SQL query string because: ${cause.getMessage}") - None - } - - if (maybeSQL.isDefined) { + try { + val generatedSQL = toSQL(canonicalizedPlan) logDebug( s"""Built SQL query string successfully from given logical plan: | @@ -54,10 +50,11 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi |${logicalPlan.treeString} |# Canonicalized logical plan: |${canonicalizedPlan.treeString} - |# Built SQL query string: - |${maybeSQL.get} + |# Generated SQL: + |$generatedSQL """.stripMargin) - } else { + generatedSQL + } catch { case NonFatal(e) => logDebug( s"""Failed to build SQL query string from given logical plan: | @@ -66,128 +63,113 @@ class SQLBuilder(logicalPlan: LogicalPlan, sqlContext: SQLContext) extends Loggi |# Canonicalized logical plan: |${canonicalizedPlan.treeString} """.stripMargin) + throw e } - - maybeSQL } - private def projectToSQL( - projectList: Seq[NamedExpression], - child: LogicalPlan, - isDistinct: Boolean): Option[String] = { - for { - childSQL <- toSQL(child) - listSQL = projectList.map(_.sql).mkString(", ") - maybeFrom = child match { - case OneRowRelation => " " - case _ => " FROM " - } - distinct = if (isDistinct) " DISTINCT " else " " - } yield s"SELECT$distinct$listSQL$maybeFrom$childSQL" - } + private def toSQL(node: LogicalPlan): String = node match { + case Distinct(p: Project) => + projectToSQL(p, isDistinct = true) - private def aggregateToSQL( - groupingExprs: Seq[Expression], - aggExprs: Seq[Expression], - child: LogicalPlan): Option[String] = { - val aggSQL = aggExprs.map(_.sql).mkString(", ") - val groupingSQL = groupingExprs.map(_.sql).mkString(", ") - val maybeGroupBy = if (groupingSQL.isEmpty) "" else " GROUP BY " - val maybeFrom = child match { - case OneRowRelation => " " - case _ => " FROM " - } + case p: Project => + projectToSQL(p, isDistinct = false) - toSQL(child).map { childSQL => - s"SELECT $aggSQL$maybeFrom$childSQL$maybeGroupBy$groupingSQL" - } - } + case p: Aggregate => + aggregateToSQL(p) - private def toSQL(node: LogicalPlan): Option[String] = node match { - case Distinct(Project(list, child)) => - projectToSQL(list, child, isDistinct = true) - - case Project(list, child) => - projectToSQL(list, child, isDistinct = false) - - case Aggregate(groupingExprs, aggExprs, child) => - aggregateToSQL(groupingExprs, aggExprs, child) - - case Limit(limit, child) => - for { - childSQL <- toSQL(child) - limitSQL = limit.sql - } yield s"$childSQL LIMIT $limitSQL" - - case Filter(condition, child) => - for { - childSQL <- toSQL(child) - whereOrHaving = child match { - case _: Aggregate => "HAVING" - case _ => "WHERE" - } - conditionSQL = condition.sql - } yield s"$childSQL $whereOrHaving $conditionSQL" - - case Union(children) if children.length > 1 => - val childrenSql = children.map(toSQL(_)) - if (childrenSql.exists(_.isEmpty)) { - None - } else { - Some(childrenSql.map(_.get).mkString(" UNION ALL ")) + case p: Limit => + s"${toSQL(p.child)} LIMIT ${p.limitExpr.sql}" + + case p: Filter => + val whereOrHaving = p.child match { + case _: Aggregate => "HAVING" + case _ => "WHERE" + } + build(toSQL(p.child), whereOrHaving, p.condition.sql) + + case p: Union if p.children.length > 1 => + val childrenSql = p.children.map(toSQL(_)) + childrenSql.mkString(" UNION ALL ") + + case p: Subquery => + p.child match { + // Persisted data source relation + case LogicalRelation(_, _, Some(TableIdentifier(table, Some(database)))) => + s"`$database`.`$table`" + // Parentheses is not used for persisted data source relations + // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 + case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => + build(toSQL(p.child), "AS", p.alias) + case _ => + build("(" + toSQL(p.child) + ")", "AS", p.alias) } - // Persisted data source relation - case Subquery(alias, LogicalRelation(_, _, Some(TableIdentifier(table, Some(database))))) => - Some(s"`$database`.`$table`") - - case Subquery(alias, child) => - toSQL(child).map( childSQL => - child match { - // Parentheses is not used for persisted data source relations - // e.g., select x.c1 from (t1) as x inner join (t1) as y on x.c1 = y.c1 - case Subquery(_, _: LogicalRelation | _: MetastoreRelation) => - s"$childSQL AS $alias" - case _ => - s"($childSQL) AS $alias" - }) - - case Join(left, right, joinType, condition) => - for { - leftSQL <- toSQL(left) - rightSQL <- toSQL(right) - joinTypeSQL = joinType.sql - conditionSQL = condition.map(" ON " + _.sql).getOrElse("") - } yield s"$leftSQL $joinTypeSQL JOIN $rightSQL$conditionSQL" - - case MetastoreRelation(database, table, alias) => - val aliasSQL = alias.map(a => s" AS `$a`").getOrElse("") - Some(s"`$database`.`$table`$aliasSQL") + case p: Join => + build( + toSQL(p.left), + p.joinType.sql, + "JOIN", + toSQL(p.right), + p.condition.map(" ON " + _.sql).getOrElse("")) + + case p: MetastoreRelation => + build( + s"`${p.databaseName}`.`${p.tableName}`", + p.alias.map(a => s" AS `$a`").getOrElse("") + ) case Sort(orders, _, RepartitionByExpression(partitionExprs, child, _)) if orders.map(_.child) == partitionExprs => - for { - childSQL <- toSQL(child) - partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") - } yield s"$childSQL CLUSTER BY $partitionExprsSQL" - - case Sort(orders, global, child) => - for { - childSQL <- toSQL(child) - ordersSQL = orders.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") - orderOrSort = if (global) "ORDER" else "SORT" - } yield s"$childSQL $orderOrSort BY $ordersSQL" - - case RepartitionByExpression(partitionExprs, child, _) => - for { - childSQL <- toSQL(child) - partitionExprsSQL = partitionExprs.map(_.sql).mkString(", ") - } yield s"$childSQL DISTRIBUTE BY $partitionExprsSQL" + build(toSQL(child), "CLUSTER BY", partitionExprs.map(_.sql).mkString(", ")) + + case p: Sort => + build( + toSQL(p.child), + if (p.global) "ORDER BY" else "SORT BY", + p.order.map { case SortOrder(e, dir) => s"${e.sql} ${dir.sql}" }.mkString(", ") + ) + + case p: RepartitionByExpression => + build( + toSQL(p.child), + "DISTRIBUTE BY", + p.partitionExpressions.map(_.sql).mkString(", ") + ) case OneRowRelation => - Some("") + "" - case _ => None + case _ => + throw new UnsupportedOperationException(s"unsupported plan $node") + } + + /** + * Turns a bunch of string segments into a single string and separate each segment by a space. + * The segments are trimmed so only a single space appears in the separation. + * For example, `build("a", " b ", " c")` becomes "a b c". + */ + private def build(segments: String*): String = segments.map(_.trim).mkString(" ") + + private def projectToSQL(plan: Project, isDistinct: Boolean): String = { + build( + "SELECT", + if (isDistinct) "DISTINCT" else "", + plan.projectList.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child) + ) + } + + private def aggregateToSQL(plan: Aggregate): String = { + val groupingSQL = plan.groupingExpressions.map(_.sql).mkString(", ") + build( + "SELECT", + plan.aggregateExpressions.map(_.sql).mkString(", "), + if (plan.child == OneRowRelation) "" else "FROM", + toSQL(plan.child), + if (groupingSQL.isEmpty) "" else "GROUP BY", + groupingSQL + ) } object Canonicalizer extends RuleExecutor[LogicalPlan] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 31bda56e8a..5da58a73e1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive.execution +import scala.util.control.NonFatal + import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.expressions.Alias @@ -72,7 +74,9 @@ private[hive] case class CreateViewAsSelect( private def prepareTable(sqlContext: SQLContext): HiveTable = { val expandedText = if (sqlContext.conf.canonicalView) { - rebuildViewQueryString(sqlContext).getOrElse(wrapViewTextWithSelect) + try rebuildViewQueryString(sqlContext) catch { + case NonFatal(e) => wrapViewTextWithSelect + } } else { wrapViewTextWithSelect } @@ -112,7 +116,7 @@ private[hive] case class CreateViewAsSelect( s"SELECT $viewOutput FROM ($viewText) $viewName" } - private def rebuildViewQueryString(sqlContext: SQLContext): Option[String] = { + private def rebuildViewQueryString(sqlContext: SQLContext): String = { val logicalPlan = if (tableDesc.schema.isEmpty) { child } else { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala index 3a6eb57add..3fb6543b1a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionSQLBuilderSuite.scala @@ -33,8 +33,7 @@ class ExpressionSQLBuilderSuite extends SQLBuilderTest { checkSQL(Literal(1.5F), "CAST(1.5 AS FLOAT)") checkSQL(Literal(2.5D), "2.5") checkSQL( - Literal(Timestamp.valueOf("2016-01-01 00:00:00")), - "TIMESTAMP('2016-01-01 00:00:00.0')") + Literal(Timestamp.valueOf("2016-01-01 00:00:00")), "TIMESTAMP('2016-01-01 00:00:00.0')") // TODO tests for decimals } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 80ae312d91..dc8ac7e47f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.util.control.NonFatal + import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SQLTestUtils @@ -46,29 +48,28 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { private def checkHiveQl(hiveQl: String): Unit = { val df = sql(hiveQl) - val convertedSQL = new SQLBuilder(df).toSQL - if (convertedSQL.isEmpty) { - fail( - s"""Cannot convert the following HiveQL query plan back to SQL query string: - | - |# Original HiveQL query string: - |$hiveQl - | - |# Resolved query plan: - |${df.queryExecution.analyzed.treeString} - """.stripMargin) + val convertedSQL = try new SQLBuilder(df).toSQL catch { + case NonFatal(e) => + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) } - val sqlString = convertedSQL.get try { - checkAnswer(sql(sqlString), df) + checkAnswer(sql(convertedSQL), df) } catch { case cause: Throwable => fail( s"""Failed to execute converted SQL string or got wrong answer: | |# Converted SQL query string: - |$sqlString + |$convertedSQL | |# Original HiveQL query string: |$hiveQl diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala index a5e209ac9d..4adc5c1116 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/SQLBuilderTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.hive +import scala.util.control.NonFatal + import org.apache.spark.sql.{DataFrame, QueryTest} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -40,9 +42,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { } protected def checkSQL(plan: LogicalPlan, expectedSQL: String): Unit = { - val maybeSQL = new SQLBuilder(plan, hiveContext).toSQL - - if (maybeSQL.isEmpty) { + val generatedSQL = try new SQLBuilder(plan, hiveContext).toSQL catch { case NonFatal(e) => fail( s"""Cannot convert the following logical query plan to SQL: | @@ -50,10 +50,8 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - val actualSQL = maybeSQL.get - try { - assert(actualSQL === expectedSQL) + assert(generatedSQL === expectedSQL) } catch { case cause: Throwable => fail( @@ -65,7 +63,7 @@ abstract class SQLBuilderTest extends QueryTest with TestHiveSingleton { """.stripMargin) } - checkAnswer(sqlContext.sql(actualSQL), new DataFrame(sqlContext, plan)) + checkAnswer(sqlContext.sql(generatedSQL), new DataFrame(sqlContext, plan)) } protected def checkSQL(df: DataFrame, expectedSQL: String): Unit = { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 207bb814f0..af4c44e578 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -412,21 +412,22 @@ abstract class HiveComparisonTest originalQuery } else { numTotalQueries += 1 - new SQLBuilder(originalQuery.analyzed, TestHive).toSQL.map { sql => + try { + val sql = new SQLBuilder(originalQuery.analyzed, TestHive).toSQL numConvertibleQueries += 1 logInfo( s""" - |### Running SQL generation round-trip test {{{ - |${originalQuery.analyzed.treeString} - |Original SQL: - |$queryString - | - |Generated SQL: - |$sql - |}}} + |### Running SQL generation round-trip test {{{ + |${originalQuery.analyzed.treeString} + |Original SQL: + |$queryString + | + |Generated SQL: + |$sql + |}}} """.stripMargin.trim) new TestHive.QueryExecution(sql) - }.getOrElse { + } catch { case NonFatal(e) => logInfo( s""" |### Cannot convert the following logical plan back to SQL {{{ |