aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorHerman van Hovell <hvanhovell@questtec.nl>2016-05-10 09:56:07 -0700
committerDavies Liu <davies.liu@gmail.com>2016-05-10 09:56:07 -0700
commit2646265368aab0f0b800d3052e557dea7c40c2d6 (patch)
tree42950c53642f66e5c2c198a820cd7f9be015892c /sql/core
parent2dfb9cd1f7e7f0438ce571aae7e3a7b77d4082b7 (diff)
downloadspark-2646265368aab0f0b800d3052e557dea7c40c2d6.tar.gz
spark-2646265368aab0f0b800d3052e557dea7c40c2d6.tar.bz2
spark-2646265368aab0f0b800d3052e557dea7c40c2d6.zip
[SPARK-14773] [SPARK-15179] [SQL] Fix SQL building and enable Hive tests
## What changes were proposed in this pull request? This PR fixes SQL building for predicate subqueries and correlated scalar subqueries. It also enables most Hive subquery tests. ## How was this patch tested? Enabled new tests in HiveComparisionSuite. Author: Herman van Hovell <hvanhovell@questtec.nl> Closes #12988 from hvanhovell/SPARK-14773.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala63
1 files changed, 58 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
index 7f26a7e411..9dc367920e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala
@@ -69,8 +69,14 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging {
try {
val replaced = finalPlan.transformAllExpressions {
- case e: SubqueryExpression =>
- SubqueryHolder(new SQLBuilder(e.query).toSQL)
+ case s: SubqueryExpression =>
+ val query = new SQLBuilder(s.query).toSQL
+ val sql = s match {
+ case _: ListQuery => query
+ case _: Exists => s"EXISTS($query)"
+ case _ => s"($query)"
+ }
+ SubqueryHolder(sql)
case e: NonSQLExpression =>
throw new UnsupportedOperationException(
s"Expression $e doesn't have a SQL representation"
@@ -404,7 +410,9 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging {
// that table relation, as we can only convert table sample to standard SQL string.
ResolveSQLTable,
// Insert sub queries on top of operators that need to appear after FROM clause.
- AddSubquery
+ AddSubquery,
+ // Reconstruct subquery expressions.
+ ConstructSubqueryExpressions
)
)
@@ -484,6 +492,52 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging {
}
}
+ object ConstructSubqueryExpressions extends Rule[LogicalPlan] {
+ def apply(tree: LogicalPlan): LogicalPlan = tree transformAllExpressions {
+ case ScalarSubquery(query, conditions, exprId) if conditions.nonEmpty =>
+ def rewriteAggregate(a: Aggregate): Aggregate = {
+ val filter = Filter(conditions.reduce(And), addSubqueryIfNeeded(a.child))
+ Aggregate(Nil, a.aggregateExpressions.take(1), filter)
+ }
+ val cleaned = query match {
+ case Project(_, child) => child
+ case child => child
+ }
+ val rewrite = cleaned match {
+ case a: Aggregate =>
+ rewriteAggregate(a)
+ case Filter(c, a: Aggregate) =>
+ Filter(c, rewriteAggregate(a))
+ }
+ ScalarSubquery(rewrite, Seq.empty, exprId)
+
+ case PredicateSubquery(query, conditions, false, exprId) =>
+ val plan = Project(Seq(Alias(Literal(1), "1")()),
+ Filter(conditions.reduce(And), addSubqueryIfNeeded(query)))
+ Exists(plan, exprId)
+
+ case PredicateSubquery(query, conditions, true, exprId) =>
+ val (in, correlated) = conditions.partition(_.isInstanceOf[EqualTo])
+ val (outer, inner) = in.zipWithIndex.map {
+ case (EqualTo(l, r), i) if query.outputSet.intersect(r.references).nonEmpty =>
+ (l, Alias(r, s"_c$i")())
+ case (EqualTo(r, l), i) =>
+ (l, Alias(r, s"_c$i")())
+ }.unzip
+ val wrapped = addSubqueryIfNeeded(query)
+ val filtered = if (correlated.nonEmpty) {
+ Filter(conditions.reduce(And), wrapped)
+ } else {
+ wrapped
+ }
+ val value = outer match {
+ case Seq(expr) => expr
+ case exprs => CreateStruct(exprs)
+ }
+ In(value, Seq(ListQuery(Project(inner, filtered), exprId)))
+ }
+ }
+
private def addSubquery(plan: LogicalPlan): SubqueryAlias = {
SubqueryAlias(newSubqueryName(), plan)
}
@@ -526,9 +580,8 @@ class SQLBuilder(logicalPlan: LogicalPlan) extends Logging {
/**
* A place holder for generated SQL for subquery expression.
*/
- case class SubqueryHolder(query: String) extends LeafExpression with Unevaluable {
+ case class SubqueryHolder(override val sql: String) extends LeafExpression with Unevaluable {
override def dataType: DataType = NullType
override def nullable: Boolean = true
- override def sql: String = s"($query)"
}
}