aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala27
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala29
3 files changed, 58 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
index b1b449a0b3..7d5a46873c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala
@@ -204,20 +204,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
case None => OneRowRelation
}
+ val withLateralView = lateralViewClause.map { lv =>
+ nodeToGenerate(lv.children.head, outer = false, relations)
+ }.getOrElse(relations)
+
val withWhere = whereClause.map { whereNode =>
val Seq(whereExpr) = whereNode.children
- Filter(nodeToExpr(whereExpr), relations)
- }.getOrElse(relations)
+ Filter(nodeToExpr(whereExpr), withLateralView)
+ }.getOrElse(withLateralView)
val select = (selectClause orElse selectDistinctClause)
.getOrElse(sys.error("No select clause."))
val transformation = nodeToTransformation(select.children.head, withWhere)
- val withLateralView = lateralViewClause.map { lv =>
- nodeToGenerate(lv.children.head, outer = false, withWhere)
- }.getOrElse(withWhere)
-
// The projection of the query can either be a normal projection, an aggregation
// (if there is a group by) or a script transformation.
val withProject: LogicalPlan = transformation.getOrElse {
@@ -227,13 +227,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
groupByClause.map(e => e match {
case Token("TOK_GROUPBY", children) =>
// Not a transformation so must be either project or aggregation.
- Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView)
+ Aggregate(children.map(nodeToExpr), selectExpressions, withWhere)
case _ => sys.error("Expect GROUP BY")
}),
groupingSetsClause.map(e => e match {
case Token("TOK_GROUPING_SETS", children) =>
val(groupByExprs, masks) = extractGroupingSet(children)
- GroupingSets(masks, groupByExprs, withLateralView, selectExpressions)
+ GroupingSets(masks, groupByExprs, withWhere, selectExpressions)
case _ => sys.error("Expect GROUPING SETS")
}),
rollupGroupByClause.map(e => e match {
@@ -241,7 +241,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
Aggregate(
Seq(Rollup(children.map(nodeToExpr))),
selectExpressions,
- withLateralView)
+ withWhere)
case _ => sys.error("Expect WITH ROLLUP")
}),
cubeGroupByClause.map(e => e match {
@@ -249,10 +249,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
Aggregate(
Seq(Cube(children.map(nodeToExpr))),
selectExpressions,
- withLateralView)
+ withWhere)
case _ => sys.error("Expect WITH CUBE")
}),
- Some(Project(selectExpressions, withLateralView))).flatten.head
+ Some(Project(selectExpressions, withWhere))).flatten.head
}
// Handle HAVING clause.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 969fcdf428..ac2ca3c5a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -134,15 +134,24 @@ class Dataset[T] private[sql](
this(sqlContext, sqlContext.executePlan(logicalPlan), encoder)
}
- @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match {
- // For various commands (like DDL) and queries with side effects, we force query optimization to
- // happen right away to let these side effects take place eagerly.
- case _: Command |
- _: InsertIntoTable |
- _: CreateTableUsingAsSelect =>
- LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
- case _ =>
- queryExecution.analyzed
+ @transient protected[sql] val logicalPlan: LogicalPlan = {
+ def hasSideEffects(plan: LogicalPlan): Boolean = plan match {
+ case _: Command |
+ _: InsertIntoTable |
+ _: CreateTableUsingAsSelect => true
+ case _ => false
+ }
+
+ queryExecution.logical match {
+ // For various commands (like DDL) and queries with side effects, we force query execution
+ // to happen right away to let these side effects take place eagerly.
+ case p if hasSideEffects(p) =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case Union(children) if children.forall(hasSideEffects) =>
+ LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext)
+ case _ =>
+ queryExecution.analyzed
+ }
}
/**
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 21dfb82876..d6c10d6ed9 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -1777,4 +1777,33 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
|FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test
""".stripMargin), Row("value1", "12", BigDecimal("3.14"), "hello"))
}
+
+ test("multi-insert with lateral view") {
+ withTempTable("t1") {
+ sqlContext.range(10)
+ .select(array($"id", $"id" + 1).as("arr"), $"id")
+ .registerTempTable("source")
+ withTable("dest1", "dest2") {
+ sql("CREATE TABLE dest1 (i INT)")
+ sql("CREATE TABLE dest2 (i INT)")
+ sql(
+ """
+ |FROM source
+ |INSERT OVERWRITE TABLE dest1
+ |SELECT id
+ |WHERE id > 3
+ |INSERT OVERWRITE TABLE dest2
+ |select col LATERAL VIEW EXPLODE(arr) exp AS col
+ |WHERE col > 3
+ """.stripMargin)
+
+ checkAnswer(
+ sqlContext.table("dest1"),
+ sql("SELECT id FROM source WHERE id > 3"))
+ checkAnswer(
+ sqlContext.table("dest2"),
+ sql("SELECT col FROM source LATERAL VIEW EXPLODE(arr) exp AS col WHERE col > 3"))
+ }
+ }
+ }
}