diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-03-16 10:52:36 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-03-16 10:52:36 -0700 |
commit | d9e8f26d0334f393e3b02d7a3b607be54a2a5efe (patch) | |
tree | 9790f502141132cf03c860411031f59cacd366ae /sql | |
parent | eacd9d8eda68260bbda7b0cd07410321dffaf428 (diff) | |
download | spark-d9e8f26d0334f393e3b02d7a3b607be54a2a5efe.tar.gz spark-d9e8f26d0334f393e3b02d7a3b607be54a2a5efe.tar.bz2 spark-d9e8f26d0334f393e3b02d7a3b607be54a2a5efe.zip |
[SPARK-13924][SQL] officially support multi-insert
## What changes were proposed in this pull request?
There is a feature of hive SQL called multi-insert. For example:
```
FROM src
INSERT OVERWRITE TABLE dest1
SELECT key + 1
INSERT OVERWRITE TABLE dest2
SELECT key WHERE key > 2
INSERT OVERWRITE TABLE dest3
SELECT col EXPLODE(arr) exp AS col
...
```
We partially support it currently, with some limitations: 1) WHERE can't reference columns produced by LATERAL VIEW. 2) It's not executed eagerly, i.e. `sql("...multi-insert clause...")` won't take place right away like other commands, e.g. CREATE TABLE.
This PR removes these limitations and make us fully support multi-insert.
## How was this patch tested?
new tests in `SQLQuerySuite`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #11754 from cloud-fan/lateral-view.
Diffstat (limited to 'sql')
3 files changed, 58 insertions, 20 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala index b1b449a0b3..7d5a46873c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala @@ -204,20 +204,20 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C case None => OneRowRelation } + val withLateralView = lateralViewClause.map { lv => + nodeToGenerate(lv.children.head, outer = false, relations) + }.getOrElse(relations) + val withWhere = whereClause.map { whereNode => val Seq(whereExpr) = whereNode.children - Filter(nodeToExpr(whereExpr), relations) - }.getOrElse(relations) + Filter(nodeToExpr(whereExpr), withLateralView) + }.getOrElse(withLateralView) val select = (selectClause orElse selectDistinctClause) .getOrElse(sys.error("No select clause.")) val transformation = nodeToTransformation(select.children.head, withWhere) - val withLateralView = lateralViewClause.map { lv => - nodeToGenerate(lv.children.head, outer = false, withWhere) - }.getOrElse(withWhere) - // The projection of the query can either be a normal projection, an aggregation // (if there is a group by) or a script transformation. val withProject: LogicalPlan = transformation.getOrElse { @@ -227,13 +227,13 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C groupByClause.map(e => e match { case Token("TOK_GROUPBY", children) => // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withLateralView) + Aggregate(children.map(nodeToExpr), selectExpressions, withWhere) case _ => sys.error("Expect GROUP BY") }), groupingSetsClause.map(e => e match { case Token("TOK_GROUPING_SETS", children) => val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withLateralView, selectExpressions) + GroupingSets(masks, groupByExprs, withWhere, selectExpressions) case _ => sys.error("Expect GROUPING SETS") }), rollupGroupByClause.map(e => e match { @@ -241,7 +241,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Aggregate( Seq(Rollup(children.map(nodeToExpr))), selectExpressions, - withLateralView) + withWhere) case _ => sys.error("Expect WITH ROLLUP") }), cubeGroupByClause.map(e => e match { @@ -249,10 +249,10 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C Aggregate( Seq(Cube(children.map(nodeToExpr))), selectExpressions, - withLateralView) + withWhere) case _ => sys.error("Expect WITH CUBE") }), - Some(Project(selectExpressions, withLateralView))).flatten.head + Some(Project(selectExpressions, withWhere))).flatten.head } // Handle HAVING clause. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 969fcdf428..ac2ca3c5a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -134,15 +134,24 @@ class Dataset[T] private[sql]( this(sqlContext, sqlContext.executePlan(logicalPlan), encoder) } - @transient protected[sql] val logicalPlan: LogicalPlan = queryExecution.logical match { - // For various commands (like DDL) and queries with side effects, we force query optimization to - // happen right away to let these side effects take place eagerly. - case _: Command | - _: InsertIntoTable | - _: CreateTableUsingAsSelect => - LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) - case _ => - queryExecution.analyzed + @transient protected[sql] val logicalPlan: LogicalPlan = { + def hasSideEffects(plan: LogicalPlan): Boolean = plan match { + case _: Command | + _: InsertIntoTable | + _: CreateTableUsingAsSelect => true + case _ => false + } + + queryExecution.logical match { + // For various commands (like DDL) and queries with side effects, we force query execution + // to happen right away to let these side effects take place eagerly. + case p if hasSideEffects(p) => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case Union(children) if children.forall(hasSideEffects) => + LogicalRDD(queryExecution.analyzed.output, queryExecution.toRdd)(sqlContext) + case _ => + queryExecution.analyzed + } } /** diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 21dfb82876..d6c10d6ed9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1777,4 +1777,33 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { |FROM (SELECT '{"f1": "value1", "f2": 12}' json, 'hello' as str) test """.stripMargin), Row("value1", "12", BigDecimal("3.14"), "hello")) } + + test("multi-insert with lateral view") { + withTempTable("t1") { + sqlContext.range(10) + .select(array($"id", $"id" + 1).as("arr"), $"id") + .registerTempTable("source") + withTable("dest1", "dest2") { + sql("CREATE TABLE dest1 (i INT)") + sql("CREATE TABLE dest2 (i INT)") + sql( + """ + |FROM source + |INSERT OVERWRITE TABLE dest1 + |SELECT id + |WHERE id > 3 + |INSERT OVERWRITE TABLE dest2 + |select col LATERAL VIEW EXPLODE(arr) exp AS col + |WHERE col > 3 + """.stripMargin) + + checkAnswer( + sqlContext.table("dest1"), + sql("SELECT id FROM source WHERE id > 3")) + checkAnswer( + sqlContext.table("dest2"), + sql("SELECT col FROM source LATERAL VIEW EXPLODE(arr) exp AS col WHERE col > 3")) + } + } + } } |