From aac13fb48c8aa7d6816ea46c2e40154913477717 Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Sun, 27 Mar 2016 23:50:23 -0700 Subject: [SPARK-14185][SQL][MINOR] Make indentation of debug log for generated code proper ## What changes were proposed in this pull request? The indentation of debug log output by `CodeGenerator` is weird. The first line of the generated code should be put on the next line of the first line of the log message. ``` 16/03/28 11:10:24 DEBUG CodeGenerator: /* 001 */ /* 002 */ public java.lang.Object generate(Object[] references) { /* 003 */ return new SpecificSafeProjection(references); ... ``` After this patch is applied, we get debug log like as follows. ``` 16/03/28 10:45:50 DEBUG CodeGenerator: /* 001 */ /* 002 */ public java.lang.Object generate(Object[] references) { /* 003 */ return new SpecificSafeProjection(references); ... ``` ## How was this patch tested? Ran some jobs and checked debug logs. Author: Kousuke Saruta Closes #11990 from sarutak/fix-debuglog-indentation. --- .../apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b511b4b3a0..cd490dd676 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -661,7 +661,7 @@ object CodeGenerator extends Logging { logDebug({ // Only add extra debugging info to byte code when we are going to print the source code. evaluator.setDebuggingInformation(true, true, false) - formatted + s"\n$formatted" }) try { -- cgit v1.2.3 From 4a7636f2da2121ee8c6fb7e6614820aaf3db8e0f Mon Sep 17 00:00:00 2001 From: Kazuaki Ishizaki Date: Mon, 28 Mar 2016 10:35:48 -0700 Subject: [SPARK-13844] [SQL] Generate better code for filters with a non-nullable column ## What changes were proposed in this pull request? This PR simplifies generated code with a non-nullable column. This PR addresses three items: 1. Generate simplified code for and / or 2. Generate better code for divide and remainder with non-zero dividend 3. Pass nullable information into BoundReference at WholeStageCodegen I have attached the generated code with and without this PR ## How was this patch tested? Tested by existing test suites in sql/core Here is a motivating example ```` (0 to 6).map(i => (i.toString, i.toInt)).toDF("k", "v") .filter("v % 2 == 0").filter("v <= 4").filter("v > 1").show() ```` Generated code without this PR ````java /* 032 */ protected void processNext() throws java.io.IOException { /* 033 */ /*** PRODUCE: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 034 */ /* 035 */ /*** PRODUCE: Filter ((isnotnull((_2#1 % 2)) && ((_2#1 % 2) = 0)) && ((_2#1 <= 4) && (_2#1 > 1))) */ /* 036 */ /* 037 */ /*** PRODUCE: INPUT */ /* 038 */ /* 039 */ while (!shouldStop() && inputadapter_input.hasNext()) { /* 040 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 041 */ /*** CONSUME: Filter ((isnotnull((_2#1 % 2)) && ((_2#1 % 2) = 0)) && ((_2#1 <= 4) && (_2#1 > 1))) */ /* 042 */ /* input[1, int] */ /* 043 */ int filter_value1 = inputadapter_row.getInt(1); /* 044 */ /* 045 */ /* isnotnull((input[1, int] % 2)) */ /* 046 */ /* (input[1, int] % 2) */ /* 047 */ boolean filter_isNull3 = false; /* 048 */ int filter_value3 = -1; /* 049 */ if (false || 2 == 0) { /* 050 */ filter_isNull3 = true; /* 051 */ } else { /* 052 */ if (false) { /* 053 */ filter_isNull3 = true; /* 054 */ } else { /* 055 */ filter_value3 = (int)(filter_value1 % 2); /* 056 */ } /* 057 */ } /* 058 */ if (!(!(filter_isNull3))) continue; /* 059 */ /* 060 */ /* ((input[1, int] % 2) = 0) */ /* 061 */ boolean filter_isNull6 = true; /* 062 */ boolean filter_value6 = false; /* 063 */ /* (input[1, int] % 2) */ /* 064 */ boolean filter_isNull7 = false; /* 065 */ int filter_value7 = -1; /* 066 */ if (false || 2 == 0) { /* 067 */ filter_isNull7 = true; /* 068 */ } else { /* 069 */ if (false) { /* 070 */ filter_isNull7 = true; /* 071 */ } else { /* 072 */ filter_value7 = (int)(filter_value1 % 2); /* 073 */ } /* 074 */ } /* 075 */ if (!filter_isNull7) { /* 076 */ filter_isNull6 = false; // resultCode could change nullability. /* 077 */ filter_value6 = filter_value7 == 0; /* 078 */ /* 079 */ } /* 080 */ if (filter_isNull6 || !filter_value6) continue; /* 081 */ /* 082 */ /* (input[1, int] <= 4) */ /* 083 */ boolean filter_value11 = false; /* 084 */ filter_value11 = filter_value1 <= 4; /* 085 */ if (!filter_value11) continue; /* 086 */ /* 087 */ /* (input[1, int] > 1) */ /* 088 */ boolean filter_value14 = false; /* 089 */ filter_value14 = filter_value1 > 1; /* 090 */ if (!filter_value14) continue; /* 091 */ /* 092 */ filter_metricValue.add(1); /* 093 */ /* 094 */ /*** CONSUME: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 095 */ /* 096 */ /* input[0, string] */ /* 097 */ /* input[0, string] */ /* 098 */ boolean filter_isNull = inputadapter_row.isNullAt(0); /* 099 */ UTF8String filter_value = filter_isNull ? null : (inputadapter_row.getUTF8String(0)); /* 100 */ project_holder.reset(); /* 101 */ /* 102 */ project_rowWriter.zeroOutNullBytes(); /* 103 */ /* 104 */ if (filter_isNull) { /* 105 */ project_rowWriter.setNullAt(0); /* 106 */ } else { /* 107 */ project_rowWriter.write(0, filter_value); /* 108 */ } /* 109 */ /* 110 */ project_rowWriter.write(1, filter_value1); /* 111 */ project_result.setTotalSize(project_holder.totalSize()); /* 112 */ append(project_result.copy()); /* 113 */ } /* 114 */ } /* 115 */ } ```` Generated code with this PR ````java /* 032 */ protected void processNext() throws java.io.IOException { /* 033 */ /*** PRODUCE: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 034 */ /* 035 */ /*** PRODUCE: Filter (((_2#1 % 2) = 0) && ((_2#1 <= 5) && (_2#1 > 1))) */ /* 036 */ /* 037 */ /*** PRODUCE: INPUT */ /* 038 */ /* 039 */ while (!shouldStop() && inputadapter_input.hasNext()) { /* 040 */ InternalRow inputadapter_row = (InternalRow) inputadapter_input.next(); /* 041 */ /*** CONSUME: Filter (((_2#1 % 2) = 0) && ((_2#1 <= 5) && (_2#1 > 1))) */ /* 042 */ /* input[1, int] */ /* 043 */ int filter_value1 = inputadapter_row.getInt(1); /* 044 */ /* 045 */ /* ((input[1, int] % 2) = 0) */ /* 046 */ /* (input[1, int] % 2) */ /* 047 */ int filter_value3 = (int)(filter_value1 % 2); /* 048 */ /* 049 */ boolean filter_value2 = false; /* 050 */ filter_value2 = filter_value3 == 0; /* 051 */ if (!filter_value2) continue; /* 052 */ /* 053 */ /* (input[1, int] <= 5) */ /* 054 */ boolean filter_value7 = false; /* 055 */ filter_value7 = filter_value1 <= 5; /* 056 */ if (!filter_value7) continue; /* 057 */ /* 058 */ /* (input[1, int] > 1) */ /* 059 */ boolean filter_value10 = false; /* 060 */ filter_value10 = filter_value1 > 1; /* 061 */ if (!filter_value10) continue; /* 062 */ /* 063 */ filter_metricValue.add(1); /* 064 */ /* 065 */ /*** CONSUME: Project [_1#0 AS k#3,_2#1 AS v#4] */ /* 066 */ /* 067 */ /* input[0, string] */ /* 068 */ /* input[0, string] */ /* 069 */ boolean filter_isNull = inputadapter_row.isNullAt(0); /* 070 */ UTF8String filter_value = filter_isNull ? null : (inputadapter_row.getUTF8String(0)); /* 071 */ project_holder.reset(); /* 072 */ /* 073 */ project_rowWriter.zeroOutNullBytes(); /* 074 */ /* 075 */ if (filter_isNull) { /* 076 */ project_rowWriter.setNullAt(0); /* 077 */ } else { /* 078 */ project_rowWriter.write(0, filter_value); /* 079 */ } /* 080 */ /* 081 */ project_rowWriter.write(1, filter_value1); /* 082 */ project_result.setTotalSize(project_holder.totalSize()); /* 083 */ append(project_result.copy()); /* 084 */ } /* 085 */ } /* 086 */ } ```` Author: Kazuaki Ishizaki Closes #11684 from kiszk/SPARK-13844. --- .../sql/catalyst/expressions/arithmetic.scala | 72 ++++++++++++++------ .../sql/catalyst/expressions/predicates.scala | 78 ++++++++++++++-------- 2 files changed, 102 insertions(+), 48 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index ed812e0679..1e9c971800 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -237,21 +237,35 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $divide; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $divide; + } + } + """ + } } } @@ -299,21 +313,35 @@ case class Remainder(left: Expression, right: Expression) extends BinaryArithmet } else { s"($javaType)(${eval1.value} $symbol ${eval2.value})" } - s""" - ${eval2.code} - boolean ${ev.isNull} = false; - $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; - if (${eval2.isNull} || $isZero) { - ${ev.isNull} = true; - } else { - ${eval1.code} - if (${eval1.isNull}) { + if (!left.nullable && !right.nullable) { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if ($isZero) { ${ev.isNull} = true; } else { + ${eval1.code} ${ev.value} = $remainder; } - } - """ + """ + } else { + s""" + ${eval2.code} + boolean ${ev.isNull} = false; + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + if (${eval2.isNull} || $isZero) { + ${ev.isNull} = true; + } else { + ${eval1.code} + if (${eval1.isNull}) { + ${ev.isNull} = true; + } else { + ${ev.value} = $remainder; + } + } + """ + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 20818bfb1a..e23ad5596b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -274,22 +274,35 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with val eval2 = right.gen(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. - s""" - ${eval1.code} - boolean ${ev.isNull} = false; - boolean ${ev.value} = false; + if (!left.nullable && !right.nullable) { + ev.isNull = "false" + s""" + ${eval1.code} + boolean ${ev.value} = false; - if (!${eval1.isNull} && !${eval1.value}) { - } else { - ${eval2.code} - if (!${eval2.isNull} && !${eval2.value}) { - } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.value} = true; + if (${eval1.value}) { + ${eval2.code} + ${ev.value} = ${eval2.value}; + } + """ + } else { + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.value} = false; + + if (!${eval1.isNull} && !${eval1.value}) { } else { - ${ev.isNull} = true; + ${eval2.code} + if (!${eval2.isNull} && !${eval2.value}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.value} = true; + } else { + ${ev.isNull} = true; + } } - } - """ + """ + } } } @@ -325,22 +338,35 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P val eval2 = right.gen(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. - s""" - ${eval1.code} - boolean ${ev.isNull} = false; - boolean ${ev.value} = true; + if (!left.nullable && !right.nullable) { + ev.isNull = "false" + s""" + ${eval1.code} + boolean ${ev.value} = true; - if (!${eval1.isNull} && ${eval1.value}) { - } else { - ${eval2.code} - if (!${eval2.isNull} && ${eval2.value}) { - } else if (!${eval1.isNull} && !${eval2.isNull}) { - ${ev.value} = false; + if (!${eval1.value}) { + ${eval2.code} + ${ev.value} = ${eval2.value}; + } + """ + } else { + s""" + ${eval1.code} + boolean ${ev.isNull} = false; + boolean ${ev.value} = true; + + if (!${eval1.isNull} && ${eval1.value}) { } else { - ${ev.isNull} = true; + ${eval2.code} + if (!${eval2.isNull} && ${eval2.value}) { + } else if (!${eval1.isNull} && !${eval2.isNull}) { + ${ev.value} = false; + } else { + ${ev.isNull} = true; + } } - } - """ + """ + } } } -- cgit v1.2.3 From 600c0b69cab4767e8e5a6f4284777d8b9d4bd40e Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 28 Mar 2016 12:31:12 -0700 Subject: [SPARK-13713][SQL] Migrate parser from ANTLR3 to ANTLR4 ### What changes were proposed in this pull request? The current ANTLR3 parser is quite complex to maintain and suffers from code blow-ups. This PR introduces a new parser that is based on ANTLR4. This parser is based on the [Presto's SQL parser](https://github.com/facebook/presto/blob/master/presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4). The current implementation can parse and create Catalyst and SQL plans. Large parts of the HiveQl DDL and some of the DML functionality is currently missing, the plan is to add this in follow-up PRs. This PR is a work in progress, and work needs to be done in the following area's: - [x] Error handling should be improved. - [x] Documentation should be improved. - [x] Multi-Insert needs to be tested. - [ ] Naming and package locations. ### How was this patch tested? Catalyst and SQL unit tests. Author: Herman van Hovell Closes #11557 from hvanhovell/ngParser. --- LICENSE | 1 + dev/deps/spark-deps-hadoop-2.2 | 1 + dev/deps/spark-deps-hadoop-2.3 | 1 + dev/deps/spark-deps-hadoop-2.4 | 1 + dev/deps/spark-deps-hadoop-2.6 | 1 + dev/deps/spark-deps-hadoop-2.7 | 1 + pom.xml | 6 + project/SparkBuild.scala | 8 +- project/plugins.sbt | 6 + python/pyspark/sql/tests.py | 6 +- python/pyspark/sql/utils.py | 8 + sql/catalyst/pom.xml | 4 + .../apache/spark/sql/catalyst/parser/ng/SqlBase.g4 | 911 ++++++++++++ .../apache/spark/sql/catalyst/dsl/package.scala | 34 +- .../spark/sql/catalyst/parser/ng/AstBuilder.scala | 1452 ++++++++++++++++++++ .../spark/sql/catalyst/parser/ng/ParseDriver.scala | 240 ++++ .../spark/sql/catalyst/parser/ng/ParserUtils.scala | 118 ++ .../sql/catalyst/parser/CatalystQlSuite.scala | 52 +- .../sql/catalyst/parser/DataTypeParserSuite.scala | 55 +- .../sql/catalyst/parser/ng/ErrorParserSuite.scala | 67 + .../catalyst/parser/ng/ExpressionParserSuite.scala | 497 +++++++ .../sql/catalyst/parser/ng/PlanParserSuite.scala | 429 ++++++ .../parser/ng/TableIdentifierParserSuite.scala | 42 + .../apache/spark/sql/catalyst/plans/PlanTest.scala | 20 +- .../spark/sql/execution/SparkSqlParser.scala | 219 +++ .../scala/org/apache/spark/sql/functions.scala | 5 +- .../apache/spark/sql/internal/SessionState.scala | 2 +- .../scala/org/apache/spark/sql/JoinSuite.scala | 4 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 2 +- 29 files changed, 4127 insertions(+), 66 deletions(-) create mode 100644 sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala (limited to 'sql/catalyst') diff --git a/LICENSE b/LICENSE index d7a790a628..5a8c78b98b 100644 --- a/LICENSE +++ b/LICENSE @@ -238,6 +238,7 @@ The text of each license is also included at licenses/LICENSE-[project].txt. (BSD 3 Clause) netlib core (com.github.fommil.netlib:core:1.1.2 - https://github.com/fommil/netlib-java/core) (BSD 3 Clause) JPMML-Model (org.jpmml:pmml-model:1.2.7 - https://github.com/jpmml/jpmml-model) (BSD License) AntLR Parser Generator (antlr:antlr:2.7.7 - http://www.antlr.org/) + (BSD License) ANTLR 4.5.2-1 (org.antlr:antlr4:4.5.2-1 - http://wwww.antlr.org/) (BSD licence) ANTLR ST4 4.0.4 (org.antlr:ST4:4.0.4 - http://www.stringtemplate.org) (BSD licence) ANTLR StringTemplate (org.antlr:stringtemplate:3.2.1 - http://www.stringtemplate.org) (BSD License) Javolution (javolution:javolution:5.5.1 - http://javolution.org) diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 512675a599..7c2f88bdb1 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index 31f8694fed..f4d600038d 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 0fa8bccab0..7c5e2c35bd 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar arpack_combined_all-0.1.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 8d2f6e6e32..03d9a51057 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index a114c4ae8d..5765071a1c 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -3,6 +3,7 @@ RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar antlr-runtime-3.5.2.jar +antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar apacheds-i18n-2.0.0-M15.jar diff --git a/pom.xml b/pom.xml index b4cfa3a598..475f0544bd 100644 --- a/pom.xml +++ b/pom.xml @@ -178,6 +178,7 @@ 1.3.9 0.9.2 3.5.2 + 4.5.2-1 ${java.home} @@ -1759,6 +1760,11 @@ antlr-runtime ${antlr.version} + + org.antlr + antlr4-runtime + ${antlr4.version} + diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index fb229b979d..39a9e16f7e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -25,6 +25,7 @@ import sbt._ import sbt.Classpaths.publishTask import sbt.Keys._ import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion +import com.simplytyped.Antlr4Plugin._ import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys} import com.typesafe.tools.mima.plugin.MimaKeys @@ -401,7 +402,10 @@ object OldDeps { } object Catalyst { - lazy val settings = Seq( + lazy val settings = antlr4Settings ++ Seq( + antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"), + antlr4GenListener in Antlr4 := true, + antlr4GenVisitor in Antlr4 := true, // ANTLR code-generation step. // // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of @@ -414,7 +418,7 @@ object Catalyst { "SparkSqlLexer.g", "SparkSqlParser.g") val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value + val targetDir = (sourceManaged in Compile).value / "antlr3" // Create default ANTLR Tool. val antlr = new org.antlr.Tool diff --git a/project/plugins.sbt b/project/plugins.sbt index eeca94a47c..d9ed7962bf 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -23,3 +23,9 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" libraryDependencies += "org.antlr" % "antlr" % "3.5.2" + + +// TODO I am not sure we want such a dep. +resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" + +addSbtPlugin("com.simplytyped" % "sbt-antlr4" % "0.7.10") diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 83ef76c13c..1a5d422af9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -51,7 +51,7 @@ from pyspark.sql.types import UserDefinedType, _infer_type from pyspark.tests import ReusedPySparkTestCase from pyspark.sql.functions import UserDefinedFunction, sha2 from pyspark.sql.window import Window -from pyspark.sql.utils import AnalysisException, IllegalArgumentException +from pyspark.sql.utils import AnalysisException, ParseException, IllegalArgumentException class UTCOffsetTimezone(datetime.tzinfo): @@ -1130,7 +1130,9 @@ class SQLTests(ReusedPySparkTestCase): def test_capture_analysis_exception(self): self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("select abc")) self.assertRaises(AnalysisException, lambda: self.df.selectExpr("a + b")) - self.assertRaises(AnalysisException, lambda: self.sqlCtx.sql("abc")) + + def test_capture_parse_exception(self): + self.assertRaises(ParseException, lambda: self.sqlCtx.sql("abc")) def test_capture_illegalargument_exception(self): self.assertRaisesRegexp(IllegalArgumentException, "Setting negative mapred.reduce.tasks", diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b0a0373372..b89ea8c6e0 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -33,6 +33,12 @@ class AnalysisException(CapturedException): """ +class ParseException(CapturedException): + """ + Failed to parse a SQL command. + """ + + class IllegalArgumentException(CapturedException): """ Passed an illegal or inappropriate argument. @@ -49,6 +55,8 @@ def capture_sql_exception(f): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) + if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '): + raise ParseException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) raise diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 5d1d9edd25..c834a011f1 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -75,6 +75,10 @@ org.antlr antlr-runtime + + org.antlr + antlr4-runtime + commons-codec commons-codec diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 new file mode 100644 index 0000000000..e46fd9bed5 --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 @@ -0,0 +1,911 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +grammar SqlBase; + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +statement + : query #statementDefault + | USE db=identifier #use + | CREATE DATABASE (IF NOT EXISTS)? identifier + (COMMENT comment=STRING)? locationSpec? + (WITH DBPROPERTIES tablePropertyList)? #createDatabase + | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | createTableHeader ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTableUsing + | createTableHeader tableProvider + (OPTIONS tablePropertyList)? AS? query #createTableUsing + | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)? + (PARTITIONED BY identifierList)? bucketSpec? skewSpec? + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + (AS? query)? #createTable + | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq?) #analyze + | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable + | ALTER TABLE tableIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER TABLE tableIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER TABLE tableIdentifier bucketSpec #bucketTable + | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable + | ALTER TABLE tableIdentifier NOT SORTED #unsortTable + | ALTER TABLE tableIdentifier skewSpec #skewTable + | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable + | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable + | ALTER TABLE tableIdentifier + SET SKEWED LOCATION skewedLocationList #setTableSkewLocations + | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE tableIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER TABLE from=tableIdentifier + EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition + | ALTER TABLE tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition + | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition + | ALTER TABLE tableIdentifier partitionSpec? + SET FILEFORMAT fileFormat #setTableFileFormat + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable + | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable + | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? oldName=identifier colType + (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn + | ALTER TABLE tableIdentifier partitionSpec? + ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns + | ALTER TABLE tableIdentifier partitionSpec? + REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? + (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | EXPLAIN explainOption* statement #explain + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE (qualifiedName | pattern=STRING))? #showTables + | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction + | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | REFRESH TABLE tableIdentifier #refreshTable + | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable + | UNCACHE TABLE identifier #uncacheTable + | CLEAR CACHE #clearCache + | ADD identifier .*? #addResource + | SET .*? #setConfiguration + | hiveNativeCommands #executeNativeCommand + ; + +hiveNativeCommands + : createTableHeader LIKE tableIdentifier + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + | DELETE FROM tableIdentifier (WHERE booleanExpression)? + | TRUNCATE TABLE tableIdentifier partitionSpec? + (COLUMNS identifierList)? + | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier + | ALTER VIEW from=tableIdentifier AS? + SET TBLPROPERTIES tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + ADD (IF NOT EXISTS)? partitionSpecLocation+ + | ALTER VIEW from=tableIdentifier AS? + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? + | DROP VIEW (IF EXISTS)? qualifiedName + | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? + | START TRANSACTION (transactionMode (',' transactionMode)*)? + | COMMIT WORK? + | ROLLBACK WORK? + | SHOW PARTITIONS tableIdentifier partitionSpec? + | DFS .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*? + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +query + : ctes? queryNoWith + ; + +insertInto + : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + | INSERT INTO TABLE? tableIdentifier partitionSpec? + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +describeColName + : identifier ('.' (identifier | STRING))* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=identifier AS? '(' queryNoWith ')' + ; + +tableProvider + : USING qualifiedName + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=STRING)? + ; + +tablePropertyKey + : looseIdentifier ('.' looseIdentifier)* + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +skewedLocation + : (constant | constantList) EQ STRING + ; + +skewedLocationList + : '(' skewedLocation (',' skewedLocation)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? + (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +queryNoWith + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windows? + (LIMIT limit=expression)? + ; + +multiInsertQueryBody + : insertInto? + querySpecification + queryOrganization + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | TABLE tableIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' queryNoWith ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? + ; + +querySpecification + : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq)) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + fromClause? + (WHERE where=booleanExpression)?) + | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) + lateralView* + (WHERE where=booleanExpression)? + aggregation? + (HAVING having=booleanExpression)? + windows?) + ; + +fromClause + : FROM relation (',' relation)* lateralView* + ; + +aggregation + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : left=relation + ((CROSS | joinType) JOIN right=relation joinCriteria? + | NATURAL joinType JOIN right=relation + ) #joinRelation + | relationPrimary #relationDefault + ; + +joinType + : INNER? + | LEFT OUTER? + | LEFT SEMI + | RIGHT OUTER? + | FULL OUTER? + ; + +joinCriteria + : ON booleanExpression + | USING '(' identifier (',' identifier)* ')' + ; + +sample + : TABLESAMPLE '(' + ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + | (expression sampleType=ROWS) + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + ')' + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : identifier (',' identifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : identifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier (COMMENT STRING)? + ; + +relationPrimary + : tableIdentifier sample? (AS? identifier)? #tableName + | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery + | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + ; + +inlineTable + : VALUES expression (',' expression)* (AS? identifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +tableIdentifier + : (db=identifier '.')? table=identifier + ; + +namedExpression + : expression (AS? (identifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +expression + : booleanExpression + ; + +booleanExpression + : predicated #booleanDefault + | NOT booleanExpression #logicalNot + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + | EXISTS '(' query ')' #exists + ; + +// workaround for: +// https://github.com/antlr/antlr4/issues/780 +// https://github.com/antlr/antlr4/issues/781 +predicated + : valueExpression predicate[$valueExpression.ctx]? + ; + +predicate[ParserRuleContext value] + : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between + | NOT? IN '(' expression (',' expression)* ')' #inList + | NOT? IN '(' query ')' #inSubquery + | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like + | IS NOT? NULL #nullPredicate + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' expression (',' expression)+ ')' #rowConstructor + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | '(' query ')' #subqueryExpression + | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL intervalField* + ; + +intervalField + : value=intervalValue unit=identifier (TO to=identifier)? + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : identifier ':'? dataType (COMMENT STRING)? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windows + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : identifier AS windowSpec + ; + +windowSpec + : name=identifier #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + + +explainOption + : LOGICAL | FORMATTED | EXTENDED + ; + +transactionMode + : ISOLATION LEVEL SNAPSHOT #isolationLevel + | READ accessMode=(ONLY | WRITE) #transactionAccessMode + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). +looseIdentifier + : identifier + | FROM + | TO + | TABLE + | WITH + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : DECIMAL_VALUE #decimalLiteral + | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | INTEGER_VALUE #integerLiteral + | BIGINT_LITERAL #bigIntLiteral + | SMALLINT_LITERAL #smallIntLiteral + | TINYINT_LITERAL #tinyIntLiteral + | DOUBLE_LITERAL #doubleLiteral + ; + +nonReserved + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + | ADD + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | GROUPING | CUBE | ROLLUP + | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF + | SET + | VIEW | REPLACE + | IF + | NO | DATA + | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL + | SNAPSHOT | READ | WRITE | ONLY + | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST + | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT + ; + +SELECT: 'SELECT'; +FROM: 'FROM'; +ADD: 'ADD'; +AS: 'AS'; +ALL: 'ALL'; +DISTINCT: 'DISTINCT'; +WHERE: 'WHERE'; +GROUP: 'GROUP'; +BY: 'BY'; +GROUPING: 'GROUPING'; +SETS: 'SETS'; +CUBE: 'CUBE'; +ROLLUP: 'ROLLUP'; +ORDER: 'ORDER'; +HAVING: 'HAVING'; +LIMIT: 'LIMIT'; +AT: 'AT'; +OR: 'OR'; +AND: 'AND'; +IN: 'IN'; +NOT: 'NOT' | '!'; +NO: 'NO'; +EXISTS: 'EXISTS'; +BETWEEN: 'BETWEEN'; +LIKE: 'LIKE'; +RLIKE: 'RLIKE' | 'REGEXP'; +IS: 'IS'; +NULL: 'NULL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +NULLS: 'NULLS'; +ASC: 'ASC'; +DESC: 'DESC'; +FOR: 'FOR'; +INTERVAL: 'INTERVAL'; +CASE: 'CASE'; +WHEN: 'WHEN'; +THEN: 'THEN'; +ELSE: 'ELSE'; +END: 'END'; +JOIN: 'JOIN'; +CROSS: 'CROSS'; +OUTER: 'OUTER'; +INNER: 'INNER'; +LEFT: 'LEFT'; +SEMI: 'SEMI'; +RIGHT: 'RIGHT'; +FULL: 'FULL'; +NATURAL: 'NATURAL'; +ON: 'ON'; +LATERAL: 'LATERAL'; +WINDOW: 'WINDOW'; +OVER: 'OVER'; +PARTITION: 'PARTITION'; +RANGE: 'RANGE'; +ROWS: 'ROWS'; +UNBOUNDED: 'UNBOUNDED'; +PRECEDING: 'PRECEDING'; +FOLLOWING: 'FOLLOWING'; +CURRENT: 'CURRENT'; +ROW: 'ROW'; +WITH: 'WITH'; +VALUES: 'VALUES'; +CREATE: 'CREATE'; +TABLE: 'TABLE'; +VIEW: 'VIEW'; +REPLACE: 'REPLACE'; +INSERT: 'INSERT'; +DELETE: 'DELETE'; +INTO: 'INTO'; +DESCRIBE: 'DESCRIBE'; +EXPLAIN: 'EXPLAIN'; +FORMAT: 'FORMAT'; +LOGICAL: 'LOGICAL'; +CAST: 'CAST'; +SHOW: 'SHOW'; +TABLES: 'TABLES'; +COLUMNS: 'COLUMNS'; +COLUMN: 'COLUMN'; +USE: 'USE'; +PARTITIONS: 'PARTITIONS'; +FUNCTIONS: 'FUNCTIONS'; +DROP: 'DROP'; +UNION: 'UNION'; +EXCEPT: 'EXCEPT'; +INTERSECT: 'INTERSECT'; +TO: 'TO'; +TABLESAMPLE: 'TABLESAMPLE'; +STRATIFY: 'STRATIFY'; +ALTER: 'ALTER'; +RENAME: 'RENAME'; +ARRAY: 'ARRAY'; +MAP: 'MAP'; +STRUCT: 'STRUCT'; +COMMENT: 'COMMENT'; +SET: 'SET'; +DATA: 'DATA'; +START: 'START'; +TRANSACTION: 'TRANSACTION'; +COMMIT: 'COMMIT'; +ROLLBACK: 'ROLLBACK'; +WORK: 'WORK'; +ISOLATION: 'ISOLATION'; +LEVEL: 'LEVEL'; +SNAPSHOT: 'SNAPSHOT'; +READ: 'READ'; +WRITE: 'WRITE'; +ONLY: 'ONLY'; + +IF: 'IF'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<='; +GT : '>'; +GTE : '>='; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +DIV: 'DIV'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +HAT: '^'; + +PERCENTLIT: 'PERCENT'; +BUCKET: 'BUCKET'; +OUT: 'OUT'; +OF: 'OF'; + +SORT: 'SORT'; +CLUSTER: 'CLUSTER'; +DISTRIBUTE: 'DISTRIBUTE'; +OVERWRITE: 'OVERWRITE'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; +USING: 'USING'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +DELIMITED: 'DELIMITED'; +FIELDS: 'FIELDS'; +TERMINATED: 'TERMINATED'; +COLLECTION: 'COLLECTION'; +ITEMS: 'ITEMS'; +KEYS: 'KEYS'; +ESCAPED: 'ESCAPED'; +LINES: 'LINES'; +SEPARATED: 'SEPARATED'; +FUNCTION: 'FUNCTION'; +EXTENDED: 'EXTENDED'; +REFRESH: 'REFRESH'; +CLEAR: 'CLEAR'; +CACHE: 'CACHE'; +UNCACHE: 'UNCACHE'; +LAZY: 'LAZY'; +FORMATTED: 'FORMATTED'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +OPTIONS: 'OPTIONS'; +UNSET: 'UNSET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +DBPROPERTIES: 'DBPROPERTIES'; +BUCKETS: 'BUCKETS'; +SKEWED: 'SKEWED'; +STORED: 'STORED'; +DIRECTORIES: 'DIRECTORIES'; +LOCATION: 'LOCATION'; +EXCHANGE: 'EXCHANGE'; +ARCHIVE: 'ARCHIVE'; +UNARCHIVE: 'UNARCHIVE'; +FILEFORMAT: 'FILEFORMAT'; +TOUCH: 'TOUCH'; +COMPACT: 'COMPACT'; +CONCATENATE: 'CONCATENATE'; +CHANGE: 'CHANGE'; +FIRST: 'FIRST'; +AFTER: 'AFTER'; +CASCADE: 'CASCADE'; +RESTRICT: 'RESTRICT'; +CLUSTERED: 'CLUSTERED'; +SORTED: 'SORTED'; +PURGE: 'PURGE'; +INPUTFORMAT: 'INPUTFORMAT'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +INPUTDRIVER: 'INPUTDRIVER'; +OUTPUTDRIVER: 'OUTPUTDRIVER'; +DATABASE: 'DATABASE' | 'SCHEMA'; +DFS: 'DFS'; +TRUNCATE: 'TRUNCATE'; +METADATA: 'METADATA'; +REPLICATION: 'REPLICATION'; +ANALYZE: 'ANALYZE'; +COMPUTE: 'COMPUTE'; +STATISTICS: 'STATISTICS'; +PARTITIONED: 'PARTITIONED'; +EXTERNAL: 'EXTERNAL'; +DEFINED: 'DEFINED'; +REVOKE: 'REVOKE'; +GRANT: 'GRANT'; +LOCK: 'LOCK'; +UNLOCK: 'UNLOCK'; +MSCK: 'MSCK'; +EXPORT: 'EXPORT'; +IMPORT: 'IMPORT'; +LOAD: 'LOAD'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +SCIENTIFIC_DECIMAL_VALUE + : DIGIT+ ('.' DIGIT*)? EXPONENT + | '.' DIGIT+ EXPONENT + ; + +DOUBLE_LITERAL + : + (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 3540014c3e..105947028d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,7 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions -import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.{Inner, JoinType} @@ -161,6 +161,10 @@ package object dsl { def lower(e: Expression): Expression = Lower(e) def sqrt(e: Expression): Expression = Sqrt(e) def abs(e: Expression): Expression = Abs(e) + def star(names: String*): Expression = names match { + case Seq() => UnresolvedStar(None) + case target => UnresolvedStar(Option(target)) + } implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? @@ -231,6 +235,12 @@ package object dsl { AttributeReference(s, structType, nullable = true)() def struct(attrs: AttributeReference*): AttributeReference = struct(StructType.fromAttributes(attrs)) + + /** Create a function. */ + def function(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = false) + def distinctFunction(exprs: Expression*): UnresolvedFunction = + UnresolvedFunction(s, exprs, isDistinct = true) } implicit class DslAttribute(a: AttributeReference) { @@ -243,8 +253,20 @@ package object dsl { object expressions extends ExpressionConversions // scalastyle:ignore object plans { // scalastyle:ignore + def table(ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref), None) + + def table(db: String, ref: String): LogicalPlan = + UnresolvedRelation(TableIdentifier(ref, Option(db)), None) + implicit class DslLogicalPlan(val logicalPlan: LogicalPlan) { - def select(exprs: NamedExpression*): LogicalPlan = Project(exprs, logicalPlan) + def select(exprs: Expression*): LogicalPlan = { + val namedExpressions = exprs.map { + case e: NamedExpression => e + case e => UnresolvedAlias(e) + } + Project(namedExpressions, logicalPlan) + } def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) @@ -296,6 +318,14 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) + def as(alias: String): LogicalPlan = logicalPlan match { + case UnresolvedRelation(tbl, _) => UnresolvedRelation(tbl, Option(alias)) + case plan => SubqueryAlias(alias, plan) + } + + def distribute(exprs: Expression*): LogicalPlan = + RepartitionByExpression(exprs, logicalPlan) + def analyze: LogicalPlan = EliminateSubqueryAliases(analysis.SimpleAnalyzer.execute(logicalPlan)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala new file mode 100644 index 0000000000..5a64c414fb --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala @@ -0,0 +1,1452 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or + * TableIdentifier. + */ +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { + import ParserUtils._ + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visit(ctx.dataType).asInstanceOf[DataType] + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Make sure we do not try to create a plan for a native command. + */ + override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + if (qualifiedName != null) { + val names = qualifiedName().identifier().asScala.map(_.getText).toList + names match { + case db :: name :: Nil => + ShowFunctions(Some(db), Some(name)) + case name :: Nil => + ShowFunctions(None, Some(name)) + case _ => + throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) + } + } else if (pattern != null) { + ShowFunctions(None, Some(string(pattern))) + } else { + ShowFunctions(None, None) + } + } + + /** + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") + DescribeFunction(functionName, ctx.EXTENDED != null) + } + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { + case nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + + // Check for duplicate names. + ctes.groupBy(_._1).filter(_._2.size > 1).foreach { + case (name, _) => + throw new ParseException( + s"Name '$name' is used for multiple common table expressions", ctx) + } + + With(query, ctes.toMap) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + assert(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), + partitionKeys, + query, + ctx.OVERWRITE != null, + ctx.EXISTS != null) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText.toLowerCase + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + }.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + RepartitionByExpression(expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + RepartitionByExpression(expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression(expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation.optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) + specType match { + case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createStructType(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case SqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (aggregation != null) { + withAggregation(aggregation, namedExpressions, withFilter) + } else if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + // Having + val withHaving = withProject.optional(having) { + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(expression(having), BooleanType), withProject) + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withHaving) + } else { + withHaving + } + + // Window + withDistinct.optionalMap(windows)(withWindows) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = null + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [DISTINCT] + * - UNION ALL + * - EXCEPT [DISTINCT] + * - INTERSECT [DISTINCT] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case SqlBaseParser.UNION if all => + Union(left, right) + case SqlBaseParser.UNION => + Distinct(Union(left, right)) + case SqlBaseParser.INTERSECT if all => + throw new ParseException("INTERSECT ALL is not supported.", ctx) + case SqlBaseParser.INTERSECT => + Intersect(left, right) + case SqlBaseParser.EXCEPT if all => + throw new ParseException("EXCEPT ALL is not supported.", ctx) + case SqlBaseParser.EXCEPT => + Except(left, right) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + val groupByExpressions = expressionList(groupingExpressions) + + if (GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val expressionMap = groupByExpressions.zipWithIndex.toMap + val numExpressions = expressionMap.size + val mask = (1 << numExpressions) - 1 + val masks = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(mask) { + case (bitmap, eCtx) => + // Find the index of the expression. + val e = typedVisit[Expression](eCtx) + val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( + throw new ParseException( + s"$e doesn't show up in the GROUP BY list", ctx)) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numExpressions - 1 - index)) + } + } + GroupingSets(masks, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + + // Create the generator. + val generator = ctx.qualifiedName.getText.toLowerCase match { + case "explode" if expressions.size == 1 => + Explode(expressions.head) + case "json_tuple" => + JsonTuple(expressions) + case other => + withGenerator(other, expressions, ctx) + } + + Generate( + generator, + join = true, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a [[Generator]]. Override this method in order to support custom Generators. + */ + protected def withGenerator( + name: String, + expressions: Seq[Expression], + ctx: LateralViewContext): Generator = { + throw new ParseException(s"Generator function '$name' is not supported", ctx) + } + + /** + * Create a joins between two or more logical plans. + */ + override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { + /** Build a join between two plans. */ + def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if ctx.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, right, joinType, condition) + } + + // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the + // first join clause is at the top. However fields of previously referenced tables can be used + // in following join clauses. The tree needs to be reversed in order to make this work. + var result = plan(ctx.left) + var current = ctx + while (current != null) { + current.right match { + case right: JoinRelationContext => + result = join(current, result, plan(right.left)) + current = right + case right => + result = join(current, result, plan(right)) + current = null + } + } + result + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + } + + ctx.sampleType.getType match { + case SqlBaseParser.ROWS => + Limit(expression(ctx.expression), query) + + case SqlBaseParser.PERCENTLIT => + val fraction = ctx.percentage.getText.toDouble + sample(fraction / 100.0d) + + case SqlBaseParser.BUCKET if ctx.ON != null => + throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + + case SqlBaseParser.BUCKET => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.identifier).map(_.getText)) + table.optionalMap(ctx.sample)(withSample) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val expressions = ctx.expression.asScala.map { eCtx => + val e = expression(eCtx) + assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) + e + } + + // Validate and evaluate the rows. + val (structType, structConstructor) = expressions.head.dataType match { + case st: StructType => + (st, (e: Expression) => e) + case dt => + val st = CreateStruct(Seq(expressions.head)).dataType + (st, (e: Expression) => CreateStruct(Seq(e))) + } + val rows = expressions.map { + case expression => + val safe = Cast(structConstructor(expression), structType) + safe.eval().asInstanceOf[InternalRow] + } + + // Construct attributes. + val baseAttributes = structType.toAttributes.map(_.withNullability(true)) + val attributes = if (ctx.identifierList != null) { + val aliases = visitIdentifierList(ctx.identifierList) + assert(aliases.size == baseAttributes.size, + "Number of aliases must match the number of fields in an inline table.", ctx) + baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + } else { + baseAttributes + } + + // Create plan and add an alias if a name has been defined. + LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a LogicalPlan. + */ + private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Invert a boolean expression if it has a valid NOT clause. + */ + private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { + if (not != null) { + Not(expression) + } else { + expression + } + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case SqlBaseParser.AND => And.apply _ + case SqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverse.map(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query. This is not supported yet. + */ + override def visitExists(ctx: ExistsContext): Expression = { + throw new ParseException("EXISTS clauses are not supported.", ctx) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.EQ => + EqualTo(left, right) + case SqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case SqlBaseParser.LT => + LessThan(left, right) + case SqlBaseParser.LTE => + LessThanOrEqual(left, right) + case SqlBaseParser.GT => + GreaterThan(left, right) + case SqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two + * other expressions. The inverse can also be created. + */ + override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + val between = And( + GreaterThanOrEqual(value, expression(ctx.lower)), + LessThanOrEqual(value, expression(ctx.upper))) + invertIfNotDefined(between, ctx.NOT) + } + + /** + * Create an IN expression. This tests if the value of the left hand side expression is + * contained by the sequence of expressions on the right hand side. + */ + override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { + val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) + invertIfNotDefined(in, ctx.NOT) + } + + /** + * Create an IN expression, where the the right hand side is a query. This is unsupported. + */ + override def visitInSubquery(ctx: InSubqueryContext): Expression = { + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + } + + /** + * Create a (R)LIKE/REGEXP expression. + */ + override def visitLike(ctx: LikeContext): Expression = { + val left = expression(ctx.value) + val right = expression(ctx.pattern) + val like = ctx.like.getType match { + case SqlBaseParser.LIKE => + Like(left, right) + case SqlBaseParser.RLIKE => + RLike(left, right) + } + invertIfNotDefined(like, ctx.NOT) + } + + /** + * Create an IS (NOT) NULL expression. + */ + override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + if (ctx.NOT != null) { + IsNotNull(value) + } else { + IsNull(value) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Mulitplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case SqlBaseParser.ASTERISK => + Multiply(left, right) + case SqlBaseParser.SLASH => + Divide(left, right) + case SqlBaseParser.PERCENT => + Remainder(left, right) + case SqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case SqlBaseParser.PLUS => + Add(left, right) + case SqlBaseParser.MINUS => + Subtract(left, right) + case SqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case SqlBaseParser.HAT => + BitwiseXor(left, right) + case SqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case SqlBaseParser.PLUS => + value + case SqlBaseParser.MINUS => + UnaryMinus(value) + case SqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.expression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). Move this to analysis? + Seq(Literal(1)) + case expressions => + expressions + } + val function = UnresolvedFunction(name, arguments, isDistinct) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case SqlBaseParser.RANGE => RangeFrame + case SqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value + * Preceding/Following boundaries. These expressions must be constant (foldable) and return an + * integer value. + */ + override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { + // We currently only allow foldable integers. + def value: Int = { + val e = expression(ctx.expression) + assert(e.resolved && e.foldable && e.dataType == IntegerType, + "Frame bound value must be a constant integer.", + ctx) + e.eval().asInstanceOf[Int] + } + + // Create the FrameBoundary + ctx.boundType.getType match { + case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case SqlBaseParser.PRECEDING => + ValuePreceding(value) + case SqlBaseParser.CURRENT => + CurrentRow + case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case SqlBaseParser.FOLLOWING => + ValueFollowing(value) + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.expression.asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a dereference expression. The return type depends on the type of the parent, this can + * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an + * [[UnresolvedExtractValue]] if the parent is some expression. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ attr) + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression. + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + UnresolvedAttribute.quoted(ctx.getText) + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + if (ctx.DESC != null) { + SortOrder(expression(ctx.expression), Descending) + } else { + SortOrder(expression(ctx.expression), Ascending) + } + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date and Timestamp typed literals are supported. + * + * TODO what the added value of this over casting? + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + ctx.identifier.getText.toUpperCase match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a double literal for a number denoted in scientifc notation. + */ + override def visitScientificDecimalLiteral( + ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(ctx.getText.toDouble) + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { + val raw = ctx.getText + try { + Literal(f(raw.substring(0, raw.length - 1))) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { + _.toByte + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { + _.toShort + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { + _.toLong + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { + _.toDouble + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + ctx.STRING().asScala.map(string).mkString + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + assert(interval != null, "No interval can be constructed", ctx) + interval + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("char" | "varchar" | "string", Nil) => StringType + case ("char" | "varchar", _ :: Nil) => StringType + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + throw new ParseException( + s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case SqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case SqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case SqlBaseParser.STRUCT => + createStructType(ctx.colTypeList()) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + // Add the comment to the metadata. + val builder = new MetadataBuilder + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + + StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala new file mode 100644 index 0000000000..c9a286374c --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala @@ -0,0 +1,240 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.ParserInterface +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType + +/** + * Base SQL parsing infrastructure. + */ +abstract class AbstractSqlParser extends ParserInterface with Logging { + + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) + } + + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) + } + + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } + + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } + + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder + + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } + + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") + + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) + + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) + + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) + } + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) + } + } + catch { + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) + } + } +} + +/** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + +/** + * This string stream provides the lexer with upper case characters only. This greatly simplifies + * lexing the stream, while we can maintain the original command. + * + * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream + * + * The comment below (taken from the original class) describes the rationale for doing this: + * + * This class provides and implementation for a case insensitive token checker for the lexical + * analysis part of antlr. By converting the token stream into upper case at the time when lexical + * rules are checked, this class ensures that the lexical rules need to just match the token with + * upper case letters as opposed to combination of upper case and lower case characters. This is + * purely used for matching lexical rules. The actual token text is stored in the same way as the + * user input without actually converting it into an upper case. The token values are generated by + * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead + * function and is purely used for matching lexical rules. This also means that the grammar will + * only accept capitalized tokens in case it is run from other tools like antlrworks which do not + * have the ANTLRNoCaseStringStream implementation. + */ + +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { + override def LA(i: Int): Int = { + val la = super.LA(i) + if (la == 0 || la == IntStream.EOF) la + else Character.toUpperCase(la) + } +} + +/** + * The ParseErrorListener converts parse errors into AnalysisExceptions. + */ +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) + } +} + +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) + } + + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString + } + + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) + } +} + +/** + * The post-processor validates & cleans-up the parse tree during the parse process. + */ +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala new file mode 100644 index 0000000000..1fbfa763b4 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode + +import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} + +/** + * A collection of utility methods for use during the parsing process. + */ +object ParserUtils { + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) + } + + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) + } + + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) + } + + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) + + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) + } + + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) + + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) + + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) + } + + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } + } + + /** + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. + */ + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } + } + + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala index c068e895b6..223485e292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala @@ -21,15 +21,20 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.unsafe.types.CalendarInterval class CatalystQlSuite extends PlanTest { val parser = new CatalystQl() + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + val star = UnresolvedAlias(UnresolvedStar(None)) test("test case insensitive") { - val result = Project(UnresolvedAlias(Literal(1)):: Nil, OneRowRelation) + val result = OneRowRelation.select(1) assert(result === parser.parsePlan("seLect 1")) assert(result === parser.parsePlan("select 1")) assert(result === parser.parsePlan("SELECT 1")) @@ -37,52 +42,31 @@ class CatalystQlSuite extends PlanTest { test("test NOT operator with comparison operations") { val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") - val expected = Project( - UnresolvedAlias( - Not( - GreaterThan(Literal(true), Literal(true))) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Not(GreaterThan(true, true))) comparePlans(parsed, expected) } test("test Union Distinct operator") { - val parsed1 = parser.parsePlan("SELECT * FROM t0 UNION SELECT * FROM t1") - val parsed2 = parser.parsePlan("SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") - val expected = - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - SubqueryAlias("u_1", - Distinct( - Union( - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t0"), None)), - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t1"), None)))))) + val parsed1 = parser.parsePlan( + "SELECT * FROM t0 UNION SELECT * FROM t1") + val parsed2 = parser.parsePlan( + "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") + val expected = Distinct(Union(table("t0").select(star), table("t1").select(star))) + .as("u_1").select(star) comparePlans(parsed1, expected) comparePlans(parsed2, expected) } test("test Union All operator") { val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") - val expected = - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - SubqueryAlias("u_1", - Union( - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t0"), None)), - Project(UnresolvedAlias(UnresolvedStar(None)) :: Nil, - UnresolvedRelation(TableIdentifier("t1"), None))))) + val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star) comparePlans(parsed, expected) } test("support hive interval literal") { def checkInterval(sql: String, result: CalendarInterval): Unit = { val parsed = parser.parsePlan(sql) - val expected = Project( - UnresolvedAlias( - Literal(result) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Literal(result)) comparePlans(parsed, expected) } @@ -129,11 +113,7 @@ class CatalystQlSuite extends PlanTest { test("support scientific notation") { def assertRight(input: String, output: Double): Unit = { val parsed = parser.parsePlan("SELECT " + input) - val expected = Project( - UnresolvedAlias( - Literal(output) - ) :: Nil, - OneRowRelation) + val expected = OneRowRelation.select(Literal(output)) comparePlans(parsed, expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index 7d3608033b..d9bd33c50a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -18,19 +18,24 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException} import org.apache.spark.sql.types._ -class DataTypeParserSuite extends SparkFunSuite { +abstract class AbstractDataTypeParserSuite extends SparkFunSuite { + + def parse(sql: String): DataType def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = { test(s"parse ${dataTypeString.replace("\n", "")}") { - assert(DataTypeParser.parse(dataTypeString) === expectedDataType) + assert(parse(dataTypeString) === expectedDataType) } } + def intercept(sql: String) + def unsupported(dataTypeString: String): Unit = { test(s"$dataTypeString is not supported") { - intercept[DataTypeException](DataTypeParser.parse(dataTypeString)) + intercept(dataTypeString) } } @@ -97,13 +102,6 @@ class DataTypeParserSuite extends SparkFunSuite { StructField("arrAy", ArrayType(DoubleType, true), true) :: StructField("anotherArray", ArrayType(StringType, true), true) :: Nil) ) - // A column name can be a reserved word in our DDL parser and SqlParser. - checkDataType( - "Struct", - StructType( - StructField("TABLE", StringType, true) :: - StructField("CASE", BooleanType, true) :: Nil) - ) // Use backticks to quote column names having special characters. checkDataType( "struct<`x+y`:int, `!@#$%^&*()`:string, `1_2.345<>:\"`:varchar(20)>", @@ -118,6 +116,43 @@ class DataTypeParserSuite extends SparkFunSuite { unsupported("it is not a data type") unsupported("struct") unsupported("struct", + StructType( + StructField("TABLE", StringType, true) :: + StructField("CASE", BooleanType, true) :: Nil) + ) + unsupported("struct") + unsupported("struct<`x``y` int>") } + +class CatalystQlDataTypeParserSuite extends AbstractDataTypeParserSuite { + override def intercept(sql: String): Unit = + intercept[ParseException](CatalystSqlParser.parseDataType(sql)) + + override def parse(sql: String): DataType = + CatalystSqlParser.parseDataType(sql) + + // A column name can be a reserved word in our DDL parser and SqlParser. + unsupported("Struct") + + checkDataType( + "struct", + (new StructType).add("x", IntegerType).add("y", StringType)) + + checkDataType( + "struct<`x``y` int>", + (new StructType).add("x`y", IntegerType)) +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala new file mode 100644 index 0000000000..1963fc368f --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite + +/** + * Test various parser errors. + */ +class ErrorParserSuite extends SparkFunSuite { + def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) + + // Check position. + assert(e.line.isDefined) + assert(e.line.get === line) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition) + + // Check messages. + val error = e.getMessage + messages.foreach { message => + assert(error.contains(message)) + } + } + + test("no viable input") { + intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") + intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") + intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + } + + test("extraneous input") { + intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + } + + test("mismatched input") { + intercept("select * from r order by q from t", 1, 27, + "mismatched input", + "---------------------------^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + } + + test("semantic errors") { + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", + "^^^") + intercept("select * from r where a in (select * from t)", 1, 24, + "IN with a Sub-query is currently not supported", + "------------------------^^^") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala new file mode 100644 index 0000000000..32311a5a66 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala new file mode 100644 index 0000000000..4206d22ca7 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala new file mode 100644 index 0000000000..0874322187 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser.ng + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 0541844e0b..aa5d4330d3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ /** @@ -32,6 +32,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { */ protected def normalizeExprIds(plan: LogicalPlan) = { plan transformAllExpressions { + case s: ScalarSubquery => + ScalarSubquery(s.query, ExprId(0)) case a: AttributeReference => AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => @@ -40,21 +42,25 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { } /** - * Normalizes the filter conditions that appear in the plan. For instance, - * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) - * etc., will all now be equivalent. + * Normalizes plans: + * - Filter the filter conditions that appear in a plan. For instance, + * ((expr 1 && expr 2) && expr 3), (expr 1 && expr 2 && expr 3), (expr 3 && (expr 1 && expr 2) + * etc., will all now be equivalent. + * - Sample the seed will replaced by 0L. */ - private def normalizeFilters(plan: LogicalPlan) = { + private def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).sortBy(_.hashCode()).reduce(And), child) + case sample: Sample => + sample.copy(seed = 0L)(true) } } /** Fails the test if the two plans do not match */ protected def comparePlans(plan1: LogicalPlan, plan2: LogicalPlan) { - val normalized1 = normalizeFilters(normalizeExprIds(plan1)) - val normalized2 = normalizeFilters(normalizeExprIds(plan2)) + val normalized1 = normalizePlan(normalizeExprIds(plan1)) + val normalized2 = normalizePlan(normalizeExprIds(plan2)) if (normalized1 != normalized2) { fail( s""" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala new file mode 100644 index 0000000000..c098fa99c2 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -0,0 +1,219 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.execution + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder} +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} +import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} +import org.apache.spark.sql.execution.datasources._ + +/** + * Concrete parser for Spark SQL statements. + */ +object SparkSqlParser extends AbstractSqlParser{ + val astBuilder = new SparkSqlAstBuilder +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class SparkSqlAstBuilder extends AstBuilder { + import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._ + + /** + * Create a [[SetCommand]] logical plan. + * + * Note that we assume that everything after the SET keyword is assumed to be a part of the + * key-value pair. The split between key and value is made by searching for the first `=` + * character in the raw string. + */ + override def visitSetConfiguration(ctx: SetConfigurationContext): LogicalPlan = withOrigin(ctx) { + // Construct the command. + val raw = remainder(ctx.SET.getSymbol) + val keyValueSeparatorIndex = raw.indexOf('=') + if (keyValueSeparatorIndex >= 0) { + val key = raw.substring(0, keyValueSeparatorIndex).trim + val value = raw.substring(keyValueSeparatorIndex + 1).trim + SetCommand(Some(key -> Option(value))) + } else if (raw.nonEmpty) { + SetCommand(Some(raw.trim -> None)) + } else { + SetCommand(None) + } + } + + /** + * Create a [[SetDatabaseCommand]] logical plan. + */ + override def visitUse(ctx: UseContext): LogicalPlan = withOrigin(ctx) { + SetDatabaseCommand(ctx.db.getText) + } + + /** + * Create a [[ShowTablesCommand]] logical plan. + */ + override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { + if (ctx.LIKE != null) { + logWarning("SHOW TABLES LIKE option is ignored.") + } + ShowTablesCommand(Option(ctx.db).map(_.getText)) + } + + /** + * Create a [[RefreshTable]] logical plan. + */ + override def visitRefreshTable(ctx: RefreshTableContext): LogicalPlan = withOrigin(ctx) { + RefreshTable(visitTableIdentifier(ctx.tableIdentifier)) + } + + /** + * Create a [[CacheTableCommand]] logical plan. + */ + override def visitCacheTable(ctx: CacheTableContext): LogicalPlan = withOrigin(ctx) { + val query = Option(ctx.query).map(plan) + CacheTableCommand(ctx.identifier.getText, query, ctx.LAZY != null) + } + + /** + * Create an [[UncacheTableCommand]] logical plan. + */ + override def visitUncacheTable(ctx: UncacheTableContext): LogicalPlan = withOrigin(ctx) { + UncacheTableCommand(ctx.identifier.getText) + } + + /** + * Create a [[ClearCacheCommand]] logical plan. + */ + override def visitClearCache(ctx: ClearCacheContext): LogicalPlan = withOrigin(ctx) { + ClearCacheCommand + } + + /** + * Create an [[ExplainCommand]] logical plan. + */ + override def visitExplain(ctx: ExplainContext): LogicalPlan = withOrigin(ctx) { + val options = ctx.explainOption.asScala + if (options.exists(_.FORMATTED != null)) { + logWarning("EXPLAIN FORMATTED option is ignored.") + } + if (options.exists(_.LOGICAL != null)) { + logWarning("EXPLAIN LOGICAL option is ignored.") + } + + // Create the explain comment. + val statement = plan(ctx.statement) + if (isExplainableStatement(statement)) { + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null)) + } else { + ExplainCommand(OneRowRelation) + } + } + + /** + * Determine if a plan should be explained at all. + */ + protected def isExplainableStatement(plan: LogicalPlan): Boolean = plan match { + case _: datasources.DescribeCommand => false + case _ => true + } + + /** + * Create a [[DescribeCommand]] logical plan. + */ + override def visitDescribeTable(ctx: DescribeTableContext): LogicalPlan = withOrigin(ctx) { + // FORMATTED and columns are not supported. Return null and let the parser decide what to do + // with this (create an exception or pass it on to a different system). + if (ctx.describeColName != null || ctx.FORMATTED != null || ctx.partitionSpec != null) { + null + } else { + datasources.DescribeCommand( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXTENDED != null) + } + } + + /** Type to keep track of a table header. */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + assert(!temporary || !ifNotExists, + "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.", + ctx) + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a [[CreateTableUsing]] or a [[CreateTableUsingAsSelect]] logical plan. + * + * TODO add bucketing and partitioning. + */ + override def visitCreateTableUsing(ctx: CreateTableUsingContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + logWarning("EXTERNAL option is not supported.") + } + val options = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + + if (ctx.query != null) { + // Get the backing query. + val query = plan(ctx.query) + + // Determine the storage mode. + val mode = if (ifNotExists) { + SaveMode.Ignore + } else if (temp) { + SaveMode.Overwrite + } else { + SaveMode.ErrorIfExists + } + CreateTableUsingAsSelect(table, provider, temp, Array.empty, None, mode, options, query) + } else { + val struct = Option(ctx.colTypeList).map(createStructType) + CreateTableUsing(table, struct, provider, temp, options, ifNotExists, managedIfNoPath = false) + } + } + + /** + * Convert a table property list into a key-value map. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + ctx.tableProperty.asScala.map { property => + // A key can either be a String or a collection of dot separated elements. We need to treat + // these differently. + val key = if (property.key.STRING != null) { + string(property.key.STRING) + } else { + property.key.getText + } + val value = Option(property.value).map(string).orNull + key -> value + }.toMap + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8abb9d7e4a..7ce15e3f35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -27,8 +27,8 @@ import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ -import org.apache.spark.sql.catalyst.parser.CatalystQl import org.apache.spark.sql.catalyst.plans.logical.BroadcastHint +import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -1172,8 +1172,7 @@ object functions { * @group normal_funcs */ def expr(expr: String): Column = { - val parser = SQLContext.getActive().map(_.sessionState.sqlParser).getOrElse(new CatalystQl()) - Column(parser.parseExpression(expr)) + Column(SparkSqlParser.parseExpression(expr)) } ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index e5f02caabc..9bc640763f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -81,7 +81,7 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Parser that extracts expressions, plans, table identifiers etc. from SQL texts. */ - lazy val sqlParser: ParserInterface = new SparkQl(conf) + lazy val sqlParser: ParserInterface = SparkSqlParser /** * Planner that converts optimized logical plans to physical plans. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 5af1a4fcd7..a5a4ff13de 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -329,8 +329,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } test("full outer join") { - upperCaseData.where('N <= 4).registerTempTable("left") - upperCaseData.where('N >= 3).registerTempTable("right") + upperCaseData.where('N <= 4).registerTempTable("`left`") + upperCaseData.where('N >= 3).registerTempTable("`right`") val left = UnresolvedRelation(TableIdentifier("left"), None) val right = UnresolvedRelation(TableIdentifier("right"), None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index c958eac266..b727e88668 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1656,7 +1656,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e2 = intercept[AnalysisException] { sql("select interval 23 nanosecond") } - assert(e2.message.contains("cannot recognize input near")) + assert(e2.message.contains("No interval can be constructed")) } test("SPARK-8945: add and subtract expressions for interval type") { -- cgit v1.2.3 From 7007f72ba7f0ccdc1d11315757a80f55a93451df Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Mon, 28 Mar 2016 13:50:42 -0700 Subject: [SPARK-13713][SQL][TEST-MAVEN] Add Antlr4 maven plugin. Seems https://github.com/apache/spark/commit/600c0b69cab4767e8e5a6f4284777d8b9d4bd40e is missing the antlr4 maven plugin. This pr adds it. Author: Yin Huai Closes #12010 from yhuai/mavenAntlr4. --- pom.xml | 5 +++++ sql/catalyst/pom.xml | 15 +++++++++++++++ 2 files changed, 20 insertions(+) (limited to 'sql/catalyst') diff --git a/pom.xml b/pom.xml index 475f0544bd..1513a18b71 100644 --- a/pom.xml +++ b/pom.xml @@ -1891,6 +1891,11 @@ antlr3-maven-plugin 3.5.2 + + org.antlr + antlr4-maven-plugin + ${antlr4.version} + org.apache.maven.plugins diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index c834a011f1..9bfe495e90 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -133,6 +133,21 @@ + + org.antlr + antlr4-maven-plugin + + + + antlr4 + + + + + true + ../catalyst/src/main/antlr4 + + -- cgit v1.2.3 From eebc8c1c95fb7752d09a5846b7cac65f7702c8f2 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 28 Mar 2016 16:25:15 -0700 Subject: [SPARK-13923][SPARK-14014][SQL] Session catalog follow-ups ## What changes were proposed in this pull request? This patch addresses the remaining comments left in #11750 and #11918 after they are merged. For a full list of changes in this patch, just trace the commits. ## How was this patch tested? `SessionCatalogSuite` and `CatalogTestCases` Author: Andrew Or Closes #12006 from andrewor14/session-catalog-followup. --- .../sql/catalyst/catalog/InMemoryCatalog.scala | 18 +- .../sql/catalyst/catalog/SessionCatalog.scala | 74 ++--- .../spark/sql/catalyst/catalog/interface.scala | 14 +- .../spark/sql/catalyst/analysis/AnalysisTest.scala | 2 +- .../catalyst/analysis/DecimalPrecisionSuite.scala | 2 +- .../sql/catalyst/catalog/CatalogTestCases.scala | 6 +- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 30 +-- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../spark/sql/execution/datasources/ddl.scala | 4 +- .../org/apache/spark/sql/hive/HiveCatalog.scala | 297 -------------------- .../org/apache/spark/sql/hive/HiveContext.scala | 4 +- .../spark/sql/hive/HiveExternalCatalog.scala | 298 +++++++++++++++++++++ .../spark/sql/hive/HiveMetastoreCatalog.scala | 22 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 6 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 6 +- .../apache/spark/sql/hive/client/HiveClient.scala | 2 +- .../spark/sql/hive/client/HiveClientImpl.scala | 11 +- .../sql/hive/execution/CreateTableAsSelect.scala | 6 +- .../sql/hive/execution/CreateViewAsSelect.scala | 4 +- .../org/apache/spark/sql/hive/test/TestHive.scala | 4 +- .../apache/spark/sql/hive/HiveCatalogSuite.scala | 49 ---- .../spark/sql/hive/HiveExternalCatalogSuite.scala | 49 ++++ .../org/apache/spark/sql/hive/HiveQlSuite.scala | 16 +- .../apache/spark/sql/hive/ListTablesSuite.scala | 2 +- .../spark/sql/hive/MetastoreDataSourcesSuite.scala | 2 +- .../spark/sql/hive/client/VersionsSuite.scala | 2 +- 26 files changed, 469 insertions(+), 463 deletions(-) delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala delete mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index e216fa5528..2bbb970ec9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -155,7 +155,7 @@ class InMemoryCatalog extends ExternalCatalog { tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = synchronized { requireDbExists(db) - val table = tableDefinition.name.table + val table = tableDefinition.identifier.table if (tableExists(db, table)) { if (!ignoreIfExists) { throw new AnalysisException(s"Table '$table' already exists in database '$db'") @@ -182,14 +182,14 @@ class InMemoryCatalog extends ExternalCatalog { override def renameTable(db: String, oldName: String, newName: String): Unit = synchronized { requireTableExists(db, oldName) val oldDesc = catalog(db).tables(oldName) - oldDesc.table = oldDesc.table.copy(name = TableIdentifier(newName, Some(db))) + oldDesc.table = oldDesc.table.copy(identifier = TableIdentifier(newName, Some(db))) catalog(db).tables.put(newName, oldDesc) catalog(db).tables.remove(oldName) } override def alterTable(db: String, tableDefinition: CatalogTable): Unit = synchronized { - requireTableExists(db, tableDefinition.name.table) - catalog(db).tables(tableDefinition.name.table).table = tableDefinition + requireTableExists(db, tableDefinition.identifier.table) + catalog(db).tables(tableDefinition.identifier.table).table = tableDefinition } override def getTable(db: String, table: String): CatalogTable = synchronized { @@ -296,10 +296,10 @@ class InMemoryCatalog extends ExternalCatalog { override def createFunction(db: String, func: CatalogFunction): Unit = synchronized { requireDbExists(db) - if (functionExists(db, func.name.funcName)) { + if (functionExists(db, func.identifier.funcName)) { throw new AnalysisException(s"Function '$func' already exists in '$db' database") } else { - catalog(db).functions.put(func.name.funcName, func) + catalog(db).functions.put(func.identifier.funcName, func) } } @@ -310,14 +310,14 @@ class InMemoryCatalog extends ExternalCatalog { override def renameFunction(db: String, oldName: String, newName: String): Unit = synchronized { requireFunctionExists(db, oldName) - val newFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db))) + val newFunc = getFunction(db, oldName).copy(identifier = FunctionIdentifier(newName, Some(db))) catalog(db).functions.remove(oldName) catalog(db).functions.put(newName, newFunc) } override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized { - requireFunctionExists(db, funcDefinition.name.funcName) - catalog(db).functions.put(funcDefinition.name.funcName, funcDefinition) + requireFunctionExists(db, funcDefinition.identifier.funcName) + catalog(db).functions.put(funcDefinition.identifier.funcName, funcDefinition) } override def getFunction(db: String, funcName: String): CatalogFunction = synchronized { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 34265faa74..a9cf80764d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalyst.catalog -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.JavaConverters._ +import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} @@ -31,6 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} * An internal catalog that is used by a Spark Session. This internal catalog serves as a * proxy to the underlying metastore (e.g. Hive Metastore) and it also manages temporary * tables and functions of the Spark Session that it belongs to. + * + * This class is not thread-safe. */ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { import ExternalCatalog._ @@ -39,8 +39,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { this(externalCatalog, new SimpleCatalystConf(true)) } - protected[this] val tempTables = new ConcurrentHashMap[String, LogicalPlan] - protected[this] val tempFunctions = new ConcurrentHashMap[String, CatalogFunction] + protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan] + protected[this] val tempFunctions = new mutable.HashMap[String, CatalogFunction] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first @@ -122,9 +122,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * If no such database is specified, create it in the current database. */ def createTable(tableDefinition: CatalogTable, ignoreIfExists: Boolean): Unit = { - val db = tableDefinition.name.database.getOrElse(currentDb) - val table = formatTableName(tableDefinition.name.table) - val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db))) + val db = tableDefinition.identifier.database.getOrElse(currentDb) + val table = formatTableName(tableDefinition.identifier.table) + val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) externalCatalog.createTable(db, newTableDefinition, ignoreIfExists) } @@ -138,9 +138,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * this becomes a no-op. */ def alterTable(tableDefinition: CatalogTable): Unit = { - val db = tableDefinition.name.database.getOrElse(currentDb) - val table = formatTableName(tableDefinition.name.table) - val newTableDefinition = tableDefinition.copy(name = TableIdentifier(table, Some(db))) + val db = tableDefinition.identifier.database.getOrElse(currentDb) + val table = formatTableName(tableDefinition.identifier.table) + val newTableDefinition = tableDefinition.copy(identifier = TableIdentifier(table, Some(db))) externalCatalog.alterTable(db, newTableDefinition) } @@ -164,9 +164,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { def createTempTable( name: String, tableDefinition: LogicalPlan, - ignoreIfExists: Boolean): Unit = { + overrideIfExists: Boolean): Unit = { val table = formatTableName(name) - if (tempTables.containsKey(table) && !ignoreIfExists) { + if (tempTables.contains(table) && !overrideIfExists) { throw new AnalysisException(s"Temporary table '$name' already exists.") } tempTables.put(table, tableDefinition) @@ -188,10 +188,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { val db = oldName.database.getOrElse(currentDb) val oldTableName = formatTableName(oldName.table) val newTableName = formatTableName(newName.table) - if (oldName.database.isDefined || !tempTables.containsKey(oldTableName)) { + if (oldName.database.isDefined || !tempTables.contains(oldTableName)) { externalCatalog.renameTable(db, oldTableName, newTableName) } else { - val table = tempTables.remove(oldTableName) + val table = tempTables(oldTableName) + tempTables.remove(oldTableName) tempTables.put(newTableName, table) } } @@ -206,7 +207,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { def dropTable(name: TableIdentifier, ignoreIfNotExists: Boolean): Unit = { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.containsKey(table)) { + if (name.database.isDefined || !tempTables.contains(table)) { externalCatalog.dropTable(db, table, ignoreIfNotExists) } else { tempTables.remove(table) @@ -224,11 +225,11 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) val relation = - if (name.database.isDefined || !tempTables.containsKey(table)) { + if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) CatalogRelation(db, metadata, alias) } else { - tempTables.get(table) + tempTables(table) } val qualifiedTable = SubqueryAlias(table, relation) // If an alias was specified by the lookup, wrap the plan in a subquery so that @@ -247,7 +248,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { def tableExists(name: TableIdentifier): Boolean = { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.containsKey(table)) { + if (name.database.isDefined || !tempTables.contains(table)) { externalCatalog.tableExists(db, table) } else { true // it's a temporary table @@ -266,7 +267,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { val dbTables = externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } val regex = pattern.replaceAll("\\*", ".*").r - val _tempTables = tempTables.keys().asScala + val _tempTables = tempTables.keys.toSeq .filter { t => regex.pattern.matcher(t).matches() } .map { t => TableIdentifier(t) } dbTables ++ _tempTables @@ -290,7 +291,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * For testing only. */ private[catalog] def getTempTable(name: String): Option[LogicalPlan] = { - Option(tempTables.get(name)) + tempTables.get(name) } // ---------------------------------------------------------------------------- @@ -399,9 +400,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * If no such database is specified, create it in the current database. */ def createFunction(funcDefinition: CatalogFunction): Unit = { - val db = funcDefinition.name.database.getOrElse(currentDb) + val db = funcDefinition.identifier.database.getOrElse(currentDb) val newFuncDefinition = funcDefinition.copy( - name = FunctionIdentifier(funcDefinition.name.funcName, Some(db))) + identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))) externalCatalog.createFunction(db, newFuncDefinition) } @@ -424,9 +425,9 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * this becomes a no-op. */ def alterFunction(funcDefinition: CatalogFunction): Unit = { - val db = funcDefinition.name.database.getOrElse(currentDb) + val db = funcDefinition.identifier.database.getOrElse(currentDb) val newFuncDefinition = funcDefinition.copy( - name = FunctionIdentifier(funcDefinition.name.funcName, Some(db))) + identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))) externalCatalog.alterFunction(db, newFuncDefinition) } @@ -439,10 +440,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * This assumes no database is specified in `funcDefinition`. */ def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { - require(funcDefinition.name.database.isEmpty, + require(funcDefinition.identifier.database.isEmpty, "attempted to create a temporary function while specifying a database") - val name = funcDefinition.name.funcName - if (tempFunctions.containsKey(name) && !ignoreIfExists) { + val name = funcDefinition.identifier.funcName + if (tempFunctions.contains(name) && !ignoreIfExists) { throw new AnalysisException(s"Temporary function '$name' already exists.") } tempFunctions.put(name, funcDefinition) @@ -455,7 +456,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { // Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate // dropFunction and dropTempFunction. def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!tempFunctions.containsKey(name) && !ignoreIfNotExists) { + if (!tempFunctions.contains(name) && !ignoreIfNotExists) { throw new AnalysisException( s"Temporary function '$name' cannot be dropped because it does not exist!") } @@ -476,11 +477,12 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { throw new AnalysisException("rename does not support moving functions across databases") } val db = oldName.database.getOrElse(currentDb) - if (oldName.database.isDefined || !tempFunctions.containsKey(oldName.funcName)) { + if (oldName.database.isDefined || !tempFunctions.contains(oldName.funcName)) { externalCatalog.renameFunction(db, oldName.funcName, newName.funcName) } else { - val func = tempFunctions.remove(oldName.funcName) - val newFunc = func.copy(name = func.name.copy(funcName = newName.funcName)) + val func = tempFunctions(oldName.funcName) + val newFunc = func.copy(identifier = func.identifier.copy(funcName = newName.funcName)) + tempFunctions.remove(oldName.funcName) tempFunctions.put(newName.funcName, newFunc) } } @@ -494,10 +496,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { */ def getFunction(name: FunctionIdentifier): CatalogFunction = { val db = name.database.getOrElse(currentDb) - if (name.database.isDefined || !tempFunctions.containsKey(name.funcName)) { + if (name.database.isDefined || !tempFunctions.contains(name.funcName)) { externalCatalog.getFunction(db, name.funcName) } else { - tempFunctions.get(name.funcName) + tempFunctions(name.funcName) } } @@ -510,7 +512,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } val regex = pattern.replaceAll("\\*", ".*").r - val _tempFunctions = tempFunctions.keys().asScala + val _tempFunctions = tempFunctions.keys.toSeq .filter { f => regex.pattern.matcher(f).matches() } .map { f => FunctionIdentifier(f) } dbFunctions ++ _tempFunctions @@ -520,7 +522,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * Return a temporary function. For testing only. */ private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = { - Option(tempFunctions.get(name)) + tempFunctions.get(name) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 34803133f6..8bb8e09a28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -169,10 +169,10 @@ abstract class ExternalCatalog { /** * A function defined in the catalog. * - * @param name name of the function + * @param identifier name of the function * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" */ -case class CatalogFunction(name: FunctionIdentifier, className: String) +case class CatalogFunction(identifier: FunctionIdentifier, className: String) /** @@ -216,7 +216,7 @@ case class CatalogTablePartition( * future once we have a better understanding of how we want to handle skewed columns. */ case class CatalogTable( - name: TableIdentifier, + identifier: TableIdentifier, tableType: CatalogTableType, storage: CatalogStorageFormat, schema: Seq[CatalogColumn], @@ -230,12 +230,12 @@ case class CatalogTable( viewText: Option[String] = None) { /** Return the database this table was specified to belong to, assuming it exists. */ - def database: String = name.database.getOrElse { - throw new AnalysisException(s"table $name did not specify database") + def database: String = identifier.database.getOrElse { + throw new AnalysisException(s"table $identifier did not specify database") } /** Return the fully qualified name of this table, assuming the database was specified. */ - def qualifiedName: String = name.unquotedString + def qualifiedName: String = identifier.unquotedString /** Syntactic sugar to update a field in `storage`. */ def withNewStorage( @@ -290,6 +290,6 @@ case class CatalogRelation( // TODO: implement this override def output: Seq[Attribute] = Seq.empty - require(metadata.name.database == Some(db), + require(metadata.identifier.database == Some(db), "provided database does not much the one specified in the table definition") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 6fa4beed99..34cb97699d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -31,7 +31,7 @@ trait AnalysisTest extends PlanTest { private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, conf) - catalog.createTempTable("TaBlE", TestRelations.testRelation, ignoreIfExists = true) + catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true) new Analyzer(catalog, EmptyFunctionRegistry, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 31501864a8..6c08ccc34c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -52,7 +52,7 @@ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { private val b: Expression = UnresolvedAttribute("b") before { - catalog.createTempTable("table", relation, ignoreIfExists = true) + catalog.createTempTable("table", relation, overrideIfExists = true) } private def checkType(expression: Expression, expectedType: DataType): Unit = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 277c2d717e..959bd564d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -210,7 +210,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { } test("get table") { - assert(newBasicCatalog().getTable("db2", "tbl1").name.table == "tbl1") + assert(newBasicCatalog().getTable("db2", "tbl1").identifier.table == "tbl1") } test("get table when database/table does not exist") { @@ -452,7 +452,7 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { assert(catalog.getFunction("db2", "func1").className == funcClass) catalog.renameFunction("db2", "func1", newName) intercept[AnalysisException] { catalog.getFunction("db2", "func1") } - assert(catalog.getFunction("db2", newName).name.funcName == newName) + assert(catalog.getFunction("db2", newName).identifier.funcName == newName) assert(catalog.getFunction("db2", newName).className == funcClass) intercept[AnalysisException] { catalog.renameFunction("db2", "does_not_exist", "me") } } @@ -549,7 +549,7 @@ abstract class CatalogTestUtils { def newTable(name: String, database: Option[String] = None): CatalogTable = { CatalogTable( - name = TableIdentifier(name, database), + identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL_TABLE, storage = storageFormat, schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 74e995cc5b..2948c5f8bd 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -197,17 +197,17 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable1 = Range(1, 10, 1, 10, Seq()) val tempTable2 = Range(1, 20, 2, 10, Seq()) - catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false) - catalog.createTempTable("tbl2", tempTable2, ignoreIfExists = false) + catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) + catalog.createTempTable("tbl2", tempTable2, overrideIfExists = false) assert(catalog.getTempTable("tbl1") == Some(tempTable1)) assert(catalog.getTempTable("tbl2") == Some(tempTable2)) assert(catalog.getTempTable("tbl3") == None) // Temporary table already exists intercept[AnalysisException] { - catalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false) + catalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) } // Temporary table already exists but we override it - catalog.createTempTable("tbl1", tempTable2, ignoreIfExists = true) + catalog.createTempTable("tbl1", tempTable2, overrideIfExists = true) assert(catalog.getTempTable("tbl1") == Some(tempTable2)) } @@ -243,7 +243,7 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) val tempTable = Range(1, 10, 2, 10, Seq()) - sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false) + sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) @@ -255,7 +255,7 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false) assert(externalCatalog.listTables("db2").toSet == Set("tbl2")) // If database is specified, temp tables are never dropped - sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false) + sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) sessionCatalog.createTable(newTable("tbl1", "db2"), ignoreIfExists = false) sessionCatalog.dropTable(TableIdentifier("tbl1", Some("db2")), ignoreIfNotExists = false) assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) @@ -299,7 +299,7 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) val tempTable = Range(1, 10, 2, 10, Seq()) - sessionCatalog.createTempTable("tbl1", tempTable, ignoreIfExists = false) + sessionCatalog.createTempTable("tbl1", tempTable, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") assert(sessionCatalog.getTempTable("tbl1") == Some(tempTable)) assert(externalCatalog.listTables("db2").toSet == Set("tbl1", "tbl2")) @@ -327,7 +327,7 @@ class SessionCatalogSuite extends SparkFunSuite { assert(newTbl1.properties.get("toh") == Some("frem")) // Alter table without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.alterTable(tbl1.copy(name = TableIdentifier("tbl1"))) + sessionCatalog.alterTable(tbl1.copy(identifier = TableIdentifier("tbl1"))) val newestTbl1 = externalCatalog.getTable("db2", "tbl1") assert(newestTbl1 == tbl1) } @@ -368,7 +368,7 @@ class SessionCatalogSuite extends SparkFunSuite { val sessionCatalog = new SessionCatalog(externalCatalog) val tempTable1 = Range(1, 10, 1, 10, Seq()) val metastoreTable1 = externalCatalog.getTable("db2", "tbl1") - sessionCatalog.createTempTable("tbl1", tempTable1, ignoreIfExists = false) + sessionCatalog.createTempTable("tbl1", tempTable1, overrideIfExists = false) sessionCatalog.setCurrentDatabase("db2") // If we explicitly specify the database, we'll look up the relation in that database assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1", Some("db2"))) @@ -406,7 +406,7 @@ class SessionCatalogSuite extends SparkFunSuite { assert(!catalog.tableExists(TableIdentifier("tbl2", Some("db1")))) // If database is explicitly specified, do not check temporary tables val tempTable = Range(1, 10, 1, 10, Seq()) - catalog.createTempTable("tbl3", tempTable, ignoreIfExists = false) + catalog.createTempTable("tbl3", tempTable, overrideIfExists = false) assert(!catalog.tableExists(TableIdentifier("tbl3", Some("db2")))) // If database is not explicitly specified, check the current database catalog.setCurrentDatabase("db2") @@ -418,8 +418,8 @@ class SessionCatalogSuite extends SparkFunSuite { test("list tables without pattern") { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable = Range(1, 10, 2, 10, Seq()) - catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false) - catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false) + catalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + catalog.createTempTable("tbl4", tempTable, overrideIfExists = false) assert(catalog.listTables("db1").toSet == Set(TableIdentifier("tbl1"), TableIdentifier("tbl4"))) assert(catalog.listTables("db2").toSet == @@ -435,8 +435,8 @@ class SessionCatalogSuite extends SparkFunSuite { test("list tables with pattern") { val catalog = new SessionCatalog(newBasicCatalog()) val tempTable = Range(1, 10, 2, 10, Seq()) - catalog.createTempTable("tbl1", tempTable, ignoreIfExists = false) - catalog.createTempTable("tbl4", tempTable, ignoreIfExists = false) + catalog.createTempTable("tbl1", tempTable, overrideIfExists = false) + catalog.createTempTable("tbl4", tempTable, overrideIfExists = false) assert(catalog.listTables("db1", "*").toSet == catalog.listTables("db1").toSet) assert(catalog.listTables("db2", "*").toSet == catalog.listTables("db2").toSet) assert(catalog.listTables("db2", "tbl*").toSet == @@ -826,7 +826,7 @@ class SessionCatalogSuite extends SparkFunSuite { sessionCatalog.createFunction(newFunc("func1", Some("db2"))) sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4")) assert(sessionCatalog.getTempFunction("func4") == - Some(tempFunc.copy(name = FunctionIdentifier("func4")))) + Some(tempFunc.copy(identifier = FunctionIdentifier("func4")))) assert(sessionCatalog.getTempFunction("func1") == None) assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3")) // Then, if no such temporary function exist, rename the function in the current database diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index e413e77bc1..c94600925f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -671,7 +671,7 @@ class SQLContext private[sql]( sessionState.catalog.createTempTable( sessionState.sqlParser.parseTableIdentifier(tableName).table, df.logicalPlan, - ignoreIfExists = true) + overrideIfExists = true) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 24923bbb10..877e159fbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -107,7 +107,7 @@ case class CreateTempTableUsing( sqlContext.sessionState.catalog.createTempTable( tableIdent.table, Dataset.ofRows(sqlContext, LogicalRelation(dataSource.resolveRelation())).logicalPlan, - ignoreIfExists = true) + overrideIfExists = true) Seq.empty[Row] } @@ -138,7 +138,7 @@ case class CreateTempTableUsingAsSelect( sqlContext.sessionState.catalog.createTempTable( tableIdent.table, Dataset.ofRows(sqlContext, LogicalRelation(result)).logicalPlan, - ignoreIfExists = true) + overrideIfExists = true) Seq.empty[Row] } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala deleted file mode 100644 index 0722fb02a8..0000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveCatalog.scala +++ /dev/null @@ -1,297 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import scala.util.control.NonFatal - -import org.apache.hadoop.hive.ql.metadata.HiveException -import org.apache.thrift.TException - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchItemException -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.hive.client.HiveClient - - -/** - * A persistent implementation of the system catalog using Hive. - * All public methods must be synchronized for thread-safety. - */ -private[spark] class HiveCatalog(client: HiveClient) extends ExternalCatalog with Logging { - import ExternalCatalog._ - - // Exceptions thrown by the hive client that we would like to wrap - private val clientExceptions = Set( - classOf[HiveException].getCanonicalName, - classOf[TException].getCanonicalName) - - /** - * Whether this is an exception thrown by the hive client that should be wrapped. - * - * Due to classloader isolation issues, pattern matching won't work here so we need - * to compare the canonical names of the exceptions, which we assume to be stable. - */ - private def isClientException(e: Throwable): Boolean = { - var temp: Class[_] = e.getClass - var found = false - while (temp != null && !found) { - found = clientExceptions.contains(temp.getCanonicalName) - temp = temp.getSuperclass - } - found - } - - /** - * Run some code involving `client` in a [[synchronized]] block and wrap certain - * exceptions thrown in the process in [[AnalysisException]]. - */ - private def withClient[T](body: => T): T = synchronized { - try { - body - } catch { - case e: NoSuchItemException => - throw new AnalysisException(e.getMessage) - case NonFatal(e) if isClientException(e) => - throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) - } - } - - private def requireDbMatches(db: String, table: CatalogTable): Unit = { - if (table.name.database != Some(db)) { - throw new AnalysisException( - s"Provided database $db does not much the one specified in the " + - s"table definition (${table.name.database.getOrElse("n/a")})") - } - } - - private def requireTableExists(db: String, table: String): Unit = { - withClient { getTable(db, table) } - } - - // -------------------------------------------------------------------------- - // Databases - // -------------------------------------------------------------------------- - - override def createDatabase( - dbDefinition: CatalogDatabase, - ignoreIfExists: Boolean): Unit = withClient { - client.createDatabase(dbDefinition, ignoreIfExists) - } - - override def dropDatabase( - db: String, - ignoreIfNotExists: Boolean, - cascade: Boolean): Unit = withClient { - client.dropDatabase(db, ignoreIfNotExists, cascade) - } - - /** - * Alter a database whose name matches the one specified in `dbDefinition`, - * assuming the database exists. - * - * Note: As of now, this only supports altering database properties! - */ - override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { - val existingDb = getDatabase(dbDefinition.name) - if (existingDb.properties == dbDefinition.properties) { - logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + - s"the provided database properties are the same as the old ones. Hive does not " + - s"currently support altering other database fields.") - } - client.alterDatabase(dbDefinition) - } - - override def getDatabase(db: String): CatalogDatabase = withClient { - client.getDatabase(db) - } - - override def databaseExists(db: String): Boolean = withClient { - client.getDatabaseOption(db).isDefined - } - - override def listDatabases(): Seq[String] = withClient { - client.listDatabases("*") - } - - override def listDatabases(pattern: String): Seq[String] = withClient { - client.listDatabases(pattern) - } - - override def setCurrentDatabase(db: String): Unit = withClient { - client.setCurrentDatabase(db) - } - - // -------------------------------------------------------------------------- - // Tables - // -------------------------------------------------------------------------- - - override def createTable( - db: String, - tableDefinition: CatalogTable, - ignoreIfExists: Boolean): Unit = withClient { - requireDbExists(db) - requireDbMatches(db, tableDefinition) - client.createTable(tableDefinition, ignoreIfExists) - } - - override def dropTable( - db: String, - table: String, - ignoreIfNotExists: Boolean): Unit = withClient { - requireDbExists(db) - client.dropTable(db, table, ignoreIfNotExists) - } - - override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { - val newTable = client.getTable(db, oldName).copy(name = TableIdentifier(newName, Some(db))) - client.alterTable(oldName, newTable) - } - - /** - * Alter a table whose name that matches the one specified in `tableDefinition`, - * assuming the table exists. - * - * Note: As of now, this only supports altering table properties, serde properties, - * and num buckets! - */ - override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient { - requireDbMatches(db, tableDefinition) - requireTableExists(db, tableDefinition.name.table) - client.alterTable(tableDefinition) - } - - override def getTable(db: String, table: String): CatalogTable = withClient { - client.getTable(db, table) - } - - override def tableExists(db: String, table: String): Boolean = withClient { - client.getTableOption(db, table).isDefined - } - - override def listTables(db: String): Seq[String] = withClient { - requireDbExists(db) - client.listTables(db) - } - - override def listTables(db: String, pattern: String): Seq[String] = withClient { - requireDbExists(db) - client.listTables(db, pattern) - } - - // -------------------------------------------------------------------------- - // Partitions - // -------------------------------------------------------------------------- - - override def createPartitions( - db: String, - table: String, - parts: Seq[CatalogTablePartition], - ignoreIfExists: Boolean): Unit = withClient { - requireTableExists(db, table) - client.createPartitions(db, table, parts, ignoreIfExists) - } - - override def dropPartitions( - db: String, - table: String, - parts: Seq[TablePartitionSpec], - ignoreIfNotExists: Boolean): Unit = withClient { - requireTableExists(db, table) - // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we - // need to implement it here ourselves. This is currently somewhat expensive because - // we make multiple synchronous calls to Hive for each partition we want to drop. - val partsToDrop = - if (ignoreIfNotExists) { - parts.filter { spec => - try { - getPartition(db, table, spec) - true - } catch { - // Filter out the partitions that do not actually exist - case _: AnalysisException => false - } - } - } else { - parts - } - if (partsToDrop.nonEmpty) { - client.dropPartitions(db, table, partsToDrop) - } - } - - override def renamePartitions( - db: String, - table: String, - specs: Seq[TablePartitionSpec], - newSpecs: Seq[TablePartitionSpec]): Unit = withClient { - client.renamePartitions(db, table, specs, newSpecs) - } - - override def alterPartitions( - db: String, - table: String, - newParts: Seq[CatalogTablePartition]): Unit = withClient { - client.alterPartitions(db, table, newParts) - } - - override def getPartition( - db: String, - table: String, - spec: TablePartitionSpec): CatalogTablePartition = withClient { - client.getPartition(db, table, spec) - } - - override def listPartitions( - db: String, - table: String): Seq[CatalogTablePartition] = withClient { - client.getAllPartitions(db, table) - } - - // -------------------------------------------------------------------------- - // Functions - // -------------------------------------------------------------------------- - - override def createFunction( - db: String, - funcDefinition: CatalogFunction): Unit = withClient { - client.createFunction(db, funcDefinition) - } - - override def dropFunction(db: String, name: String): Unit = withClient { - client.dropFunction(db, name) - } - - override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { - client.renameFunction(db, oldName, newName) - } - - override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient { - client.alterFunction(db, funcDefinition) - } - - override def getFunction(db: String, funcName: String): CatalogFunction = withClient { - client.getFunction(db, funcName) - } - - override def listFunctions(db: String, pattern: String): Seq[String] = withClient { - client.listFunctions(db, pattern) - } - -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index ca3ce43591..c0b6d16d3c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -86,7 +86,7 @@ class HiveContext private[hive]( @transient private[hive] val executionHive: HiveClientImpl, @transient private[hive] val metadataHive: HiveClient, isRootContext: Boolean, - @transient private[sql] val hiveCatalog: HiveCatalog) + @transient private[sql] val hiveCatalog: HiveExternalCatalog) extends SQLContext(sc, cacheManager, listener, isRootContext, hiveCatalog) with Logging { self => @@ -98,7 +98,7 @@ class HiveContext private[hive]( execHive, metaHive, true, - new HiveCatalog(metaHive)) + new HiveExternalCatalog(metaHive)) } def this(sc: SparkContext) = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala new file mode 100644 index 0000000000..f75509fe80 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -0,0 +1,298 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.hadoop.hive.ql.metadata.HiveException +import org.apache.thrift.TException + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.NoSuchItemException +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.client.HiveClient + + +/** + * A persistent implementation of the system catalog using Hive. + * All public methods must be synchronized for thread-safety. + */ +private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCatalog with Logging { + import ExternalCatalog._ + + // Exceptions thrown by the hive client that we would like to wrap + private val clientExceptions = Set( + classOf[HiveException].getCanonicalName, + classOf[TException].getCanonicalName) + + /** + * Whether this is an exception thrown by the hive client that should be wrapped. + * + * Due to classloader isolation issues, pattern matching won't work here so we need + * to compare the canonical names of the exceptions, which we assume to be stable. + */ + private def isClientException(e: Throwable): Boolean = { + var temp: Class[_] = e.getClass + var found = false + while (temp != null && !found) { + found = clientExceptions.contains(temp.getCanonicalName) + temp = temp.getSuperclass + } + found + } + + /** + * Run some code involving `client` in a [[synchronized]] block and wrap certain + * exceptions thrown in the process in [[AnalysisException]]. + */ + private def withClient[T](body: => T): T = synchronized { + try { + body + } catch { + case e: NoSuchItemException => + throw new AnalysisException(e.getMessage) + case NonFatal(e) if isClientException(e) => + throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) + } + } + + private def requireDbMatches(db: String, table: CatalogTable): Unit = { + if (table.identifier.database != Some(db)) { + throw new AnalysisException( + s"Provided database $db does not much the one specified in the " + + s"table definition (${table.identifier.database.getOrElse("n/a")})") + } + } + + private def requireTableExists(db: String, table: String): Unit = { + withClient { getTable(db, table) } + } + + // -------------------------------------------------------------------------- + // Databases + // -------------------------------------------------------------------------- + + override def createDatabase( + dbDefinition: CatalogDatabase, + ignoreIfExists: Boolean): Unit = withClient { + client.createDatabase(dbDefinition, ignoreIfExists) + } + + override def dropDatabase( + db: String, + ignoreIfNotExists: Boolean, + cascade: Boolean): Unit = withClient { + client.dropDatabase(db, ignoreIfNotExists, cascade) + } + + /** + * Alter a database whose name matches the one specified in `dbDefinition`, + * assuming the database exists. + * + * Note: As of now, this only supports altering database properties! + */ + override def alterDatabase(dbDefinition: CatalogDatabase): Unit = withClient { + val existingDb = getDatabase(dbDefinition.name) + if (existingDb.properties == dbDefinition.properties) { + logWarning(s"Request to alter database ${dbDefinition.name} is a no-op because " + + s"the provided database properties are the same as the old ones. Hive does not " + + s"currently support altering other database fields.") + } + client.alterDatabase(dbDefinition) + } + + override def getDatabase(db: String): CatalogDatabase = withClient { + client.getDatabase(db) + } + + override def databaseExists(db: String): Boolean = withClient { + client.getDatabaseOption(db).isDefined + } + + override def listDatabases(): Seq[String] = withClient { + client.listDatabases("*") + } + + override def listDatabases(pattern: String): Seq[String] = withClient { + client.listDatabases(pattern) + } + + override def setCurrentDatabase(db: String): Unit = withClient { + client.setCurrentDatabase(db) + } + + // -------------------------------------------------------------------------- + // Tables + // -------------------------------------------------------------------------- + + override def createTable( + db: String, + tableDefinition: CatalogTable, + ignoreIfExists: Boolean): Unit = withClient { + requireDbExists(db) + requireDbMatches(db, tableDefinition) + client.createTable(tableDefinition, ignoreIfExists) + } + + override def dropTable( + db: String, + table: String, + ignoreIfNotExists: Boolean): Unit = withClient { + requireDbExists(db) + client.dropTable(db, table, ignoreIfNotExists) + } + + override def renameTable(db: String, oldName: String, newName: String): Unit = withClient { + val newTable = client.getTable(db, oldName) + .copy(identifier = TableIdentifier(newName, Some(db))) + client.alterTable(oldName, newTable) + } + + /** + * Alter a table whose name that matches the one specified in `tableDefinition`, + * assuming the table exists. + * + * Note: As of now, this only supports altering table properties, serde properties, + * and num buckets! + */ + override def alterTable(db: String, tableDefinition: CatalogTable): Unit = withClient { + requireDbMatches(db, tableDefinition) + requireTableExists(db, tableDefinition.identifier.table) + client.alterTable(tableDefinition) + } + + override def getTable(db: String, table: String): CatalogTable = withClient { + client.getTable(db, table) + } + + override def tableExists(db: String, table: String): Boolean = withClient { + client.getTableOption(db, table).isDefined + } + + override def listTables(db: String): Seq[String] = withClient { + requireDbExists(db) + client.listTables(db) + } + + override def listTables(db: String, pattern: String): Seq[String] = withClient { + requireDbExists(db) + client.listTables(db, pattern) + } + + // -------------------------------------------------------------------------- + // Partitions + // -------------------------------------------------------------------------- + + override def createPartitions( + db: String, + table: String, + parts: Seq[CatalogTablePartition], + ignoreIfExists: Boolean): Unit = withClient { + requireTableExists(db, table) + client.createPartitions(db, table, parts, ignoreIfExists) + } + + override def dropPartitions( + db: String, + table: String, + parts: Seq[TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withClient { + requireTableExists(db, table) + // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we + // need to implement it here ourselves. This is currently somewhat expensive because + // we make multiple synchronous calls to Hive for each partition we want to drop. + val partsToDrop = + if (ignoreIfNotExists) { + parts.filter { spec => + try { + getPartition(db, table, spec) + true + } catch { + // Filter out the partitions that do not actually exist + case _: AnalysisException => false + } + } + } else { + parts + } + if (partsToDrop.nonEmpty) { + client.dropPartitions(db, table, partsToDrop) + } + } + + override def renamePartitions( + db: String, + table: String, + specs: Seq[TablePartitionSpec], + newSpecs: Seq[TablePartitionSpec]): Unit = withClient { + client.renamePartitions(db, table, specs, newSpecs) + } + + override def alterPartitions( + db: String, + table: String, + newParts: Seq[CatalogTablePartition]): Unit = withClient { + client.alterPartitions(db, table, newParts) + } + + override def getPartition( + db: String, + table: String, + spec: TablePartitionSpec): CatalogTablePartition = withClient { + client.getPartition(db, table, spec) + } + + override def listPartitions( + db: String, + table: String): Seq[CatalogTablePartition] = withClient { + client.getAllPartitions(db, table) + } + + // -------------------------------------------------------------------------- + // Functions + // -------------------------------------------------------------------------- + + override def createFunction( + db: String, + funcDefinition: CatalogFunction): Unit = withClient { + client.createFunction(db, funcDefinition) + } + + override def dropFunction(db: String, name: String): Unit = withClient { + client.dropFunction(db, name) + } + + override def renameFunction(db: String, oldName: String, newName: String): Unit = withClient { + client.renameFunction(db, oldName, newName) + } + + override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient { + client.alterFunction(db, funcDefinition) + } + + override def getFunction(db: String, funcName: String): CatalogFunction = withClient { + client.getFunction(db, funcName) + } + + override def listFunctions(db: String, pattern: String): Seq[String] = withClient { + client.listFunctions(db, pattern) + } + +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index c7066d7363..eedd12d76a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -102,7 +102,7 @@ private[hive] object HiveSerDe { * Legacy catalog for interacting with the Hive metastore. * * This is still used for things like creating data source tables, but in the future will be - * cleaned up to integrate more nicely with [[HiveCatalog]]. + * cleaned up to integrate more nicely with [[HiveExternalCatalog]]. */ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveContext) extends Logging { @@ -124,8 +124,8 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte private def getQualifiedTableName(t: CatalogTable): QualifiedTableName = { QualifiedTableName( - t.name.database.getOrElse(getCurrentDatabase).toLowerCase, - t.name.table.toLowerCase) + t.identifier.database.getOrElse(getCurrentDatabase).toLowerCase, + t.identifier.table.toLowerCase) } /** A cache of Spark SQL data source tables that have been accessed. */ @@ -299,7 +299,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte def newSparkSQLSpecificMetastoreTable(): CatalogTable = { CatalogTable( - name = TableIdentifier(tblName, Option(dbName)), + identifier = TableIdentifier(tblName, Option(dbName)), tableType = tableType, schema = Nil, storage = CatalogStorageFormat( @@ -319,7 +319,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte assert(relation.partitionSchema.isEmpty) CatalogTable( - name = TableIdentifier(tblName, Option(dbName)), + identifier = TableIdentifier(tblName, Option(dbName)), tableType = tableType, storage = CatalogStorageFormat( locationUri = Some(relation.location.paths.map(_.toUri.toString).head), @@ -431,7 +431,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte alias match { // because hive use things like `_c0` to build the expanded text // currently we cannot support view from "create view v1(c1) as ..." - case None => SubqueryAlias(table.name.table, hive.parseSql(viewText)) + case None => SubqueryAlias(table.identifier.table, hive.parseSql(viewText)) case Some(aliasText) => SubqueryAlias(aliasText, hive.parseSql(viewText)) } } else { @@ -611,7 +611,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) execution.CreateViewAsSelect( - table.copy(name = TableIdentifier(tblName, Some(dbName))), + table.copy(identifier = TableIdentifier(tblName, Some(dbName))), child, allowExisting, replace) @@ -633,7 +633,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte if (hive.convertCTAS && table.storage.serde.isEmpty) { // Do the conversion when spark.sql.hive.convertCTAS is true and the query // does not specify any storage format (file format and storage handler). - if (table.name.database.isDefined) { + if (table.identifier.database.isDefined) { throw new AnalysisException( "Cannot specify database name in a CTAS statement " + "when spark.sql.hive.convertCTAS is set to true.") @@ -641,7 +641,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val mode = if (allowExisting) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTableUsingAsSelect( - TableIdentifier(desc.name.table), + TableIdentifier(desc.identifier.table), conf.defaultDataSourceName, temporary = false, Array.empty[String], @@ -662,7 +662,7 @@ private[hive] class HiveMetastoreCatalog(val client: HiveClient, hive: HiveConte val QualifiedTableName(dbName, tblName) = getQualifiedTableName(table) execution.CreateTableAsSelect( - desc.copy(name = TableIdentifier(tblName, Some(dbName))), + desc.copy(identifier = TableIdentifier(tblName, Some(dbName))), child, allowExisting) } @@ -792,7 +792,7 @@ private[hive] case class MetastoreRelation( // We start by constructing an API table as Hive performs several important transformations // internally when converting an API table to a QL table. val tTable = new org.apache.hadoop.hive.metastore.api.Table() - tTable.setTableName(table.name.table) + tTable.setTableName(table.identifier.table) tTable.setDbName(table.database) val tableParameters = new java.util.HashMap[String, String]() diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index e5bcb9b1db..b3ec95fc73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -60,7 +60,7 @@ private[hive] case class CreateTableAsSelect( override def output: Seq[Attribute] = Seq.empty[Attribute] override lazy val resolved: Boolean = - tableDesc.name.database.isDefined && + tableDesc.identifier.database.isDefined && tableDesc.schema.nonEmpty && tableDesc.storage.serde.isDefined && tableDesc.storage.inputFormat.isDefined && @@ -183,7 +183,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging val tableIdentifier = extractTableIdent(viewNameParts) val originalText = query.source val tableDesc = CatalogTable( - name = tableIdentifier, + identifier = tableIdentifier, tableType = CatalogTableType.VIRTUAL_VIEW, schema = schema, storage = CatalogStorageFormat( @@ -352,7 +352,7 @@ private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging // TODO add bucket support var tableDesc: CatalogTable = CatalogTable( - name = tableIdentifier, + identifier = tableIdentifier, tableType = if (externalTable.isDefined) { CatalogTableType.EXTERNAL_TABLE diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index aa44cba4b5..ec7bf61be1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.types.StructType class HiveSessionCatalog( - externalCatalog: HiveCatalog, + externalCatalog: HiveExternalCatalog, client: HiveClient, context: HiveContext, conf: SQLConf) @@ -41,11 +41,11 @@ class HiveSessionCatalog( override def lookupRelation(name: TableIdentifier, alias: Option[String]): LogicalPlan = { val table = formatTableName(name.table) - if (name.database.isDefined || !tempTables.containsKey(table)) { + if (name.database.isDefined || !tempTables.contains(table)) { val newName = name.copy(table = table) metastoreCatalog.lookupRelation(newName, alias) } else { - val relation = tempTables.get(table) + val relation = tempTables(table) val tableWithQualifiers = SubqueryAlias(table, relation) // If an alias was specified by the lookup, wrap the plan in a subquery so that // attributes are properly qualified with this alias. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index f4d30358ca..ee56f9d75d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -88,7 +88,7 @@ private[hive] trait HiveClient { def dropTable(dbName: String, tableName: String, ignoreIfNotExists: Boolean): Unit /** Alter a table whose name matches the one specified in `table`, assuming it exists. */ - final def alterTable(table: CatalogTable): Unit = alterTable(table.name.table, table) + final def alterTable(table: CatalogTable): Unit = alterTable(table.identifier.table, table) /** Updates the given table with new metadata, optionally renaming the table. */ def alterTable(tableName: String, table: CatalogTable): Unit diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index e4e15d13df..a31178e347 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -298,7 +298,7 @@ private[hive] class HiveClientImpl( logDebug(s"Looking up $dbName.$tableName") Option(client.getTable(dbName, tableName, false)).map { h => CatalogTable( - name = TableIdentifier(h.getTableName, Option(h.getDbName)), + identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { case HiveTableType.EXTERNAL_TABLE => CatalogTableType.EXTERNAL_TABLE case HiveTableType.MANAGED_TABLE => CatalogTableType.MANAGED_TABLE @@ -544,13 +544,14 @@ private[hive] class HiveClientImpl( } override def renameFunction(db: String, oldName: String, newName: String): Unit = withHiveState { - val catalogFunc = getFunction(db, oldName).copy(name = FunctionIdentifier(newName, Some(db))) + val catalogFunc = getFunction(db, oldName) + .copy(identifier = FunctionIdentifier(newName, Some(db))) val hiveFunc = toHiveFunction(catalogFunc, db) client.alterFunction(db, oldName, hiveFunc) } override def alterFunction(db: String, func: CatalogFunction): Unit = withHiveState { - client.alterFunction(db, func.name.funcName, toHiveFunction(func, db)) + client.alterFunction(db, func.identifier.funcName, toHiveFunction(func, db)) } override def getFunctionOption( @@ -611,7 +612,7 @@ private[hive] class HiveClientImpl( private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { new HiveFunction( - f.name.funcName, + f.identifier.funcName, db, f.className, null, @@ -639,7 +640,7 @@ private[hive] class HiveClientImpl( } private def toHiveTable(table: CatalogTable): HiveTable = { - val hiveTable = new HiveTable(table.database, table.name.table) + val hiveTable = new HiveTable(table.database, table.identifier.table) hiveTable.setTableType(table.tableType match { case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala index 5a61eef0f2..29f7dc2997 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateTableAsSelect.scala @@ -38,7 +38,7 @@ case class CreateTableAsSelect( allowExisting: Boolean) extends RunnableCommand { - private val tableIdentifier = tableDesc.name + private val tableIdentifier = tableDesc.identifier override def children: Seq[LogicalPlan] = Seq(query) @@ -93,6 +93,8 @@ case class CreateTableAsSelect( } override def argString: String = { - s"[Database:${tableDesc.database}}, TableName: ${tableDesc.name.table}, InsertIntoHiveTable]" + s"[Database:${tableDesc.database}}, " + + s"TableName: ${tableDesc.identifier.table}, " + + s"InsertIntoHiveTable]" } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala index 9ff520da1d..33cd8b4480 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/CreateViewAsSelect.scala @@ -44,7 +44,7 @@ private[hive] case class CreateViewAsSelect( assert(tableDesc.schema == Nil || tableDesc.schema.length == childSchema.length) assert(tableDesc.viewText.isDefined) - private val tableIdentifier = tableDesc.name + private val tableIdentifier = tableDesc.identifier override def run(sqlContext: SQLContext): Seq[Row] = { val hiveContext = sqlContext.asInstanceOf[HiveContext] @@ -116,7 +116,7 @@ private[hive] case class CreateViewAsSelect( } val viewText = tableDesc.viewText.get - val viewName = quote(tableDesc.name.table) + val viewName = quote(tableDesc.identifier.table) s"SELECT $viewOutput FROM ($viewText) $viewName" } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index a1785ca038..4afc8d18a6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -78,7 +78,7 @@ class TestHiveContext private[hive]( executionHive: HiveClientImpl, metadataHive: HiveClient, isRootContext: Boolean, - hiveCatalog: HiveCatalog, + hiveCatalog: HiveExternalCatalog, val warehousePath: File, val scratchDirPath: File, metastoreTemporaryConf: Map[String, String]) @@ -110,7 +110,7 @@ class TestHiveContext private[hive]( executionHive, metadataHive, true, - new HiveCatalog(metadataHive), + new HiveExternalCatalog(metadataHive), warehousePath, scratchDirPath, metastoreTemporaryConf) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala deleted file mode 100644 index 427f5747a0..0000000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveCatalogSuite.scala +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.util.VersionInfo - -import org.apache.spark.SparkConf -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader} -import org.apache.spark.util.Utils - -/** - * Test suite for the [[HiveCatalog]]. - */ -class HiveCatalogSuite extends CatalogTestCases { - - private val client: HiveClient = { - IsolatedClientLoader.forVersion( - hiveMetastoreVersion = HiveContext.hiveExecutionVersion, - hadoopVersion = VersionInfo.getVersion, - sparkConf = new SparkConf(), - hadoopConf = new Configuration()).createClient() - } - - protected override val utils: CatalogTestUtils = new CatalogTestUtils { - override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" - override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat" - override def newEmptyCatalog(): ExternalCatalog = new HiveCatalog(client) - } - - protected override def resetState(): Unit = client.reset() - -} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala new file mode 100644 index 0000000000..3334c16f0b --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.util.VersionInfo + +import org.apache.spark.SparkConf +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.hive.client.{HiveClient, IsolatedClientLoader} +import org.apache.spark.util.Utils + +/** + * Test suite for the [[HiveExternalCatalog]]. + */ +class HiveExternalCatalogSuite extends CatalogTestCases { + + private val client: HiveClient = { + IsolatedClientLoader.forVersion( + hiveMetastoreVersion = HiveContext.hiveExecutionVersion, + hadoopVersion = VersionInfo.getVersion, + sparkConf = new SparkConf(), + hadoopConf = new Configuration()).createClient() + } + + protected override val utils: CatalogTestUtils = new CatalogTestUtils { + override val tableInputFormat: String = "org.apache.hadoop.mapred.SequenceFileInputFormat" + override val tableOutputFormat: String = "org.apache.hadoop.mapred.SequenceFileOutputFormat" + override def newEmptyCatalog(): ExternalCatalog = new HiveExternalCatalog(client) + } + + protected override def resetState(): Unit = client.reset() + +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 1c775db9b6..0aaf57649c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -54,8 +54,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { val (desc, exists) = extractTableDesc(s1) assert(exists) - assert(desc.name.database == Some("mydb")) - assert(desc.name.table == "page_view") + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) assert(desc.storage.locationUri == Some("/user/external/page_view")) assert(desc.schema == @@ -100,8 +100,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { val (desc, exists) = extractTableDesc(s2) assert(exists) - assert(desc.name.database == Some("mydb")) - assert(desc.name.table == "page_view") + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) assert(desc.storage.locationUri == Some("/user/external/page_view")) assert(desc.schema == @@ -127,8 +127,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { val s3 = """CREATE TABLE page_view AS SELECT * FROM src""" val (desc, exists) = extractTableDesc(s3) assert(exists == false) - assert(desc.name.database == None) - assert(desc.name.table == "page_view") + assert(desc.identifier.database == None) + assert(desc.identifier.table == "page_view") assert(desc.tableType == CatalogTableType.MANAGED_TABLE) assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) @@ -162,8 +162,8 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { | ORDER BY key, value""".stripMargin val (desc, exists) = extractTableDesc(s5) assert(exists == false) - assert(desc.name.database == None) - assert(desc.name.table == "ctas2") + assert(desc.identifier.database == None) + assert(desc.identifier.table == "ctas2") assert(desc.tableType == CatalogTableType.MANAGED_TABLE) assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala index 5272f4192e..e8188e5f02 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ListTablesSuite.scala @@ -34,7 +34,7 @@ class ListTablesSuite extends QueryTest with TestHiveSingleton with BeforeAndAft super.beforeAll() // The catalog in HiveContext is a case insensitive one. sessionState.catalog.createTempTable( - "ListTablesSuiteTable", df.logicalPlan, ignoreIfExists = true) + "ListTablesSuiteTable", df.logicalPlan, overrideIfExists = true) sql("CREATE TABLE HiveListTablesSuiteTable (key int, value string)") sql("CREATE DATABASE IF NOT EXISTS ListTablesSuiteDB") sql("CREATE TABLE ListTablesSuiteDB.HiveInDBListTablesSuiteTable (key int, value string)") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala index 71652897e6..3c299daa77 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala @@ -722,7 +722,7 @@ class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with TestHiv withTable(tableName) { val schema = StructType(StructField("int", IntegerType, true) :: Nil) val hiveTable = CatalogTable( - name = TableIdentifier(tableName, Some("default")), + identifier = TableIdentifier(tableName, Some("default")), tableType = CatalogTableType.MANAGED_TABLE, schema = Seq.empty, storage = CatalogStorageFormat( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala index d59bca4c7e..8b0719209d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/VersionsSuite.scala @@ -148,7 +148,7 @@ class VersionsSuite extends SparkFunSuite with Logging { test(s"$version: createTable") { val table = CatalogTable( - name = TableIdentifier("src", Some("default")), + identifier = TableIdentifier("src", Some("default")), tableType = CatalogTableType.MANAGED_TABLE, schema = Seq(CatalogColumn("key", "int")), storage = CatalogStorageFormat( -- cgit v1.2.3 From b7836492bb0b5b430539d2bfa20bcc32e3fe3504 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 28 Mar 2016 16:26:32 -0700 Subject: [SPARK-14155][SQL] Hide UserDefinedType interface in Spark 2.0 ## What changes were proposed in this pull request? UserDefinedType is a developer API in Spark 1.x. With very high probability we will create a new API for user-defined type that also works well with column batches as well as encoders (datasets). In Spark 2.0, let's make `UserDefinedType` `private[spark]` first. ## How was this patch tested? Existing unit tests. Author: Reynold Xin Closes #11955 from rxin/SPARK-14155. --- .../src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala index dabf9a2fc0..fb7251d71b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UserDefinedType.scala @@ -23,7 +23,6 @@ import org.json4s.JsonDSL._ import org.apache.spark.annotation.DeveloperApi /** - * ::DeveloperApi:: * The data type for User Defined Types (UDTs). * * This interface allows a user to make their own classes more interoperable with SparkSQL; @@ -35,8 +34,11 @@ import org.apache.spark.annotation.DeveloperApi * * The conversion via `serialize` occurs when instantiating a `DataFrame` from another RDD. * The conversion via `deserialize` occurs when reading from a `DataFrame`. + * + * Note: This was previously a developer API in Spark 1.x. We are making this private in Spark 2.0 + * because we will very likely create a new version of this that works better with Datasets. */ -@DeveloperApi +private[spark] abstract class UserDefinedType[UserType >: Null] extends DataType with Serializable { /** Underlying storage type for this UDT */ -- cgit v1.2.3 From 27aab80695cfcf0c0ecf1e98a5a862a8123213a1 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 28 Mar 2016 16:45:02 -0700 Subject: [SPARK-14013][SQL] Proper temp function support in catalog ## What changes were proposed in this pull request? Session catalog was added in #11750. However, it doesn't really support temporary functions properly; right now we only store the metadata in the form of `CatalogFunction`, but this doesn't make sense for temporary functions because there is no class name. This patch moves the `FunctionRegistry` into the `SessionCatalog`. With this, the user can call `catalog.createTempFunction` and `catalog.lookupFunction` to use the function they registered previously. This is currently still dead code, however. ## How was this patch tested? `SessionCatalogSuite`. Author: Andrew Or Closes #11972 from andrewor14/temp-functions. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 +++- .../sql/catalyst/analysis/FunctionRegistry.scala | 24 +++++++ .../sql/catalyst/catalog/SessionCatalog.scala | 82 +++++++++++++--------- .../spark/sql/catalyst/analysis/AnalysisTest.scala | 2 +- .../catalyst/analysis/DecimalPrecisionSuite.scala | 2 +- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 57 +++++++-------- .../optimizer/BooleanSimplificationSuite.scala | 2 +- .../catalyst/optimizer/EliminateSortsSuite.scala | 2 +- .../apache/spark/sql/internal/SessionState.scala | 8 +-- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 4 +- .../apache/spark/sql/hive/HiveSessionState.scala | 14 ++-- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 10 +++ 12 files changed, 136 insertions(+), 83 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3b83e68018..8dc0532b3f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -42,9 +42,15 @@ import org.apache.spark.sql.types._ * to resolve attribute references. */ object SimpleAnalyzer - extends SimpleAnalyzer(new SimpleCatalystConf(caseSensitiveAnalysis = true)) -class SimpleAnalyzer(conf: CatalystConf) - extends Analyzer(new SessionCatalog(new InMemoryCatalog, conf), EmptyFunctionRegistry, conf) + extends SimpleAnalyzer( + EmptyFunctionRegistry, + new SimpleCatalystConf(caseSensitiveAnalysis = true)) + +class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) + extends Analyzer( + new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), + functionRegistry, + conf) /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f584a4b73a..e9788b7e4d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -45,6 +45,13 @@ trait FunctionRegistry { /* Get the class of the registered function by specified name. */ def lookupFunction(name: String): Option[ExpressionInfo] + + /* Get the builder of the registered function by specified name. */ + def lookupFunctionBuilder(name: String): Option[FunctionBuilder] + + /** Drop a function and return whether the function existed. */ + def dropFunction(name: String): Boolean + } class SimpleFunctionRegistry extends FunctionRegistry { @@ -76,6 +83,14 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.get(name).map(_._1) } + override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = synchronized { + functionBuilders.get(name).map(_._2) + } + + override def dropFunction(name: String): Boolean = synchronized { + functionBuilders.remove(name).isDefined + } + def copy(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry functionBuilders.iterator.foreach { case (name, (info, builder)) => @@ -106,6 +121,15 @@ object EmptyFunctionRegistry extends FunctionRegistry { override def lookupFunction(name: String): Option[ExpressionInfo] = { throw new UnsupportedOperationException } + + override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { + throw new UnsupportedOperationException + } + + override def dropFunction(name: String): Boolean = { + throw new UnsupportedOperationException + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index a9cf80764d..7165db1d5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -22,6 +22,9 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} @@ -32,15 +35,22 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} * * This class is not thread-safe. */ -class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { +class SessionCatalog( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: CatalystConf) { import ExternalCatalog._ + def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) { + this(externalCatalog, functionRegistry, new SimpleCatalystConf(true)) + } + + // For testing only. def this(externalCatalog: ExternalCatalog) { - this(externalCatalog, new SimpleCatalystConf(true)) + this(externalCatalog, new SimpleFunctionRegistry) } protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan] - protected[this] val tempFunctions = new mutable.HashMap[String, CatalogFunction] // Note: we track current database here because certain operations do not explicitly // specify the database (e.g. DROP TABLE my_table). In these cases we must first @@ -431,6 +441,18 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { externalCatalog.alterFunction(db, newFuncDefinition) } + /** + * Retrieve the metadata of a metastore function. + * + * If a database is specified in `name`, this will return the function in that database. + * If no database is specified, this will return the function in the current database. + */ + def getFunction(name: FunctionIdentifier): CatalogFunction = { + val db = name.database.getOrElse(currentDb) + externalCatalog.getFunction(db, name.funcName) + } + + // ---------------------------------------------------------------- // | Methods that interact with temporary and metastore functions | // ---------------------------------------------------------------- @@ -439,14 +461,14 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { * Create a temporary function. * This assumes no database is specified in `funcDefinition`. */ - def createTempFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { - require(funcDefinition.identifier.database.isEmpty, - "attempted to create a temporary function while specifying a database") - val name = funcDefinition.identifier.funcName - if (tempFunctions.contains(name) && !ignoreIfExists) { + def createTempFunction( + name: String, + funcDefinition: FunctionBuilder, + ignoreIfExists: Boolean): Unit = { + if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { throw new AnalysisException(s"Temporary function '$name' already exists.") } - tempFunctions.put(name, funcDefinition) + functionRegistry.registerFunction(name, funcDefinition) } /** @@ -456,11 +478,10 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { // Hive has DROP FUNCTION and DROP TEMPORARY FUNCTION. We may want to consolidate // dropFunction and dropTempFunction. def dropTempFunction(name: String, ignoreIfNotExists: Boolean): Unit = { - if (!tempFunctions.contains(name) && !ignoreIfNotExists) { + if (!functionRegistry.dropFunction(name) && !ignoreIfNotExists) { throw new AnalysisException( s"Temporary function '$name' cannot be dropped because it does not exist!") } - tempFunctions.remove(name) } /** @@ -477,34 +498,29 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { throw new AnalysisException("rename does not support moving functions across databases") } val db = oldName.database.getOrElse(currentDb) - if (oldName.database.isDefined || !tempFunctions.contains(oldName.funcName)) { + val oldBuilder = functionRegistry.lookupFunctionBuilder(oldName.funcName) + if (oldName.database.isDefined || oldBuilder.isEmpty) { externalCatalog.renameFunction(db, oldName.funcName, newName.funcName) } else { - val func = tempFunctions(oldName.funcName) - val newFunc = func.copy(identifier = func.identifier.copy(funcName = newName.funcName)) - tempFunctions.remove(oldName.funcName) - tempFunctions.put(newName.funcName, newFunc) + val oldExpressionInfo = functionRegistry.lookupFunction(oldName.funcName).get + val newExpressionInfo = new ExpressionInfo( + oldExpressionInfo.getClassName, + newName.funcName, + oldExpressionInfo.getUsage, + oldExpressionInfo.getExtended) + functionRegistry.dropFunction(oldName.funcName) + functionRegistry.registerFunction(newName.funcName, newExpressionInfo, oldBuilder.get) } } /** - * Retrieve the metadata of an existing function. - * - * If a database is specified in `name`, this will return the function in that database. - * If no database is specified, this will first attempt to return a temporary function with - * the same name, then, if that does not exist, return the function in the current database. + * Return an [[Expression]] that represents the specified function, assuming it exists. + * Note: This is currently only used for temporary functions. */ - def getFunction(name: FunctionIdentifier): CatalogFunction = { - val db = name.database.getOrElse(currentDb) - if (name.database.isDefined || !tempFunctions.contains(name.funcName)) { - externalCatalog.getFunction(db, name.funcName) - } else { - tempFunctions(name.funcName) - } + def lookupFunction(name: String, children: Seq[Expression]): Expression = { + functionRegistry.lookupFunction(name, children) } - // TODO: implement lookupFunction that returns something from the registry itself - /** * List all matching functions in the specified database, including temporary functions. */ @@ -512,7 +528,7 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } val regex = pattern.replaceAll("\\*", ".*").r - val _tempFunctions = tempFunctions.keys.toSeq + val _tempFunctions = functionRegistry.listFunction() .filter { f => regex.pattern.matcher(f).matches() } .map { f => FunctionIdentifier(f) } dbFunctions ++ _tempFunctions @@ -521,8 +537,8 @@ class SessionCatalog(externalCatalog: ExternalCatalog, conf: CatalystConf) { /** * Return a temporary function. For testing only. */ - private[catalog] def getTempFunction(name: String): Option[CatalogFunction] = { - tempFunctions.get(name) + private[catalog] def getTempFunction(name: String): Option[FunctionBuilder] = { + functionRegistry.lookupFunctionBuilder(name) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 34cb97699d..3ec95ef5b5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -30,7 +30,7 @@ trait AnalysisTest extends PlanTest { private def makeAnalyzer(caseSensitive: Boolean): Analyzer = { val conf = new SimpleCatalystConf(caseSensitive) - val catalog = new SessionCatalog(new InMemoryCatalog, conf) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true) new Analyzer(catalog, EmptyFunctionRegistry, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 6c08ccc34c..1a350bf847 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.types._ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) - private val catalog = new SessionCatalog(new InMemoryCatalog, conf) + private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) private val relation = LocalRelation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 2948c5f8bd..acd97592b6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias} @@ -682,20 +683,20 @@ class SessionCatalogSuite extends SparkFunSuite { test("create temp function") { val catalog = new SessionCatalog(newBasicCatalog()) - val tempFunc1 = newFunc("temp1") - val tempFunc2 = newFunc("temp2") - catalog.createTempFunction(tempFunc1, ignoreIfExists = false) - catalog.createTempFunction(tempFunc2, ignoreIfExists = false) + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last + catalog.createTempFunction("temp1", tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp2", tempFunc2, ignoreIfExists = false) assert(catalog.getTempFunction("temp1") == Some(tempFunc1)) assert(catalog.getTempFunction("temp2") == Some(tempFunc2)) assert(catalog.getTempFunction("temp3") == None) + val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) // Temporary function already exists intercept[AnalysisException] { - catalog.createTempFunction(tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = false) } // Temporary function is overridden - val tempFunc3 = tempFunc1.copy(className = "something else") - catalog.createTempFunction(tempFunc3, ignoreIfExists = true) + catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = true) assert(catalog.getTempFunction("temp1") == Some(tempFunc3)) } @@ -725,8 +726,8 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop temp function") { val catalog = new SessionCatalog(newBasicCatalog()) - val tempFunc = newFunc("func1") - catalog.createTempFunction(tempFunc, ignoreIfExists = false) + val tempFunc = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", tempFunc, ignoreIfExists = false) assert(catalog.getTempFunction("func1") == Some(tempFunc)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) assert(catalog.getTempFunction("func1") == None) @@ -755,20 +756,15 @@ class SessionCatalogSuite extends SparkFunSuite { } } - test("get temp function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val metastoreFunc = externalCatalog.getFunction("db2", "func1") - val tempFunc = newFunc("func1").copy(className = "something weird") - sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - // If a database is specified, we'll always return the function in that database - assert(sessionCatalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == metastoreFunc) - // If no database is specified, we'll first return temporary functions - assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == tempFunc) - // Then, if no such temporary function exist, check the current database - sessionCatalog.dropTempFunction("func1", ignoreIfNotExists = false) - assert(sessionCatalog.getFunction(FunctionIdentifier("func1")) == metastoreFunc) + test("lookup temp function") { + val catalog = new SessionCatalog(newBasicCatalog()) + val tempFunc1 = (e: Seq[Expression]) => e.head + catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false) + assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) + catalog.dropTempFunction("func1", ignoreIfNotExists = false) + intercept[AnalysisException] { + catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) + } } test("rename function") { @@ -813,8 +809,8 @@ class SessionCatalogSuite extends SparkFunSuite { test("rename temp function") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - val tempFunc = newFunc("func1").copy(className = "something weird") - sessionCatalog.createTempFunction(tempFunc, ignoreIfExists = false) + val tempFunc = (e: Seq[Expression]) => e.head + sessionCatalog.createTempFunction("func1", tempFunc, ignoreIfExists = false) sessionCatalog.setCurrentDatabase("db2") // If a database is specified, we'll always rename the function in that database sessionCatalog.renameFunction( @@ -825,8 +821,7 @@ class SessionCatalogSuite extends SparkFunSuite { // If no database is specified, we'll first rename temporary functions sessionCatalog.createFunction(newFunc("func1", Some("db2"))) sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4")) - assert(sessionCatalog.getTempFunction("func4") == - Some(tempFunc.copy(identifier = FunctionIdentifier("func4")))) + assert(sessionCatalog.getTempFunction("func4") == Some(tempFunc)) assert(sessionCatalog.getTempFunction("func1") == None) assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3")) // Then, if no such temporary function exist, rename the function in the current database @@ -858,12 +853,12 @@ class SessionCatalogSuite extends SparkFunSuite { test("list functions") { val catalog = new SessionCatalog(newBasicCatalog()) - val tempFunc1 = newFunc("func1").copy(className = "march") - val tempFunc2 = newFunc("yes_me").copy(className = "april") + val tempFunc1 = (e: Seq[Expression]) => e.head + val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2"))) catalog.createFunction(newFunc("not_me", Some("db2"))) - catalog.createTempFunction(tempFunc1, ignoreIfExists = false) - catalog.createTempFunction(tempFunc2, ignoreIfExists = false) + catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("yes_me", tempFunc2, ignoreIfExists = false) assert(catalog.listFunctions("db1", "*").toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index e2c76b700f..dd6b5cac28 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -140,7 +140,7 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { private val caseInsensitiveConf = new SimpleCatalystConf(false) private val caseInsensitiveAnalyzer = new Analyzer( - new SessionCatalog(new InMemoryCatalog, caseInsensitiveConf), + new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), EmptyFunctionRegistry, caseInsensitiveConf) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 3824c67563..009889d5a1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.rules._ class EliminateSortsSuite extends PlanTest { val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) - val catalog = new SessionCatalog(new InMemoryCatalog, conf) + val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) object Optimize extends RuleExecutor[LogicalPlan] { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 9bc640763f..f7fdfacd31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -44,14 +44,14 @@ private[sql] class SessionState(ctx: SQLContext) { lazy val experimentalMethods = new ExperimentalMethods /** - * Internal catalog for managing table and database states. + * Internal catalog for managing functions registered by the user. */ - lazy val catalog = new SessionCatalog(ctx.externalCatalog, conf) + lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() /** - * Internal catalog for managing functions registered by the user. + * Internal catalog for managing table and database states. */ - lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() + lazy val catalog = new SessionCatalog(ctx.externalCatalog, functionRegistry, conf) /** * Interface exposed to the user for registering user-defined functions. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index ec7bf61be1..ff12245e8d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.hive import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.catalog.SessionCatalog import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule @@ -31,8 +32,9 @@ class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, context: HiveContext, + functionRegistry: FunctionRegistry, conf: SQLConf) - extends SessionCatalog(externalCatalog, conf) { + extends SessionCatalog(externalCatalog, functionRegistry, conf) { override def setCurrentDatabase(db: String): Unit = { super.setCurrentDatabase(db) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index caa7f296ed..c9b6b1dfb6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -34,13 +34,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - /** - * Internal catalog for managing table and database states. - */ - override lazy val catalog = { - new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, conf) - } - /** * Internal catalog for managing functions registered by the user. * Note that HiveUDFs will be overridden by functions registered in this context. @@ -49,6 +42,13 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), ctx.executionHive) } + /** + * Internal catalog for managing table and database states. + */ + override lazy val catalog = { + new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, functionRegistry, conf) + } + /** * An analyzer that uses the Hive metastore. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index efaa052370..c07c428895 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -141,6 +141,16 @@ private[hive] class HiveFunctionRegistry( } }.getOrElse(None)) } + + override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { + underlying.lookupFunctionBuilder(name) + } + + // Note: This does not drop functions stored in the metastore + override def dropFunction(name: String): Boolean = { + underlying.dropFunction(name) + } + } private[hive] case class HiveSimpleUDF( -- cgit v1.2.3 From 27d4ef0c619501be592366ae8f0be77294c9687d Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Mon, 28 Mar 2016 20:19:21 -0700 Subject: [SPARK-14213][SQL] Migrate HiveQl parsing to ANTLR4 parser ### What changes were proposed in this pull request? This PR migrates all HiveQl parsing to the new ANTLR4 parser. This PR is build on top of https://github.com/apache/spark/pull/12011, and we should wait with merging until that one is in (hence the WIP tag). As soon as this PR is merged we can start removing much of the old parser infrastructure. ### How was this patch tested? Exisiting Hive unit tests. cc rxin andrewor14 yhuai Author: Herman van Hovell Closes #12015 from hvanhovell/SPARK-14213. --- .../apache/spark/sql/catalyst/parser/ng/SqlBase.g4 | 34 +- .../spark/sql/hive/HiveMetastoreCatalog.scala | 13 +- .../apache/spark/sql/hive/HiveSessionState.scala | 3 +- .../spark/sql/hive/execution/HiveSqlParser.scala | 442 +++++++++++++++++++++ 4 files changed, 488 insertions(+), 4 deletions(-) create mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 index e46fd9bed5..4e77b6db25 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 @@ -118,7 +118,9 @@ statement | UNCACHE TABLE identifier #uncacheTable | CLEAR CACHE #clearCache | ADD identifier .*? #addResource + | SET ROLE .*? #failNativeCommand | SET .*? #setConfiguration + | kws=unsupportedHiveNativeCommands .*? #failNativeCommand | hiveNativeCommands #executeNativeCommand ; @@ -145,7 +147,26 @@ hiveNativeCommands | ROLLBACK WORK? | SHOW PARTITIONS tableIdentifier partitionSpec? | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD) .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS ; createTableHeader @@ -619,7 +640,8 @@ nonReserved | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION ; SELECT: 'SELECT'; @@ -834,6 +856,14 @@ MSCK: 'MSCK'; EXPORT: 'EXPORT'; IMPORT: 'IMPORT'; LOAD: 'LOAD'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +COMPACTIONS: 'COMPACTIONS'; +PRINCIPALS: 'PRINCIPALS'; +TRANSACTIONS: 'TRANSACTIONS'; +INDEXES: 'INDEXES'; +LOCKS: 'LOCKS'; +OPTION: 'OPTION'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index eedd12d76a..9a5ec9880e 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -85,7 +85,18 @@ private[hive] object HiveSerDe { HiveSerDe( inputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), outputFormat = Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat"), - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe"))) + serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")), + + "textfile" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")), + + "avro" -> + HiveSerDe( + inputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat"), + serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe"))) val key = source.toLowerCase match { case s if s.startsWith("org.apache.spark.sql.parquet") => "parquet" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index c9b6b1dfb6..11ef0fd1bb 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -22,6 +22,7 @@ import org.apache.spark.sql.catalyst.analysis.{Analyzer, FunctionRegistry} import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.execution.{python, SparkPlanner} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.hive.execution.HiveSqlParser import org.apache.spark.sql.internal.{SessionState, SQLConf} @@ -70,7 +71,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) /** * Parser for HiveQl query texts. */ - override lazy val sqlParser: ParserInterface = new HiveQl(conf) + override lazy val sqlParser: ParserInterface = HiveSqlParser /** * Planner that takes into account Hive-specific strategies. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala new file mode 100644 index 0000000000..d6a08fcc57 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -0,0 +1,442 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.hive.execution + +import scala.collection.JavaConverters._ + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.hadoop.hive.ql.exec.FunctionRegistry +import org.apache.hadoop.hive.ql.parse.EximUtil +import org.apache.hadoop.hive.ql.session.SessionState +import org.apache.hadoop.hive.serde.serdeConstants +import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe + +import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.ng._ +import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveSerDe} +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper + +/** + * Concrete parser for HiveQl statements. + */ +object HiveSqlParser extends AbstractSqlParser { + val astBuilder = new HiveSqlAstBuilder + + override protected def nativeCommand(sqlText: String): LogicalPlan = { + HiveNativeCommand(sqlText) + } +} + +/** + * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. + */ +class HiveSqlAstBuilder extends SparkSqlAstBuilder { + import ParserUtils._ + + /** + * Get the current Hive Configuration. + */ + private[this] def hiveConf: HiveConf = { + var ss = SessionState.get() + // SessionState is lazy initialization, it can be null here + if (ss == null) { + val original = Thread.currentThread().getContextClassLoader + val conf = new HiveConf(classOf[SessionState]) + conf.setClassLoader(original) + ss = new SessionState(conf) + SessionState.start(ss) + } + ss.getConf + } + + /** + * Pass a command to Hive using a [[HiveNativeCommand]]. + */ + override def visitExecuteNativeCommand( + ctx: ExecuteNativeCommandContext): LogicalPlan = withOrigin(ctx) { + HiveNativeCommand(command(ctx)) + } + + /** + * Fail an unsupported Hive native command. + */ + override def visitFailNativeCommand( + ctx: FailNativeCommandContext): LogicalPlan = withOrigin(ctx) { + val keywords = if (ctx.kws != null) { + Seq(ctx.kws.kw1, ctx.kws.kw2, ctx.kws.kw3).filter(_ != null).map(_.getText).mkString(" ") + } else { + // SET ROLE is the exception to the rule, because we handle this before other SET commands. + "SET ROLE" + } + throw new ParseException(s"Unsupported operation: $keywords", ctx) + } + + /** + * Create an [[AddJar]] or [[AddFile]] command depending on the requested resource. + */ + override def visitAddResource(ctx: AddResourceContext): LogicalPlan = withOrigin(ctx) { + ctx.identifier.getText.toLowerCase match { + case "file" => AddFile(remainder(ctx.identifier).trim) + case "jar" => AddJar(remainder(ctx.identifier).trim) + case other => throw new ParseException(s"Unsupported resource type '$other'.", ctx) + } + } + + /** + * Create a [[DropTable]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.PURGE != null) { + logWarning("PURGE option is ignored.") + } + if (ctx.REPLICATION != null) { + logWarning("REPLICATION clause is ignored.") + } + DropTable(visitTableIdentifier(ctx.tableIdentifier).toString, ctx.EXISTS != null) + } + + /** + * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other + * options are passed on to Hive) e.g.: + * {{{ + * ANALYZE TABLE table COMPUTE STATISTICS NOSCAN; + * }}} + */ + override def visitAnalyze(ctx: AnalyzeContext): LogicalPlan = withOrigin(ctx) { + if (ctx.partitionSpec == null && + ctx.identifier != null && + ctx.identifier.getText.toLowerCase == "noscan") { + AnalyzeTable(visitTableIdentifier(ctx.tableIdentifier).toString) + } else { + HiveNativeCommand(command(ctx)) + } + } + + /** + * Create a [[CreateTableAsSelect]] command. + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = { + if (ctx.query == null) { + HiveNativeCommand(command(ctx)) + } else { + // Get the table header. + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + val tableType = if (external) { + CatalogTableType.EXTERNAL_TABLE + } else { + CatalogTableType.MANAGED_TABLE + } + + // Unsupported clauses. + if (temp) { + logWarning("TEMPORARY clause is ignored.") + } + if (ctx.bucketSpec != null) { + // TODO add this - we need cluster columns in the CatalogTable for this to work. + logWarning("CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause is ignored.") + } + if (ctx.skewSpec != null) { + logWarning("SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause is ignored.") + } + + // Create the schema. + val schema = Option(ctx.colTypeList).toSeq.flatMap(_.colType.asScala).map { col => + CatalogColumn( + col.identifier.getText, + col.dataType.getText.toLowerCase, // TODO validate this? + nullable = true, + Option(col.STRING).map(string)) + } + + // Get the column by which the table is partitioned. + val partitionCols = Option(ctx.identifierList).toSeq.flatMap(visitIdentifierList).map { + CatalogColumn(_, null, nullable = true, None) + } + + // Create the storage. + def format(fmt: ParserRuleContext): CatalogStorageFormat = { + Option(fmt).map(typedVisit[CatalogStorageFormat]).getOrElse(EmptyStorageFormat) + } + // Default storage. + val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) + val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { + HiveSerDe( + inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), + outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + } + // Defined storage. + val fileStorage = format(ctx.createFileFormat) + val rowStorage = format(ctx.rowFormat) + val storage = CatalogStorageFormat( + Option(ctx.locationSpec).map(visitLocationSpec), + fileStorage.inputFormat.orElse(hiveSerDe.inputFormat), + fileStorage.outputFormat.orElse(hiveSerDe.outputFormat), + rowStorage.serde.orElse(hiveSerDe.serde).orElse(fileStorage.serde), + rowStorage.serdeProperties ++ fileStorage.serdeProperties + ) + + val tableDesc = CatalogTable( + identifier = table, + tableType = tableType, + schema = schema, + partitionColumns = partitionCols, + storage = storage, + properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), + // TODO support the sql text - have a proper location for this! + viewText = Option(ctx.STRING).map(string)) + CTAS(tableDesc, plan(ctx.query), ifNotExists) + } + } + + /** + * Create or replace a view. This creates a [[CreateViewAsSelect]] command. + */ + override def visitCreateView(ctx: CreateViewContext): LogicalPlan = withOrigin(ctx) { + // Pass a partitioned view on to hive. + if (ctx.identifierList != null) { + HiveNativeCommand(command(ctx)) + } else { + if (ctx.STRING != null) { + logWarning("COMMENT clause is ignored.") + } + val identifiers = Option(ctx.identifierCommentList).toSeq.flatMap(_.identifierComment.asScala) + val schema = identifiers.map { ic => + CatalogColumn(ic.identifier.getText, null, nullable = true, Option(ic.STRING).map(string)) + } + createView( + ctx, + ctx.tableIdentifier, + schema, + ctx.query, + Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), + ctx.EXISTS != null, + ctx.REPLACE != null + ) + } + } + + /** + * Alter the query of a view. This creates a [[CreateViewAsSelect]] command. + */ + override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + createView( + ctx, + ctx.tableIdentifier, + Seq.empty, + ctx.query, + Map.empty, + allowExist = false, + replace = true) + } + + /** + * Create a [[CreateViewAsSelect]] command. + */ + private def createView( + ctx: ParserRuleContext, + name: TableIdentifierContext, + schema: Seq[CatalogColumn], + query: QueryContext, + properties: Map[String, String], + allowExist: Boolean, + replace: Boolean): LogicalPlan = { + val sql = Option(source(query)) + val tableDesc = CatalogTable( + identifier = visitTableIdentifier(name), + tableType = CatalogTableType.VIRTUAL_VIEW, + schema = schema, + storage = EmptyStorageFormat, + properties = properties, + viewOriginalText = sql, + viewText = sql) + CreateView(tableDesc, plan(query), allowExist, replace, command(ctx)) + } + + /** + * Create a [[Generator]]. Override this method in order to support custom Generators. + */ + override protected def withGenerator( + name: String, + expressions: Seq[Expression], + ctx: LateralViewContext): Generator = { + val info = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse { + throw new ParseException(s"Couldn't find Generator function '$name'", ctx) + } + HiveGenericUDTF(name, new HiveFunctionWrapper(info.getFunctionClass.getName), expressions) + } + + /** + * Create a [[HiveScriptIOSchema]]. + */ + override protected def withScriptIOSchema( + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): HiveScriptIOSchema = { + if (recordWriter != null || recordReader != null) { + logWarning("Used defined record reader/writer classes are currently ignored.") + } + + // Decode and input/output format. + type Format = (Seq[(String, String)], Option[String], Seq[(String, String)], Option[String]) + def format(fmt: RowFormatContext, confVar: ConfVars): Format = fmt match { + case c: RowFormatDelimitedContext => + // TODO we should use the visitRowFormatDelimited function here. However HiveScriptIOSchema + // expects a seq of pairs in which the old parsers' token names are used as keys. + // Transforming the result of visitRowFormatDelimited would be quite a bit messier than + // retrieving the key value pairs ourselves. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).map(t => key -> t.getText).toSeq + } + val entries = entry("TOK_TABLEROWFORMATFIELD", c.fieldsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATCOLLITEMS", c.collectionItemsTerminatedBy) ++ + entry("TOK_TABLEROWFORMATMAPKEYS", c.keysTerminatedBy) ++ + entry("TOK_TABLEROWFORMATLINES", c.linesSeparatedBy) ++ + entry("TOK_TABLEROWFORMATNULL", c.nullDefinedAs) + + (entries, None, Seq.empty, None) + + case c: RowFormatSerdeContext => + // Use a serde format. + val CatalogStorageFormat(None, None, None, Some(name), props) = visitRowFormatSerde(c) + + // SPARK-10310: Special cases LazySimpleSerDe + val recordHandler = if (name == classOf[LazySimpleSerDe].getCanonicalName) { + Option(hiveConf.getVar(confVar)) + } else { + None + } + (Seq.empty, Option(name), props.toSeq, recordHandler) + + case null => + // Use default (serde) format. + val name = hiveConf.getVar(ConfVars.HIVESCRIPTSERDE) + val props = Seq(serdeConstants.FIELD_DELIM -> "\t") + val recordHandler = Option(hiveConf.getVar(confVar)) + (Nil, Option(name), props, recordHandler) + } + + val (inFormat, inSerdeClass, inSerdeProps, reader) = + format(inRowFormat, ConfVars.HIVESCRIPTRECORDREADER) + + val (outFormat, outSerdeClass, outSerdeProps, writer) = + format(inRowFormat, ConfVars.HIVESCRIPTRECORDWRITER) + + HiveScriptIOSchema( + inFormat, outFormat, + inSerdeClass, outSerdeClass, + inSerdeProps, outSerdeProps, + reader, writer, + schemaLess) + } + + /** + * Create location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = { + EximUtil.relativeToAbsolutePath(hiveConf, super.visitLocationSpec(ctx)) + } + + /** Empty storage format for default values and copies. */ + private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty) + + /** + * Create a [[CatalogStorageFormat]]. The INPUTDRIVER and OUTPUTDRIVER clauses are currently + * ignored. + */ + override def visitTableFileFormat( + ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + import ctx._ + if (inDriver != null || outDriver != null) { + logWarning("INPUTDRIVER ... OUTPUTDRIVER ... clauses are ignored.") + } + EmptyStorageFormat.copy( + inputFormat = Option(string(inFmt)), + outputFormat = Option(string(outFmt)), + serde = Option(serdeCls).map(string) + ) + } + + /** + * Resolve a [[HiveSerDe]] based on the format name given. + */ + override def visitGenericFileFormat( + ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + val source = ctx.identifier.getText + HiveSerDe.sourceToSerDe(source, hiveConf) match { + case Some(s) => + EmptyStorageFormat.copy( + inputFormat = s.inputFormat, + outputFormat = s.outputFormat, + serde = s.serde) + case None => + throw new ParseException(s"Unrecognized file format in STORED AS clause: $source", ctx) + } + } + + /** + * Storage Handlers are currently not supported in the statements we support (CTAS). + */ + override def visitStorageHandler(ctx: StorageHandlerContext): AnyRef = withOrigin(ctx) { + throw new ParseException("Storage Handlers are currently unsupported.", ctx) + } + + /** + * Create SERDE row format name and properties pair. + */ + override def visitRowFormatSerde( + ctx: RowFormatSerdeContext): CatalogStorageFormat = withOrigin(ctx) { + import ctx._ + EmptyStorageFormat.copy( + serde = Option(string(name)), + serdeProperties = Option(tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) + } + + /** + * Create a delimited row format properties object. + */ + override def visitRowFormatDelimited( + ctx: RowFormatDelimitedContext): CatalogStorageFormat = withOrigin(ctx) { + // Collect the entries if any. + def entry(key: String, value: Token): Seq[(String, String)] = { + Option(value).toSeq.map(x => key -> string(x)) + } + // TODO we need proper support for the NULL format. + val entries = entry(serdeConstants.FIELD_DELIM, ctx.fieldsTerminatedBy) ++ + entry(serdeConstants.SERIALIZATION_FORMAT, ctx.fieldsTerminatedBy) ++ + entry(serdeConstants.ESCAPE_CHAR, ctx.escapedBy) ++ + entry(serdeConstants.COLLECTION_DELIM, ctx.collectionItemsTerminatedBy) ++ + entry(serdeConstants.MAPKEY_DELIM, ctx.keysTerminatedBy) ++ + Option(ctx.linesSeparatedBy).toSeq.map { token => + val value = string(token) + assert( + value == "\n", + s"LINES TERMINATED BY only supports newline '\\n' right now: $value", + ctx) + serdeConstants.LINE_DELIM -> value + } + EmptyStorageFormat.copy(serdeProperties = entries.toMap) + } +} -- cgit v1.2.3 From d612228eff9ed6589b2a94658986ec06ed833bf5 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 29 Mar 2016 12:45:43 -0700 Subject: [MINOR][SQL] Fix typos by replacing 'much' with 'match'. ## What changes were proposed in this pull request? This PR fixes two trivial typos: 'does not **much**' --> 'does not **match**'. ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12042 from dongjoon-hyun/fix_typo_by_replacing_much_with_match. --- .../main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala | 2 +- .../src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 8bb8e09a28..9f270236ae 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -291,5 +291,5 @@ case class CatalogRelation( override def output: Seq[Attribute] = Seq.empty require(metadata.identifier.database == Some(db), - "provided database does not much the one specified in the table definition") + "provided database does not match the one specified in the table definition") } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index f75509fe80..11205ae67c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -76,7 +76,7 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat private def requireDbMatches(db: String, table: CatalogTable): Unit = { if (table.identifier.database != Some(db)) { throw new AnalysisException( - s"Provided database $db does not much the one specified in the " + + s"Provided database $db does not match the one specified in the " + s"table definition (${table.identifier.database.getOrElse("n/a")})") } } -- cgit v1.2.3 From 366cac6fb0bb5591a0463c4696f5b9de2a294022 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Tue, 29 Mar 2016 16:46:45 -0700 Subject: [SPARK-14225][SQL] Cap the length of toCommentSafeString at 128 chars ## What changes were proposed in this pull request? Builds on https://github.com/apache/spark/pull/12022 and (a) appends "..." to truncated comment strings and (b) fixes indentation in lines after the commented strings if they happen to have a `(`, `{`, `)` or `}` ## How was this patch tested? Manually examined the generated code. Author: Sameer Agarwal Closes #12044 from sameeragarwal/comment. --- .../expressions/codegen/CodeFormatter.scala | 37 ++++++++++++++++--- .../apache/spark/sql/catalyst/util/package.scala | 9 +++-- .../expressions/codegen/CodeFormatterSuite.scala | 41 +++++++++++++++++++++- 3 files changed, 79 insertions(+), 8 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 9d99bbffbe..8e40754dc3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -43,15 +43,44 @@ object CodeFormatter { private class CodeFormatter { private val code = new StringBuilder - private var indentLevel = 0 private val indentSize = 2 + + // Tracks the level of indentation in the current line. + private var indentLevel = 0 private var indentString = "" private var currentLine = 1 + // Tracks the level of indentation in multi-line comment blocks. + private var inCommentBlock = false + private var indentLevelOutsideCommentBlock = indentLevel + private def addLine(line: String): Unit = { - val indentChange = - line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) - val newIndentLevel = math.max(0, indentLevel + indentChange) + + // We currently infer the level of indentation of a given line based on a simple heuristic that + // examines the number of parenthesis and braces in that line. This isn't the most robust + // implementation but works for all code that we generate. + val indentChange = line.count(c => "({".indexOf(c) >= 0) - line.count(c => ")}".indexOf(c) >= 0) + var newIndentLevel = math.max(0, indentLevel + indentChange) + + // Please note that while we try to format the comment blocks in exactly the same way as the + // rest of the code, once the block ends, we reset the next line's indentation level to what it + // was immediately before entering the comment block. + if (!inCommentBlock) { + if (line.startsWith("/*")) { + // Handle multi-line comments + inCommentBlock = true + indentLevelOutsideCommentBlock = indentLevel + } else if (line.startsWith("//")) { + // Handle single line comments + newIndentLevel = indentLevel + } + } else { + if (line.endsWith("*/")) { + inCommentBlock = false + newIndentLevel = indentLevelOutsideCommentBlock + } + } + // Lines starting with '}' should be de-indented even if they contain '{' after; // in addition, lines ending with ':' are typically labels val thisLineIndent = if (line.startsWith("}") || line.startsWith(")") || line.endsWith(":")) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index b11365b297..f879b34358 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -155,10 +155,13 @@ package object util { /** * Returns the string representation of this expression that is safe to be put in - * code comments of generated code. + * code comments of generated code. The length is capped at 128 characters. */ - def toCommentSafeString(str: String): String = - str.replace("*/", "\\*\\/").replace("\\u", "\\\\u") + def toCommentSafeString(str: String): String = { + val len = math.min(str.length, 128) + val suffix = if (str.length > len) "..." else "" + str.substring(0, len).replace("*/", "\\*\\/").replace("\\u", "\\\\u") + suffix + } /* FIX ME implicit class debugLogging(a: Any) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index 9da1068e9c..d7836aa3b2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -18,13 +18,20 @@ package org.apache.spark.sql.catalyst.expressions.codegen import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.util._ class CodeFormatterSuite extends SparkFunSuite { def testCase(name: String)(input: String)(expected: String): Unit = { test(name) { - assert(CodeFormatter.format(input).trim === expected.trim) + if (CodeFormatter.format(input).trim !== expected.trim) { + fail( + s""" + |== FAIL: Formatted code doesn't match === + |${sideBySide(CodeFormatter.format(input).trim, expected.trim).mkString("\n")} + """.stripMargin) + } } } @@ -93,4 +100,36 @@ class CodeFormatterSuite extends SparkFunSuite { |/* 004 */ c) """.stripMargin } + + testCase("single line comments") { + """// This is a comment about class A { { { ( ( + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ // This is a comment about class A { { { ( ( + |/* 002 */ class A { + |/* 003 */ class body; + |/* 004 */ } + """.stripMargin + } + + testCase("multi-line comments") { + """ /* This is a comment about + |class A { + |class body; ...*/ + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ /* This is a comment about + |/* 002 */ class A { + |/* 003 */ class body; ...*/ + |/* 004 */ class A { + |/* 005 */ class body; + |/* 006 */ } + """.stripMargin + } } -- cgit v1.2.3 From b66b97cd04067e1ec344fa2e28dd91e7ef937af5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 29 Mar 2016 17:39:52 -0700 Subject: [SPARK-14124][SQL] Implement Database-related DDL Commands #### What changes were proposed in this pull request? This PR is to implement the following four Database-related DDL commands: - `CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name` - `DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]` - `DESCRIBE DATABASE [EXTENDED] db_name` - `ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...)` Another PR will be submitted to handle the unsupported commands. In the Database-related DDL commands, we will issue an error exception for `ALTER (DATABASE|SCHEMA) database_name SET OWNER [USER|ROLE] user_or_role`. cc yhuai andrewor14 rxin Could you review the changes? Is it in the right direction? Thanks! #### How was this patch tested? Added a few test cases in `command/DDLSuite.scala` for testing DDL command execution in `SQLContext`. Since `HiveContext` also shares the same implementation, the existing test cases in `\hive` also verifies the correctness of these commands. Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12009 from gatorsmile/dbDDL. --- .../sql/catalyst/catalog/SessionCatalog.scala | 6 + .../spark/sql/catalyst/catalog/interface.scala | 2 +- .../org/apache/spark/sql/execution/SparkQl.scala | 17 +-- .../spark/sql/execution/SparkSqlParser.scala | 10 +- .../apache/spark/sql/execution/command/ddl.scala | 125 ++++++++++++++--- .../sql/execution/command/DDLCommandSuite.scala | 42 ++---- .../spark/sql/execution/command/DDLSuite.scala | 151 +++++++++++++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 8 ++ 9 files changed, 302 insertions(+), 61 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7165db1d5d..569b99e414 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.catalog +import java.io.File + import scala.collection.mutable import org.apache.spark.sql.AnalysisException @@ -114,6 +116,10 @@ class SessionCatalog( currentDb = db } + def getDefaultDBPath(db: String): String = { + System.getProperty("java.io.tmpdir") + File.separator + db + ".db" + } + // ---------------------------------------------------------------------------- // Tables // ---------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 9f270236ae..303846d313 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -39,7 +39,7 @@ abstract class ExternalCatalog { protected def requireDbExists(db: String): Unit = { if (!databaseExists(db)) { - throw new AnalysisException(s"Database $db does not exist") + throw new AnalysisException(s"Database '$db' does not exist") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala index d4d1992d27..6fe04757ba 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala @@ -97,7 +97,8 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly // CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] // [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]; - case Token("TOK_CREATEDATABASE", Token(databaseName, Nil) :: args) => + case Token("TOK_CREATEDATABASE", Token(dbName, Nil) :: args) => + val databaseName = cleanIdentifier(dbName) val Seq(ifNotExists, dbLocation, databaseComment, dbprops) = getClauses(Seq( "TOK_IFNOTEXISTS", "TOK_DATABASELOCATION", @@ -126,7 +127,7 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly extractProps(propList, "TOK_TABLEPROPERTY") case _ => parseFailed("Invalid CREATE DATABASE command", node) }.toMap - CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props)(node.source) + CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props) // DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; case Token("TOK_DROPDATABASE", Token(dbName, Nil) :: otherArgs) => @@ -136,15 +137,15 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly // :- database_name // :- TOK_IFEXISTS // +- TOK_RESTRICT/TOK_CASCADE - val databaseName = unquoteString(dbName) + val databaseName = cleanIdentifier(dbName) // The default is RESTRICT val Seq(ifExists, _, cascade) = getClauses(Seq( "TOK_IFEXISTS", "TOK_RESTRICT", "TOK_CASCADE"), otherArgs) - DropDatabase(databaseName, ifExists.isDefined, restrict = cascade.isEmpty)(node.source) + DropDatabase(databaseName, ifExists.isDefined, cascade.isDefined) // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) case Token("TOK_ALTERDATABASE_PROPERTIES", Token(dbName, Nil) :: args) => - val databaseName = unquoteString(dbName) + val databaseName = cleanIdentifier(dbName) val dbprops = getClause("TOK_DATABASEPROPERTIES", args) val props = dbprops match { case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => @@ -161,13 +162,13 @@ private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends Cataly extractProps(propList, "TOK_TABLEPROPERTY") case _ => parseFailed("Invalid ALTER DATABASE command", node) } - AlterDatabaseProperties(databaseName, props.toMap)(node.source) + AlterDatabaseProperties(databaseName, props.toMap) // DESCRIBE DATABASE [EXTENDED] db_name case Token("TOK_DESCDATABASE", Token(dbName, Nil) :: describeArgs) => - val databaseName = unquoteString(dbName) + val databaseName = cleanIdentifier(dbName) val extended = getClauseOption("EXTENDED", describeArgs) - DescribeDatabase(databaseName, extended.isDefined)(node.source) + DescribeDatabase(databaseName, extended.isDefined) // CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name // [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri'] ]; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index a8313deeef..8333074eca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -232,8 +232,7 @@ class SparkSqlAstBuilder extends AstBuilder { ctx.EXISTS != null, Option(ctx.locationSpec).map(visitLocationSpec), Option(ctx.comment).map(string), - Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty))( - command(ctx)) + Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty)) } /** @@ -248,8 +247,7 @@ class SparkSqlAstBuilder extends AstBuilder { ctx: SetDatabasePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterDatabaseProperties( ctx.identifier.getText, - visitTablePropertyList(ctx.tablePropertyList))( - command(ctx)) + visitTablePropertyList(ctx.tablePropertyList)) } /** @@ -261,7 +259,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitDropDatabase(ctx: DropDatabaseContext): LogicalPlan = withOrigin(ctx) { - DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE == null)(command(ctx)) + DropDatabase(ctx.identifier.getText, ctx.EXISTS != null, ctx.CASCADE != null) } /** @@ -273,7 +271,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitDescribeDatabase(ctx: DescribeDatabaseContext): LogicalPlan = withOrigin(ctx) { - DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null)(command(ctx)) + DescribeDatabase(ctx.identifier.getText, ctx.EXTENDED != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 0e51abb44b..6c2a67f81c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.execution.datasources.BucketSpec @@ -45,46 +45,135 @@ abstract class NativeDDLCommand(val sql: String) extends RunnableCommand { } +/** + * A command for users to create a new database. + * + * It will issue an error message when the database with the same name already exists, + * unless 'ifNotExists' is true. + * The syntax of using this command in SQL is: + * {{{ + * CREATE DATABASE|SCHEMA [IF NOT EXISTS] database_name + * }}} + */ case class CreateDatabase( databaseName: String, ifNotExists: Boolean, path: Option[String], comment: Option[String], - props: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging + props: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.createDatabase( + CatalogDatabase( + databaseName, + comment.getOrElse(""), + path.getOrElse(catalog.getDefaultDBPath(databaseName)), + props), + ifNotExists) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} + /** - * Drop Database: Removes a database from the system. + * A command for users to remove a database from the system. * * 'ifExists': * - true, if database_name does't exist, no action * - false (default), if database_name does't exist, a warning message will be issued - * 'restric': - * - true (default), the database cannot be dropped if it is not empty. The inclusive - * tables must be dropped at first. - * - false, it is in the Cascade mode. The dependent objects are automatically dropped - * before dropping database. + * 'cascade': + * - true, the dependent objects are automatically dropped before dropping database. + * - false (default), it is in the Restrict mode. The database cannot be dropped if + * it is not empty. The inclusive tables must be dropped at first. + * + * The syntax of using this command in SQL is: + * {{{ + * DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; + * }}} */ case class DropDatabase( databaseName: String, ifExists: Boolean, - restrict: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + cascade: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.dropDatabase(databaseName, ifExists, cascade) + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} -/** ALTER DATABASE: add new (key, value) pairs into DBPROPERTIES */ +/** + * A command for users to add new (key, value) pairs into DBPROPERTIES + * If the database does not exist, an error message will be issued to indicate the database + * does not exist. + * The syntax of using this command in SQL is: + * {{{ + * ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) + * }}} + */ case class AlterDatabaseProperties( databaseName: String, - props: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging + props: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val db: CatalogDatabase = catalog.getDatabase(databaseName) + catalog.alterDatabase(db.copy(properties = db.properties ++ props)) + + Seq.empty[Row] + } + + override val output: Seq[Attribute] = Seq.empty +} /** - * DESCRIBE DATABASE: shows the name of the database, its comment (if one has been set), and its + * A command for users to show the name of the database, its comment (if one has been set), and its * root location on the filesystem. When extended is true, it also shows the database's properties + * If the database does not exist, an error message will be issued to indicate the database + * does not exist. + * The syntax of using this command in SQL is + * {{{ + * DESCRIBE DATABASE [EXTENDED] db_name + * }}} */ case class DescribeDatabase( databaseName: String, - extended: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + extended: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val dbMetadata: CatalogDatabase = sqlContext.sessionState.catalog.getDatabase(databaseName) + val result = + Row("Database Name", dbMetadata.name) :: + Row("Description", dbMetadata.description) :: + Row("Location", dbMetadata.locationUri) :: Nil + + if (extended) { + val properties = + if (dbMetadata.properties.isEmpty) { + "" + } else { + dbMetadata.properties.toSeq.mkString("(", ", ", ")") + } + result :+ Row("Properties", properties) + } else { + result + } + } + + override val output: Seq[Attribute] = { + AttributeReference("database_description_item", StringType, nullable = false)() :: + AttributeReference("database_description_value", StringType, nullable = false)() :: Nil + } +} case class CreateFunction( databaseName: Option[String], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 03079c6890..ccbfd41cca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -39,7 +39,7 @@ class DDLCommandSuite extends PlanTest { ifNotExists = true, Some("/home/user/db"), Some("database_comment"), - Map("a" -> "a", "b" -> "b", "c" -> "c"))(sql) + Map("a" -> "a", "b" -> "b", "c" -> "c")) comparePlans(parsed, expected) } @@ -65,39 +65,27 @@ class DDLCommandSuite extends PlanTest { val expected1 = DropDatabase( "database_name", ifExists = true, - restrict = true)(sql1) + cascade = false) val expected2 = DropDatabase( "database_name", ifExists = true, - restrict = false)(sql2) + cascade = true) val expected3 = DropDatabase( - "database_name", - ifExists = true, - restrict = true)(sql3) - val expected4 = DropDatabase( - "database_name", - ifExists = true, - restrict = false)(sql4) - val expected5 = DropDatabase( - "database_name", - ifExists = true, - restrict = true)(sql5) - val expected6 = DropDatabase( "database_name", ifExists = false, - restrict = true)(sql6) - val expected7 = DropDatabase( + cascade = false) + val expected4 = DropDatabase( "database_name", ifExists = false, - restrict = false)(sql7) + cascade = true) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - comparePlans(parsed5, expected5) - comparePlans(parsed6, expected6) - comparePlans(parsed7, expected7) + comparePlans(parsed3, expected1) + comparePlans(parsed4, expected2) + comparePlans(parsed5, expected1) + comparePlans(parsed6, expected3) + comparePlans(parsed7, expected4) } test("alter database set dbproperties") { @@ -110,10 +98,10 @@ class DDLCommandSuite extends PlanTest { val expected1 = AlterDatabaseProperties( "database_name", - Map("a" -> "a", "b" -> "b", "c" -> "c"))(sql1) + Map("a" -> "a", "b" -> "b", "c" -> "c")) val expected2 = AlterDatabaseProperties( "database_name", - Map("a" -> "a"))(sql2) + Map("a" -> "a")) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -129,10 +117,10 @@ class DDLCommandSuite extends PlanTest { val expected1 = DescribeDatabase( "db_name", - extended = true)(sql1) + extended = true) val expected2 = DescribeDatabase( "db_name", - extended = false)(sql2) + extended = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala new file mode 100644 index 0000000000..47c9a22acd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -0,0 +1,151 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.io.File + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.parser.ParserUtils._ +import org.apache.spark.sql.test.SharedSQLContext + +class DDLSuite extends QueryTest with SharedSQLContext { + + /** + * Drops database `databaseName` after calling `f`. + */ + private def withDatabase(dbNames: String*)(f: => Unit): Unit = { + try f finally { + dbNames.foreach { name => + sqlContext.sql(s"DROP DATABASE IF EXISTS $name CASCADE") + } + } + } + + test("Create/Drop Database") { + val catalog = sqlContext.sessionState.catalog + + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + withDatabase(dbName) { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db", + Map.empty)) + sql(s"DROP DATABASE $dbName CASCADE") + assert(!catalog.databaseExists(dbNameWithoutBackTicks)) + } + } + } + + test("Create Database - database already exists") { + val catalog = sqlContext.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + withDatabase(dbName) { + sql(s"CREATE DATABASE $dbName") + val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + assert(db1 == CatalogDatabase( + dbNameWithoutBackTicks, + "", + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db", + Map.empty)) + + val message = intercept[AnalysisException] { + sql(s"CREATE DATABASE $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists.")) + } + } + } + + test("Alter/Describe Database") { + val catalog = sqlContext.sessionState.catalog + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + withDatabase(dbName) { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + val location = + System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db" + sql(s"CREATE DATABASE $dbName") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('a'='a', 'b'='b', 'c'='c')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c))") :: Nil) + + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + + checkAnswer( + sql(s"DESCRIBE DATABASE EXTENDED $dbName"), + Row("Database Name", dbNameWithoutBackTicks) :: + Row("Description", "") :: + Row("Location", location) :: + Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } + } + } + + test("Drop/Alter/Describe Database - database does not exists") { + val databaseNames = Seq("db1", "`database`") + + databaseNames.foreach { dbName => + val dbNameWithoutBackTicks = cleanIdentifier(dbName) + assert(!sqlContext.sessionState.catalog.databaseExists(dbNameWithoutBackTicks)) + + var message = intercept[AnalysisException] { + sql(s"DROP DATABASE $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + message = intercept[AnalysisException] { + sql(s"ALTER DATABASE $dbName SET DBPROPERTIES ('d'='d')") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + message = intercept[AnalysisException] { + sql(s"DESCRIBE DATABASE EXTENDED $dbName") + }.getMessage + assert(message.contains(s"Database '$dbNameWithoutBackTicks' does not exist")) + + sql(s"DROP DATABASE IF EXISTS $dbName") + } + } + + // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext +} diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 8e1ebe2937..7ad7f92bd2 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -183,7 +183,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { test("Single command with --database") { runCliWithin(2.minute)( "CREATE DATABASE hive_test_db;" - -> "OK", + -> "", "USE hive_test_db;" -> "", "CREATE TABLE hive_test(key INT, val STRING);" diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index ff12245e8d..1cd783e63a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.hive +import org.apache.hadoop.fs.Path +import org.apache.hadoop.hive.conf.HiveConf + import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry import org.apache.spark.sql.catalyst.catalog.SessionCatalog @@ -59,6 +62,11 @@ class HiveSessionCatalog( // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- + override def getDefaultDBPath(db: String): String = { + val defaultPath = context.hiveconf.getVar(HiveConf.ConfVars.METASTOREWAREHOUSE) + new Path(new Path(defaultPath), db + ".db").toString + } + // Catalog for handling data source tables. TODO: This really doesn't belong here since it is // essentially a cache for metastore tables. However, it relies on a lot of session-specific // things so it would be a lot of work to split its functionality between HiveSessionCatalog -- cgit v1.2.3 From d46c71b39da92f5cabf6d9057c953c52f7f3f965 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 30 Mar 2016 11:03:15 -0700 Subject: [SPARK-14268][SQL] rename toRowExpressions and fromRowExpression to serializer and deserializer in ExpressionEncoder ## What changes were proposed in this pull request? In `ExpressionEncoder`, we use `constructorFor` to build `fromRowExpression` as the `deserializer` in `ObjectOperator`. It's kind of confusing, we should make the name consistent. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12058 from cloud-fan/rename. --- .../main/scala/org/apache/spark/sql/Encoder.scala | 4 +- .../spark/sql/catalyst/JavaTypeInference.scala | 32 ++++---- .../spark/sql/catalyst/ScalaReflection.scala | 40 +++++----- .../sql/catalyst/encoders/ExpressionEncoder.scala | 87 +++++++++++----------- .../spark/sql/catalyst/encoders/RowEncoder.scala | 34 ++++----- .../spark/sql/catalyst/plans/logical/object.scala | 14 ++-- .../spark/sql/catalyst/ScalaReflectionSuite.scala | 4 +- .../catalyst/encoders/ExpressionEncoderSuite.scala | 4 +- .../scala/org/apache/spark/sql/QueryTest.scala | 4 +- 9 files changed, 110 insertions(+), 113 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index b19538a23f..1f20e26354 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -245,10 +245,10 @@ object Encoders { ExpressionEncoder[T]( schema = new StructType().add("value", BinaryType), flat = true, - toRowExpressions = Seq( + serializer = Seq( EncodeUsingSerializer( BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - fromRowExpression = + deserializer = DecodeUsingSerializer[T]( BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), clsTag = classTag[T] diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index 59ee41d02f..6f9fbbbead 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -155,16 +155,16 @@ object JavaTypeInference { } /** - * Returns an expression that can be used to construct an object of java bean `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns an expression that can be used to deserialize an internal row to an object of java bean + * `T` with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. */ - def constructorFor(beanClass: Class[_]): Expression = { - constructorFor(TypeToken.of(beanClass), None) + def deserializerFor(beanClass: Class[_]): Expression = { + deserializerFor(TypeToken.of(beanClass), None) } - private def constructorFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { + private def deserializerFor(typeToken: TypeToken[_], path: Option[Expression]): Expression = { /** Returns the current path with a sub-field extracted. */ def addToPath(part: String): Expression = path .map(p => UnresolvedExtractValue(p, expressions.Literal(part))) @@ -231,7 +231,7 @@ object JavaTypeInference { }.getOrElse { Invoke( MapObjects( - p => constructorFor(typeToken.getComponentType, Some(p)), + p => deserializerFor(typeToken.getComponentType, Some(p)), getPath, inferDataType(elementType)._1), "array", @@ -243,7 +243,7 @@ object JavaTypeInference { val array = Invoke( MapObjects( - p => constructorFor(et, Some(p)), + p => deserializerFor(et, Some(p)), getPath, inferDataType(et)._1), "array", @@ -259,7 +259,7 @@ object JavaTypeInference { val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p)), + p => deserializerFor(keyType, Some(p)), Invoke(getPath, "keyArray", ArrayType(keyDataType)), keyDataType), "array", @@ -268,7 +268,7 @@ object JavaTypeInference { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p)), + p => deserializerFor(valueType, Some(p)), Invoke(getPath, "valueArray", ArrayType(valueDataType)), valueDataType), "array", @@ -288,7 +288,7 @@ object JavaTypeInference { val fieldName = p.getName val fieldType = typeToken.method(p.getReadMethod).getReturnType val (_, nullable) = inferDataType(fieldType) - val constructor = constructorFor(fieldType, Some(addToPath(fieldName))) + val constructor = deserializerFor(fieldType, Some(addToPath(fieldName))) val setter = if (nullable) { constructor } else { @@ -313,14 +313,14 @@ object JavaTypeInference { } /** - * Returns expressions for extracting all the fields from the given type. + * Returns an expression for serializing an object of the given type to an internal row. */ - def extractorsFor(beanClass: Class[_]): CreateNamedStruct = { + def serializerFor(beanClass: Class[_]): CreateNamedStruct = { val inputObject = BoundReference(0, ObjectType(beanClass), nullable = true) - extractorFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] + serializerFor(inputObject, TypeToken.of(beanClass)).asInstanceOf[CreateNamedStruct] } - private def extractorFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { + private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): Expression = { def toCatalystArray(input: Expression, elementType: TypeToken[_]): Expression = { val (dataType, nullable) = inferDataType(elementType) @@ -330,7 +330,7 @@ object JavaTypeInference { input :: Nil, dataType = ArrayType(dataType, nullable)) } else { - MapObjects(extractorFor(_, elementType), input, ObjectType(elementType.getRawType)) + MapObjects(serializerFor(_, elementType), input, ObjectType(elementType.getRawType)) } } @@ -403,7 +403,7 @@ object JavaTypeInference { inputObject, p.getReadMethod.getName, inferExternalType(fieldType.getRawType)) - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType) :: Nil + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil }) } else { throw new UnsupportedOperationException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index f208401160..d241b8a79b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -110,8 +110,8 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns an expression that can be used to construct an object of type `T` given an input - * row with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes + * Returns an expression that can be used to deserialize an input row to an object of type `T` + * with a compatible schema. Fields of the row will be extracted using UnresolvedAttributes * of the same name as the constructor arguments. Nested classes will have their fields accessed * using UnresolvedExtractValue. * @@ -119,14 +119,14 @@ object ScalaReflection extends ScalaReflection { * from ordinal 0 (since there are no names to map to). The actual location can be moved by * calling resolve/bind with a new schema. */ - def constructorFor[T : TypeTag]: Expression = { + def deserializerFor[T : TypeTag]: Expression = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil - constructorFor(tpe, None, walkedTypePath) + deserializerFor(tpe, None, walkedTypePath) } - private def constructorFor( + private def deserializerFor( tpe: `Type`, path: Option[Expression], walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { @@ -161,7 +161,7 @@ object ScalaReflection extends ScalaReflection { } /** - * When we build the `fromRowExpression` for an encoder, we set up a lot of "unresolved" stuff + * When we build the `deserializer` for an encoder, we set up a lot of "unresolved" stuff * and lost the required data type, which may lead to runtime error if the real type doesn't * match the encoder's schema. * For example, we build an encoder for `case class Data(a: Int, b: String)` and the real type @@ -188,7 +188,7 @@ object ScalaReflection extends ScalaReflection { val TypeRef(_, _, Seq(optType)) = t val className = getClassNameFromType(optType) val newTypePath = s"""- option value class: "$className"""" +: walkedTypePath - WrapOption(constructorFor(optType, path, newTypePath), dataTypeFor(optType)) + WrapOption(deserializerFor(optType, path, newTypePath), dataTypeFor(optType)) case t if t <:< localTypeOf[java.lang.Integer] => val boxedType = classOf[java.lang.Integer] @@ -272,7 +272,7 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath Invoke( MapObjects( - p => constructorFor(elementType, Some(p), newTypePath), + p => deserializerFor(elementType, Some(p), newTypePath), getPath, schemaFor(elementType).dataType), "array", @@ -286,7 +286,7 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- array element class: "$className"""" +: walkedTypePath val mapFunction: Expression => Expression = p => { - val converter = constructorFor(elementType, Some(p), newTypePath) + val converter = deserializerFor(elementType, Some(p), newTypePath) if (nullable) { converter } else { @@ -312,7 +312,7 @@ object ScalaReflection extends ScalaReflection { val keyData = Invoke( MapObjects( - p => constructorFor(keyType, Some(p), walkedTypePath), + p => deserializerFor(keyType, Some(p), walkedTypePath), Invoke(getPath, "keyArray", ArrayType(schemaFor(keyType).dataType)), schemaFor(keyType).dataType), "array", @@ -321,7 +321,7 @@ object ScalaReflection extends ScalaReflection { val valueData = Invoke( MapObjects( - p => constructorFor(valueType, Some(p), walkedTypePath), + p => deserializerFor(valueType, Some(p), walkedTypePath), Invoke(getPath, "valueArray", ArrayType(schemaFor(valueType).dataType)), schemaFor(valueType).dataType), "array", @@ -344,12 +344,12 @@ object ScalaReflection extends ScalaReflection { val newTypePath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath // For tuples, we based grab the inner fields by ordinal instead of name. if (cls.getName startsWith "scala.Tuple") { - constructorFor( + deserializerFor( fieldType, Some(addToPathOrdinal(i, dataType, newTypePath)), newTypePath) } else { - val constructor = constructorFor( + val constructor = deserializerFor( fieldType, Some(addToPath(fieldName, dataType, newTypePath)), newTypePath) @@ -387,7 +387,7 @@ object ScalaReflection extends ScalaReflection { } /** - * Returns expressions for extracting all the fields from the given type. + * Returns an expression for serializing an object of type T to an internal row. * * If the given type is not supported, i.e. there is no encoder can be built for this type, * an [[UnsupportedOperationException]] will be thrown with detailed error message to explain @@ -398,18 +398,18 @@ object ScalaReflection extends ScalaReflection { * * the element type of [[Array]] or [[Seq]]: `array element class: "abc.xyz.MyClass"` * * the field of [[Product]]: `field (class: "abc.xyz.MyClass", name: "myField")` */ - def extractorsFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { + def serializerFor[T : TypeTag](inputObject: Expression): CreateNamedStruct = { val tpe = localTypeOf[T] val clsName = getClassNameFromType(tpe) val walkedTypePath = s"""- root class: "${clsName}"""" :: Nil - extractorFor(inputObject, tpe, walkedTypePath) match { + serializerFor(inputObject, tpe, walkedTypePath) match { case expressions.If(_, _, s: CreateNamedStruct) if tpe <:< localTypeOf[Product] => s case other => CreateNamedStruct(expressions.Literal("value") :: other :: Nil) } } /** Helper for extracting internal fields from a case class. */ - private def extractorFor( + private def serializerFor( inputObject: Expression, tpe: `Type`, walkedTypePath: Seq[String]): Expression = ScalaReflectionLock.synchronized { @@ -425,7 +425,7 @@ object ScalaReflection extends ScalaReflection { } else { val clsName = getClassNameFromType(elementType) val newPath = s"""- array element class: "$clsName"""" +: walkedTypePath - MapObjects(extractorFor(_, elementType, newPath), input, externalDataType) + MapObjects(serializerFor(_, elementType, newPath), input, externalDataType) } } @@ -491,7 +491,7 @@ object ScalaReflection extends ScalaReflection { expressions.If( IsNull(unwrapped), expressions.Literal.create(null, silentSchemaFor(optType).dataType), - extractorFor(unwrapped, optType, newPath)) + serializerFor(unwrapped, optType, newPath)) } case t if t <:< localTypeOf[Product] => @@ -500,7 +500,7 @@ object ScalaReflection extends ScalaReflection { val fieldValue = Invoke(inputObject, fieldName, dataTypeFor(fieldType)) val clsName = getClassNameFromType(fieldType) val newPath = s"""- field (class: "$clsName", name: "$fieldName")""" +: walkedTypePath - expressions.Literal(fieldName) :: extractorFor(fieldValue, fieldType, newPath) :: Nil + expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType, newPath) :: Nil }) val nullOutput = expressions.Literal.create(null, nonNullOutput.dataType) expressions.If(IsNull(inputObject), nullOutput, nonNullOutput) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 918233ddcd..1c712fde26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -51,8 +51,8 @@ object ExpressionEncoder { val flat = !classOf[Product].isAssignableFrom(cls) val inputObject = BoundReference(0, ScalaReflection.dataTypeFor[T], nullable = false) - val toRowExpression = ScalaReflection.extractorsFor[T](inputObject) - val fromRowExpression = ScalaReflection.constructorFor[T] + val serializer = ScalaReflection.serializerFor[T](inputObject) + val deserializer = ScalaReflection.deserializerFor[T] val schema = ScalaReflection.schemaFor[T] match { case ScalaReflection.Schema(s: StructType, _) => s @@ -62,8 +62,8 @@ object ExpressionEncoder { new ExpressionEncoder[T]( schema, flat, - toRowExpression.flatten, - fromRowExpression, + serializer.flatten, + deserializer, ClassTag[T](cls)) } @@ -72,14 +72,14 @@ object ExpressionEncoder { val schema = JavaTypeInference.inferDataType(beanClass)._1 assert(schema.isInstanceOf[StructType]) - val toRowExpression = JavaTypeInference.extractorsFor(beanClass) - val fromRowExpression = JavaTypeInference.constructorFor(beanClass) + val serializer = JavaTypeInference.serializerFor(beanClass) + val deserializer = JavaTypeInference.deserializerFor(beanClass) new ExpressionEncoder[T]( schema.asInstanceOf[StructType], flat = false, - toRowExpression.flatten, - fromRowExpression, + serializer.flatten, + deserializer, ClassTag[T](beanClass)) } @@ -103,9 +103,9 @@ object ExpressionEncoder { val cls = Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}") - val toRowExpressions = encoders.map { - case e if e.flat => e.toRowExpressions.head - case other => CreateStruct(other.toRowExpressions) + val serializer = encoders.map { + case e if e.flat => e.serializer.head + case other => CreateStruct(other.serializer) }.zipWithIndex.map { case (expr, index) => expr.transformUp { case BoundReference(0, t, _) => @@ -116,14 +116,14 @@ object ExpressionEncoder { } } - val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) => + val childrenDeserializers = encoders.zipWithIndex.map { case (enc, index) => if (enc.flat) { - enc.fromRowExpression.transform { + enc.deserializer.transform { case b: BoundReference => b.copy(ordinal = index) } } else { val input = BoundReference(index, enc.schema, nullable = true) - enc.fromRowExpression.transformUp { + enc.deserializer.transformUp { case UnresolvedAttribute(nameParts) => assert(nameParts.length == 1) UnresolvedExtractValue(input, Literal(nameParts.head)) @@ -132,14 +132,14 @@ object ExpressionEncoder { } } - val fromRowExpression = - NewInstance(cls, fromRowExpressions, ObjectType(cls), propagateNull = false) + val deserializer = + NewInstance(cls, childrenDeserializers, ObjectType(cls), propagateNull = false) new ExpressionEncoder[Any]( schema, flat = false, - toRowExpressions, - fromRowExpression, + serializer, + deserializer, ClassTag(cls)) } @@ -174,29 +174,29 @@ object ExpressionEncoder { * A generic encoder for JVM objects. * * @param schema The schema after converting `T` to a Spark SQL row. - * @param toRowExpressions A set of expressions, one for each top-level field that can be used to - * extract the values from a raw object into an [[InternalRow]]. - * @param fromRowExpression An expression that will construct an object given an [[InternalRow]]. + * @param serializer A set of expressions, one for each top-level field that can be used to + * extract the values from a raw object into an [[InternalRow]]. + * @param deserializer An expression that will construct an object given an [[InternalRow]]. * @param clsTag A classtag for `T`. */ case class ExpressionEncoder[T]( schema: StructType, flat: Boolean, - toRowExpressions: Seq[Expression], - fromRowExpression: Expression, + serializer: Seq[Expression], + deserializer: Expression, clsTag: ClassTag[T]) extends Encoder[T] { - if (flat) require(toRowExpressions.size == 1) + if (flat) require(serializer.size == 1) @transient - private lazy val extractProjection = GenerateUnsafeProjection.generate(toRowExpressions) + private lazy val extractProjection = GenerateUnsafeProjection.generate(serializer) @transient private lazy val inputRow = new GenericMutableRow(1) @transient - private lazy val constructProjection = GenerateSafeProjection.generate(fromRowExpression :: Nil) + private lazy val constructProjection = GenerateSafeProjection.generate(deserializer :: Nil) /** * Returns this encoder where it has been bound to its own output (i.e. no remaping of columns @@ -212,7 +212,7 @@ case class ExpressionEncoder[T]( * Returns a new set (with unique ids) of [[NamedExpression]] that represent the serialized form * of this object. */ - def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(toRowExpressions).map { + def namedExpressions: Seq[NamedExpression] = schema.map(_.name).zip(serializer).map { case (_, ne: NamedExpression) => ne.newInstance() case (name, e) => Alias(e, name)() } @@ -228,7 +228,7 @@ case class ExpressionEncoder[T]( } catch { case e: Exception => throw new RuntimeException( - s"Error while encoding: $e\n${toRowExpressions.map(_.treeString).mkString("\n")}", e) + s"Error while encoding: $e\n${serializer.map(_.treeString).mkString("\n")}", e) } /** @@ -240,7 +240,7 @@ case class ExpressionEncoder[T]( constructProjection(row).get(0, ObjectType(clsTag.runtimeClass)).asInstanceOf[T] } catch { case e: Exception => - throw new RuntimeException(s"Error while decoding: $e\n${fromRowExpression.treeString}", e) + throw new RuntimeException(s"Error while decoding: $e\n${deserializer.treeString}", e) } /** @@ -249,7 +249,7 @@ case class ExpressionEncoder[T]( * has not been done already in places where we plan to do later composition of encoders. */ def assertUnresolved(): Unit = { - (fromRowExpression +: toRowExpressions).foreach(_.foreach { + (deserializer +: serializer).foreach(_.foreach { case a: AttributeReference if a.name != "loopVar" => sys.error(s"Unresolved encoder expected, but $a was found.") case _ => @@ -257,7 +257,7 @@ case class ExpressionEncoder[T]( } /** - * Validates `fromRowExpression` to make sure it can be resolved by given schema, and produce + * Validates `deserializer` to make sure it can be resolved by given schema, and produce * friendly error messages to explain why it fails to resolve if there is something wrong. */ def validate(schema: Seq[Attribute]): Unit = { @@ -271,7 +271,7 @@ case class ExpressionEncoder[T]( // If this is a tuple encoder or tupled encoder, which means its leaf nodes are all // `BoundReference`, make sure their ordinals are all valid. var maxOrdinal = -1 - fromRowExpression.foreach { + deserializer.foreach { case b: BoundReference => if (b.ordinal > maxOrdinal) maxOrdinal = b.ordinal case _ => } @@ -285,7 +285,7 @@ case class ExpressionEncoder[T]( // we unbound it by the given `schema` and propagate the actual type to `GetStructField`, after // we resolve the `fromRowExpression`. val resolved = SimpleAnalyzer.resolveExpression( - fromRowExpression, + deserializer, LocalRelation(schema), throws = true) @@ -312,42 +312,39 @@ case class ExpressionEncoder[T]( } /** - * Returns a new copy of this encoder, where the expressions used by `fromRow` are resolved to the - * given schema. + * Returns a new copy of this encoder, where the `deserializer` is resolved to the given schema. */ def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val deserializer = SimpleAnalyzer.ResolveReferences.resolveDeserializer( - fromRowExpression, schema) + val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema) // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check // analysis, go through optimizer, etc. - val plan = Project(Alias(deserializer, "")() :: Nil, LocalRelation(schema)) + val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) - copy(fromRowExpression = SimplifyCasts(analyzedPlan).expressions.head.children.head) + copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) } /** - * Returns a copy of this encoder where the expressions used to construct an object from an input - * row have been bound to the ordinals of the given schema. Note that you need to first call - * resolve before bind. + * Returns a copy of this encoder where the `deserializer` has been bound to the + * ordinals of the given schema. Note that you need to first call resolve before bind. */ def bind(schema: Seq[Attribute]): ExpressionEncoder[T] = { - copy(fromRowExpression = BindReferences.bindReference(fromRowExpression, schema)) + copy(deserializer = BindReferences.bindReference(deserializer, schema)) } /** * Returns a new encoder with input columns shifted by `delta` ordinals */ def shift(delta: Int): ExpressionEncoder[T] = { - copy(fromRowExpression = fromRowExpression transform { + copy(deserializer = deserializer transform { case r: BoundReference => r.copy(ordinal = r.ordinal + delta) }) } - protected val attrs = toRowExpressions.flatMap(_.collect { + protected val attrs = serializer.flatMap(_.collect { case _: UnresolvedAttribute => "" case a: Attribute => s"#${a.exprId}" case b: BoundReference => s"[${b.ordinal}]" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index 30f56d8c2f..a8397aa5e5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -36,23 +36,23 @@ object RowEncoder { val cls = classOf[Row] val inputObject = BoundReference(0, ObjectType(cls), nullable = true) // We use an If expression to wrap extractorsFor result of StructType - val extractExpressions = extractorsFor(inputObject, schema).asInstanceOf[If].falseValue - val constructExpression = constructorFor(schema) + val serializer = serializerFor(inputObject, schema).asInstanceOf[If].falseValue + val deserializer = deserializerFor(schema) new ExpressionEncoder[Row]( schema, flat = false, - extractExpressions.asInstanceOf[CreateStruct].children, - constructExpression, + serializer.asInstanceOf[CreateStruct].children, + deserializer, ClassTag(cls)) } - private def extractorsFor( + private def serializerFor( inputObject: Expression, inputType: DataType): Expression = inputType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => inputObject - case p: PythonUserDefinedType => extractorsFor(inputObject, p.sqlType) + case p: PythonUserDefinedType => serializerFor(inputObject, p.sqlType) case udt: UserDefinedType[_] => val obj = NewInstance( @@ -95,7 +95,7 @@ object RowEncoder { classOf[GenericArrayData], inputObject :: Nil, dataType = t) - case _ => MapObjects(extractorsFor(_, et), inputObject, externalDataTypeForInput(et)) + case _ => MapObjects(serializerFor(_, et), inputObject, externalDataTypeForInput(et)) } case t @ MapType(kt, vt, valueNullable) => @@ -104,14 +104,14 @@ object RowEncoder { Invoke(inputObject, "keysIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedKeys = extractorsFor(keys, ArrayType(kt, false)) + val convertedKeys = serializerFor(keys, ArrayType(kt, false)) val values = Invoke( Invoke(inputObject, "valuesIterator", ObjectType(classOf[scala.collection.Iterator[_]])), "toSeq", ObjectType(classOf[scala.collection.Seq[_]])) - val convertedValues = extractorsFor(values, ArrayType(vt, valueNullable)) + val convertedValues = serializerFor(values, ArrayType(vt, valueNullable)) NewInstance( classOf[ArrayBasedMapData], @@ -128,7 +128,7 @@ object RowEncoder { If( Invoke(inputObject, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, f.dataType), - extractorsFor( + serializerFor( Invoke(inputObject, method, externalDataTypeForInput(f.dataType), Literal(i) :: Nil), f.dataType)) } @@ -166,7 +166,7 @@ object RowEncoder { case _: NullType => ObjectType(classOf[java.lang.Object]) } - private def constructorFor(schema: StructType): Expression = { + private def deserializerFor(schema: StructType): Expression = { val fields = schema.zipWithIndex.map { case (f, i) => val dt = f.dataType match { case p: PythonUserDefinedType => p.sqlType @@ -176,13 +176,13 @@ object RowEncoder { If( IsNull(field), Literal.create(null, externalDataTypeFor(dt)), - constructorFor(field) + deserializerFor(field) ) } CreateExternalRow(fields, schema) } - private def constructorFor(input: Expression): Expression = input.dataType match { + private def deserializerFor(input: Expression): Expression = input.dataType match { case NullType | BooleanType | ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType | BinaryType | CalendarIntervalType => input @@ -216,7 +216,7 @@ object RowEncoder { case ArrayType(et, nullable) => val arrayData = Invoke( - MapObjects(constructorFor(_), input, et), + MapObjects(deserializerFor(_), input, et), "array", ObjectType(classOf[Array[_]])) StaticInvoke( @@ -227,10 +227,10 @@ object RowEncoder { case MapType(kt, vt, valueNullable) => val keyArrayType = ArrayType(kt, false) - val keyData = constructorFor(Invoke(input, "keyArray", keyArrayType)) + val keyData = deserializerFor(Invoke(input, "keyArray", keyArrayType)) val valueArrayType = ArrayType(vt, valueNullable) - val valueData = constructorFor(Invoke(input, "valueArray", valueArrayType)) + val valueData = deserializerFor(Invoke(input, "valueArray", valueArrayType)) StaticInvoke( ArrayBasedMapData.getClass, @@ -243,7 +243,7 @@ object RowEncoder { If( Invoke(input, "isNullAt", BooleanType, Literal(i) :: Nil), Literal.create(null, externalDataTypeFor(f.dataType)), - constructorFor(GetStructField(input, i))) + deserializerFor(GetStructField(input, i))) } If(IsNull(input), Literal.create(null, externalDataTypeFor(input.dataType)), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index da7f81c785..058fb6bff1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -71,7 +71,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - encoderFor[T].fromRowExpression, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) } @@ -98,7 +98,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - encoderFor[T].fromRowExpression, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, child) } @@ -133,8 +133,8 @@ object MapGroups { child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - encoderFor[K].fromRowExpression, - encoderFor[T].fromRowExpression, + encoderFor[K].deserializer, + encoderFor[T].deserializer, encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, @@ -178,9 +178,9 @@ object CoGroup { CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], - encoderFor[Key].fromRowExpression, - encoderFor[Left].fromRowExpression, - encoderFor[Right].fromRowExpression, + encoderFor[Key].deserializer, + encoderFor[Left].deserializer, + encoderFor[Right].deserializer, encoderFor[Result].namedExpressions, leftGroup, rightGroup, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala index dd31050bb5..5ca5a72512 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala @@ -248,10 +248,10 @@ class ScalaReflectionSuite extends SparkFunSuite { Seq( ("mirror", () => mirror), ("dataTypeFor", () => dataTypeFor[ComplexData]), - ("constructorFor", () => constructorFor[ComplexData]), + ("constructorFor", () => deserializerFor[ComplexData]), ("extractorsFor", { val inputObject = BoundReference(0, dataTypeForComplexData, nullable = false) - () => extractorsFor[ComplexData](inputObject) + () => serializerFor[ComplexData](inputObject) }), ("getConstructorParameters(cls)", () => getConstructorParameters(classOf[ComplexData])), ("getConstructorParameterNames", () => getConstructorParameterNames(classOf[ComplexData])), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala index f6583bfe42..18752014ea 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala @@ -315,7 +315,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { val attr = AttributeReference("obj", ObjectType(encoder.clsTag.runtimeClass))() val inputPlan = LocalRelation(attr) val plan = - Project(Alias(encoder.fromRowExpression, "obj")() :: Nil, + Project(Alias(encoder.deserializer, "obj")() :: Nil, Project(encoder.namedExpressions, inputPlan)) assertAnalysisSuccess(plan) @@ -360,7 +360,7 @@ class ExpressionEncoderSuite extends PlanTest with AnalysisTest { |${encoder.schema.treeString} | |fromRow Expressions: - |${boundEncoder.fromRowExpression.treeString} + |${boundEncoder.deserializer.treeString} """.stripMargin) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 7ff4ffcaec..854a662cc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -90,7 +90,7 @@ abstract class QueryTest extends PlanTest { s""" |Exception collecting dataset as objects |${ds.resolvedTEncoder} - |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.resolvedTEncoder.deserializer.treeString} |${ds.queryExecution} """.stripMargin, e) } @@ -109,7 +109,7 @@ abstract class QueryTest extends PlanTest { fail( s"""Decoded objects do not match expected objects: |$comparision - |${ds.resolvedTEncoder.fromRowExpression.treeString} + |${ds.resolvedTEncoder.deserializer.treeString} """.stripMargin) } } -- cgit v1.2.3 From 258a2434193aae62999102a8df73ca70bf0cb9f1 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 30 Mar 2016 16:15:37 -0700 Subject: [SPARK-14282][SQL] CodeFormatter should handle oneline comment with /* */ properly ## What changes were proposed in this pull request? This PR improves `CodeFormatter` to fix the following malformed indentations. ```java /* 019 */ public java.lang.Object apply(java.lang.Object _i) { /* 020 */ InternalRow i = (InternalRow) _i; /* 021 */ /* createexternalrow(if (isnull(input[0, double])) null else input[0, double], if (isnull(input[1, int])) null else input[1, int], ... */ /* 022 */ boolean isNull = false; /* 023 */ final Object[] values = new Object[2]; /* 024 */ /* if (isnull(input[0, double])) null else input[0, double] */ /* 025 */ /* isnull(input[0, double]) */ ... /* 053 */ if (!false && false) { /* 054 */ /* null */ /* 055 */ final int value9 = -1; /* 056 */ isNull6 = true; /* 057 */ value6 = value9; /* 058 */ } else { ... /* 077 */ return mutableRow; /* 078 */ } /* 079 */ } /* 080 */ ``` After this PR, the code will be formatted like the following. ```java /* 019 */ public java.lang.Object apply(java.lang.Object _i) { /* 020 */ InternalRow i = (InternalRow) _i; /* 021 */ /* createexternalrow(if (isnull(input[0, double])) null else input[0, double], if (isnull(input[1, int])) null else input[1, int], ... */ /* 022 */ boolean isNull = false; /* 023 */ final Object[] values = new Object[2]; /* 024 */ /* if (isnull(input[0, double])) null else input[0, double] */ /* 025 */ /* isnull(input[0, double]) */ ... /* 053 */ if (!false && false) { /* 054 */ /* null */ /* 055 */ final int value9 = -1; /* 056 */ isNull6 = true; /* 057 */ value6 = value9; /* 058 */ } else { ... /* 077 */ return mutableRow; /* 078 */ } /* 079 */ } /* 080 */ ``` Also, this issue fixes the following too. (Similar with [SPARK-14185](https://issues.apache.org/jira/browse/SPARK-14185)) ```java 16/03/30 12:39:24 DEBUG WholeStageCodegen: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } ``` ```java 16/03/30 12:46:32 DEBUG WholeStageCodegen: /* 001 */ public Object generate(Object[] references) { /* 002 */ return new GeneratedIterator(references); /* 003 */ } ``` ## How was this patch tested? Pass the Jenkins tests (including new CodeFormatterSuite testcases.) Author: Dongjoon Hyun Closes #12072 from dongjoon-hyun/SPARK-14282. --- .../sql/catalyst/expressions/codegen/CodeFormatter.scala | 3 ++- .../catalyst/expressions/codegen/CodeFormatterSuite.scala | 14 ++++++++++++++ .../org/apache/spark/sql/execution/WholeStageCodegen.scala | 2 +- 3 files changed, 17 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala index 8e40754dc3..ab4831f7ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatter.scala @@ -74,7 +74,8 @@ private class CodeFormatter { // Handle single line comments newIndentLevel = indentLevel } - } else { + } + if (inCommentBlock) { if (line.endsWith("*/")) { inCommentBlock = false newIndentLevel = indentLevelOutsideCommentBlock diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala index d7836aa3b2..f57b82bb96 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeFormatterSuite.scala @@ -115,6 +115,20 @@ class CodeFormatterSuite extends SparkFunSuite { """.stripMargin } + testCase("single line comments /* */ ") { + """/** This is a comment about class A { { { ( ( */ + |class A { + |class body; + |}""".stripMargin + }{ + """ + |/* 001 */ /** This is a comment about class A { { { ( ( */ + |/* 002 */ class A { + |/* 003 */ class body; + |/* 004 */ } + """.stripMargin + } + testCase("multi-line comments") { """ /* This is a comment about |class A { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index da3ee46b7d..6a779abd40 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -337,7 +337,7 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) - logDebug(s"${CodeFormatter.format(cleanedSource)}") + logDebug(s"\n${CodeFormatter.format(cleanedSource)}") CodeGenerator.compile(cleanedSource) (ctx, cleanedSource) } -- cgit v1.2.3 From a9b93e07391faede77dde4c0b3c21c9b3f97f8eb Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Thu, 31 Mar 2016 09:25:09 -0700 Subject: [SPARK-14211][SQL] Remove ANTLR3 based parser ### What changes were proposed in this pull request? This PR removes the ANTLR3 based parser, and moves the new ANTLR4 based parser into the `org.apache.spark.sql.catalyst.parser package`. ### How was this patch tested? Existing unit tests. cc rxin andrewor14 yhuai Author: Herman van Hovell Closes #12071 from hvanhovell/SPARK-14211. --- dev/deps/spark-deps-hadoop-2.2 | 4 +- dev/deps/spark-deps-hadoop-2.3 | 4 +- dev/deps/spark-deps-hadoop-2.4 | 4 +- dev/deps/spark-deps-hadoop-2.6 | 4 +- dev/deps/spark-deps-hadoop-2.7 | 4 +- pom.xml | 6 - project/SparkBuild.scala | 54 +- project/plugins.sbt | 3 - python/pyspark/sql/utils.py | 2 +- sql/catalyst/pom.xml | 22 - .../spark/sql/catalyst/parser/ExpressionParser.g | 400 --- .../spark/sql/catalyst/parser/FromClauseParser.g | 341 --- .../spark/sql/catalyst/parser/IdentifiersParser.g | 184 -- .../spark/sql/catalyst/parser/KeywordParser.g | 244 -- .../spark/sql/catalyst/parser/SelectClauseParser.g | 235 -- .../spark/sql/catalyst/parser/SparkSqlLexer.g | 491 ---- .../spark/sql/catalyst/parser/SparkSqlParser.g | 2596 -------------------- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 943 +++++++ .../apache/spark/sql/catalyst/parser/ng/SqlBase.g4 | 941 ------- .../apache/spark/sql/catalyst/parser/ASTNode.scala | 99 - .../catalyst/parser/AbstractSparkSQLParser.scala | 145 -- .../spark/sql/catalyst/parser/AstBuilder.scala | 1460 +++++++++++ .../spark/sql/catalyst/parser/CatalystQl.scala | 933 ------- .../spark/sql/catalyst/parser/DataTypeParser.scala | 67 + .../spark/sql/catalyst/parser/ParseDriver.scala | 245 +- .../spark/sql/catalyst/parser/ParserConf.scala | 26 - .../spark/sql/catalyst/parser/ParserUtils.scala | 209 +- .../spark/sql/catalyst/parser/ng/AstBuilder.scala | 1452 ----------- .../spark/sql/catalyst/parser/ng/ParseDriver.scala | 240 -- .../spark/sql/catalyst/parser/ng/ParserUtils.scala | 118 - .../spark/sql/catalyst/parser/ASTNodeSuite.scala | 38 - .../sql/catalyst/parser/CatalystQlSuite.scala | 223 -- .../sql/catalyst/parser/DataTypeParserSuite.scala | 1 - .../sql/catalyst/parser/ErrorParserSuite.scala | 67 + .../catalyst/parser/ExpressionParserSuite.scala | 497 ++++ .../sql/catalyst/parser/PlanParserSuite.scala | 429 ++++ .../parser/TableIdentifierParserSuite.scala | 42 + .../sql/catalyst/parser/ng/ErrorParserSuite.scala | 67 - .../catalyst/parser/ng/ExpressionParserSuite.scala | 497 ---- .../sql/catalyst/parser/ng/PlanParserSuite.scala | 429 ---- .../parser/ng/TableIdentifierParserSuite.scala | 42 - .../org/apache/spark/sql/execution/SparkQl.scala | 387 --- .../spark/sql/execution/SparkSqlParser.scala | 6 +- .../command/AlterTableCommandParser.scala | 431 ---- .../org/apache/spark/sql/internal/SQLConf.scala | 20 +- .../spark/sql/execution/command/DDLSuite.scala | 13 +- sql/hive/pom.xml | 19 - .../spark/sql/hive/HiveMetastoreCatalog.scala | 27 +- .../scala/org/apache/spark/sql/hive/HiveQl.scala | 749 ------ .../spark/sql/hive/execution/HiveSqlParser.scala | 31 +- .../apache/spark/sql/hive/ErrorPositionSuite.scala | 4 +- .../org/apache/spark/sql/hive/HiveQlSuite.scala | 4 +- .../apache/spark/sql/hive/StatisticsSuite.scala | 5 +- 53 files changed, 3816 insertions(+), 11688 deletions(-) delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g delete mode 100644 sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g create mode 100644 sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 delete mode 100644 sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala delete mode 100644 sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala (limited to 'sql/catalyst') diff --git a/dev/deps/spark-deps-hadoop-2.2 b/dev/deps/spark-deps-hadoop-2.2 index 7c2f88bdb1..0c4e43b9c8 100644 --- a/dev/deps/spark-deps-hadoop-2.2 +++ b/dev/deps/spark-deps-hadoop-2.2 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -173,6 +174,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.3 b/dev/deps/spark-deps-hadoop-2.3 index f4d600038d..a0d62a1c30 100644 --- a/dev/deps/spark-deps-hadoop-2.3 +++ b/dev/deps/spark-deps-hadoop-2.3 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -164,6 +165,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.4 b/dev/deps/spark-deps-hadoop-2.4 index 7c5e2c35bd..cc6e40329c 100644 --- a/dev/deps/spark-deps-hadoop-2.4 +++ b/dev/deps/spark-deps-hadoop-2.4 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -165,6 +166,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.6 b/dev/deps/spark-deps-hadoop-2.6 index 03d9a51057..5c93db5082 100644 --- a/dev/deps/spark-deps-hadoop-2.6 +++ b/dev/deps/spark-deps-hadoop-2.6 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -171,6 +172,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/dev/deps/spark-deps-hadoop-2.7 b/dev/deps/spark-deps-hadoop-2.7 index 5765071a1c..860fd79aad 100644 --- a/dev/deps/spark-deps-hadoop-2.7 +++ b/dev/deps/spark-deps-hadoop-2.7 @@ -2,7 +2,8 @@ JavaEWAH-0.3.2.jar RoaringBitmap-0.5.11.jar ST4-4.0.4.jar activation-1.1.1.jar -antlr-runtime-3.5.2.jar +antlr-2.7.7.jar +antlr-runtime-3.4.jar antlr4-runtime-4.5.2-1.jar aopalliance-1.0.jar apache-log4j-extras-1.2.17.jar @@ -172,6 +173,7 @@ spire_2.11-0.7.4.jar stax-api-1.0-2.jar stax-api-1.0.1.jar stream-2.7.0.jar +stringtemplate-3.2.1.jar super-csv-2.2.0.jar univocity-parsers-1.5.6.jar unused-1.0.0.jar diff --git a/pom.xml b/pom.xml index 1513a18b71..9dab0bca74 100644 --- a/pom.xml +++ b/pom.xml @@ -177,7 +177,6 @@ 3.5.2 1.3.9 0.9.2 - 3.5.2 4.5.2-1 ${java.home} @@ -1755,11 +1754,6 @@ - - org.antlr - antlr-runtime - ${antlr.version} - org.antlr antlr4-runtime diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 39a9e16f7e..5d62b688b9 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -403,59 +403,9 @@ object OldDeps { object Catalyst { lazy val settings = antlr4Settings ++ Seq( - antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser.ng"), + antlr4PackageName in Antlr4 := Some("org.apache.spark.sql.catalyst.parser"), antlr4GenListener in Antlr4 := true, - antlr4GenVisitor in Antlr4 := true, - // ANTLR code-generation step. - // - // This has been heavily inspired by com.github.stefri.sbt-antlr (0.5.3). It fixes a number of - // build errors in the current plugin. - // Create Parser from ANTLR grammar files. - sourceGenerators in Compile += Def.task { - val log = streams.value.log - - val grammarFileNames = Seq( - "SparkSqlLexer.g", - "SparkSqlParser.g") - val sourceDir = (sourceDirectory in Compile).value / "antlr3" - val targetDir = (sourceManaged in Compile).value / "antlr3" - - // Create default ANTLR Tool. - val antlr = new org.antlr.Tool - - // Setup input and output directories. - antlr.setInputDirectory(sourceDir.getPath) - antlr.setOutputDirectory(targetDir.getPath) - antlr.setForceRelativeOutput(true) - antlr.setMake(true) - - // Add grammar files. - grammarFileNames.flatMap(gFileName => (sourceDir ** gFileName).get).foreach { gFilePath => - val relGFilePath = (gFilePath relativeTo sourceDir).get.getPath - log.info("ANTLR: Grammar file '%s' detected.".format(relGFilePath)) - antlr.addGrammarFile(relGFilePath) - // We will set library directory multiple times here. However, only the - // last one has effect. Because the grammar files are located under the same directory, - // We assume there is only one library directory. - antlr.setLibDirectory(gFilePath.getParent) - } - - // Generate the parser. - antlr.process() - val errorState = org.antlr.tool.ErrorManager.getErrorState - if (errorState.errors > 0) { - sys.error("ANTLR: Caught %d build errors.".format(errorState.errors)) - } else if (errorState.warnings > 0) { - sys.error("ANTLR: Caught %d build warnings.".format(errorState.warnings)) - } - - // Return all generated java files. - (targetDir ** "*.java").get.toSeq - }.taskValue, - // Include ANTLR tokens files. - resourceGenerators in Compile += Def.task { - ((sourceManaged in Compile).value ** "*.tokens").get.toSeq - }.taskValue + antlr4GenVisitor in Antlr4 := true ) } diff --git a/project/plugins.sbt b/project/plugins.sbt index d9ed7962bf..4929ba3c4d 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -22,9 +22,6 @@ libraryDependencies += "org.ow2.asm" % "asm" % "5.0.3" libraryDependencies += "org.ow2.asm" % "asm-commons" % "5.0.3" -libraryDependencies += "org.antlr" % "antlr" % "3.5.2" - - // TODO I am not sure we want such a dep. resolvers += "simplytyped" at "http://simplytyped.github.io/repo/releases" diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index b89ea8c6e0..7ea0e0d5c9 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -55,7 +55,7 @@ def capture_sql_exception(f): e.java_exception.getStackTrace())) if s.startswith('org.apache.spark.sql.AnalysisException: '): raise AnalysisException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.catalyst.parser.ng.ParseException: '): + if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): raise ParseException(s.split(': ', 1)[1], stackTrace) if s.startswith('java.lang.IllegalArgumentException: '): raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 9bfe495e90..1748fa2778 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -71,10 +71,6 @@ org.codehaus.janino janino - - org.antlr - antlr-runtime - org.antlr antlr4-runtime @@ -115,24 +111,6 @@ - - org.antlr - antlr3-maven-plugin - - - - antlr - - - - - ../catalyst/src/main/antlr3 - - **/SparkSqlLexer.g - **/SparkSqlParser.g - - - org.antlr antlr4-maven-plugin diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g deleted file mode 100644 index 13a6a2d276..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/ExpressionParser.g +++ /dev/null @@ -1,400 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. -*/ - -parser grammar ExpressionParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -// fun(par1, par2, par3) -function -@init { gParent.pushMsg("function specification", state); } -@after { gParent.popMsg(state); } - : - functionName - LPAREN - ( - (STAR) => (star=STAR) - | (dist=KW_DISTINCT)? (selectExpression (COMMA selectExpression)*)? - ) - RPAREN (KW_OVER ws=window_specification)? - -> {$star != null}? ^(TOK_FUNCTIONSTAR functionName $ws?) - -> {$dist == null}? ^(TOK_FUNCTION functionName (selectExpression+)? $ws?) - -> ^(TOK_FUNCTIONDI functionName (selectExpression+)? $ws?) - ; - -functionName -@init { gParent.pushMsg("function name", state); } -@after { gParent.popMsg(state); } - : // Keyword IF is also a function name - (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) => (KW_IF | KW_ARRAY | KW_MAP | KW_STRUCT | KW_UNIONTYPE) - | - (functionIdentifier) => functionIdentifier - | - {!useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsCastFunctionName -> Identifier[$sql11ReservedKeywordsUsedAsCastFunctionName.text] - ; - -castExpression -@init { gParent.pushMsg("cast expression", state); } -@after { gParent.popMsg(state); } - : - KW_CAST - LPAREN - expression - KW_AS - primitiveType - RPAREN -> ^(TOK_FUNCTION primitiveType expression) - ; - -caseExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE expression - (KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_CASE expression*) - ; - -whenExpression -@init { gParent.pushMsg("case expression", state); } -@after { gParent.popMsg(state); } - : - KW_CASE - ( KW_WHEN expression KW_THEN expression)+ - (KW_ELSE expression)? - KW_END -> ^(TOK_FUNCTION KW_WHEN expression*) - ; - -constant -@init { gParent.pushMsg("constant", state); } -@after { gParent.popMsg(state); } - : - Number - | dateLiteral - | timestampLiteral - | intervalLiteral - | StringLiteral - | stringLiteralSequence - | BigintLiteral - | SmallintLiteral - | TinyintLiteral - | DoubleLiteral - | booleanValue - ; - -stringLiteralSequence - : - StringLiteral StringLiteral+ -> ^(TOK_STRINGLITERALSEQUENCE StringLiteral StringLiteral+) - ; - -dateLiteral - : - KW_DATE StringLiteral -> - { - // Create DateLiteral token, but with the text of the string value - // This makes the dateLiteral more consistent with the other type literals. - adaptor.create(TOK_DATELITERAL, $StringLiteral.text) - } - | - KW_CURRENT_DATE -> ^(TOK_FUNCTION KW_CURRENT_DATE) - ; - -timestampLiteral - : - KW_TIMESTAMP StringLiteral -> - { - adaptor.create(TOK_TIMESTAMPLITERAL, $StringLiteral.text) - } - | - KW_CURRENT_TIMESTAMP -> ^(TOK_FUNCTION KW_CURRENT_TIMESTAMP) - ; - -intervalLiteral - : - (KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH) => KW_INTERVAL intervalConstant KW_YEAR KW_TO KW_MONTH - -> ^(TOK_INTERVAL_YEAR_MONTH_LITERAL intervalConstant) - | (KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND) => KW_INTERVAL intervalConstant KW_DAY KW_TO KW_SECOND - -> ^(TOK_INTERVAL_DAY_TIME_LITERAL intervalConstant) - | KW_INTERVAL - ((intervalConstant KW_YEAR)=> year=intervalConstant KW_YEAR)? - ((intervalConstant KW_MONTH)=> month=intervalConstant KW_MONTH)? - ((intervalConstant KW_WEEK)=> week=intervalConstant KW_WEEK)? - ((intervalConstant KW_DAY)=> day=intervalConstant KW_DAY)? - ((intervalConstant KW_HOUR)=> hour=intervalConstant KW_HOUR)? - ((intervalConstant KW_MINUTE)=> minute=intervalConstant KW_MINUTE)? - ((intervalConstant KW_SECOND)=> second=intervalConstant KW_SECOND)? - ((intervalConstant KW_MILLISECOND)=> millisecond=intervalConstant KW_MILLISECOND)? - ((intervalConstant KW_MICROSECOND)=> microsecond=intervalConstant KW_MICROSECOND)? - -> ^(TOK_INTERVAL - ^(TOK_INTERVAL_YEAR_LITERAL $year?) - ^(TOK_INTERVAL_MONTH_LITERAL $month?) - ^(TOK_INTERVAL_WEEK_LITERAL $week?) - ^(TOK_INTERVAL_DAY_LITERAL $day?) - ^(TOK_INTERVAL_HOUR_LITERAL $hour?) - ^(TOK_INTERVAL_MINUTE_LITERAL $minute?) - ^(TOK_INTERVAL_SECOND_LITERAL $second?) - ^(TOK_INTERVAL_MILLISECOND_LITERAL $millisecond?) - ^(TOK_INTERVAL_MICROSECOND_LITERAL $microsecond?)) - ; - -intervalConstant - : - sign=(MINUS|PLUS)? value=Number -> { - adaptor.create(Number, ($sign != null ? $sign.getText() : "") + $value.getText()) - } - | StringLiteral - ; - -expression -@init { gParent.pushMsg("expression specification", state); } -@after { gParent.popMsg(state); } - : - precedenceOrExpression - ; - -atomExpression - : - (KW_NULL) => KW_NULL -> TOK_NULL - | (constant) => constant - | castExpression - | caseExpression - | whenExpression - | (functionName LPAREN) => function - | tableOrColumn - | (LPAREN KW_SELECT) => subQueryExpression - -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP) subQueryExpression) - | LPAREN! expression RPAREN! - ; - - -precedenceFieldExpression - : - atomExpression ((LSQUARE^ expression RSQUARE!) | (DOT^ identifier))* - ; - -precedenceUnaryOperator - : - PLUS | MINUS | TILDE - ; - -nullCondition - : - KW_NULL -> ^(TOK_ISNULL) - | KW_NOT KW_NULL -> ^(TOK_ISNOTNULL) - ; - -precedenceUnaryPrefixExpression - : - (precedenceUnaryOperator+)=> precedenceUnaryOperator^ precedenceUnaryPrefixExpression - | precedenceFieldExpression - ; - -precedenceUnarySuffixExpression - : - ( - (LPAREN precedenceUnaryPrefixExpression RPAREN) => LPAREN precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? RPAREN - | - precedenceUnaryPrefixExpression (a=KW_IS nullCondition)? - ) - -> {$a != null}? ^(TOK_FUNCTION nullCondition precedenceUnaryPrefixExpression) - -> precedenceUnaryPrefixExpression - ; - - -precedenceBitwiseXorOperator - : - BITWISEXOR - ; - -precedenceBitwiseXorExpression - : - precedenceUnarySuffixExpression (precedenceBitwiseXorOperator^ precedenceUnarySuffixExpression)* - ; - - -precedenceStarOperator - : - STAR | DIVIDE | MOD | DIV - ; - -precedenceStarExpression - : - precedenceBitwiseXorExpression (precedenceStarOperator^ precedenceBitwiseXorExpression)* - ; - - -precedencePlusOperator - : - PLUS | MINUS - ; - -precedencePlusExpression - : - precedenceStarExpression (precedencePlusOperator^ precedenceStarExpression)* - ; - - -precedenceAmpersandOperator - : - AMPERSAND - ; - -precedenceAmpersandExpression - : - precedencePlusExpression (precedenceAmpersandOperator^ precedencePlusExpression)* - ; - - -precedenceBitwiseOrOperator - : - BITWISEOR - ; - -precedenceBitwiseOrExpression - : - precedenceAmpersandExpression (precedenceBitwiseOrOperator^ precedenceAmpersandExpression)* - ; - - -// Equal operators supporting NOT prefix -precedenceEqualNegatableOperator - : - KW_LIKE | KW_RLIKE | KW_REGEXP - ; - -precedenceEqualOperator - : - precedenceEqualNegatableOperator | EQUAL | EQUAL_NS | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -subQueryExpression - : - LPAREN! selectStatement[true] RPAREN! - ; - -precedenceEqualExpression - : - (LPAREN precedenceBitwiseOrExpression COMMA) => precedenceEqualExpressionMutiple - | - precedenceEqualExpressionSingle - ; - -precedenceEqualExpressionSingle - : - (left=precedenceBitwiseOrExpression -> $left) - ( - (KW_NOT precedenceEqualNegatableOperator notExpr=precedenceBitwiseOrExpression) - -> ^(KW_NOT ^(precedenceEqualNegatableOperator $precedenceEqualExpressionSingle $notExpr)) - | (precedenceEqualOperator equalExpr=precedenceBitwiseOrExpression) - -> ^(precedenceEqualOperator $precedenceEqualExpressionSingle $equalExpr) - | (KW_NOT KW_IN LPAREN KW_SELECT)=> (KW_NOT KW_IN subQueryExpression) - -> ^(KW_NOT ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle)) - | (KW_NOT KW_IN expressions) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions)) - | (KW_IN LPAREN KW_SELECT)=> (KW_IN subQueryExpression) - -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_IN) subQueryExpression $precedenceEqualExpressionSingle) - | (KW_IN expressions) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionSingle expressions) - | ( KW_NOT KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_TRUE $left $min $max) - | ( KW_BETWEEN (min=precedenceBitwiseOrExpression) KW_AND (max=precedenceBitwiseOrExpression) ) - -> ^(TOK_FUNCTION Identifier["between"] KW_FALSE $left $min $max) - )* - | (KW_EXISTS LPAREN KW_SELECT)=> (KW_EXISTS subQueryExpression) -> ^(TOK_SUBQUERY_EXPR ^(TOK_SUBQUERY_OP KW_EXISTS) subQueryExpression) - ; - -expressions - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -//we transform the (col0, col1) in ((v00,v01),(v10,v11)) into struct(col0, col1) in (struct(v00,v01),struct(v10,v11)) -precedenceEqualExpressionMutiple - : - (LPAREN precedenceBitwiseOrExpression (COMMA precedenceBitwiseOrExpression)+ RPAREN -> ^(TOK_FUNCTION Identifier["struct"] precedenceBitwiseOrExpression+)) - ( (KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+) - | (KW_NOT KW_IN LPAREN expressionsToStruct (COMMA expressionsToStruct)+ RPAREN) - -> ^(KW_NOT ^(TOK_FUNCTION KW_IN $precedenceEqualExpressionMutiple expressionsToStruct+))) - ; - -expressionsToStruct - : - LPAREN expression (COMMA expression)* RPAREN -> ^(TOK_FUNCTION Identifier["struct"] expression+) - ; - -precedenceNotOperator - : - KW_NOT - ; - -precedenceNotExpression - : - (precedenceNotOperator^)* precedenceEqualExpression - ; - - -precedenceAndOperator - : - KW_AND - ; - -precedenceAndExpression - : - precedenceNotExpression (precedenceAndOperator^ precedenceNotExpression)* - ; - - -precedenceOrOperator - : - KW_OR - ; - -precedenceOrExpression - : - precedenceAndExpression (precedenceOrOperator^ precedenceAndExpression)* - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g deleted file mode 100644 index 1bf461c912..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/FromClauseParser.g +++ /dev/null @@ -1,341 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/FromClauseParser.g grammar. -*/ -parser grammar FromClauseParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------------------------------------------------------------------- - -tableAllColumns - : STAR - -> ^(TOK_ALLCOLREF) - | tableName DOT STAR - -> ^(TOK_ALLCOLREF tableName) - ; - -// (table|column) -tableOrColumn -@init { gParent.pushMsg("table or column identifier", state); } -@after { gParent.popMsg(state); } - : - identifier -> ^(TOK_TABLE_OR_COL identifier) - ; - -expressionList -@init { gParent.pushMsg("expression list", state); } -@after { gParent.popMsg(state); } - : - expression (COMMA expression)* -> ^(TOK_EXPLIST expression+) - ; - -aliasList -@init { gParent.pushMsg("alias list", state); } -@after { gParent.popMsg(state); } - : - identifier (COMMA identifier)* -> ^(TOK_ALIASLIST identifier+) - ; - -//----------------------- Rules for parsing fromClause ------------------------------ -// from [col1, col2, col3] table1, [col4, col5] table2 -fromClause -@init { gParent.pushMsg("from clause", state); } -@after { gParent.popMsg(state); } - : - KW_FROM joinSource -> ^(TOK_FROM joinSource) - ; - -joinSource -@init { gParent.pushMsg("join source", state); } -@after { gParent.popMsg(state); } - : fromSource ( joinToken^ fromSource ( joinCond {$joinToken.start.getType() != COMMA}? )? )* - | uniqueJoinToken^ uniqueJoinSource (COMMA! uniqueJoinSource)+ - ; - -joinCond -@init { gParent.pushMsg("join expression list", state); } -@after { gParent.popMsg(state); } - : KW_ON! expression - | KW_USING LPAREN columnNameList RPAREN -> ^(TOK_USING columnNameList) - ; - -uniqueJoinSource -@init { gParent.pushMsg("unique join source", state); } -@after { gParent.popMsg(state); } - : KW_PRESERVE? fromSource uniqueJoinExpr - ; - -uniqueJoinExpr -@init { gParent.pushMsg("unique join expression list", state); } -@after { gParent.popMsg(state); } - : LPAREN e1+=expression (COMMA e1+=expression)* RPAREN - -> ^(TOK_EXPLIST $e1*) - ; - -uniqueJoinToken -@init { gParent.pushMsg("unique join", state); } -@after { gParent.popMsg(state); } - : KW_UNIQUEJOIN -> TOK_UNIQUEJOIN; - -joinToken -@init { gParent.pushMsg("join type specifier", state); } -@after { gParent.popMsg(state); } - : - KW_JOIN -> TOK_JOIN - | KW_INNER KW_JOIN -> TOK_JOIN - | KW_NATURAL KW_JOIN -> TOK_NATURALJOIN - | KW_NATURAL KW_INNER KW_JOIN -> TOK_NATURALJOIN - | COMMA -> TOK_JOIN - | KW_CROSS KW_JOIN -> TOK_CROSSJOIN - | KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_LEFTOUTERJOIN - | KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_RIGHTOUTERJOIN - | KW_FULL (KW_OUTER)? KW_JOIN -> TOK_FULLOUTERJOIN - | KW_NATURAL KW_LEFT (KW_OUTER)? KW_JOIN -> TOK_NATURALLEFTOUTERJOIN - | KW_NATURAL KW_RIGHT (KW_OUTER)? KW_JOIN -> TOK_NATURALRIGHTOUTERJOIN - | KW_NATURAL KW_FULL (KW_OUTER)? KW_JOIN -> TOK_NATURALFULLOUTERJOIN - | KW_LEFT KW_SEMI KW_JOIN -> TOK_LEFTSEMIJOIN - | KW_ANTI KW_JOIN -> TOK_ANTIJOIN - ; - -lateralView -@init {gParent.pushMsg("lateral view", state); } -@after {gParent.popMsg(state); } - : - (KW_LATERAL KW_VIEW KW_OUTER) => KW_LATERAL KW_VIEW KW_OUTER function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? - -> ^(TOK_LATERAL_VIEW_OUTER ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) - | - KW_LATERAL KW_VIEW function tableAlias (KW_AS identifier ((COMMA)=> COMMA identifier)*)? - -> ^(TOK_LATERAL_VIEW ^(TOK_SELECT ^(TOK_SELEXPR function identifier* tableAlias))) - ; - -tableAlias -@init {gParent.pushMsg("table alias", state); } -@after {gParent.popMsg(state); } - : - identifier -> ^(TOK_TABALIAS identifier) - ; - -fromSource -@init { gParent.pushMsg("from source", state); } -@after { gParent.popMsg(state); } - : - (LPAREN KW_VALUES) => fromSource0 - | fromSource0 - | (LPAREN joinSource) => LPAREN joinSource RPAREN -> joinSource - ; - - -fromSource0 -@init { gParent.pushMsg("from source 0", state); } -@after { gParent.popMsg(state); } - : - ((Identifier LPAREN)=> partitionedTableFunction | tableSource | subQuerySource | virtualTableSource) (lateralView^)* - ; - -tableBucketSample -@init { gParent.pushMsg("table bucket sample specification", state); } -@after { gParent.popMsg(state); } - : - KW_TABLESAMPLE LPAREN KW_BUCKET (numerator=Number) KW_OUT KW_OF (denominator=Number) (KW_ON expr+=expression (COMMA expr+=expression)*)? RPAREN -> ^(TOK_TABLEBUCKETSAMPLE $numerator $denominator $expr*) - ; - -splitSample -@init { gParent.pushMsg("table split sample specification", state); } -@after { gParent.popMsg(state); } - : - KW_TABLESAMPLE LPAREN (numerator=Number) (percent=KW_PERCENT|KW_ROWS) RPAREN - -> {percent != null}? ^(TOK_TABLESPLITSAMPLE TOK_PERCENT $numerator) - -> ^(TOK_TABLESPLITSAMPLE TOK_ROWCOUNT $numerator) - | - KW_TABLESAMPLE LPAREN (numerator=ByteLengthLiteral) RPAREN - -> ^(TOK_TABLESPLITSAMPLE TOK_LENGTH $numerator) - ; - -tableSample -@init { gParent.pushMsg("table sample specification", state); } -@after { gParent.popMsg(state); } - : - tableBucketSample | - splitSample - ; - -tableSource -@init { gParent.pushMsg("table source", state); } -@after { gParent.popMsg(state); } - : tabname=tableName - ((tableProperties) => props=tableProperties)? - ((tableSample) => ts=tableSample)? - ((KW_AS) => (KW_AS alias=Identifier) - | - (Identifier) => (alias=Identifier))? - -> ^(TOK_TABREF $tabname $props? $ts? $alias?) - ; - -tableName -@init { gParent.pushMsg("table name", state); } -@after { gParent.popMsg(state); } - : - id1=identifier (DOT id2=identifier)? - -> ^(TOK_TABNAME $id1 $id2?) - ; - -viewName -@init { gParent.pushMsg("view name", state); } -@after { gParent.popMsg(state); } - : - (db=identifier DOT)? view=identifier - -> ^(TOK_TABNAME $db? $view) - ; - -subQuerySource -@init { gParent.pushMsg("subquery source", state); } -@after { gParent.popMsg(state); } - : - LPAREN queryStatementExpression[false] RPAREN KW_AS? identifier -> ^(TOK_SUBQUERY queryStatementExpression identifier) - ; - -//---------------------- Rules for parsing PTF clauses ----------------------------- -partitioningSpec -@init { gParent.pushMsg("partitioningSpec clause", state); } -@after { gParent.popMsg(state); } - : - partitionByClause orderByClause? -> ^(TOK_PARTITIONINGSPEC partitionByClause orderByClause?) | - orderByClause -> ^(TOK_PARTITIONINGSPEC orderByClause) | - distributeByClause sortByClause? -> ^(TOK_PARTITIONINGSPEC distributeByClause sortByClause?) | - sortByClause -> ^(TOK_PARTITIONINGSPEC sortByClause) | - clusterByClause -> ^(TOK_PARTITIONINGSPEC clusterByClause) - ; - -partitionTableFunctionSource -@init { gParent.pushMsg("partitionTableFunctionSource clause", state); } -@after { gParent.popMsg(state); } - : - subQuerySource | - tableSource | - partitionedTableFunction - ; - -partitionedTableFunction -@init { gParent.pushMsg("ptf clause", state); } -@after { gParent.popMsg(state); } - : - name=Identifier LPAREN KW_ON - ((partitionTableFunctionSource) => (ptfsrc=partitionTableFunctionSource spec=partitioningSpec?)) - ((Identifier LPAREN expression RPAREN ) => Identifier LPAREN expression RPAREN ( COMMA Identifier LPAREN expression RPAREN)*)? - ((RPAREN) => (RPAREN)) ((Identifier) => alias=Identifier)? - -> ^(TOK_PTBLFUNCTION $name $alias? $ptfsrc $spec? expression*) - ; - -//----------------------- Rules for parsing whereClause ----------------------------- -// where a=b and ... -whereClause -@init { gParent.pushMsg("where clause", state); } -@after { gParent.popMsg(state); } - : - KW_WHERE searchCondition -> ^(TOK_WHERE searchCondition) - ; - -searchCondition -@init { gParent.pushMsg("search condition", state); } -@after { gParent.popMsg(state); } - : - expression - ; - -//----------------------------------------------------------------------------------- - -//-------- Row Constructor ---------------------------------------------------------- -//in support of SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as FOO(a,b,c) and -// INSERT INTO (col1,col2,...) VALUES(...),(...),... -// INSERT INTO
(col1,col2,...) SELECT * FROM (VALUES(1,2,3),(4,5,6),...) as Foo(a,b,c) -valueRowConstructor -@init { gParent.pushMsg("value row constructor", state); } -@after { gParent.popMsg(state); } - : - LPAREN precedenceUnaryPrefixExpression (COMMA precedenceUnaryPrefixExpression)* RPAREN -> ^(TOK_VALUE_ROW precedenceUnaryPrefixExpression+) - ; - -valuesTableConstructor -@init { gParent.pushMsg("values table constructor", state); } -@after { gParent.popMsg(state); } - : - valueRowConstructor (COMMA valueRowConstructor)* -> ^(TOK_VALUES_TABLE valueRowConstructor+) - ; - -/* -VALUES(1),(2) means 2 rows, 1 column each. -VALUES(1,2),(3,4) means 2 rows, 2 columns each. -VALUES(1,2,3) means 1 row, 3 columns -*/ -valuesClause -@init { gParent.pushMsg("values clause", state); } -@after { gParent.popMsg(state); } - : - KW_VALUES valuesTableConstructor -> valuesTableConstructor - ; - -/* -This represents a clause like this: -(VALUES(1,2),(2,3)) as VirtTable(col1,col2) -*/ -virtualTableSource -@init { gParent.pushMsg("virtual table source", state); } -@after { gParent.popMsg(state); } - : - LPAREN valuesClause RPAREN tableNameColList -> ^(TOK_VIRTUAL_TABLE tableNameColList valuesClause) - ; -/* -e.g. as VirtTable(col1,col2) -Note that we only want literals as column names -*/ -tableNameColList -@init { gParent.pushMsg("from source", state); } -@after { gParent.popMsg(state); } - : - KW_AS? identifier LPAREN identifier (COMMA identifier)* RPAREN -> ^(TOK_VIRTUAL_TABREF ^(TOK_TABNAME identifier) ^(TOK_COL_NAME identifier+)) - ; - -//----------------------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g deleted file mode 100644 index 916eb6a7ac..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/IdentifiersParser.g +++ /dev/null @@ -1,184 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. -*/ -parser grammar IdentifiersParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------------------------------------------------------------------- - -// group by a,b -groupByClause -@init { gParent.pushMsg("group by clause", state); } -@after { gParent.popMsg(state); } - : - KW_GROUP KW_BY - expression - ( COMMA expression)* - ((rollup=KW_WITH KW_ROLLUP) | (cube=KW_WITH KW_CUBE)) ? - (sets=KW_GROUPING KW_SETS - LPAREN groupingSetExpression ( COMMA groupingSetExpression)* RPAREN ) ? - -> {rollup != null}? ^(TOK_ROLLUP_GROUPBY expression+) - -> {cube != null}? ^(TOK_CUBE_GROUPBY expression+) - -> {sets != null}? ^(TOK_GROUPING_SETS expression+ groupingSetExpression+) - -> ^(TOK_GROUPBY expression+) - ; - -groupingSetExpression -@init {gParent.pushMsg("grouping set expression", state); } -@after {gParent.popMsg(state); } - : - (LPAREN) => groupingSetExpressionMultiple - | - groupingExpressionSingle - ; - -groupingSetExpressionMultiple -@init {gParent.pushMsg("grouping set part expression", state); } -@after {gParent.popMsg(state); } - : - LPAREN - expression? (COMMA expression)* - RPAREN - -> ^(TOK_GROUPING_SETS_EXPRESSION expression*) - ; - -groupingExpressionSingle -@init { gParent.pushMsg("groupingExpression expression", state); } -@after { gParent.popMsg(state); } - : - expression -> ^(TOK_GROUPING_SETS_EXPRESSION expression) - ; - -havingClause -@init { gParent.pushMsg("having clause", state); } -@after { gParent.popMsg(state); } - : - KW_HAVING havingCondition -> ^(TOK_HAVING havingCondition) - ; - -havingCondition -@init { gParent.pushMsg("having condition", state); } -@after { gParent.popMsg(state); } - : - expression - ; - -expressionsInParenthese - : - LPAREN expression (COMMA expression)* RPAREN -> expression+ - ; - -expressionsNotInParenthese - : - expression (COMMA expression)* -> expression+ - ; - -columnRefOrderInParenthese - : - LPAREN columnRefOrder (COMMA columnRefOrder)* RPAREN -> columnRefOrder+ - ; - -columnRefOrderNotInParenthese - : - columnRefOrder (COMMA columnRefOrder)* -> columnRefOrder+ - ; - -// order by a,b -orderByClause -@init { gParent.pushMsg("order by clause", state); } -@after { gParent.popMsg(state); } - : - KW_ORDER KW_BY columnRefOrder ( COMMA columnRefOrder)* -> ^(TOK_ORDERBY columnRefOrder+) - ; - -clusterByClause -@init { gParent.pushMsg("cluster by clause", state); } -@after { gParent.popMsg(state); } - : - KW_CLUSTER KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_CLUSTERBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_CLUSTERBY expressionsNotInParenthese) - ) - ; - -partitionByClause -@init { gParent.pushMsg("partition by clause", state); } -@after { gParent.popMsg(state); } - : - KW_PARTITION KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) - ) - ; - -distributeByClause -@init { gParent.pushMsg("distribute by clause", state); } -@after { gParent.popMsg(state); } - : - KW_DISTRIBUTE KW_BY - ( - (LPAREN) => expressionsInParenthese -> ^(TOK_DISTRIBUTEBY expressionsInParenthese) - | - expressionsNotInParenthese -> ^(TOK_DISTRIBUTEBY expressionsNotInParenthese) - ) - ; - -sortByClause -@init { gParent.pushMsg("sort by clause", state); } -@after { gParent.popMsg(state); } - : - KW_SORT KW_BY - ( - (LPAREN) => columnRefOrderInParenthese -> ^(TOK_SORTBY columnRefOrderInParenthese) - | - columnRefOrderNotInParenthese -> ^(TOK_SORTBY columnRefOrderNotInParenthese) - ) - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g deleted file mode 100644 index 12cd5f54a0..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/KeywordParser.g +++ /dev/null @@ -1,244 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/IdentifiersParser.g grammar. -*/ - -parser grammar KeywordParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -booleanValue - : - KW_TRUE^ | KW_FALSE^ - ; - -booleanValueTok - : - KW_TRUE -> TOK_TRUE - | KW_FALSE -> TOK_FALSE - ; - -tableOrPartition - : - tableName partitionSpec? -> ^(TOK_TAB tableName partitionSpec?) - ; - -partitionSpec - : - KW_PARTITION - LPAREN partitionVal (COMMA partitionVal )* RPAREN -> ^(TOK_PARTSPEC partitionVal +) - ; - -partitionVal - : - identifier (EQUAL constant)? -> ^(TOK_PARTVAL identifier constant?) - ; - -dropPartitionSpec - : - KW_PARTITION - LPAREN dropPartitionVal (COMMA dropPartitionVal )* RPAREN -> ^(TOK_PARTSPEC dropPartitionVal +) - ; - -dropPartitionVal - : - identifier dropPartitionOperator constant -> ^(TOK_PARTVAL identifier dropPartitionOperator constant) - ; - -dropPartitionOperator - : - EQUAL | NOTEQUAL | LESSTHANOREQUALTO | LESSTHAN | GREATERTHANOREQUALTO | GREATERTHAN - ; - -sysFuncNames - : - KW_AND - | KW_OR - | KW_NOT - | KW_LIKE - | KW_IF - | KW_CASE - | KW_WHEN - | KW_TINYINT - | KW_SMALLINT - | KW_INT - | KW_BIGINT - | KW_FLOAT - | KW_DOUBLE - | KW_BOOLEAN - | KW_STRING - | KW_BINARY - | KW_ARRAY - | KW_MAP - | KW_STRUCT - | KW_UNIONTYPE - | EQUAL - | EQUAL_NS - | NOTEQUAL - | LESSTHANOREQUALTO - | LESSTHAN - | GREATERTHANOREQUALTO - | GREATERTHAN - | DIVIDE - | PLUS - | MINUS - | STAR - | MOD - | DIV - | AMPERSAND - | TILDE - | BITWISEOR - | BITWISEXOR - | KW_RLIKE - | KW_REGEXP - | KW_IN - | KW_BETWEEN - ; - -descFuncNames - : - (sysFuncNames) => sysFuncNames - | StringLiteral - | functionIdentifier - ; - -//We are allowed to use From and To in CreateTableUsing command's options (actually seems we can use any string as the option key). But we can't simply add them into nonReserved because by doing that we mess other existing rules. So we create a looseIdentifier and looseNonReserved here. -looseIdentifier - : - Identifier - | looseNonReserved -> Identifier[$looseNonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -identifier - : - Identifier - | nonReserved -> Identifier[$nonReserved.text] - // If it decides to support SQL11 reserved keywords, i.e., useSQL11ReservedKeywordsForIdentifier()=false, - // the sql11keywords in existing q tests will NOT be added back. - | {useSQL11ReservedKeywordsForIdentifier()}? sql11ReservedKeywordsUsedAsIdentifier -> Identifier[$sql11ReservedKeywordsUsedAsIdentifier.text] - ; - -functionIdentifier -@init { gParent.pushMsg("function identifier", state); } -@after { gParent.popMsg(state); } - : - identifier (DOT identifier)? -> identifier+ - ; - -principalIdentifier -@init { gParent.pushMsg("identifier for principal spec", state); } -@after { gParent.popMsg(state); } - : identifier - | QuotedIdentifier - ; - -looseNonReserved - : nonReserved | KW_FROM | KW_TO - ; - -//The new version of nonReserved + sql11ReservedKeywordsUsedAsIdentifier = old version of nonReserved -//Non reserved keywords are basically the keywords that can be used as identifiers. -//All the KW_* are automatically not only keywords, but also reserved keywords. -//That means, they can NOT be used as identifiers. -//If you would like to use them as identifiers, put them in the nonReserved list below. -//If you are not sure, please refer to the SQL2011 column in -//http://www.postgresql.org/docs/9.5/static/sql-keywords-appendix.html -nonReserved - : - KW_ADD | KW_ADMIN | KW_AFTER | KW_ANALYZE | KW_ARCHIVE | KW_ASC | KW_BEFORE | KW_BUCKET | KW_BUCKETS - | KW_CASCADE | KW_CHANGE | KW_CLUSTER | KW_CLUSTERED | KW_CLUSTERSTATUS | KW_COLLECTION | KW_COLUMNS - | KW_COMMENT | KW_COMPACT | KW_COMPACTIONS | KW_COMPUTE | KW_CONCATENATE | KW_CONTINUE | KW_DATA | KW_DAY - | KW_DATABASES | KW_DATETIME | KW_DBPROPERTIES | KW_DEFERRED | KW_DEFINED | KW_DELIMITED | KW_DEPENDENCY - | KW_DESC | KW_DIRECTORIES | KW_DIRECTORY | KW_DISABLE | KW_DISTRIBUTE | KW_ELEM_TYPE - | KW_ENABLE | KW_ESCAPED | KW_EXCLUSIVE | KW_EXPLAIN | KW_EXPORT | KW_FIELDS | KW_FILE | KW_FILEFORMAT - | KW_FIRST | KW_FORMAT | KW_FORMATTED | KW_FUNCTIONS | KW_HOLD_DDLTIME | KW_HOUR | KW_IDXPROPERTIES | KW_IGNORE - | KW_INDEX | KW_INDEXES | KW_INPATH | KW_INPUTDRIVER | KW_INPUTFORMAT | KW_ITEMS | KW_JAR - | KW_KEYS | KW_KEY_TYPE | KW_LIMIT | KW_LINES | KW_LOAD | KW_LOCATION | KW_LOCK | KW_LOCKS | KW_LOGICAL | KW_LONG - | KW_MAPJOIN | KW_MATERIALIZED | KW_METADATA | KW_MINUS | KW_MINUTE | KW_MONTH | KW_MSCK | KW_NOSCAN | KW_NO_DROP | KW_OFFLINE - | KW_OPTION | KW_OUTPUTDRIVER | KW_OUTPUTFORMAT | KW_OVERWRITE | KW_OWNER | KW_PARTITIONED | KW_PARTITIONS | KW_PLUS | KW_PRETTY - | KW_PRINCIPALS | KW_PROTECTION | KW_PURGE | KW_READ | KW_READONLY | KW_REBUILD | KW_RECORDREADER | KW_RECORDWRITER - | KW_RELOAD | KW_RENAME | KW_REPAIR | KW_REPLACE | KW_REPLICATION | KW_RESTRICT | KW_REWRITE - | KW_ROLE | KW_ROLES | KW_SCHEMA | KW_SCHEMAS | KW_SECOND | KW_SEMI | KW_SERDE | KW_SERDEPROPERTIES | KW_SERVER | KW_SETS | KW_SHARED - | KW_SHOW | KW_SHOW_DATABASE | KW_SKEWED | KW_SORT | KW_SORTED | KW_SSL | KW_STATISTICS | KW_STORED - | KW_STREAMTABLE | KW_STRING | KW_STRUCT | KW_TABLES | KW_TBLPROPERTIES | KW_TEMPORARY | KW_TERMINATED - | KW_TINYINT | KW_TOUCH | KW_TRANSACTIONS | KW_UNARCHIVE | KW_UNDO | KW_UNIONTYPE | KW_UNLOCK | KW_UNSET - | KW_UNSIGNED | KW_URI | KW_USE | KW_UTC | KW_UTCTIMESTAMP | KW_VALUE_TYPE | KW_VIEW | KW_WHILE | KW_YEAR - | KW_WORK - | KW_TRANSACTION - | KW_WRITE - | KW_ISOLATION - | KW_LEVEL - | KW_SNAPSHOT - | KW_AUTOCOMMIT - | KW_ANTI - | KW_WEEK | KW_MILLISECOND | KW_MICROSECOND - | KW_CLEAR | KW_LAZY | KW_CACHE | KW_UNCACHE | KW_DFS -; - -//The following SQL2011 reserved keywords are used as cast function name only, but not as identifiers. -sql11ReservedKeywordsUsedAsCastFunctionName - : - KW_BIGINT | KW_BINARY | KW_BOOLEAN | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_DATE | KW_DOUBLE | KW_FLOAT | KW_INT | KW_SMALLINT | KW_TIMESTAMP - ; - -//The following SQL2011 reserved keywords are used as identifiers in many q tests, they may be added back due to backward compatibility. -//We are planning to remove the following whole list after several releases. -//Thus, please do not change the following list unless you know what to do. -sql11ReservedKeywordsUsedAsIdentifier - : - KW_ALL | KW_ALTER | KW_ARRAY | KW_AS | KW_AUTHORIZATION | KW_BETWEEN | KW_BIGINT | KW_BINARY | KW_BOOLEAN - | KW_BOTH | KW_BY | KW_CREATE | KW_CUBE | KW_CURRENT_DATE | KW_CURRENT_TIMESTAMP | KW_CURSOR | KW_DATE | KW_DECIMAL | KW_DELETE | KW_DESCRIBE - | KW_DOUBLE | KW_DROP | KW_EXISTS | KW_EXTERNAL | KW_FALSE | KW_FETCH | KW_FLOAT | KW_FOR | KW_FULL | KW_GRANT - | KW_GROUP | KW_GROUPING | KW_IMPORT | KW_IN | KW_INNER | KW_INSERT | KW_INT | KW_INTERSECT | KW_INTO | KW_IS | KW_LATERAL - | KW_LEFT | KW_LIKE | KW_LOCAL | KW_NONE | KW_NULL | KW_OF | KW_ORDER | KW_OUT | KW_OUTER | KW_PARTITION - | KW_PERCENT | KW_PROCEDURE | KW_RANGE | KW_READS | KW_REVOKE | KW_RIGHT - | KW_ROLLUP | KW_ROW | KW_ROWS | KW_SET | KW_SMALLINT | KW_TABLE | KW_TIMESTAMP | KW_TO | KW_TRIGGER | KW_TRUE - | KW_TRUNCATE | KW_UNION | KW_UPDATE | KW_USER | KW_USING | KW_VALUES | KW_WITH -//The following two keywords come from MySQL. Although they are not keywords in SQL2011, they are reserved keywords in MySQL. - | KW_REGEXP | KW_RLIKE - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g deleted file mode 100644 index f18b6ec496..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SelectClauseParser.g +++ /dev/null @@ -1,235 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/SelectClauseParser.g grammar. -*/ -parser grammar SelectClauseParser; - -options -{ -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} - -@members { - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - @Override - public void displayRecognitionError(String[] tokenNames, - RecognitionException e) { - gParent.displayRecognitionError(tokenNames, e); - } - protected boolean useSQL11ReservedKeywordsForIdentifier() { - return gParent.useSQL11ReservedKeywordsForIdentifier(); - } -} - -@rulecatch { -catch (RecognitionException e) { - throw e; -} -} - -//----------------------- Rules for parsing selectClause ----------------------------- -// select a,b,c ... -selectClause -@init { gParent.pushMsg("select clause", state); } -@after { gParent.popMsg(state); } - : - KW_SELECT hintClause? (((KW_ALL | dist=KW_DISTINCT)? selectList) - | (transform=KW_TRANSFORM selectTrfmClause)) - -> {$transform == null && $dist == null}? ^(TOK_SELECT hintClause? selectList) - -> {$transform == null && $dist != null}? ^(TOK_SELECTDI hintClause? selectList) - -> ^(TOK_SELECT hintClause? ^(TOK_SELEXPR selectTrfmClause) ) - | - trfmClause ->^(TOK_SELECT ^(TOK_SELEXPR trfmClause)) - ; - -selectList -@init { gParent.pushMsg("select list", state); } -@after { gParent.popMsg(state); } - : - selectItem ( COMMA selectItem )* -> selectItem+ - ; - -selectTrfmClause -@init { gParent.pushMsg("transform clause", state); } -@after { gParent.popMsg(state); } - : - LPAREN selectExpressionList RPAREN - inSerde=rowFormat inRec=recordWriter - KW_USING StringLiteral - ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? - outSerde=rowFormat outRec=recordReader - -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) - ; - -hintClause -@init { gParent.pushMsg("hint clause", state); } -@after { gParent.popMsg(state); } - : - DIVIDE STAR PLUS hintList STAR DIVIDE -> ^(TOK_HINTLIST hintList) - ; - -hintList -@init { gParent.pushMsg("hint list", state); } -@after { gParent.popMsg(state); } - : - hintItem (COMMA hintItem)* -> hintItem+ - ; - -hintItem -@init { gParent.pushMsg("hint item", state); } -@after { gParent.popMsg(state); } - : - hintName (LPAREN hintArgs RPAREN)? -> ^(TOK_HINT hintName hintArgs?) - ; - -hintName -@init { gParent.pushMsg("hint name", state); } -@after { gParent.popMsg(state); } - : - KW_MAPJOIN -> TOK_MAPJOIN - | KW_STREAMTABLE -> TOK_STREAMTABLE - ; - -hintArgs -@init { gParent.pushMsg("hint arguments", state); } -@after { gParent.popMsg(state); } - : - hintArgName (COMMA hintArgName)* -> ^(TOK_HINTARGLIST hintArgName+) - ; - -hintArgName -@init { gParent.pushMsg("hint argument name", state); } -@after { gParent.popMsg(state); } - : - identifier - ; - -selectItem -@init { gParent.pushMsg("selection target", state); } -@after { gParent.popMsg(state); } - : - (tableAllColumns) => tableAllColumns -> ^(TOK_SELEXPR tableAllColumns) - | - namedExpression - ; - -namedExpression -@init { gParent.pushMsg("select named expression", state); } -@after { gParent.popMsg(state); } - : - ( expression - ((KW_AS? identifier) | (KW_AS LPAREN identifier (COMMA identifier)* RPAREN))? - ) -> ^(TOK_SELEXPR expression identifier*) - ; - -trfmClause -@init { gParent.pushMsg("transform clause", state); } -@after { gParent.popMsg(state); } - : - ( KW_MAP selectExpressionList - | KW_REDUCE selectExpressionList ) - inSerde=rowFormat inRec=recordWriter - KW_USING StringLiteral - ( KW_AS ((LPAREN (aliasList | columnNameTypeList) RPAREN) | (aliasList | columnNameTypeList)))? - outSerde=rowFormat outRec=recordReader - -> ^(TOK_TRANSFORM selectExpressionList $inSerde $inRec StringLiteral $outSerde $outRec aliasList? columnNameTypeList?) - ; - -selectExpression -@init { gParent.pushMsg("select expression", state); } -@after { gParent.popMsg(state); } - : - (tableAllColumns) => tableAllColumns - | - expression - ; - -selectExpressionList -@init { gParent.pushMsg("select expression list", state); } -@after { gParent.popMsg(state); } - : - selectExpression (COMMA selectExpression)* -> ^(TOK_EXPLIST selectExpression+) - ; - -//---------------------- Rules for windowing clauses ------------------------------- -window_clause -@init { gParent.pushMsg("window_clause", state); } -@after { gParent.popMsg(state); } -: - KW_WINDOW window_defn (COMMA window_defn)* -> ^(KW_WINDOW window_defn+) -; - -window_defn -@init { gParent.pushMsg("window_defn", state); } -@after { gParent.popMsg(state); } -: - Identifier KW_AS window_specification -> ^(TOK_WINDOWDEF Identifier window_specification) -; - -window_specification -@init { gParent.pushMsg("window_specification", state); } -@after { gParent.popMsg(state); } -: - (Identifier | ( LPAREN Identifier? partitioningSpec? window_frame? RPAREN)) -> ^(TOK_WINDOWSPEC Identifier? partitioningSpec? window_frame?) -; - -window_frame : - window_range_expression | - window_value_expression -; - -window_range_expression -@init { gParent.pushMsg("window_range_expression", state); } -@after { gParent.popMsg(state); } -: - KW_ROWS sb=window_frame_start_boundary -> ^(TOK_WINDOWRANGE $sb) | - KW_ROWS KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWRANGE $s $end) -; - -window_value_expression -@init { gParent.pushMsg("window_value_expression", state); } -@after { gParent.popMsg(state); } -: - KW_RANGE sb=window_frame_start_boundary -> ^(TOK_WINDOWVALUES $sb) | - KW_RANGE KW_BETWEEN s=window_frame_boundary KW_AND end=window_frame_boundary -> ^(TOK_WINDOWVALUES $s $end) -; - -window_frame_start_boundary -@init { gParent.pushMsg("windowframestartboundary", state); } -@after { gParent.popMsg(state); } -: - KW_UNBOUNDED KW_PRECEDING -> ^(KW_PRECEDING KW_UNBOUNDED) | - KW_CURRENT KW_ROW -> ^(KW_CURRENT) | - Number KW_PRECEDING -> ^(KW_PRECEDING Number) -; - -window_frame_boundary -@init { gParent.pushMsg("windowframeboundary", state); } -@after { gParent.popMsg(state); } -: - KW_UNBOUNDED (r=KW_PRECEDING|r=KW_FOLLOWING) -> ^($r KW_UNBOUNDED) | - KW_CURRENT KW_ROW -> ^(KW_CURRENT) | - Number (d=KW_PRECEDING | d=KW_FOLLOWING ) -> ^($d Number) -; - diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g deleted file mode 100644 index fd1ad59207..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlLexer.g +++ /dev/null @@ -1,491 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveLexer.g grammar. -*/ -lexer grammar SparkSqlLexer; - -@lexer::header { -package org.apache.spark.sql.catalyst.parser; - -} - -@lexer::members { - private ParserConf parserConf; - private ParseErrorReporter reporter; - - public void configure(ParserConf parserConf, ParseErrorReporter reporter) { - this.parserConf = parserConf; - this.reporter = reporter; - } - - protected boolean allowQuotedId() { - if (parserConf == null) { - return true; - } - return parserConf.supportQuotedId(); - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - if (reporter != null) { - reporter.report(this, e, tokenNames); - } - } -} - -// Keywords - -KW_TRUE : 'TRUE'; -KW_FALSE : 'FALSE'; -KW_ALL : 'ALL'; -KW_NONE: 'NONE'; -KW_AND : 'AND'; -KW_OR : 'OR'; -KW_NOT : 'NOT' | '!'; -KW_LIKE : 'LIKE'; - -KW_IF : 'IF'; -KW_EXISTS : 'EXISTS'; - -KW_ASC : 'ASC'; -KW_DESC : 'DESC'; -KW_ORDER : 'ORDER'; -KW_GROUP : 'GROUP'; -KW_BY : 'BY'; -KW_HAVING : 'HAVING'; -KW_WHERE : 'WHERE'; -KW_FROM : 'FROM'; -KW_AS : 'AS'; -KW_SELECT : 'SELECT'; -KW_DISTINCT : 'DISTINCT'; -KW_INSERT : 'INSERT'; -KW_OVERWRITE : 'OVERWRITE'; -KW_OUTER : 'OUTER'; -KW_UNIQUEJOIN : 'UNIQUEJOIN'; -KW_PRESERVE : 'PRESERVE'; -KW_JOIN : 'JOIN'; -KW_LEFT : 'LEFT'; -KW_RIGHT : 'RIGHT'; -KW_FULL : 'FULL'; -KW_ANTI : 'ANTI'; -KW_ON : 'ON'; -KW_PARTITION : 'PARTITION'; -KW_PARTITIONS : 'PARTITIONS'; -KW_TABLE: 'TABLE'; -KW_TABLES: 'TABLES'; -KW_COLUMNS: 'COLUMNS'; -KW_INDEX: 'INDEX'; -KW_INDEXES: 'INDEXES'; -KW_REBUILD: 'REBUILD'; -KW_FUNCTIONS: 'FUNCTIONS'; -KW_SHOW: 'SHOW'; -KW_MSCK: 'MSCK'; -KW_REPAIR: 'REPAIR'; -KW_DIRECTORY: 'DIRECTORY'; -KW_LOCAL: 'LOCAL'; -KW_TRANSFORM : 'TRANSFORM'; -KW_USING: 'USING'; -KW_CLUSTER: 'CLUSTER'; -KW_DISTRIBUTE: 'DISTRIBUTE'; -KW_SORT: 'SORT'; -KW_UNION: 'UNION'; -KW_EXCEPT: 'EXCEPT'; -KW_LOAD: 'LOAD'; -KW_EXPORT: 'EXPORT'; -KW_IMPORT: 'IMPORT'; -KW_REPLICATION: 'REPLICATION'; -KW_METADATA: 'METADATA'; -KW_DATA: 'DATA'; -KW_INPATH: 'INPATH'; -KW_IS: 'IS'; -KW_NULL: 'NULL'; -KW_CREATE: 'CREATE'; -KW_EXTERNAL: 'EXTERNAL'; -KW_ALTER: 'ALTER'; -KW_CHANGE: 'CHANGE'; -KW_COLUMN: 'COLUMN'; -KW_FIRST: 'FIRST'; -KW_AFTER: 'AFTER'; -KW_DESCRIBE: 'DESCRIBE'; -KW_DROP: 'DROP'; -KW_RENAME: 'RENAME'; -KW_TO: 'TO'; -KW_COMMENT: 'COMMENT'; -KW_BOOLEAN: 'BOOLEAN'; -KW_TINYINT: 'TINYINT'; -KW_SMALLINT: 'SMALLINT'; -KW_INT: 'INT'; -KW_BIGINT: 'BIGINT'; -KW_FLOAT: 'FLOAT'; -KW_DOUBLE: 'DOUBLE'; -KW_DATE: 'DATE'; -KW_DATETIME: 'DATETIME'; -KW_TIMESTAMP: 'TIMESTAMP'; -KW_INTERVAL: 'INTERVAL'; -KW_DECIMAL: 'DECIMAL'; -KW_STRING: 'STRING'; -KW_CHAR: 'CHAR'; -KW_VARCHAR: 'VARCHAR'; -KW_ARRAY: 'ARRAY'; -KW_STRUCT: 'STRUCT'; -KW_MAP: 'MAP'; -KW_UNIONTYPE: 'UNIONTYPE'; -KW_REDUCE: 'REDUCE'; -KW_PARTITIONED: 'PARTITIONED'; -KW_CLUSTERED: 'CLUSTERED'; -KW_SORTED: 'SORTED'; -KW_INTO: 'INTO'; -KW_BUCKETS: 'BUCKETS'; -KW_ROW: 'ROW'; -KW_ROWS: 'ROWS'; -KW_FORMAT: 'FORMAT'; -KW_DELIMITED: 'DELIMITED'; -KW_FIELDS: 'FIELDS'; -KW_TERMINATED: 'TERMINATED'; -KW_ESCAPED: 'ESCAPED'; -KW_COLLECTION: 'COLLECTION'; -KW_ITEMS: 'ITEMS'; -KW_KEYS: 'KEYS'; -KW_KEY_TYPE: '$KEY$'; -KW_LINES: 'LINES'; -KW_STORED: 'STORED'; -KW_FILEFORMAT: 'FILEFORMAT'; -KW_INPUTFORMAT: 'INPUTFORMAT'; -KW_OUTPUTFORMAT: 'OUTPUTFORMAT'; -KW_INPUTDRIVER: 'INPUTDRIVER'; -KW_OUTPUTDRIVER: 'OUTPUTDRIVER'; -KW_ENABLE: 'ENABLE'; -KW_DISABLE: 'DISABLE'; -KW_LOCATION: 'LOCATION'; -KW_TABLESAMPLE: 'TABLESAMPLE'; -KW_BUCKET: 'BUCKET'; -KW_OUT: 'OUT'; -KW_OF: 'OF'; -KW_PERCENT: 'PERCENT'; -KW_CAST: 'CAST'; -KW_ADD: 'ADD'; -KW_REPLACE: 'REPLACE'; -KW_RLIKE: 'RLIKE'; -KW_REGEXP: 'REGEXP'; -KW_TEMPORARY: 'TEMPORARY'; -KW_FUNCTION: 'FUNCTION'; -KW_MACRO: 'MACRO'; -KW_FILE: 'FILE'; -KW_JAR: 'JAR'; -KW_EXPLAIN: 'EXPLAIN'; -KW_EXTENDED: 'EXTENDED'; -KW_FORMATTED: 'FORMATTED'; -KW_PRETTY: 'PRETTY'; -KW_DEPENDENCY: 'DEPENDENCY'; -KW_LOGICAL: 'LOGICAL'; -KW_SERDE: 'SERDE'; -KW_WITH: 'WITH'; -KW_DEFERRED: 'DEFERRED'; -KW_SERDEPROPERTIES: 'SERDEPROPERTIES'; -KW_DBPROPERTIES: 'DBPROPERTIES'; -KW_LIMIT: 'LIMIT'; -KW_SET: 'SET'; -KW_UNSET: 'UNSET'; -KW_TBLPROPERTIES: 'TBLPROPERTIES'; -KW_IDXPROPERTIES: 'IDXPROPERTIES'; -KW_VALUE_TYPE: '$VALUE$'; -KW_ELEM_TYPE: '$ELEM$'; -KW_DEFINED: 'DEFINED'; -KW_CASE: 'CASE'; -KW_WHEN: 'WHEN'; -KW_THEN: 'THEN'; -KW_ELSE: 'ELSE'; -KW_END: 'END'; -KW_MAPJOIN: 'MAPJOIN'; -KW_STREAMTABLE: 'STREAMTABLE'; -KW_CLUSTERSTATUS: 'CLUSTERSTATUS'; -KW_UTC: 'UTC'; -KW_UTCTIMESTAMP: 'UTC_TMESTAMP'; -KW_LONG: 'LONG'; -KW_DELETE: 'DELETE'; -KW_PLUS: 'PLUS'; -KW_MINUS: 'MINUS'; -KW_FETCH: 'FETCH'; -KW_INTERSECT: 'INTERSECT'; -KW_VIEW: 'VIEW'; -KW_IN: 'IN'; -KW_DATABASE: 'DATABASE'; -KW_DATABASES: 'DATABASES'; -KW_MATERIALIZED: 'MATERIALIZED'; -KW_SCHEMA: 'SCHEMA'; -KW_SCHEMAS: 'SCHEMAS'; -KW_GRANT: 'GRANT'; -KW_REVOKE: 'REVOKE'; -KW_SSL: 'SSL'; -KW_UNDO: 'UNDO'; -KW_LOCK: 'LOCK'; -KW_LOCKS: 'LOCKS'; -KW_UNLOCK: 'UNLOCK'; -KW_SHARED: 'SHARED'; -KW_EXCLUSIVE: 'EXCLUSIVE'; -KW_PROCEDURE: 'PROCEDURE'; -KW_UNSIGNED: 'UNSIGNED'; -KW_WHILE: 'WHILE'; -KW_READ: 'READ'; -KW_READS: 'READS'; -KW_PURGE: 'PURGE'; -KW_RANGE: 'RANGE'; -KW_ANALYZE: 'ANALYZE'; -KW_BEFORE: 'BEFORE'; -KW_BETWEEN: 'BETWEEN'; -KW_BOTH: 'BOTH'; -KW_BINARY: 'BINARY'; -KW_CROSS: 'CROSS'; -KW_CONTINUE: 'CONTINUE'; -KW_CURSOR: 'CURSOR'; -KW_TRIGGER: 'TRIGGER'; -KW_RECORDREADER: 'RECORDREADER'; -KW_RECORDWRITER: 'RECORDWRITER'; -KW_SEMI: 'SEMI'; -KW_LATERAL: 'LATERAL'; -KW_TOUCH: 'TOUCH'; -KW_ARCHIVE: 'ARCHIVE'; -KW_UNARCHIVE: 'UNARCHIVE'; -KW_COMPUTE: 'COMPUTE'; -KW_STATISTICS: 'STATISTICS'; -KW_USE: 'USE'; -KW_OPTION: 'OPTION'; -KW_CONCATENATE: 'CONCATENATE'; -KW_SHOW_DATABASE: 'SHOW_DATABASE'; -KW_UPDATE: 'UPDATE'; -KW_RESTRICT: 'RESTRICT'; -KW_CASCADE: 'CASCADE'; -KW_SKEWED: 'SKEWED'; -KW_ROLLUP: 'ROLLUP'; -KW_CUBE: 'CUBE'; -KW_DIRECTORIES: 'DIRECTORIES'; -KW_FOR: 'FOR'; -KW_WINDOW: 'WINDOW'; -KW_UNBOUNDED: 'UNBOUNDED'; -KW_PRECEDING: 'PRECEDING'; -KW_FOLLOWING: 'FOLLOWING'; -KW_CURRENT: 'CURRENT'; -KW_CURRENT_DATE: 'CURRENT_DATE'; -KW_CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; -KW_LESS: 'LESS'; -KW_MORE: 'MORE'; -KW_OVER: 'OVER'; -KW_GROUPING: 'GROUPING'; -KW_SETS: 'SETS'; -KW_TRUNCATE: 'TRUNCATE'; -KW_NOSCAN: 'NOSCAN'; -KW_PARTIALSCAN: 'PARTIALSCAN'; -KW_USER: 'USER'; -KW_ROLE: 'ROLE'; -KW_ROLES: 'ROLES'; -KW_INNER: 'INNER'; -KW_EXCHANGE: 'EXCHANGE'; -KW_URI: 'URI'; -KW_SERVER : 'SERVER'; -KW_ADMIN: 'ADMIN'; -KW_OWNER: 'OWNER'; -KW_PRINCIPALS: 'PRINCIPALS'; -KW_COMPACT: 'COMPACT'; -KW_COMPACTIONS: 'COMPACTIONS'; -KW_TRANSACTIONS: 'TRANSACTIONS'; -KW_REWRITE : 'REWRITE'; -KW_AUTHORIZATION: 'AUTHORIZATION'; -KW_CONF: 'CONF'; -KW_VALUES: 'VALUES'; -KW_RELOAD: 'RELOAD'; -KW_YEAR: 'YEAR'|'YEARS'; -KW_MONTH: 'MONTH'|'MONTHS'; -KW_DAY: 'DAY'|'DAYS'; -KW_HOUR: 'HOUR'|'HOURS'; -KW_MINUTE: 'MINUTE'|'MINUTES'; -KW_SECOND: 'SECOND'|'SECONDS'; -KW_START: 'START'; -KW_TRANSACTION: 'TRANSACTION'; -KW_COMMIT: 'COMMIT'; -KW_ROLLBACK: 'ROLLBACK'; -KW_WORK: 'WORK'; -KW_ONLY: 'ONLY'; -KW_WRITE: 'WRITE'; -KW_ISOLATION: 'ISOLATION'; -KW_LEVEL: 'LEVEL'; -KW_SNAPSHOT: 'SNAPSHOT'; -KW_AUTOCOMMIT: 'AUTOCOMMIT'; -KW_REFRESH: 'REFRESH'; -KW_OPTIONS: 'OPTIONS'; -KW_WEEK: 'WEEK'|'WEEKS'; -KW_MILLISECOND: 'MILLISECOND'|'MILLISECONDS'; -KW_MICROSECOND: 'MICROSECOND'|'MICROSECONDS'; -KW_CLEAR: 'CLEAR'; -KW_LAZY: 'LAZY'; -KW_CACHE: 'CACHE'; -KW_UNCACHE: 'UNCACHE'; -KW_DFS: 'DFS'; - -KW_NATURAL: 'NATURAL'; - -// Operators -// NOTE: if you add a new function/operator, add it to sysFuncNames so that describe function _FUNC_ will work. - -DOT : '.'; // generated as a part of Number rule -COLON : ':' ; -COMMA : ',' ; -SEMICOLON : ';' ; - -LPAREN : '(' ; -RPAREN : ')' ; -LSQUARE : '[' ; -RSQUARE : ']' ; -LCURLY : '{'; -RCURLY : '}'; - -EQUAL : '=' | '=='; -EQUAL_NS : '<=>'; -NOTEQUAL : '<>' | '!='; -LESSTHANOREQUALTO : '<='; -LESSTHAN : '<'; -GREATERTHANOREQUALTO : '>='; -GREATERTHAN : '>'; - -DIVIDE : '/'; -PLUS : '+'; -MINUS : '-'; -STAR : '*'; -MOD : '%'; -DIV : 'DIV'; - -AMPERSAND : '&'; -TILDE : '~'; -BITWISEOR : '|'; -BITWISEXOR : '^'; -QUESTION : '?'; -DOLLAR : '$'; - -// LITERALS -fragment -Letter - : 'a'..'z' | 'A'..'Z' - ; - -fragment -HexDigit - : 'a'..'f' | 'A'..'F' - ; - -fragment -Digit - : - '0'..'9' - ; - -fragment -Exponent - : - ('e' | 'E') ( PLUS|MINUS )? (Digit)+ - ; - -fragment -RegexComponent - : 'a'..'z' | 'A'..'Z' | '0'..'9' | '_' - | PLUS | STAR | QUESTION | MINUS | DOT - | LPAREN | RPAREN | LSQUARE | RSQUARE | LCURLY | RCURLY - | BITWISEXOR | BITWISEOR | DOLLAR | '!' - ; - -StringLiteral - : - ( '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' - )+ - ; - -BigintLiteral - : - (Digit)+ 'L' - ; - -SmallintLiteral - : - (Digit)+ 'S' - ; - -TinyintLiteral - : - (Digit)+ 'Y' - ; - -DoubleLiteral - : - Number 'D' - ; - -ByteLengthLiteral - : - (Digit)+ ('b' | 'B' | 'k' | 'K' | 'm' | 'M' | 'g' | 'G') - ; - -Number - : - ((Digit+ (DOT Digit*)?) | (DOT Digit+)) Exponent? - ; - -/* -An Identifier can be: -- tableName -- columnName -- select expr alias -- lateral view aliases -- database name -- view name -- subquery alias -- function name -- ptf argument identifier -- index name -- property name for: db,tbl,partition... -- fileFormat -- role name -- privilege name -- principal name -- macro name -- hint name -- window name -*/ -Identifier - : - (Letter | Digit | '_')+ - | {allowQuotedId()}? QuotedIdentifier /* though at the language level we allow all Identifiers to be QuotedIdentifiers; - at the API level only columns are allowed to be of this form */ - | '`' RegexComponent+ '`' - ; - -fragment -QuotedIdentifier - : - '`' ( '``' | ~('`') )* '`' { setText(getText().replaceAll("``", "`")); } - ; - -WS : (' '|'\r'|'\t'|'\n') {$channel=HIDDEN;} - ; - -COMMENT - : '--' (~('\n'|'\r'))* - { $channel=HIDDEN; } - ; - -/* Prevent that the lexer swallows unknown characters. */ -ANY - :. - ; diff --git a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g b/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g deleted file mode 100644 index f0c236859d..0000000000 --- a/sql/catalyst/src/main/antlr3/org/apache/spark/sql/catalyst/parser/SparkSqlParser.g +++ /dev/null @@ -1,2596 +0,0 @@ -/** - Licensed to the Apache Software Foundation (ASF) under one or more - contributor license agreements. See the NOTICE file distributed with - this work for additional information regarding copyright ownership. - The ASF licenses this file to You under the Apache License, Version 2.0 - (the "License"); you may not use this file except in compliance with - the License. You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - This file is an adaptation of Hive's org/apache/hadoop/hive/ql/HiveParser.g grammar. -*/ -parser grammar SparkSqlParser; - -options -{ -tokenVocab=SparkSqlLexer; -output=AST; -ASTLabelType=CommonTree; -backtrack=false; -k=3; -} -import SelectClauseParser, FromClauseParser, IdentifiersParser, KeywordParser, ExpressionParser; - -tokens { -TOK_INSERT; -TOK_QUERY; -TOK_SELECT; -TOK_SELECTDI; -TOK_SELEXPR; -TOK_FROM; -TOK_TAB; -TOK_PARTSPEC; -TOK_PARTVAL; -TOK_DIR; -TOK_TABREF; -TOK_SUBQUERY; -TOK_INSERT_INTO; -TOK_DESTINATION; -TOK_ALLCOLREF; -TOK_TABLE_OR_COL; -TOK_FUNCTION; -TOK_FUNCTIONDI; -TOK_FUNCTIONSTAR; -TOK_WHERE; -TOK_OP_EQ; -TOK_OP_NE; -TOK_OP_LE; -TOK_OP_LT; -TOK_OP_GE; -TOK_OP_GT; -TOK_OP_DIV; -TOK_OP_ADD; -TOK_OP_SUB; -TOK_OP_MUL; -TOK_OP_MOD; -TOK_OP_BITAND; -TOK_OP_BITNOT; -TOK_OP_BITOR; -TOK_OP_BITXOR; -TOK_OP_AND; -TOK_OP_OR; -TOK_OP_NOT; -TOK_OP_LIKE; -TOK_TRUE; -TOK_FALSE; -TOK_TRANSFORM; -TOK_SERDE; -TOK_SERDENAME; -TOK_SERDEPROPS; -TOK_EXPLIST; -TOK_ALIASLIST; -TOK_GROUPBY; -TOK_ROLLUP_GROUPBY; -TOK_CUBE_GROUPBY; -TOK_GROUPING_SETS; -TOK_GROUPING_SETS_EXPRESSION; -TOK_HAVING; -TOK_ORDERBY; -TOK_CLUSTERBY; -TOK_DISTRIBUTEBY; -TOK_SORTBY; -TOK_UNIONALL; -TOK_UNIONDISTINCT; -TOK_EXCEPT; -TOK_INTERSECT; -TOK_JOIN; -TOK_LEFTOUTERJOIN; -TOK_RIGHTOUTERJOIN; -TOK_FULLOUTERJOIN; -TOK_UNIQUEJOIN; -TOK_CROSSJOIN; -TOK_NATURALJOIN; -TOK_NATURALLEFTOUTERJOIN; -TOK_NATURALRIGHTOUTERJOIN; -TOK_NATURALFULLOUTERJOIN; -TOK_LOAD; -TOK_EXPORT; -TOK_IMPORT; -TOK_REPLICATION; -TOK_METADATA; -TOK_NULL; -TOK_ISNULL; -TOK_ISNOTNULL; -TOK_TINYINT; -TOK_SMALLINT; -TOK_INT; -TOK_BIGINT; -TOK_BOOLEAN; -TOK_FLOAT; -TOK_DOUBLE; -TOK_DATE; -TOK_DATELITERAL; -TOK_DATETIME; -TOK_TIMESTAMP; -TOK_TIMESTAMPLITERAL; -TOK_INTERVAL; -TOK_INTERVAL_YEAR_MONTH; -TOK_INTERVAL_YEAR_MONTH_LITERAL; -TOK_INTERVAL_DAY_TIME; -TOK_INTERVAL_DAY_TIME_LITERAL; -TOK_INTERVAL_YEAR_LITERAL; -TOK_INTERVAL_MONTH_LITERAL; -TOK_INTERVAL_WEEK_LITERAL; -TOK_INTERVAL_DAY_LITERAL; -TOK_INTERVAL_HOUR_LITERAL; -TOK_INTERVAL_MINUTE_LITERAL; -TOK_INTERVAL_SECOND_LITERAL; -TOK_INTERVAL_MILLISECOND_LITERAL; -TOK_INTERVAL_MICROSECOND_LITERAL; -TOK_STRING; -TOK_CHAR; -TOK_VARCHAR; -TOK_BINARY; -TOK_DECIMAL; -TOK_LIST; -TOK_STRUCT; -TOK_MAP; -TOK_UNIONTYPE; -TOK_COLTYPELIST; -TOK_CREATEDATABASE; -TOK_CREATETABLE; -TOK_CREATETABLEUSING; -TOK_TRUNCATETABLE; -TOK_CREATEINDEX; -TOK_CREATEINDEX_INDEXTBLNAME; -TOK_DEFERRED_REBUILDINDEX; -TOK_DROPINDEX; -TOK_LIKETABLE; -TOK_DESCTABLE; -TOK_DESCFUNCTION; -TOK_ALTERTABLE; -TOK_ALTERTABLE_RENAME; -TOK_ALTERTABLE_ADDCOLS; -TOK_ALTERTABLE_RENAMECOL; -TOK_ALTERTABLE_RENAMEPART; -TOK_ALTERTABLE_REPLACECOLS; -TOK_ALTERTABLE_ADDPARTS; -TOK_ALTERTABLE_DROPPARTS; -TOK_ALTERTABLE_PARTCOLTYPE; -TOK_ALTERTABLE_MERGEFILES; -TOK_ALTERTABLE_TOUCH; -TOK_ALTERTABLE_ARCHIVE; -TOK_ALTERTABLE_UNARCHIVE; -TOK_ALTERTABLE_SERDEPROPERTIES; -TOK_ALTERTABLE_SERIALIZER; -TOK_ALTERTABLE_UPDATECOLSTATS; -TOK_TABLE_PARTITION; -TOK_ALTERTABLE_FILEFORMAT; -TOK_ALTERTABLE_LOCATION; -TOK_ALTERTABLE_PROPERTIES; -TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION; -TOK_ALTERTABLE_DROPPROPERTIES; -TOK_ALTERTABLE_SKEWED; -TOK_ALTERTABLE_EXCHANGEPARTITION; -TOK_ALTERTABLE_SKEWED_LOCATION; -TOK_ALTERTABLE_BUCKETS; -TOK_ALTERTABLE_CLUSTER_SORT; -TOK_ALTERTABLE_COMPACT; -TOK_ALTERINDEX_REBUILD; -TOK_ALTERINDEX_PROPERTIES; -TOK_MSCK; -TOK_SHOWDATABASES; -TOK_SHOWTABLES; -TOK_SHOWCOLUMNS; -TOK_SHOWFUNCTIONS; -TOK_SHOWPARTITIONS; -TOK_SHOW_CREATEDATABASE; -TOK_SHOW_CREATETABLE; -TOK_SHOW_TABLESTATUS; -TOK_SHOW_TBLPROPERTIES; -TOK_SHOWLOCKS; -TOK_SHOWCONF; -TOK_LOCKTABLE; -TOK_UNLOCKTABLE; -TOK_LOCKDB; -TOK_UNLOCKDB; -TOK_SWITCHDATABASE; -TOK_DROPDATABASE; -TOK_DROPTABLE; -TOK_DATABASECOMMENT; -TOK_TABCOLLIST; -TOK_TABCOL; -TOK_TABLECOMMENT; -TOK_TABLEPARTCOLS; -TOK_TABLEROWFORMAT; -TOK_TABLEROWFORMATFIELD; -TOK_TABLEROWFORMATCOLLITEMS; -TOK_TABLEROWFORMATMAPKEYS; -TOK_TABLEROWFORMATLINES; -TOK_TABLEROWFORMATNULL; -TOK_TABLEFILEFORMAT; -TOK_FILEFORMAT_GENERIC; -TOK_OFFLINE; -TOK_ENABLE; -TOK_DISABLE; -TOK_READONLY; -TOK_NO_DROP; -TOK_STORAGEHANDLER; -TOK_NOT_CLUSTERED; -TOK_NOT_SORTED; -TOK_TABCOLNAME; -TOK_TABLELOCATION; -TOK_PARTITIONLOCATION; -TOK_TABLEBUCKETSAMPLE; -TOK_TABLESPLITSAMPLE; -TOK_PERCENT; -TOK_LENGTH; -TOK_ROWCOUNT; -TOK_TMP_FILE; -TOK_TABSORTCOLNAMEASC; -TOK_TABSORTCOLNAMEDESC; -TOK_STRINGLITERALSEQUENCE; -TOK_CREATEFUNCTION; -TOK_DROPFUNCTION; -TOK_RELOADFUNCTION; -TOK_CREATEMACRO; -TOK_DROPMACRO; -TOK_TEMPORARY; -TOK_CREATEVIEW; -TOK_DROPVIEW; -TOK_ALTERVIEW; -TOK_ALTERVIEW_PROPERTIES; -TOK_ALTERVIEW_DROPPROPERTIES; -TOK_ALTERVIEW_ADDPARTS; -TOK_ALTERVIEW_DROPPARTS; -TOK_ALTERVIEW_RENAME; -TOK_VIEWPARTCOLS; -TOK_EXPLAIN; -TOK_EXPLAIN_SQ_REWRITE; -TOK_TABLESERIALIZER; -TOK_TABLEPROPERTIES; -TOK_TABLEPROPLIST; -TOK_INDEXPROPERTIES; -TOK_INDEXPROPLIST; -TOK_TABTYPE; -TOK_LIMIT; -TOK_TABLEPROPERTY; -TOK_IFEXISTS; -TOK_IFNOTEXISTS; -TOK_ORREPLACE; -TOK_HINTLIST; -TOK_HINT; -TOK_MAPJOIN; -TOK_STREAMTABLE; -TOK_HINTARGLIST; -TOK_USERSCRIPTCOLNAMES; -TOK_USERSCRIPTCOLSCHEMA; -TOK_RECORDREADER; -TOK_RECORDWRITER; -TOK_LEFTSEMIJOIN; -TOK_ANTIJOIN; -TOK_LATERAL_VIEW; -TOK_LATERAL_VIEW_OUTER; -TOK_TABALIAS; -TOK_ANALYZE; -TOK_CREATEROLE; -TOK_DROPROLE; -TOK_GRANT; -TOK_REVOKE; -TOK_SHOW_GRANT; -TOK_PRIVILEGE_LIST; -TOK_PRIVILEGE; -TOK_PRINCIPAL_NAME; -TOK_USER; -TOK_GROUP; -TOK_ROLE; -TOK_RESOURCE_ALL; -TOK_GRANT_WITH_OPTION; -TOK_GRANT_WITH_ADMIN_OPTION; -TOK_ADMIN_OPTION_FOR; -TOK_GRANT_OPTION_FOR; -TOK_PRIV_ALL; -TOK_PRIV_ALTER_METADATA; -TOK_PRIV_ALTER_DATA; -TOK_PRIV_DELETE; -TOK_PRIV_DROP; -TOK_PRIV_INDEX; -TOK_PRIV_INSERT; -TOK_PRIV_LOCK; -TOK_PRIV_SELECT; -TOK_PRIV_SHOW_DATABASE; -TOK_PRIV_CREATE; -TOK_PRIV_OBJECT; -TOK_PRIV_OBJECT_COL; -TOK_GRANT_ROLE; -TOK_REVOKE_ROLE; -TOK_SHOW_ROLE_GRANT; -TOK_SHOW_ROLES; -TOK_SHOW_SET_ROLE; -TOK_SHOW_ROLE_PRINCIPALS; -TOK_SHOWINDEXES; -TOK_SHOWDBLOCKS; -TOK_INDEXCOMMENT; -TOK_DESCDATABASE; -TOK_DATABASEPROPERTIES; -TOK_DATABASELOCATION; -TOK_DBPROPLIST; -TOK_ALTERDATABASE_PROPERTIES; -TOK_ALTERDATABASE_OWNER; -TOK_TABNAME; -TOK_TABSRC; -TOK_RESTRICT; -TOK_CASCADE; -TOK_TABLESKEWED; -TOK_TABCOLVALUE; -TOK_TABCOLVALUE_PAIR; -TOK_TABCOLVALUES; -TOK_SKEWED_LOCATIONS; -TOK_SKEWED_LOCATION_LIST; -TOK_SKEWED_LOCATION_MAP; -TOK_STOREDASDIRS; -TOK_PARTITIONINGSPEC; -TOK_PTBLFUNCTION; -TOK_WINDOWDEF; -TOK_WINDOWSPEC; -TOK_WINDOWVALUES; -TOK_WINDOWRANGE; -TOK_SUBQUERY_EXPR; -TOK_SUBQUERY_OP; -TOK_SUBQUERY_OP_NOTIN; -TOK_SUBQUERY_OP_NOTEXISTS; -TOK_DB_TYPE; -TOK_TABLE_TYPE; -TOK_CTE; -TOK_ARCHIVE; -TOK_FILE; -TOK_JAR; -TOK_RESOURCE_URI; -TOK_RESOURCE_LIST; -TOK_SHOW_COMPACTIONS; -TOK_SHOW_TRANSACTIONS; -TOK_DELETE_FROM; -TOK_UPDATE_TABLE; -TOK_SET_COLUMNS_CLAUSE; -TOK_VALUE_ROW; -TOK_VALUES_TABLE; -TOK_VIRTUAL_TABLE; -TOK_VIRTUAL_TABREF; -TOK_ANONYMOUS; -TOK_COL_NAME; -TOK_URI_TYPE; -TOK_SERVER_TYPE; -TOK_START_TRANSACTION; -TOK_ISOLATION_LEVEL; -TOK_ISOLATION_SNAPSHOT; -TOK_TXN_ACCESS_MODE; -TOK_TXN_READ_ONLY; -TOK_TXN_READ_WRITE; -TOK_COMMIT; -TOK_ROLLBACK; -TOK_SET_AUTOCOMMIT; -TOK_REFRESHTABLE; -TOK_TABLEPROVIDER; -TOK_TABLEOPTIONS; -TOK_TABLEOPTION; -TOK_CACHETABLE; -TOK_UNCACHETABLE; -TOK_CLEARCACHE; -TOK_SETCONFIG; -TOK_DFS; -TOK_ADDFILE; -TOK_ADDJAR; -TOK_USING; -} - - -// Package headers -@header { -package org.apache.spark.sql.catalyst.parser; - -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -} - - -@members { - Stack msgs = new Stack(); - - private static HashMap xlateMap; - static { - //this is used to support auto completion in CLI - xlateMap = new HashMap(); - - // Keywords - xlateMap.put("KW_TRUE", "TRUE"); - xlateMap.put("KW_FALSE", "FALSE"); - xlateMap.put("KW_ALL", "ALL"); - xlateMap.put("KW_NONE", "NONE"); - xlateMap.put("KW_AND", "AND"); - xlateMap.put("KW_OR", "OR"); - xlateMap.put("KW_NOT", "NOT"); - xlateMap.put("KW_LIKE", "LIKE"); - - xlateMap.put("KW_ASC", "ASC"); - xlateMap.put("KW_DESC", "DESC"); - xlateMap.put("KW_ORDER", "ORDER"); - xlateMap.put("KW_BY", "BY"); - xlateMap.put("KW_GROUP", "GROUP"); - xlateMap.put("KW_WHERE", "WHERE"); - xlateMap.put("KW_FROM", "FROM"); - xlateMap.put("KW_AS", "AS"); - xlateMap.put("KW_SELECT", "SELECT"); - xlateMap.put("KW_DISTINCT", "DISTINCT"); - xlateMap.put("KW_INSERT", "INSERT"); - xlateMap.put("KW_OVERWRITE", "OVERWRITE"); - xlateMap.put("KW_OUTER", "OUTER"); - xlateMap.put("KW_JOIN", "JOIN"); - xlateMap.put("KW_LEFT", "LEFT"); - xlateMap.put("KW_RIGHT", "RIGHT"); - xlateMap.put("KW_FULL", "FULL"); - xlateMap.put("KW_ON", "ON"); - xlateMap.put("KW_PARTITION", "PARTITION"); - xlateMap.put("KW_PARTITIONS", "PARTITIONS"); - xlateMap.put("KW_TABLE", "TABLE"); - xlateMap.put("KW_TABLES", "TABLES"); - xlateMap.put("KW_TBLPROPERTIES", "TBLPROPERTIES"); - xlateMap.put("KW_SHOW", "SHOW"); - xlateMap.put("KW_MSCK", "MSCK"); - xlateMap.put("KW_DIRECTORY", "DIRECTORY"); - xlateMap.put("KW_LOCAL", "LOCAL"); - xlateMap.put("KW_TRANSFORM", "TRANSFORM"); - xlateMap.put("KW_USING", "USING"); - xlateMap.put("KW_CLUSTER", "CLUSTER"); - xlateMap.put("KW_DISTRIBUTE", "DISTRIBUTE"); - xlateMap.put("KW_SORT", "SORT"); - xlateMap.put("KW_UNION", "UNION"); - xlateMap.put("KW_LOAD", "LOAD"); - xlateMap.put("KW_DATA", "DATA"); - xlateMap.put("KW_INPATH", "INPATH"); - xlateMap.put("KW_IS", "IS"); - xlateMap.put("KW_NULL", "NULL"); - xlateMap.put("KW_CREATE", "CREATE"); - xlateMap.put("KW_EXTERNAL", "EXTERNAL"); - xlateMap.put("KW_ALTER", "ALTER"); - xlateMap.put("KW_DESCRIBE", "DESCRIBE"); - xlateMap.put("KW_DROP", "DROP"); - xlateMap.put("KW_RENAME", "RENAME"); - xlateMap.put("KW_TO", "TO"); - xlateMap.put("KW_COMMENT", "COMMENT"); - xlateMap.put("KW_BOOLEAN", "BOOLEAN"); - xlateMap.put("KW_TINYINT", "TINYINT"); - xlateMap.put("KW_SMALLINT", "SMALLINT"); - xlateMap.put("KW_INT", "INT"); - xlateMap.put("KW_BIGINT", "BIGINT"); - xlateMap.put("KW_FLOAT", "FLOAT"); - xlateMap.put("KW_DOUBLE", "DOUBLE"); - xlateMap.put("KW_DATE", "DATE"); - xlateMap.put("KW_DATETIME", "DATETIME"); - xlateMap.put("KW_TIMESTAMP", "TIMESTAMP"); - xlateMap.put("KW_STRING", "STRING"); - xlateMap.put("KW_BINARY", "BINARY"); - xlateMap.put("KW_ARRAY", "ARRAY"); - xlateMap.put("KW_MAP", "MAP"); - xlateMap.put("KW_REDUCE", "REDUCE"); - xlateMap.put("KW_PARTITIONED", "PARTITIONED"); - xlateMap.put("KW_CLUSTERED", "CLUSTERED"); - xlateMap.put("KW_SORTED", "SORTED"); - xlateMap.put("KW_INTO", "INTO"); - xlateMap.put("KW_BUCKETS", "BUCKETS"); - xlateMap.put("KW_ROW", "ROW"); - xlateMap.put("KW_FORMAT", "FORMAT"); - xlateMap.put("KW_DELIMITED", "DELIMITED"); - xlateMap.put("KW_FIELDS", "FIELDS"); - xlateMap.put("KW_TERMINATED", "TERMINATED"); - xlateMap.put("KW_COLLECTION", "COLLECTION"); - xlateMap.put("KW_ITEMS", "ITEMS"); - xlateMap.put("KW_KEYS", "KEYS"); - xlateMap.put("KW_KEY_TYPE", "\$KEY\$"); - xlateMap.put("KW_LINES", "LINES"); - xlateMap.put("KW_STORED", "STORED"); - xlateMap.put("KW_SEQUENCEFILE", "SEQUENCEFILE"); - xlateMap.put("KW_TEXTFILE", "TEXTFILE"); - xlateMap.put("KW_INPUTFORMAT", "INPUTFORMAT"); - xlateMap.put("KW_OUTPUTFORMAT", "OUTPUTFORMAT"); - xlateMap.put("KW_LOCATION", "LOCATION"); - xlateMap.put("KW_TABLESAMPLE", "TABLESAMPLE"); - xlateMap.put("KW_BUCKET", "BUCKET"); - xlateMap.put("KW_OUT", "OUT"); - xlateMap.put("KW_OF", "OF"); - xlateMap.put("KW_CAST", "CAST"); - xlateMap.put("KW_ADD", "ADD"); - xlateMap.put("KW_REPLACE", "REPLACE"); - xlateMap.put("KW_COLUMNS", "COLUMNS"); - xlateMap.put("KW_RLIKE", "RLIKE"); - xlateMap.put("KW_REGEXP", "REGEXP"); - xlateMap.put("KW_TEMPORARY", "TEMPORARY"); - xlateMap.put("KW_FUNCTION", "FUNCTION"); - xlateMap.put("KW_EXPLAIN", "EXPLAIN"); - xlateMap.put("KW_EXTENDED", "EXTENDED"); - xlateMap.put("KW_SERDE", "SERDE"); - xlateMap.put("KW_WITH", "WITH"); - xlateMap.put("KW_SERDEPROPERTIES", "SERDEPROPERTIES"); - xlateMap.put("KW_LIMIT", "LIMIT"); - xlateMap.put("KW_SET", "SET"); - xlateMap.put("KW_PROPERTIES", "TBLPROPERTIES"); - xlateMap.put("KW_VALUE_TYPE", "\$VALUE\$"); - xlateMap.put("KW_ELEM_TYPE", "\$ELEM\$"); - xlateMap.put("KW_DEFINED", "DEFINED"); - xlateMap.put("KW_SUBQUERY", "SUBQUERY"); - xlateMap.put("KW_REWRITE", "REWRITE"); - xlateMap.put("KW_UPDATE", "UPDATE"); - xlateMap.put("KW_VALUES", "VALUES"); - xlateMap.put("KW_PURGE", "PURGE"); - xlateMap.put("KW_WEEK", "WEEK"); - xlateMap.put("KW_MILLISECOND", "MILLISECOND"); - xlateMap.put("KW_MICROSECOND", "MICROSECOND"); - xlateMap.put("KW_CLEAR", "CLEAR"); - xlateMap.put("KW_LAZY", "LAZY"); - xlateMap.put("KW_CACHE", "CACHE"); - xlateMap.put("KW_UNCACHE", "UNCACHE"); - xlateMap.put("KW_DFS", "DFS"); - - // Operators - xlateMap.put("DOT", "."); - xlateMap.put("COLON", ":"); - xlateMap.put("COMMA", ","); - xlateMap.put("SEMICOLON", ");"); - - xlateMap.put("LPAREN", "("); - xlateMap.put("RPAREN", ")"); - xlateMap.put("LSQUARE", "["); - xlateMap.put("RSQUARE", "]"); - - xlateMap.put("EQUAL", "="); - xlateMap.put("NOTEQUAL", "<>"); - xlateMap.put("EQUAL_NS", "<=>"); - xlateMap.put("LESSTHANOREQUALTO", "<="); - xlateMap.put("LESSTHAN", "<"); - xlateMap.put("GREATERTHANOREQUALTO", ">="); - xlateMap.put("GREATERTHAN", ">"); - - xlateMap.put("DIVIDE", "/"); - xlateMap.put("PLUS", "+"); - xlateMap.put("MINUS", "-"); - xlateMap.put("STAR", "*"); - xlateMap.put("MOD", "\%"); - - xlateMap.put("AMPERSAND", "&"); - xlateMap.put("TILDE", "~"); - xlateMap.put("BITWISEOR", "|"); - xlateMap.put("BITWISEXOR", "^"); - xlateMap.put("CharSetLiteral", "\\'"); - } - - public static Collection getKeywords() { - return xlateMap.values(); - } - - private static String xlate(String name) { - - String ret = xlateMap.get(name); - if (ret == null) { - ret = name; - } - - return ret; - } - - @Override - public Object recoverFromMismatchedSet(IntStream input, - RecognitionException re, BitSet follow) throws RecognitionException { - throw re; - } - - @Override - public void displayRecognitionError(String[] tokenNames, RecognitionException e) { - if (reporter != null) { - reporter.report(this, e, tokenNames); - } - } - - @Override - public String getErrorHeader(RecognitionException e) { - String header = null; - if (e.charPositionInLine < 0 && input.LT(-1) != null) { - Token t = input.LT(-1); - header = "line " + t.getLine() + ":" + t.getCharPositionInLine(); - } else { - header = super.getErrorHeader(e); - } - - return header; - } - - @Override - public String getErrorMessage(RecognitionException e, String[] tokenNames) { - String msg = null; - - // Translate the token names to something that the user can understand - String[] xlateNames = new String[tokenNames.length]; - for (int i = 0; i < tokenNames.length; ++i) { - xlateNames[i] = SparkSqlParser.xlate(tokenNames[i]); - } - - if (e instanceof NoViableAltException) { - @SuppressWarnings("unused") - NoViableAltException nvae = (NoViableAltException) e; - // for development, can add - // "decision=<<"+nvae.grammarDecisionDescription+">>" - // and "(decision="+nvae.decisionNumber+") and - // "state "+nvae.stateNumber - msg = "cannot recognize input near" - + (input.LT(1) != null ? " " + getTokenErrorDisplay(input.LT(1)) : "") - + (input.LT(2) != null ? " " + getTokenErrorDisplay(input.LT(2)) : "") - + (input.LT(3) != null ? " " + getTokenErrorDisplay(input.LT(3)) : ""); - } else if (e instanceof MismatchedTokenException) { - MismatchedTokenException mte = (MismatchedTokenException) e; - msg = super.getErrorMessage(e, xlateNames) + (input.LT(-1) == null ? "":" near '" + input.LT(-1).getText()) + "'"; - } else if (e instanceof FailedPredicateException) { - FailedPredicateException fpe = (FailedPredicateException) e; - msg = "Failed to recognize predicate '" + fpe.token.getText() + "'. Failed rule: '" + fpe.ruleName + "'"; - } else { - msg = super.getErrorMessage(e, xlateNames); - } - - if (msgs.size() > 0) { - msg = msg + " in " + msgs.peek(); - } - return msg; - } - - public void pushMsg(String msg, RecognizerSharedState state) { - // ANTLR generated code does not wrap the @init code wit this backtracking check, - // even if the matching @after has it. If we have parser rules with that are doing - // some lookahead with syntactic predicates this can cause the push() and pop() calls - // to become unbalanced, so make sure both push/pop check the backtracking state. - if (state.backtracking == 0) { - msgs.push(msg); - } - } - - public void popMsg(RecognizerSharedState state) { - if (state.backtracking == 0) { - Object o = msgs.pop(); - } - } - - // counter to generate unique union aliases - private int aliasCounter; - private String generateUnionAlias() { - return "u_" + (++aliasCounter); - } - private char [] excludedCharForColumnName = {'.', ':'}; - private boolean containExcludedCharForCreateTableColumnName(String input) { - if (input.length() > 0) { - if (input.charAt(0) == '`' && input.charAt(input.length() - 1) == '`') { - // When column name is backquoted, we don't care about excluded chars. - return false; - } - } - for(char c : excludedCharForColumnName) { - if(input.indexOf(c)>-1) { - return true; - } - } - return false; - } - private CommonTree throwSetOpException() throws RecognitionException { - throw new FailedPredicateException(input, "orderByClause clusterByClause distributeByClause sortByClause limitClause can only be applied to the whole union.", ""); - } - private CommonTree throwColumnNameException() throws RecognitionException { - throw new FailedPredicateException(input, Arrays.toString(excludedCharForColumnName) + " can not be used in column name in create table statement.", ""); - } - - private ParserConf parserConf; - private ParseErrorReporter reporter; - - public void configure(ParserConf parserConf, ParseErrorReporter reporter) { - this.parserConf = parserConf; - this.reporter = reporter; - } - - protected boolean useSQL11ReservedKeywordsForIdentifier() { - if (parserConf == null) { - return true; - } - return !parserConf.supportSQL11ReservedKeywords(); - } -} - -@rulecatch { -catch (RecognitionException e) { - reportError(e); - throw e; -} -} - -// starting rule -statement - : explainStatement EOF - | execStatement EOF - | KW_ADD KW_JAR -> ^(TOK_ADDJAR) - | KW_ADD KW_FILE -> ^(TOK_ADDFILE) - | KW_DFS -> ^(TOK_DFS) - | (KW_SET)=> KW_SET -> ^(TOK_SETCONFIG) - ; - -// Rule for expression parsing -singleNamedExpression - : - namedExpression EOF - ; - -// Rule for table name parsing -singleTableName - : - tableName EOF - ; - -explainStatement -@init { pushMsg("explain statement", state); } -@after { popMsg(state); } - : KW_EXPLAIN ( - explainOption* execStatement -> ^(TOK_EXPLAIN execStatement explainOption*) - | - KW_REWRITE queryStatementExpression[true] -> ^(TOK_EXPLAIN_SQ_REWRITE queryStatementExpression)) - ; - -explainOption -@init { msgs.push("explain option"); } -@after { msgs.pop(); } - : KW_EXTENDED|KW_FORMATTED|KW_DEPENDENCY|KW_LOGICAL|KW_AUTHORIZATION - ; - -execStatement -@init { pushMsg("statement", state); } -@after { popMsg(state); } - : queryStatementExpression[true] - | loadStatement - | exportStatement - | importStatement - | ddlStatement - | deleteStatement - | updateStatement - | sqlTransactionStatement - | cacheStatement - ; - -loadStatement -@init { pushMsg("load statement", state); } -@after { popMsg(state); } - : KW_LOAD KW_DATA (islocal=KW_LOCAL)? KW_INPATH (path=StringLiteral) (isoverwrite=KW_OVERWRITE)? KW_INTO KW_TABLE (tab=tableOrPartition) - -> ^(TOK_LOAD $path $tab $islocal? $isoverwrite?) - ; - -replicationClause -@init { pushMsg("replication clause", state); } -@after { popMsg(state); } - : KW_FOR (isMetadataOnly=KW_METADATA)? KW_REPLICATION LPAREN (replId=StringLiteral) RPAREN - -> ^(TOK_REPLICATION $replId $isMetadataOnly?) - ; - -exportStatement -@init { pushMsg("export statement", state); } -@after { popMsg(state); } - : KW_EXPORT - KW_TABLE (tab=tableOrPartition) - KW_TO (path=StringLiteral) - replicationClause? - -> ^(TOK_EXPORT $tab $path replicationClause?) - ; - -importStatement -@init { pushMsg("import statement", state); } -@after { popMsg(state); } - : KW_IMPORT - ((ext=KW_EXTERNAL)? KW_TABLE (tab=tableOrPartition))? - KW_FROM (path=StringLiteral) - tableLocation? - -> ^(TOK_IMPORT $path $tab? $ext? tableLocation?) - ; - -ddlStatement -@init { pushMsg("ddl statement", state); } -@after { popMsg(state); } - : createDatabaseStatement - | switchDatabaseStatement - | dropDatabaseStatement - | createTableStatement - | dropTableStatement - | truncateTableStatement - | alterStatement - | descStatement - | refreshStatement - | showStatement - | metastoreCheck - | createViewStatement - | dropViewStatement - | createFunctionStatement - | createMacroStatement - | createIndexStatement - | dropIndexStatement - | dropFunctionStatement - | reloadFunctionStatement - | dropMacroStatement - | analyzeStatement - | lockStatement - | unlockStatement - | lockDatabase - | unlockDatabase - | createRoleStatement - | dropRoleStatement - | (grantPrivileges) => grantPrivileges - | (revokePrivileges) => revokePrivileges - | showGrants - | showRoleGrants - | showRolePrincipals - | showRoles - | grantRole - | revokeRole - | setRole - | showCurrentRole - ; - -ifExists -@init { pushMsg("if exists clause", state); } -@after { popMsg(state); } - : KW_IF KW_EXISTS - -> ^(TOK_IFEXISTS) - ; - -restrictOrCascade -@init { pushMsg("restrict or cascade clause", state); } -@after { popMsg(state); } - : KW_RESTRICT - -> ^(TOK_RESTRICT) - | KW_CASCADE - -> ^(TOK_CASCADE) - ; - -ifNotExists -@init { pushMsg("if not exists clause", state); } -@after { popMsg(state); } - : KW_IF KW_NOT KW_EXISTS - -> ^(TOK_IFNOTEXISTS) - ; - -storedAsDirs -@init { pushMsg("stored as directories", state); } -@after { popMsg(state); } - : KW_STORED KW_AS KW_DIRECTORIES - -> ^(TOK_STOREDASDIRS) - ; - -orReplace -@init { pushMsg("or replace clause", state); } -@after { popMsg(state); } - : KW_OR KW_REPLACE - -> ^(TOK_ORREPLACE) - ; - -createDatabaseStatement -@init { pushMsg("create database statement", state); } -@after { popMsg(state); } - : KW_CREATE (KW_DATABASE|KW_SCHEMA) - ifNotExists? - name=identifier - databaseComment? - dbLocation? - (KW_WITH KW_DBPROPERTIES dbprops=dbProperties)? - -> ^(TOK_CREATEDATABASE $name ifNotExists? dbLocation? databaseComment? $dbprops?) - ; - -dbLocation -@init { pushMsg("database location specification", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_DATABASELOCATION $locn) - ; - -dbProperties -@init { pushMsg("dbproperties", state); } -@after { popMsg(state); } - : - LPAREN dbPropertiesList RPAREN -> ^(TOK_DATABASEPROPERTIES dbPropertiesList) - ; - -dbPropertiesList -@init { pushMsg("database properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_DBPROPLIST keyValueProperty+) - ; - - -switchDatabaseStatement -@init { pushMsg("switch database statement", state); } -@after { popMsg(state); } - : KW_USE identifier - -> ^(TOK_SWITCHDATABASE identifier) - ; - -dropDatabaseStatement -@init { pushMsg("drop database statement", state); } -@after { popMsg(state); } - : KW_DROP (KW_DATABASE|KW_SCHEMA) ifExists? identifier restrictOrCascade? - -> ^(TOK_DROPDATABASE identifier ifExists? restrictOrCascade?) - ; - -databaseComment -@init { pushMsg("database's comment", state); } -@after { popMsg(state); } - : KW_COMMENT comment=StringLiteral - -> ^(TOK_DATABASECOMMENT $comment) - ; - -createTableStatement -@init { pushMsg("create table statement", state); } -@after { popMsg(state); } - : KW_CREATE (temp=KW_TEMPORARY)? (ext=KW_EXTERNAL)? KW_TABLE ifNotExists? name=tableName - ( - like=KW_LIKE likeName=tableName - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - -> ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? - ^(TOK_LIKETABLE $likeName?) - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - ) - | - (tableProvider) => tableProvider - tableOpts? - (KW_AS selectStatementWithCTE)? - -> ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? - tableProvider - tableOpts? - selectStatementWithCTE? - ) - | (LPAREN columnNameTypeList RPAREN)? - (p=tableProvider?) - tableOpts? - tableComment? - tablePartition? - tableBuckets? - tableSkewed? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - (KW_AS selectStatementWithCTE)? - -> {p != null}? - ^(TOK_CREATETABLEUSING $name $temp? ifNotExists? - columnNameTypeList? - $p - tableOpts? - selectStatementWithCTE? - ) - -> - ^(TOK_CREATETABLE $name $temp? $ext? ifNotExists? - ^(TOK_LIKETABLE $likeName?) - columnNameTypeList? - tableComment? - tablePartition? - tableBuckets? - tableSkewed? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - selectStatementWithCTE? - ) - ) - ; - -truncateTableStatement -@init { pushMsg("truncate table statement", state); } -@after { popMsg(state); } - : KW_TRUNCATE KW_TABLE tablePartitionPrefix (KW_COLUMNS LPAREN columnNameList RPAREN)? -> ^(TOK_TRUNCATETABLE tablePartitionPrefix columnNameList?); - -createIndexStatement -@init { pushMsg("create index statement", state);} -@after {popMsg(state);} - : KW_CREATE KW_INDEX indexName=identifier - KW_ON KW_TABLE tab=tableName LPAREN indexedCols=columnNameList RPAREN - KW_AS typeName=StringLiteral - autoRebuild? - indexPropertiesPrefixed? - indexTblName? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - indexComment? - ->^(TOK_CREATEINDEX $indexName $typeName $tab $indexedCols - autoRebuild? - indexPropertiesPrefixed? - indexTblName? - tableRowFormat? - tableFileFormat? - tableLocation? - tablePropertiesPrefixed? - indexComment?) - ; - -indexComment -@init { pushMsg("comment on an index", state);} -@after {popMsg(state);} - : - KW_COMMENT comment=StringLiteral -> ^(TOK_INDEXCOMMENT $comment) - ; - -autoRebuild -@init { pushMsg("auto rebuild index", state);} -@after {popMsg(state);} - : KW_WITH KW_DEFERRED KW_REBUILD - ->^(TOK_DEFERRED_REBUILDINDEX) - ; - -indexTblName -@init { pushMsg("index table name", state);} -@after {popMsg(state);} - : KW_IN KW_TABLE indexTbl=tableName - ->^(TOK_CREATEINDEX_INDEXTBLNAME $indexTbl) - ; - -indexPropertiesPrefixed -@init { pushMsg("table properties with prefix", state); } -@after { popMsg(state); } - : - KW_IDXPROPERTIES! indexProperties - ; - -indexProperties -@init { pushMsg("index properties", state); } -@after { popMsg(state); } - : - LPAREN indexPropertiesList RPAREN -> ^(TOK_INDEXPROPERTIES indexPropertiesList) - ; - -indexPropertiesList -@init { pushMsg("index properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_INDEXPROPLIST keyValueProperty+) - ; - -dropIndexStatement -@init { pushMsg("drop index statement", state);} -@after {popMsg(state);} - : KW_DROP KW_INDEX ifExists? indexName=identifier KW_ON tab=tableName - ->^(TOK_DROPINDEX $indexName $tab ifExists?) - ; - -dropTableStatement -@init { pushMsg("drop statement", state); } -@after { popMsg(state); } - : KW_DROP KW_TABLE ifExists? tableName KW_PURGE? replicationClause? - -> ^(TOK_DROPTABLE tableName ifExists? KW_PURGE? replicationClause?) - ; - -alterStatement -@init { pushMsg("alter statement", state); } -@after { popMsg(state); } - : KW_ALTER KW_TABLE tableName alterTableStatementSuffix -> ^(TOK_ALTERTABLE tableName alterTableStatementSuffix) - | KW_ALTER KW_VIEW tableName KW_AS? alterViewStatementSuffix -> ^(TOK_ALTERVIEW tableName alterViewStatementSuffix) - | KW_ALTER KW_INDEX alterIndexStatementSuffix -> alterIndexStatementSuffix - | KW_ALTER (KW_DATABASE|KW_SCHEMA) alterDatabaseStatementSuffix -> alterDatabaseStatementSuffix - ; - -alterTableStatementSuffix -@init { pushMsg("alter table statement", state); } -@after { popMsg(state); } - : (alterStatementSuffixRename[true]) => alterStatementSuffixRename[true] - | alterStatementSuffixDropPartitions[true] - | alterStatementSuffixAddPartitions[true] - | alterStatementSuffixTouch - | alterStatementSuffixArchive - | alterStatementSuffixUnArchive - | alterStatementSuffixProperties - | alterStatementSuffixSkewedby - | alterStatementSuffixExchangePartition - | alterStatementPartitionKeyType - | partitionSpec? alterTblPartitionStatementSuffix -> alterTblPartitionStatementSuffix partitionSpec? - ; - -alterTblPartitionStatementSuffix -@init {pushMsg("alter table partition statement suffix", state);} -@after {popMsg(state);} - : alterStatementSuffixFileFormat - | alterStatementSuffixLocation - | alterStatementSuffixMergeFiles - | alterStatementSuffixSerdeProperties - | alterStatementSuffixRenamePart - | alterStatementSuffixBucketNum - | alterTblPartitionStatementSuffixSkewedLocation - | alterStatementSuffixClusterbySortby - | alterStatementSuffixCompact - | alterStatementSuffixUpdateStatsCol - | alterStatementSuffixRenameCol - | alterStatementSuffixAddCol - ; - -alterStatementPartitionKeyType -@init {msgs.push("alter partition key type"); } -@after {msgs.pop();} - : KW_PARTITION KW_COLUMN LPAREN columnNameType RPAREN - -> ^(TOK_ALTERTABLE_PARTCOLTYPE columnNameType) - ; - -alterViewStatementSuffix -@init { pushMsg("alter view statement", state); } -@after { popMsg(state); } - : alterViewSuffixProperties - | alterStatementSuffixRename[false] - | alterStatementSuffixAddPartitions[false] - | alterStatementSuffixDropPartitions[false] - | selectStatementWithCTE - ; - -alterIndexStatementSuffix -@init { pushMsg("alter index statement", state); } -@after { popMsg(state); } - : indexName=identifier KW_ON tableName partitionSpec? - ( - KW_REBUILD - ->^(TOK_ALTERINDEX_REBUILD tableName $indexName partitionSpec?) - | - KW_SET KW_IDXPROPERTIES - indexProperties - ->^(TOK_ALTERINDEX_PROPERTIES tableName $indexName indexProperties) - ) - ; - -alterDatabaseStatementSuffix -@init { pushMsg("alter database statement", state); } -@after { popMsg(state); } - : alterDatabaseSuffixProperties - | alterDatabaseSuffixSetOwner - ; - -alterDatabaseSuffixProperties -@init { pushMsg("alter database properties statement", state); } -@after { popMsg(state); } - : name=identifier KW_SET KW_DBPROPERTIES dbProperties - -> ^(TOK_ALTERDATABASE_PROPERTIES $name dbProperties) - ; - -alterDatabaseSuffixSetOwner -@init { pushMsg("alter database set owner", state); } -@after { popMsg(state); } - : dbName=identifier KW_SET KW_OWNER principalName - -> ^(TOK_ALTERDATABASE_OWNER $dbName principalName) - ; - -alterStatementSuffixRename[boolean table] -@init { pushMsg("rename statement", state); } -@after { popMsg(state); } - : KW_RENAME KW_TO tableName - -> { table }? ^(TOK_ALTERTABLE_RENAME tableName) - -> ^(TOK_ALTERVIEW_RENAME tableName) - ; - -alterStatementSuffixAddCol -@init { pushMsg("add column statement", state); } -@after { popMsg(state); } - : (add=KW_ADD | replace=KW_REPLACE) KW_COLUMNS LPAREN columnNameTypeList RPAREN restrictOrCascade? - -> {$add != null}? ^(TOK_ALTERTABLE_ADDCOLS columnNameTypeList restrictOrCascade?) - -> ^(TOK_ALTERTABLE_REPLACECOLS columnNameTypeList restrictOrCascade?) - ; - -alterStatementSuffixRenameCol -@init { pushMsg("rename column name", state); } -@after { popMsg(state); } - : KW_CHANGE KW_COLUMN? oldName=identifier newName=identifier colType (KW_COMMENT comment=StringLiteral)? alterStatementChangeColPosition? restrictOrCascade? - ->^(TOK_ALTERTABLE_RENAMECOL $oldName $newName colType $comment? alterStatementChangeColPosition? restrictOrCascade?) - ; - -alterStatementSuffixUpdateStatsCol -@init { pushMsg("update column statistics", state); } -@after { popMsg(state); } - : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? - ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) - ; - -alterStatementChangeColPosition - : first=KW_FIRST|KW_AFTER afterCol=identifier - ->{$first != null}? ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION ) - -> ^(TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION $afterCol) - ; - -alterStatementSuffixAddPartitions[boolean table] -@init { pushMsg("add partition statement", state); } -@after { popMsg(state); } - : KW_ADD ifNotExists? alterStatementSuffixAddPartitionsElement+ - -> { table }? ^(TOK_ALTERTABLE_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) - -> ^(TOK_ALTERVIEW_ADDPARTS ifNotExists? alterStatementSuffixAddPartitionsElement+) - ; - -alterStatementSuffixAddPartitionsElement - : partitionSpec partitionLocation? - ; - -alterStatementSuffixTouch -@init { pushMsg("touch statement", state); } -@after { popMsg(state); } - : KW_TOUCH (partitionSpec)* - -> ^(TOK_ALTERTABLE_TOUCH (partitionSpec)*) - ; - -alterStatementSuffixArchive -@init { pushMsg("archive statement", state); } -@after { popMsg(state); } - : KW_ARCHIVE (partitionSpec)* - -> ^(TOK_ALTERTABLE_ARCHIVE (partitionSpec)*) - ; - -alterStatementSuffixUnArchive -@init { pushMsg("unarchive statement", state); } -@after { popMsg(state); } - : KW_UNARCHIVE (partitionSpec)* - -> ^(TOK_ALTERTABLE_UNARCHIVE (partitionSpec)*) - ; - -partitionLocation -@init { pushMsg("partition location", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_PARTITIONLOCATION $locn) - ; - -alterStatementSuffixDropPartitions[boolean table] -@init { pushMsg("drop partition statement", state); } -@after { popMsg(state); } - : KW_DROP ifExists? dropPartitionSpec (COMMA dropPartitionSpec)* KW_PURGE? replicationClause? - -> { table }? ^(TOK_ALTERTABLE_DROPPARTS dropPartitionSpec+ ifExists? KW_PURGE? replicationClause?) - -> ^(TOK_ALTERVIEW_DROPPARTS dropPartitionSpec+ ifExists? replicationClause?) - ; - -alterStatementSuffixProperties -@init { pushMsg("alter properties statement", state); } -@after { popMsg(state); } - : KW_SET KW_TBLPROPERTIES tableProperties - -> ^(TOK_ALTERTABLE_PROPERTIES tableProperties) - | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties - -> ^(TOK_ALTERTABLE_DROPPROPERTIES tableProperties ifExists?) - ; - -alterViewSuffixProperties -@init { pushMsg("alter view properties statement", state); } -@after { popMsg(state); } - : KW_SET KW_TBLPROPERTIES tableProperties - -> ^(TOK_ALTERVIEW_PROPERTIES tableProperties) - | KW_UNSET KW_TBLPROPERTIES ifExists? tableProperties - -> ^(TOK_ALTERVIEW_DROPPROPERTIES tableProperties ifExists?) - ; - -alterStatementSuffixSerdeProperties -@init { pushMsg("alter serdes statement", state); } -@after { popMsg(state); } - : KW_SET KW_SERDE serdeName=StringLiteral (KW_WITH KW_SERDEPROPERTIES tableProperties)? - -> ^(TOK_ALTERTABLE_SERIALIZER $serdeName tableProperties?) - | KW_SET KW_SERDEPROPERTIES tableProperties - -> ^(TOK_ALTERTABLE_SERDEPROPERTIES tableProperties) - ; - -tablePartitionPrefix -@init {pushMsg("table partition prefix", state);} -@after {popMsg(state);} - : tableName partitionSpec? - ->^(TOK_TABLE_PARTITION tableName partitionSpec?) - ; - -alterStatementSuffixFileFormat -@init {pushMsg("alter fileformat statement", state); } -@after {popMsg(state);} - : KW_SET KW_FILEFORMAT fileFormat - -> ^(TOK_ALTERTABLE_FILEFORMAT fileFormat) - ; - -alterStatementSuffixClusterbySortby -@init {pushMsg("alter partition cluster by sort by statement", state);} -@after {popMsg(state);} - : KW_NOT KW_CLUSTERED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_CLUSTERED) - | KW_NOT KW_SORTED -> ^(TOK_ALTERTABLE_CLUSTER_SORT TOK_NOT_SORTED) - | tableBuckets -> ^(TOK_ALTERTABLE_CLUSTER_SORT tableBuckets) - ; - -alterTblPartitionStatementSuffixSkewedLocation -@init {pushMsg("alter partition skewed location", state);} -@after {popMsg(state);} - : KW_SET KW_SKEWED KW_LOCATION skewedLocations - -> ^(TOK_ALTERTABLE_SKEWED_LOCATION skewedLocations) - ; - -skewedLocations -@init { pushMsg("skewed locations", state); } -@after { popMsg(state); } - : - LPAREN skewedLocationsList RPAREN -> ^(TOK_SKEWED_LOCATIONS skewedLocationsList) - ; - -skewedLocationsList -@init { pushMsg("skewed locations list", state); } -@after { popMsg(state); } - : - skewedLocationMap (COMMA skewedLocationMap)* -> ^(TOK_SKEWED_LOCATION_LIST skewedLocationMap+) - ; - -skewedLocationMap -@init { pushMsg("specifying skewed location map", state); } -@after { popMsg(state); } - : - key=skewedValueLocationElement EQUAL value=StringLiteral -> ^(TOK_SKEWED_LOCATION_MAP $key $value) - ; - -alterStatementSuffixLocation -@init {pushMsg("alter location", state);} -@after {popMsg(state);} - : KW_SET KW_LOCATION newLoc=StringLiteral - -> ^(TOK_ALTERTABLE_LOCATION $newLoc) - ; - - -alterStatementSuffixSkewedby -@init {pushMsg("alter skewed by statement", state);} -@after{popMsg(state);} - : tableSkewed - ->^(TOK_ALTERTABLE_SKEWED tableSkewed) - | - KW_NOT KW_SKEWED - ->^(TOK_ALTERTABLE_SKEWED) - | - KW_NOT storedAsDirs - ->^(TOK_ALTERTABLE_SKEWED storedAsDirs) - ; - -alterStatementSuffixExchangePartition -@init {pushMsg("alter exchange partition", state);} -@after{popMsg(state);} - : KW_EXCHANGE partitionSpec KW_WITH KW_TABLE exchangename=tableName - -> ^(TOK_ALTERTABLE_EXCHANGEPARTITION partitionSpec $exchangename) - ; - -alterStatementSuffixRenamePart -@init { pushMsg("alter table rename partition statement", state); } -@after { popMsg(state); } - : KW_RENAME KW_TO partitionSpec - ->^(TOK_ALTERTABLE_RENAMEPART partitionSpec) - ; - -alterStatementSuffixStatsPart -@init { pushMsg("alter table stats partition statement", state); } -@after { popMsg(state); } - : KW_UPDATE KW_STATISTICS KW_FOR KW_COLUMN? colName=identifier KW_SET tableProperties (KW_COMMENT comment=StringLiteral)? - ->^(TOK_ALTERTABLE_UPDATECOLSTATS $colName tableProperties $comment?) - ; - -alterStatementSuffixMergeFiles -@init { pushMsg("", state); } -@after { popMsg(state); } - : KW_CONCATENATE - -> ^(TOK_ALTERTABLE_MERGEFILES) - ; - -alterStatementSuffixBucketNum -@init { pushMsg("", state); } -@after { popMsg(state); } - : KW_INTO num=Number KW_BUCKETS - -> ^(TOK_ALTERTABLE_BUCKETS $num) - ; - -alterStatementSuffixCompact -@init { msgs.push("compaction request"); } -@after { msgs.pop(); } - : KW_COMPACT compactType=StringLiteral - -> ^(TOK_ALTERTABLE_COMPACT $compactType) - ; - - -fileFormat -@init { pushMsg("file format specification", state); } -@after { popMsg(state); } - : KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral KW_SERDE serdeCls=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? - -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $serdeCls $inDriver? $outDriver?) - | genericSpec=identifier -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) - ; - -tabTypeExpr -@init { pushMsg("specifying table types", state); } -@after { popMsg(state); } - : identifier (DOT^ identifier)? - (identifier (DOT^ - ( - (KW_ELEM_TYPE) => KW_ELEM_TYPE - | - (KW_KEY_TYPE) => KW_KEY_TYPE - | - (KW_VALUE_TYPE) => KW_VALUE_TYPE - | identifier - ))* - )? - ; - -partTypeExpr -@init { pushMsg("specifying table partitions", state); } -@after { popMsg(state); } - : tabTypeExpr partitionSpec? -> ^(TOK_TABTYPE tabTypeExpr partitionSpec?) - ; - -tabPartColTypeExpr -@init { pushMsg("specifying table partitions columnName", state); } -@after { popMsg(state); } - : tableName partitionSpec? extColumnName? -> ^(TOK_TABTYPE tableName partitionSpec? extColumnName?) - ; - -refreshStatement -@init { pushMsg("refresh statement", state); } -@after { popMsg(state); } - : - KW_REFRESH KW_TABLE tableName -> ^(TOK_REFRESHTABLE tableName) - ; - -descStatement -@init { pushMsg("describe statement", state); } -@after { popMsg(state); } - : - (KW_DESCRIBE|KW_DESC) - ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) KW_EXTENDED? (dbName=identifier) -> ^(TOK_DESCDATABASE $dbName KW_EXTENDED?) - | - (KW_FUNCTION) => KW_FUNCTION KW_EXTENDED? (name=descFuncNames) -> ^(TOK_DESCFUNCTION $name KW_EXTENDED?) - | - (KW_FORMATTED|KW_EXTENDED|KW_PRETTY) => ((descOptions=KW_FORMATTED|descOptions=KW_EXTENDED|descOptions=KW_PRETTY) parttype=tabPartColTypeExpr) -> ^(TOK_DESCTABLE $parttype $descOptions) - | - parttype=tabPartColTypeExpr -> ^(TOK_DESCTABLE $parttype) - ) - ; - -analyzeStatement -@init { pushMsg("analyze statement", state); } -@after { popMsg(state); } - : KW_ANALYZE KW_TABLE (parttype=tableOrPartition) KW_COMPUTE KW_STATISTICS ((noscan=KW_NOSCAN) | (partialscan=KW_PARTIALSCAN) - | (KW_FOR KW_COLUMNS (statsColumnName=columnNameList)?))? - -> ^(TOK_ANALYZE $parttype $noscan? $partialscan? KW_COLUMNS? $statsColumnName?) - ; - -showStatement -@init { pushMsg("show statement", state); } -@after { popMsg(state); } - : KW_SHOW (KW_DATABASES|KW_SCHEMAS) (KW_LIKE showStmtIdentifier)? -> ^(TOK_SHOWDATABASES showStmtIdentifier?) - | KW_SHOW KW_TABLES ((KW_FROM|KW_IN) db_name=identifier)? (KW_LIKE showStmtIdentifier|showStmtIdentifier)? -> ^(TOK_SHOWTABLES ^(TOK_FROM $db_name)? showStmtIdentifier?) - | KW_SHOW KW_COLUMNS (KW_FROM|KW_IN) tableName ((KW_FROM|KW_IN) db_name=identifier)? - -> ^(TOK_SHOWCOLUMNS tableName $db_name?) - | KW_SHOW KW_FUNCTIONS (KW_LIKE showFunctionIdentifier|showFunctionIdentifier)? -> ^(TOK_SHOWFUNCTIONS KW_LIKE? showFunctionIdentifier?) - | KW_SHOW KW_PARTITIONS tabName=tableName partitionSpec? -> ^(TOK_SHOWPARTITIONS $tabName partitionSpec?) - | KW_SHOW KW_CREATE ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) db_name=identifier -> ^(TOK_SHOW_CREATEDATABASE $db_name) - | - KW_TABLE tabName=tableName -> ^(TOK_SHOW_CREATETABLE $tabName) - ) - | KW_SHOW KW_TABLE KW_EXTENDED ((KW_FROM|KW_IN) db_name=identifier)? KW_LIKE showStmtIdentifier partitionSpec? - -> ^(TOK_SHOW_TABLESTATUS showStmtIdentifier $db_name? partitionSpec?) - | KW_SHOW KW_TBLPROPERTIES tableName (LPAREN prptyName=StringLiteral RPAREN)? -> ^(TOK_SHOW_TBLPROPERTIES tableName $prptyName?) - | KW_SHOW KW_LOCKS - ( - (KW_DATABASE|KW_SCHEMA) => (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWDBLOCKS $dbName $isExtended?) - | - (parttype=partTypeExpr)? (isExtended=KW_EXTENDED)? -> ^(TOK_SHOWLOCKS $parttype? $isExtended?) - ) - | KW_SHOW (showOptions=KW_FORMATTED)? (KW_INDEX|KW_INDEXES) KW_ON showStmtIdentifier ((KW_FROM|KW_IN) db_name=identifier)? - -> ^(TOK_SHOWINDEXES showStmtIdentifier $showOptions? $db_name?) - | KW_SHOW KW_COMPACTIONS -> ^(TOK_SHOW_COMPACTIONS) - | KW_SHOW KW_TRANSACTIONS -> ^(TOK_SHOW_TRANSACTIONS) - | KW_SHOW KW_CONF StringLiteral -> ^(TOK_SHOWCONF StringLiteral) - ; - -lockStatement -@init { pushMsg("lock statement", state); } -@after { popMsg(state); } - : KW_LOCK KW_TABLE tableName partitionSpec? lockMode -> ^(TOK_LOCKTABLE tableName lockMode partitionSpec?) - ; - -lockDatabase -@init { pushMsg("lock database statement", state); } -@after { popMsg(state); } - : KW_LOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) lockMode -> ^(TOK_LOCKDB $dbName lockMode) - ; - -lockMode -@init { pushMsg("lock mode", state); } -@after { popMsg(state); } - : KW_SHARED | KW_EXCLUSIVE - ; - -unlockStatement -@init { pushMsg("unlock statement", state); } -@after { popMsg(state); } - : KW_UNLOCK KW_TABLE tableName partitionSpec? -> ^(TOK_UNLOCKTABLE tableName partitionSpec?) - ; - -unlockDatabase -@init { pushMsg("unlock database statement", state); } -@after { popMsg(state); } - : KW_UNLOCK (KW_DATABASE|KW_SCHEMA) (dbName=Identifier) -> ^(TOK_UNLOCKDB $dbName) - ; - -createRoleStatement -@init { pushMsg("create role", state); } -@after { popMsg(state); } - : KW_CREATE KW_ROLE roleName=identifier - -> ^(TOK_CREATEROLE $roleName) - ; - -dropRoleStatement -@init {pushMsg("drop role", state);} -@after {popMsg(state);} - : KW_DROP KW_ROLE roleName=identifier - -> ^(TOK_DROPROLE $roleName) - ; - -grantPrivileges -@init {pushMsg("grant privileges", state);} -@after {popMsg(state);} - : KW_GRANT privList=privilegeList - privilegeObject? - KW_TO principalSpecification - withGrantOption? - -> ^(TOK_GRANT $privList principalSpecification privilegeObject? withGrantOption?) - ; - -revokePrivileges -@init {pushMsg("revoke privileges", state);} -@afer {popMsg(state);} - : KW_REVOKE grantOptionFor? privilegeList privilegeObject? KW_FROM principalSpecification - -> ^(TOK_REVOKE privilegeList principalSpecification privilegeObject? grantOptionFor?) - ; - -grantRole -@init {pushMsg("grant role", state);} -@after {popMsg(state);} - : KW_GRANT KW_ROLE? identifier (COMMA identifier)* KW_TO principalSpecification withAdminOption? - -> ^(TOK_GRANT_ROLE principalSpecification withAdminOption? identifier+) - ; - -revokeRole -@init {pushMsg("revoke role", state);} -@after {popMsg(state);} - : KW_REVOKE adminOptionFor? KW_ROLE? identifier (COMMA identifier)* KW_FROM principalSpecification - -> ^(TOK_REVOKE_ROLE principalSpecification adminOptionFor? identifier+) - ; - -showRoleGrants -@init {pushMsg("show role grants", state);} -@after {popMsg(state);} - : KW_SHOW KW_ROLE KW_GRANT principalName - -> ^(TOK_SHOW_ROLE_GRANT principalName) - ; - - -showRoles -@init {pushMsg("show roles", state);} -@after {popMsg(state);} - : KW_SHOW KW_ROLES - -> ^(TOK_SHOW_ROLES) - ; - -showCurrentRole -@init {pushMsg("show current role", state);} -@after {popMsg(state);} - : KW_SHOW KW_CURRENT KW_ROLES - -> ^(TOK_SHOW_SET_ROLE) - ; - -setRole -@init {pushMsg("set role", state);} -@after {popMsg(state);} - : KW_SET KW_ROLE - ( - (KW_ALL) => (all=KW_ALL) -> ^(TOK_SHOW_SET_ROLE Identifier[$all.text]) - | - (KW_NONE) => (none=KW_NONE) -> ^(TOK_SHOW_SET_ROLE Identifier[$none.text]) - | - identifier -> ^(TOK_SHOW_SET_ROLE identifier) - ) - ; - -showGrants -@init {pushMsg("show grants", state);} -@after {popMsg(state);} - : KW_SHOW KW_GRANT principalName? (KW_ON privilegeIncludeColObject)? - -> ^(TOK_SHOW_GRANT principalName? privilegeIncludeColObject?) - ; - -showRolePrincipals -@init {pushMsg("show role principals", state);} -@after {popMsg(state);} - : KW_SHOW KW_PRINCIPALS roleName=identifier - -> ^(TOK_SHOW_ROLE_PRINCIPALS $roleName) - ; - - -privilegeIncludeColObject -@init {pushMsg("privilege object including columns", state);} -@after {popMsg(state);} - : (KW_ALL) => KW_ALL -> ^(TOK_RESOURCE_ALL) - | privObjectCols -> ^(TOK_PRIV_OBJECT_COL privObjectCols) - ; - -privilegeObject -@init {pushMsg("privilege object", state);} -@after {popMsg(state);} - : KW_ON privObject -> ^(TOK_PRIV_OBJECT privObject) - ; - -// database or table type. Type is optional, default type is table -privObject - : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) - | KW_TABLE? tableName partitionSpec? -> ^(TOK_TABLE_TYPE tableName partitionSpec?) - | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) - | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) - ; - -privObjectCols - : (KW_DATABASE|KW_SCHEMA) identifier -> ^(TOK_DB_TYPE identifier) - | KW_TABLE? tableName (LPAREN cols=columnNameList RPAREN)? partitionSpec? -> ^(TOK_TABLE_TYPE tableName $cols? partitionSpec?) - | KW_URI (path=StringLiteral) -> ^(TOK_URI_TYPE $path) - | KW_SERVER identifier -> ^(TOK_SERVER_TYPE identifier) - ; - -privilegeList -@init {pushMsg("grant privilege list", state);} -@after {popMsg(state);} - : privlegeDef (COMMA privlegeDef)* - -> ^(TOK_PRIVILEGE_LIST privlegeDef+) - ; - -privlegeDef -@init {pushMsg("grant privilege", state);} -@after {popMsg(state);} - : privilegeType (LPAREN cols=columnNameList RPAREN)? - -> ^(TOK_PRIVILEGE privilegeType $cols?) - ; - -privilegeType -@init {pushMsg("privilege type", state);} -@after {popMsg(state);} - : KW_ALL -> ^(TOK_PRIV_ALL) - | KW_ALTER -> ^(TOK_PRIV_ALTER_METADATA) - | KW_UPDATE -> ^(TOK_PRIV_ALTER_DATA) - | KW_CREATE -> ^(TOK_PRIV_CREATE) - | KW_DROP -> ^(TOK_PRIV_DROP) - | KW_INDEX -> ^(TOK_PRIV_INDEX) - | KW_LOCK -> ^(TOK_PRIV_LOCK) - | KW_SELECT -> ^(TOK_PRIV_SELECT) - | KW_SHOW_DATABASE -> ^(TOK_PRIV_SHOW_DATABASE) - | KW_INSERT -> ^(TOK_PRIV_INSERT) - | KW_DELETE -> ^(TOK_PRIV_DELETE) - ; - -principalSpecification -@init { pushMsg("user/group/role name list", state); } -@after { popMsg(state); } - : principalName (COMMA principalName)* -> ^(TOK_PRINCIPAL_NAME principalName+) - ; - -principalName -@init {pushMsg("user|group|role name", state);} -@after {popMsg(state);} - : KW_USER principalIdentifier -> ^(TOK_USER principalIdentifier) - | KW_GROUP principalIdentifier -> ^(TOK_GROUP principalIdentifier) - | KW_ROLE identifier -> ^(TOK_ROLE identifier) - ; - -withGrantOption -@init {pushMsg("with grant option", state);} -@after {popMsg(state);} - : KW_WITH KW_GRANT KW_OPTION - -> ^(TOK_GRANT_WITH_OPTION) - ; - -grantOptionFor -@init {pushMsg("grant option for", state);} -@after {popMsg(state);} - : KW_GRANT KW_OPTION KW_FOR - -> ^(TOK_GRANT_OPTION_FOR) -; - -adminOptionFor -@init {pushMsg("admin option for", state);} -@after {popMsg(state);} - : KW_ADMIN KW_OPTION KW_FOR - -> ^(TOK_ADMIN_OPTION_FOR) -; - -withAdminOption -@init {pushMsg("with admin option", state);} -@after {popMsg(state);} - : KW_WITH KW_ADMIN KW_OPTION - -> ^(TOK_GRANT_WITH_ADMIN_OPTION) - ; - -metastoreCheck -@init { pushMsg("metastore check statement", state); } -@after { popMsg(state); } - : KW_MSCK (repair=KW_REPAIR)? (KW_TABLE tableName partitionSpec? (COMMA partitionSpec)*)? - -> ^(TOK_MSCK $repair? (tableName partitionSpec*)?) - ; - -resourceList -@init { pushMsg("resource list", state); } -@after { popMsg(state); } - : - resource (COMMA resource)* -> ^(TOK_RESOURCE_LIST resource+) - ; - -resource -@init { pushMsg("resource", state); } -@after { popMsg(state); } - : - resType=resourceType resPath=StringLiteral -> ^(TOK_RESOURCE_URI $resType $resPath) - ; - -resourceType -@init { pushMsg("resource type", state); } -@after { popMsg(state); } - : - KW_JAR -> ^(TOK_JAR) - | - KW_FILE -> ^(TOK_FILE) - | - KW_ARCHIVE -> ^(TOK_ARCHIVE) - ; - -createFunctionStatement -@init { pushMsg("create function statement", state); } -@after { popMsg(state); } - : KW_CREATE (temp=KW_TEMPORARY)? KW_FUNCTION functionIdentifier KW_AS StringLiteral - (KW_USING rList=resourceList)? - -> {$temp != null}? ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList? TOK_TEMPORARY) - -> ^(TOK_CREATEFUNCTION functionIdentifier StringLiteral $rList?) - ; - -dropFunctionStatement -@init { pushMsg("drop function statement", state); } -@after { popMsg(state); } - : KW_DROP (temp=KW_TEMPORARY)? KW_FUNCTION ifExists? functionIdentifier - -> {$temp != null}? ^(TOK_DROPFUNCTION functionIdentifier ifExists? TOK_TEMPORARY) - -> ^(TOK_DROPFUNCTION functionIdentifier ifExists?) - ; - -reloadFunctionStatement -@init { pushMsg("reload function statement", state); } -@after { popMsg(state); } - : KW_RELOAD KW_FUNCTION -> ^(TOK_RELOADFUNCTION); - -createMacroStatement -@init { pushMsg("create macro statement", state); } -@after { popMsg(state); } - : KW_CREATE KW_TEMPORARY KW_MACRO Identifier - LPAREN columnNameTypeList? RPAREN expression - -> ^(TOK_CREATEMACRO Identifier columnNameTypeList? expression) - ; - -dropMacroStatement -@init { pushMsg("drop macro statement", state); } -@after { popMsg(state); } - : KW_DROP KW_TEMPORARY KW_MACRO ifExists? Identifier - -> ^(TOK_DROPMACRO Identifier ifExists?) - ; - -createViewStatement -@init { - pushMsg("create view statement", state); -} -@after { popMsg(state); } - : KW_CREATE (orReplace)? KW_VIEW (ifNotExists)? name=tableName - (LPAREN columnNameCommentList RPAREN)? tableComment? viewPartition? - tablePropertiesPrefixed? - KW_AS - selectStatementWithCTE - -> ^(TOK_CREATEVIEW $name orReplace? - ifNotExists? - columnNameCommentList? - tableComment? - viewPartition? - tablePropertiesPrefixed? - selectStatementWithCTE - ) - ; - -viewPartition -@init { pushMsg("view partition specification", state); } -@after { popMsg(state); } - : KW_PARTITIONED KW_ON LPAREN columnNameList RPAREN - -> ^(TOK_VIEWPARTCOLS columnNameList) - ; - -dropViewStatement -@init { pushMsg("drop view statement", state); } -@after { popMsg(state); } - : KW_DROP KW_VIEW ifExists? viewName -> ^(TOK_DROPVIEW viewName ifExists?) - ; - -showFunctionIdentifier -@init { pushMsg("identifier for show function statement", state); } -@after { popMsg(state); } - : functionIdentifier - | StringLiteral - ; - -showStmtIdentifier -@init { pushMsg("identifier for show statement", state); } -@after { popMsg(state); } - : identifier - | StringLiteral - ; - -tableProvider -@init { pushMsg("table's provider", state); } -@after { popMsg(state); } - : - KW_USING Identifier (DOT Identifier)* - -> ^(TOK_TABLEPROVIDER Identifier+) - ; - -optionKeyValue -@init { pushMsg("table's option specification", state); } -@after { popMsg(state); } - : - (looseIdentifier (DOT looseIdentifier)*) StringLiteral - -> ^(TOK_TABLEOPTION looseIdentifier+ StringLiteral) - ; - -tableOpts -@init { pushMsg("table's options", state); } -@after { popMsg(state); } - : - KW_OPTIONS LPAREN optionKeyValue (COMMA optionKeyValue)* RPAREN - -> ^(TOK_TABLEOPTIONS optionKeyValue+) - ; - -tableComment -@init { pushMsg("table's comment", state); } -@after { popMsg(state); } - : - KW_COMMENT comment=StringLiteral -> ^(TOK_TABLECOMMENT $comment) - ; - -tablePartition -@init { pushMsg("table partition specification", state); } -@after { popMsg(state); } - : KW_PARTITIONED KW_BY LPAREN columnNameTypeList RPAREN - -> ^(TOK_TABLEPARTCOLS columnNameTypeList) - ; - -tableBuckets -@init { pushMsg("table buckets specification", state); } -@after { popMsg(state); } - : - KW_CLUSTERED KW_BY LPAREN bucketCols=columnNameList RPAREN (KW_SORTED KW_BY LPAREN sortCols=columnNameOrderList RPAREN)? KW_INTO num=Number KW_BUCKETS - -> ^(TOK_ALTERTABLE_BUCKETS $bucketCols $sortCols? $num) - ; - -tableSkewed -@init { pushMsg("table skewed specification", state); } -@after { popMsg(state); } - : - KW_SKEWED KW_BY LPAREN skewedCols=columnNameList RPAREN KW_ON LPAREN (skewedValues=skewedValueElement) RPAREN ((storedAsDirs) => storedAsDirs)? - -> ^(TOK_TABLESKEWED $skewedCols $skewedValues storedAsDirs?) - ; - -rowFormat -@init { pushMsg("serde specification", state); } -@after { popMsg(state); } - : rowFormatSerde -> ^(TOK_SERDE rowFormatSerde) - | rowFormatDelimited -> ^(TOK_SERDE rowFormatDelimited) - | -> ^(TOK_SERDE) - ; - -recordReader -@init { pushMsg("record reader specification", state); } -@after { popMsg(state); } - : KW_RECORDREADER StringLiteral -> ^(TOK_RECORDREADER StringLiteral) - | -> ^(TOK_RECORDREADER) - ; - -recordWriter -@init { pushMsg("record writer specification", state); } -@after { popMsg(state); } - : KW_RECORDWRITER StringLiteral -> ^(TOK_RECORDWRITER StringLiteral) - | -> ^(TOK_RECORDWRITER) - ; - -rowFormatSerde -@init { pushMsg("serde format specification", state); } -@after { popMsg(state); } - : KW_ROW KW_FORMAT KW_SERDE name=StringLiteral (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? - -> ^(TOK_SERDENAME $name $serdeprops?) - ; - -rowFormatDelimited -@init { pushMsg("serde properties specification", state); } -@after { popMsg(state); } - : - KW_ROW KW_FORMAT KW_DELIMITED tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat? - -> ^(TOK_SERDEPROPS tableRowFormatFieldIdentifier? tableRowFormatCollItemsIdentifier? tableRowFormatMapKeysIdentifier? tableRowFormatLinesIdentifier? tableRowNullFormat?) - ; - -tableRowFormat -@init { pushMsg("table row format specification", state); } -@after { popMsg(state); } - : - rowFormatDelimited - -> ^(TOK_TABLEROWFORMAT rowFormatDelimited) - | rowFormatSerde - -> ^(TOK_TABLESERIALIZER rowFormatSerde) - ; - -tablePropertiesPrefixed -@init { pushMsg("table properties with prefix", state); } -@after { popMsg(state); } - : - KW_TBLPROPERTIES! tableProperties - ; - -tableProperties -@init { pushMsg("table properties", state); } -@after { popMsg(state); } - : - LPAREN tablePropertiesList RPAREN -> ^(TOK_TABLEPROPERTIES tablePropertiesList) - ; - -tablePropertiesList -@init { pushMsg("table properties list", state); } -@after { popMsg(state); } - : - keyValueProperty (COMMA keyValueProperty)* -> ^(TOK_TABLEPROPLIST keyValueProperty+) - | - keyProperty (COMMA keyProperty)* -> ^(TOK_TABLEPROPLIST keyProperty+) - ; - -keyValueProperty -@init { pushMsg("specifying key/value property", state); } -@after { popMsg(state); } - : - key=StringLiteral EQUAL value=StringLiteral -> ^(TOK_TABLEPROPERTY $key $value) - ; - -keyProperty -@init { pushMsg("specifying key property", state); } -@after { popMsg(state); } - : - key=StringLiteral -> ^(TOK_TABLEPROPERTY $key TOK_NULL) - ; - -tableRowFormatFieldIdentifier -@init { pushMsg("table row format's field separator", state); } -@after { popMsg(state); } - : - KW_FIELDS KW_TERMINATED KW_BY fldIdnt=StringLiteral (KW_ESCAPED KW_BY fldEscape=StringLiteral)? - -> ^(TOK_TABLEROWFORMATFIELD $fldIdnt $fldEscape?) - ; - -tableRowFormatCollItemsIdentifier -@init { pushMsg("table row format's column separator", state); } -@after { popMsg(state); } - : - KW_COLLECTION KW_ITEMS KW_TERMINATED KW_BY collIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATCOLLITEMS $collIdnt) - ; - -tableRowFormatMapKeysIdentifier -@init { pushMsg("table row format's map key separator", state); } -@after { popMsg(state); } - : - KW_MAP KW_KEYS KW_TERMINATED KW_BY mapKeysIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATMAPKEYS $mapKeysIdnt) - ; - -tableRowFormatLinesIdentifier -@init { pushMsg("table row format's line separator", state); } -@after { popMsg(state); } - : - KW_LINES KW_TERMINATED KW_BY linesIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATLINES $linesIdnt) - ; - -tableRowNullFormat -@init { pushMsg("table row format's null specifier", state); } -@after { popMsg(state); } - : - KW_NULL KW_DEFINED KW_AS nullIdnt=StringLiteral - -> ^(TOK_TABLEROWFORMATNULL $nullIdnt) - ; -tableFileFormat -@init { pushMsg("table file format specification", state); } -@after { popMsg(state); } - : - (KW_STORED KW_AS KW_INPUTFORMAT) => KW_STORED KW_AS KW_INPUTFORMAT inFmt=StringLiteral KW_OUTPUTFORMAT outFmt=StringLiteral (KW_INPUTDRIVER inDriver=StringLiteral KW_OUTPUTDRIVER outDriver=StringLiteral)? - -> ^(TOK_TABLEFILEFORMAT $inFmt $outFmt $inDriver? $outDriver?) - | KW_STORED KW_BY storageHandler=StringLiteral - (KW_WITH KW_SERDEPROPERTIES serdeprops=tableProperties)? - -> ^(TOK_STORAGEHANDLER $storageHandler $serdeprops?) - | KW_STORED KW_AS genericSpec=identifier - -> ^(TOK_FILEFORMAT_GENERIC $genericSpec) - ; - -tableLocation -@init { pushMsg("table location specification", state); } -@after { popMsg(state); } - : - KW_LOCATION locn=StringLiteral -> ^(TOK_TABLELOCATION $locn) - ; - -columnNameTypeList -@init { pushMsg("column name type list", state); } -@after { popMsg(state); } - : columnNameType (COMMA columnNameType)* -> ^(TOK_TABCOLLIST columnNameType+) - ; - -columnNameColonTypeList -@init { pushMsg("column name type list", state); } -@after { popMsg(state); } - : columnNameColonType (COMMA columnNameColonType)* -> ^(TOK_TABCOLLIST columnNameColonType+) - ; - -columnNameList -@init { pushMsg("column name list", state); } -@after { popMsg(state); } - : columnName (COMMA columnName)* -> ^(TOK_TABCOLNAME columnName+) - ; - -columnName -@init { pushMsg("column name", state); } -@after { popMsg(state); } - : - identifier - ; - -extColumnName -@init { pushMsg("column name for complex types", state); } -@after { popMsg(state); } - : - identifier (DOT^ ((KW_ELEM_TYPE) => KW_ELEM_TYPE | (KW_KEY_TYPE) => KW_KEY_TYPE | (KW_VALUE_TYPE) => KW_VALUE_TYPE | identifier))* - ; - -columnNameOrderList -@init { pushMsg("column name order list", state); } -@after { popMsg(state); } - : columnNameOrder (COMMA columnNameOrder)* -> ^(TOK_TABCOLNAME columnNameOrder+) - ; - -skewedValueElement -@init { pushMsg("skewed value element", state); } -@after { popMsg(state); } - : - skewedColumnValues - | skewedColumnValuePairList - ; - -skewedColumnValuePairList -@init { pushMsg("column value pair list", state); } -@after { popMsg(state); } - : skewedColumnValuePair (COMMA skewedColumnValuePair)* -> ^(TOK_TABCOLVALUE_PAIR skewedColumnValuePair+) - ; - -skewedColumnValuePair -@init { pushMsg("column value pair", state); } -@after { popMsg(state); } - : - LPAREN colValues=skewedColumnValues RPAREN - -> ^(TOK_TABCOLVALUES $colValues) - ; - -skewedColumnValues -@init { pushMsg("column values", state); } -@after { popMsg(state); } - : skewedColumnValue (COMMA skewedColumnValue)* -> ^(TOK_TABCOLVALUE skewedColumnValue+) - ; - -skewedColumnValue -@init { pushMsg("column value", state); } -@after { popMsg(state); } - : - constant - ; - -skewedValueLocationElement -@init { pushMsg("skewed value location element", state); } -@after { popMsg(state); } - : - skewedColumnValue - | skewedColumnValuePair - ; - -columnNameOrder -@init { pushMsg("column name order", state); } -@after { popMsg(state); } - : identifier (asc=KW_ASC | desc=KW_DESC)? - -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC identifier) - -> ^(TOK_TABSORTCOLNAMEDESC identifier) - ; - -columnNameCommentList -@init { pushMsg("column name comment list", state); } -@after { popMsg(state); } - : columnNameComment (COMMA columnNameComment)* -> ^(TOK_TABCOLNAME columnNameComment+) - ; - -columnNameComment -@init { pushMsg("column name comment", state); } -@after { popMsg(state); } - : colName=identifier (KW_COMMENT comment=StringLiteral)? - -> ^(TOK_TABCOL $colName TOK_NULL $comment?) - ; - -columnRefOrder -@init { pushMsg("column order", state); } -@after { popMsg(state); } - : expression (asc=KW_ASC | desc=KW_DESC)? - -> {$desc == null}? ^(TOK_TABSORTCOLNAMEASC expression) - -> ^(TOK_TABSORTCOLNAMEDESC expression) - ; - -columnNameType -@init { pushMsg("column specification", state); } -@after { popMsg(state); } - : colName=identifier colType (KW_COMMENT comment=StringLiteral)? - -> {containExcludedCharForCreateTableColumnName($colName.text)}? {throwColumnNameException()} - -> {$comment == null}? ^(TOK_TABCOL $colName colType) - -> ^(TOK_TABCOL $colName colType $comment) - ; - -columnNameColonType -@init { pushMsg("column specification", state); } -@after { popMsg(state); } - : colName=identifier COLON colType (KW_COMMENT comment=StringLiteral)? - -> {$comment == null}? ^(TOK_TABCOL $colName colType) - -> ^(TOK_TABCOL $colName colType $comment) - ; - -colType -@init { pushMsg("column type", state); } -@after { popMsg(state); } - : type - ; - -colTypeList -@init { pushMsg("column type list", state); } -@after { popMsg(state); } - : colType (COMMA colType)* -> ^(TOK_COLTYPELIST colType+) - ; - -type - : primitiveType - | listType - | structType - | mapType - | unionType; - -primitiveType -@init { pushMsg("primitive type specification", state); } -@after { popMsg(state); } - : KW_TINYINT -> TOK_TINYINT - | KW_SMALLINT -> TOK_SMALLINT - | KW_INT -> TOK_INT - | KW_BIGINT -> TOK_BIGINT - | KW_LONG -> TOK_BIGINT - | KW_BOOLEAN -> TOK_BOOLEAN - | KW_FLOAT -> TOK_FLOAT - | KW_DOUBLE -> TOK_DOUBLE - | KW_DATE -> TOK_DATE - | KW_DATETIME -> TOK_DATETIME - | KW_TIMESTAMP -> TOK_TIMESTAMP - // Uncomment to allow intervals as table column types - //| KW_INTERVAL KW_YEAR KW_TO KW_MONTH -> TOK_INTERVAL_YEAR_MONTH - //| KW_INTERVAL KW_DAY KW_TO KW_SECOND -> TOK_INTERVAL_DAY_TIME - | KW_STRING -> TOK_STRING - | KW_BINARY -> TOK_BINARY - | KW_DECIMAL (LPAREN prec=Number (COMMA scale=Number)? RPAREN)? -> ^(TOK_DECIMAL $prec? $scale?) - | KW_VARCHAR LPAREN length=Number RPAREN -> ^(TOK_VARCHAR $length) - | KW_CHAR LPAREN length=Number RPAREN -> ^(TOK_CHAR $length) - ; - -listType -@init { pushMsg("list type", state); } -@after { popMsg(state); } - : KW_ARRAY LESSTHAN type GREATERTHAN -> ^(TOK_LIST type) - ; - -structType -@init { pushMsg("struct type", state); } -@after { popMsg(state); } - : KW_STRUCT LESSTHAN columnNameColonTypeList GREATERTHAN -> ^(TOK_STRUCT columnNameColonTypeList) - ; - -mapType -@init { pushMsg("map type", state); } -@after { popMsg(state); } - : KW_MAP LESSTHAN left=type COMMA right=type GREATERTHAN - -> ^(TOK_MAP $left $right) - ; - -unionType -@init { pushMsg("uniontype type", state); } -@after { popMsg(state); } - : KW_UNIONTYPE LESSTHAN colTypeList GREATERTHAN -> ^(TOK_UNIONTYPE colTypeList) - ; - -setOperator -@init { pushMsg("set operator", state); } -@after { popMsg(state); } - : KW_UNION KW_ALL -> ^(TOK_UNIONALL) - | KW_UNION KW_DISTINCT? -> ^(TOK_UNIONDISTINCT) - | KW_EXCEPT -> ^(TOK_EXCEPT) - | KW_INTERSECT -> ^(TOK_INTERSECT) - ; - -queryStatementExpression[boolean topLevel] - : - /* Would be nice to do this as a gated semantic perdicate - But the predicate gets pushed as a lookahead decision. - Calling rule doesnot know about topLevel - */ - (w=withClause {topLevel}?)? - queryStatementExpressionBody[topLevel] { - if ($w.tree != null) { - $queryStatementExpressionBody.tree.insertChild(0, $w.tree); - } - } - -> queryStatementExpressionBody - ; - -queryStatementExpressionBody[boolean topLevel] - : - fromStatement[topLevel] - | regularBody[topLevel] - ; - -withClause - : - KW_WITH cteStatement (COMMA cteStatement)* -> ^(TOK_CTE cteStatement+) -; - -cteStatement - : - identifier KW_AS LPAREN queryStatementExpression[false] RPAREN - -> ^(TOK_SUBQUERY queryStatementExpression identifier) -; - -fromStatement[boolean topLevel] -: (singleFromStatement -> singleFromStatement) - (u=setOperator r=singleFromStatement - -> ^($u {$fromStatement.tree} $r) - )* - -> {u != null && topLevel}? ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - {$fromStatement.tree} - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) - ) - ) - -> {$fromStatement.tree} - ; - - -singleFromStatement - : - fromClause - ( b+=body )+ -> ^(TOK_QUERY fromClause body+) - ; - -/* -The valuesClause rule below ensures that the parse tree for -"insert into table FOO values (1,2),(3,4)" looks the same as -"insert into table FOO select a,b from (values(1,2),(3,4)) as BAR(a,b)" which itself is made to look -very similar to the tree for "insert into table FOO select a,b from BAR". Since virtual table name -is implicit, it's represented as TOK_ANONYMOUS. -*/ -regularBody[boolean topLevel] - : - i=insertClause - ( - s=selectStatement[topLevel] - {$s.tree.getFirstChildWithType(TOK_INSERT).replaceChildren(0, 0, $i.tree);} -> {$s.tree} - | - valuesClause - -> ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_VIRTUAL_TABLE ^(TOK_VIRTUAL_TABREF ^(TOK_ANONYMOUS)) valuesClause) - ) - ^(TOK_INSERT {$i.tree} ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF))) - ) - ) - | - selectStatement[topLevel] - ; - -selectStatement[boolean topLevel] - : - ( - ( - LPAREN - s=selectClause - f=fromClause? - w=whereClause? - g=groupByClause? - h=havingClause? - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - RPAREN - | - s=selectClause - f=fromClause? - w=whereClause? - g=groupByClause? - h=havingClause? - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - ) - -> ^(TOK_QUERY $f? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - $s $w? $g? $h? $o? $c? - $d? $sort? $win? $l?)) - ) - (set=setOpSelectStatement[$selectStatement.tree, topLevel])? - -> {set == null}? - {$selectStatement.tree} - -> {o==null && c==null && d==null && sort==null && l==null}? - {$set.tree} - -> {throwSetOpException()} - ; - -setOpSelectStatement[CommonTree t, boolean topLevel] - : - (( - u=setOperator LPAREN b=simpleSelectStatement RPAREN - | - u=setOperator b=simpleSelectStatement) - -> {$setOpSelectStatement.tree != null}? - ^($u {$setOpSelectStatement.tree} $b) - -> ^($u {$t} $b) - )+ - o=orderByClause? - c=clusterByClause? - d=distributeByClause? - sort=sortByClause? - win=window_clause? - l=limitClause? - -> {o==null && c==null && d==null && sort==null && win==null && l==null && !topLevel}? - {$setOpSelectStatement.tree} - -> ^(TOK_QUERY - ^(TOK_FROM - ^(TOK_SUBQUERY - {$setOpSelectStatement.tree} - {adaptor.create(Identifier, generateUnionAlias())} - ) - ) - ^(TOK_INSERT - ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - ^(TOK_SELECT ^(TOK_SELEXPR TOK_ALLCOLREF)) - $o? $c? $d? $sort? $win? $l? - ) - ) - ; - -simpleSelectStatement - : - selectClause - fromClause? - whereClause? - groupByClause? - havingClause? - ((window_clause) => window_clause)? - -> ^(TOK_QUERY fromClause? ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - selectClause whereClause? groupByClause? havingClause? window_clause?)) - ; - -selectStatementWithCTE - : - (w=withClause)? - selectStatement[true] { - if ($w.tree != null) { - $selectStatement.tree.insertChild(0, $w.tree); - } - } - -> selectStatement - ; - -body - : - insertClause - selectClause - lateralView? - whereClause? - groupByClause? - havingClause? - orderByClause? - clusterByClause? - distributeByClause? - sortByClause? - window_clause? - limitClause? -> ^(TOK_INSERT insertClause - selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? - distributeByClause? sortByClause? window_clause? limitClause?) - | - selectClause - lateralView? - whereClause? - groupByClause? - havingClause? - orderByClause? - clusterByClause? - distributeByClause? - sortByClause? - window_clause? - limitClause? -> ^(TOK_INSERT ^(TOK_DESTINATION ^(TOK_DIR TOK_TMP_FILE)) - selectClause lateralView? whereClause? groupByClause? havingClause? orderByClause? clusterByClause? - distributeByClause? sortByClause? window_clause? limitClause?) - ; - -insertClause -@init { pushMsg("insert clause", state); } -@after { popMsg(state); } - : - KW_INSERT KW_OVERWRITE destination ifNotExists? -> ^(TOK_DESTINATION destination ifNotExists?) - | KW_INSERT KW_INTO KW_TABLE? tableOrPartition (LPAREN targetCols=columnNameList RPAREN)? - -> ^(TOK_INSERT_INTO tableOrPartition $targetCols?) - ; - -destination -@init { pushMsg("destination specification", state); } -@after { popMsg(state); } - : - (local = KW_LOCAL)? KW_DIRECTORY StringLiteral tableRowFormat? tableFileFormat? - -> ^(TOK_DIR StringLiteral $local? tableRowFormat? tableFileFormat?) - | KW_TABLE tableOrPartition -> tableOrPartition - ; - -limitClause -@init { pushMsg("limit clause", state); } -@after { popMsg(state); } - : - KW_LIMIT num=Number -> ^(TOK_LIMIT $num) - ; - -//DELETE FROM WHERE ...; -deleteStatement -@init { pushMsg("delete statement", state); } -@after { popMsg(state); } - : - KW_DELETE KW_FROM tableName (whereClause)? -> ^(TOK_DELETE_FROM tableName whereClause?) - ; - -/*SET = (3 + col2)*/ -columnAssignmentClause - : - tableOrColumn EQUAL^ precedencePlusExpression - ; - -/*SET col1 = 5, col2 = (4 + col4), ...*/ -setColumnsClause - : - KW_SET columnAssignmentClause (COMMA columnAssignmentClause)* -> ^(TOK_SET_COLUMNS_CLAUSE columnAssignmentClause* ) - ; - -/* - UPDATE
- SET col1 = val1, col2 = val2... WHERE ... -*/ -updateStatement -@init { pushMsg("update statement", state); } -@after { popMsg(state); } - : - KW_UPDATE tableName setColumnsClause whereClause? -> ^(TOK_UPDATE_TABLE tableName setColumnsClause whereClause?) - ; - -/* -BEGIN user defined transaction boundaries; follows SQL 2003 standard exactly except for addition of -"setAutoCommitStatement" which is not in the standard doc but is supported by most SQL engines. -*/ -sqlTransactionStatement -@init { pushMsg("transaction statement", state); } -@after { popMsg(state); } - : startTransactionStatement - | commitStatement - | rollbackStatement - | setAutoCommitStatement - ; - -startTransactionStatement - : - KW_START KW_TRANSACTION ( transactionMode ( COMMA transactionMode )* )? -> ^(TOK_START_TRANSACTION transactionMode*) - ; - -transactionMode - : - isolationLevel - | transactionAccessMode -> ^(TOK_TXN_ACCESS_MODE transactionAccessMode) - ; - -transactionAccessMode - : - KW_READ KW_ONLY -> TOK_TXN_READ_ONLY - | KW_READ KW_WRITE -> TOK_TXN_READ_WRITE - ; - -isolationLevel - : - KW_ISOLATION KW_LEVEL levelOfIsolation -> ^(TOK_ISOLATION_LEVEL levelOfIsolation) - ; - -/*READ UNCOMMITTED | READ COMMITTED | REPEATABLE READ | SERIALIZABLE may be supported later*/ -levelOfIsolation - : - KW_SNAPSHOT -> TOK_ISOLATION_SNAPSHOT - ; - -commitStatement - : - KW_COMMIT ( KW_WORK )? -> TOK_COMMIT - ; - -rollbackStatement - : - KW_ROLLBACK ( KW_WORK )? -> TOK_ROLLBACK - ; -setAutoCommitStatement - : - KW_SET KW_AUTOCOMMIT booleanValueTok -> ^(TOK_SET_AUTOCOMMIT booleanValueTok) - ; -/* -END user defined transaction boundaries -*/ - -/* -Table Caching statements. - */ -cacheStatement -@init { pushMsg("cache statement", state); } -@after { popMsg(state); } - : - cacheTableStatement - | uncacheTableStatement - | clearCacheStatement - ; - -cacheTableStatement - : - KW_CACHE (lazy=KW_LAZY)? KW_TABLE identifier (KW_AS selectStatementWithCTE)? -> ^(TOK_CACHETABLE identifier $lazy? selectStatementWithCTE?) - ; - -uncacheTableStatement - : - KW_UNCACHE KW_TABLE identifier -> ^(TOK_UNCACHETABLE identifier) - ; - -clearCacheStatement - : - KW_CLEAR KW_CACHE -> ^(TOK_CLEARCACHE) - ; - diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 new file mode 100644 index 0000000000..3b9f82a80f --- /dev/null +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -0,0 +1,943 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. + */ + +grammar SqlBase; + +tokens { + DELIMITER +} + +singleStatement + : statement EOF + ; + +singleExpression + : namedExpression EOF + ; + +singleTableIdentifier + : tableIdentifier EOF + ; + +singleDataType + : dataType EOF + ; + +statement + : query #statementDefault + | USE db=identifier #use + | CREATE DATABASE (IF NOT EXISTS)? identifier + (COMMENT comment=STRING)? locationSpec? + (WITH DBPROPERTIES tablePropertyList)? #createDatabase + | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | createTableHeader ('(' colTypeList ')')? tableProvider + (OPTIONS tablePropertyList)? #createTableUsing + | createTableHeader tableProvider + (OPTIONS tablePropertyList)? AS? query #createTableUsing + | createTableHeader ('(' columns=colTypeList ')')? + (COMMENT STRING)? + (PARTITIONED BY '(' partitionColumns=colTypeList ')')? + bucketSpec? skewSpec? + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + (AS? query)? #createTable + | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS + (identifier | FOR COLUMNS identifierSeq?)? #analyze + | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable + | ALTER TABLE tableIdentifier + SET TBLPROPERTIES tablePropertyList #setTableProperties + | ALTER TABLE tableIdentifier + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe + | ALTER TABLE tableIdentifier (partitionSpec)? + SET SERDEPROPERTIES tablePropertyList #setTableSerDe + | ALTER TABLE tableIdentifier bucketSpec #bucketTable + | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable + | ALTER TABLE tableIdentifier NOT SORTED #unsortTable + | ALTER TABLE tableIdentifier skewSpec #skewTable + | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable + | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable + | ALTER TABLE tableIdentifier + SET SKEWED LOCATION skewedLocationList #setTableSkewLocations + | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? + partitionSpecLocation+ #addTablePartition + | ALTER TABLE tableIdentifier + from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition + | ALTER TABLE from=tableIdentifier + EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition + | ALTER TABLE tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition + | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition + | ALTER TABLE tableIdentifier partitionSpec? + SET FILEFORMAT fileFormat #setTableFileFormat + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable + | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable + | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable + | ALTER TABLE tableIdentifier partitionSpec? + CHANGE COLUMN? oldName=identifier colType + (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn + | ALTER TABLE tableIdentifier partitionSpec? + ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns + | ALTER TABLE tableIdentifier partitionSpec? + REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns + | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? + (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier + identifierCommentList? (COMMENT STRING)? + (PARTITIONED ON identifierList)? + (TBLPROPERTIES tablePropertyList)? AS query #createView + | ALTER VIEW tableIdentifier AS? query #alterViewQuery + | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING + (USING resource (',' resource)*)? #createFunction + | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction + | EXPLAIN explainOption* statement #explain + | SHOW TABLES ((FROM | IN) db=identifier)? + (LIKE (qualifiedName | pattern=STRING))? #showTables + | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions + | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction + | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? + tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | REFRESH TABLE tableIdentifier #refreshTable + | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable + | UNCACHE TABLE identifier #uncacheTable + | CLEAR CACHE #clearCache + | ADD identifier .*? #addResource + | SET ROLE .*? #failNativeCommand + | SET .*? #setConfiguration + | kws=unsupportedHiveNativeCommands .*? #failNativeCommand + | hiveNativeCommands #executeNativeCommand + ; + +hiveNativeCommands + : createTableHeader LIKE tableIdentifier + rowFormat? createFileFormat? locationSpec? + (TBLPROPERTIES tablePropertyList)? + | DELETE FROM tableIdentifier (WHERE booleanExpression)? + | TRUNCATE TABLE tableIdentifier partitionSpec? + (COLUMNS identifierList)? + | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier + | ALTER VIEW from=tableIdentifier AS? + SET TBLPROPERTIES tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList + | ALTER VIEW from=tableIdentifier AS? + ADD (IF NOT EXISTS)? partitionSpecLocation+ + | ALTER VIEW from=tableIdentifier AS? + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? + | DROP VIEW (IF EXISTS)? qualifiedName + | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? + | START TRANSACTION (transactionMode (',' transactionMode)*)? + | COMMIT WORK? + | ROLLBACK WORK? + | SHOW PARTITIONS tableIdentifier partitionSpec? + | DFS .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? + ; + +unsupportedHiveNativeCommands + : kw1=CREATE kw2=ROLE + | kw1=DROP kw2=ROLE + | kw1=GRANT kw2=ROLE? + | kw1=REVOKE kw2=ROLE? + | kw1=SHOW kw2=GRANT + | kw1=SHOW kw2=ROLE kw3=GRANT? + | kw1=SHOW kw2=PRINCIPALS + | kw1=SHOW kw2=ROLES + | kw1=SHOW kw2=CURRENT kw3=ROLES + | kw1=EXPORT kw2=TABLE + | kw1=IMPORT kw2=TABLE + | kw1=SHOW kw2=COMPACTIONS + | kw1=SHOW kw2=CREATE kw3=TABLE + | kw1=SHOW kw2=TRANSACTIONS + | kw1=SHOW kw2=INDEXES + | kw1=SHOW kw2=LOCKS + ; + +createTableHeader + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + ; + +bucketSpec + : CLUSTERED BY identifierList + (SORTED BY orderedIdentifierList)? + INTO INTEGER_VALUE BUCKETS + ; + +skewSpec + : SKEWED BY identifierList + ON (constantList | nestedConstantList) + (STORED AS DIRECTORIES)? + ; + +locationSpec + : LOCATION STRING + ; + +query + : ctes? queryNoWith + ; + +insertInto + : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? + | INSERT INTO TABLE? tableIdentifier partitionSpec? + ; + +partitionSpecLocation + : partitionSpec locationSpec? + ; + +partitionSpec + : PARTITION '(' partitionVal (',' partitionVal)* ')' + ; + +partitionVal + : identifier (EQ constant)? + ; + +describeColName + : identifier ('.' (identifier | STRING))* + ; + +ctes + : WITH namedQuery (',' namedQuery)* + ; + +namedQuery + : name=identifier AS? '(' queryNoWith ')' + ; + +tableProvider + : USING qualifiedName + ; + +tablePropertyList + : '(' tableProperty (',' tableProperty)* ')' + ; + +tableProperty + : key=tablePropertyKey (EQ? value=STRING)? + ; + +tablePropertyKey + : looseIdentifier ('.' looseIdentifier)* + | STRING + ; + +constantList + : '(' constant (',' constant)* ')' + ; + +nestedConstantList + : '(' constantList (',' constantList)* ')' + ; + +skewedLocation + : (constant | constantList) EQ STRING + ; + +skewedLocationList + : '(' skewedLocation (',' skewedLocation)* ')' + ; + +createFileFormat + : STORED AS fileFormat + | STORED BY storageHandler + ; + +fileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? + (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + | identifier #genericFileFormat + ; + +storageHandler + : STRING (WITH SERDEPROPERTIES tablePropertyList)? + ; + +resource + : identifier STRING + ; + +queryNoWith + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery + ; + +queryOrganization + : (ORDER BY order+=sortItem (',' order+=sortItem)*)? + (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? + (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? + (SORT BY sort+=sortItem (',' sort+=sortItem)*)? + windows? + (LIMIT limit=expression)? + ; + +multiInsertQueryBody + : insertInto? + querySpecification + queryOrganization + ; + +queryTerm + : queryPrimary #queryTermDefault + | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation + ; + +queryPrimary + : querySpecification #queryPrimaryDefault + | TABLE tableIdentifier #table + | inlineTable #inlineTableDefault1 + | '(' queryNoWith ')' #subquery + ; + +sortItem + : expression ordering=(ASC | DESC)? + ; + +querySpecification + : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' + | kind=MAP namedExpressionSeq + | kind=REDUCE namedExpressionSeq)) + inRowFormat=rowFormat? + (RECORDWRITER recordWriter=STRING)? + USING script=STRING + (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? + outRowFormat=rowFormat? + (RECORDREADER recordReader=STRING)? + fromClause? + (WHERE where=booleanExpression)?) + | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? + | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) + lateralView* + (WHERE where=booleanExpression)? + aggregation? + (HAVING having=booleanExpression)? + windows?) + ; + +fromClause + : FROM relation (',' relation)* lateralView* + ; + +aggregation + : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( + WITH kind=ROLLUP + | WITH kind=CUBE + | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? + ; + +groupingSet + : '(' (expression (',' expression)*)? ')' + | expression + ; + +lateralView + : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? + ; + +setQuantifier + : DISTINCT + | ALL + ; + +relation + : left=relation + ((CROSS | joinType) JOIN right=relation joinCriteria? + | NATURAL joinType JOIN right=relation + ) #joinRelation + | relationPrimary #relationDefault + ; + +joinType + : INNER? + | LEFT OUTER? + | LEFT SEMI + | RIGHT OUTER? + | FULL OUTER? + ; + +joinCriteria + : ON booleanExpression + | USING '(' identifier (',' identifier)* ')' + ; + +sample + : TABLESAMPLE '(' + ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) + | (expression sampleType=ROWS) + | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) + ')' + ; + +identifierList + : '(' identifierSeq ')' + ; + +identifierSeq + : identifier (',' identifier)* + ; + +orderedIdentifierList + : '(' orderedIdentifier (',' orderedIdentifier)* ')' + ; + +orderedIdentifier + : identifier ordering=(ASC | DESC)? + ; + +identifierCommentList + : '(' identifierComment (',' identifierComment)* ')' + ; + +identifierComment + : identifier (COMMENT STRING)? + ; + +relationPrimary + : tableIdentifier sample? (AS? identifier)? #tableName + | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery + | '(' relation ')' sample? (AS? identifier)? #aliasedRelation + | inlineTable #inlineTableDefault2 + ; + +inlineTable + : VALUES expression (',' expression)* (AS? identifier identifierList?)? + ; + +rowFormat + : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde + | ROW FORMAT DELIMITED + (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? + (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? + (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? + (LINES TERMINATED BY linesSeparatedBy=STRING)? + (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited + ; + +tableIdentifier + : (db=identifier '.')? table=identifier + ; + +namedExpression + : expression (AS? (identifier | identifierList))? + ; + +namedExpressionSeq + : namedExpression (',' namedExpression)* + ; + +expression + : booleanExpression + ; + +booleanExpression + : predicated #booleanDefault + | NOT booleanExpression #logicalNot + | left=booleanExpression operator=AND right=booleanExpression #logicalBinary + | left=booleanExpression operator=OR right=booleanExpression #logicalBinary + | EXISTS '(' query ')' #exists + ; + +// workaround for: +// https://github.com/antlr/antlr4/issues/780 +// https://github.com/antlr/antlr4/issues/781 +predicated + : valueExpression predicate[$valueExpression.ctx]? + ; + +predicate[ParserRuleContext value] + : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between + | NOT? IN '(' expression (',' expression)* ')' #inList + | NOT? IN '(' query ')' #inSubquery + | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like + | IS NOT? NULL #nullPredicate + ; + +valueExpression + : primaryExpression #valueExpressionDefault + | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary + | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary + | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary + | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary + | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary + | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary + | left=valueExpression comparisonOperator right=valueExpression #comparison + ; + +primaryExpression + : constant #constantDefault + | ASTERISK #star + | qualifiedName '.' ASTERISK #star + | '(' expression (',' expression)+ ')' #rowConstructor + | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall + | '(' query ')' #subqueryExpression + | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase + | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase + | CAST '(' expression AS dataType ')' #cast + | value=primaryExpression '[' index=valueExpression ']' #subscript + | identifier #columnReference + | base=primaryExpression '.' fieldName=identifier #dereference + | '(' expression ')' #parenthesizedExpression + ; + +constant + : NULL #nullLiteral + | interval #intervalLiteral + | identifier STRING #typeConstructor + | number #numericLiteral + | booleanValue #booleanLiteral + | STRING+ #stringLiteral + ; + +comparisonOperator + : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ + ; + +booleanValue + : TRUE | FALSE + ; + +interval + : INTERVAL intervalField* + ; + +intervalField + : value=intervalValue unit=identifier (TO to=identifier)? + ; + +intervalValue + : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) + | STRING + ; + +dataType + : complex=ARRAY '<' dataType '>' #complexDataType + | complex=MAP '<' dataType ',' dataType '>' #complexDataType + | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType + | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType + ; + +colTypeList + : colType (',' colType)* + ; + +colType + : identifier ':'? dataType (COMMENT STRING)? + ; + +whenClause + : WHEN condition=expression THEN result=expression + ; + +windows + : WINDOW namedWindow (',' namedWindow)* + ; + +namedWindow + : identifier AS windowSpec + ; + +windowSpec + : name=identifier #windowRef + | '(' + ( CLUSTER BY partition+=expression (',' partition+=expression)* + | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? + ((ORDER | SORT) BY sortItem (',' sortItem)*)?) + windowFrame? + ')' #windowDef + ; + +windowFrame + : frameType=RANGE start=frameBound + | frameType=ROWS start=frameBound + | frameType=RANGE BETWEEN start=frameBound AND end=frameBound + | frameType=ROWS BETWEEN start=frameBound AND end=frameBound + ; + +frameBound + : UNBOUNDED boundType=(PRECEDING | FOLLOWING) + | boundType=CURRENT ROW + | expression boundType=(PRECEDING | FOLLOWING) + ; + + +explainOption + : LOGICAL | FORMATTED | EXTENDED + ; + +transactionMode + : ISOLATION LEVEL SNAPSHOT #isolationLevel + | READ accessMode=(ONLY | WRITE) #transactionAccessMode + ; + +qualifiedName + : identifier ('.' identifier)* + ; + +// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). +looseIdentifier + : identifier + | FROM + | TO + | TABLE + | WITH + ; + +identifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | nonReserved #unquotedIdentifier + ; + +quotedIdentifier + : BACKQUOTED_IDENTIFIER + ; + +number + : DECIMAL_VALUE #decimalLiteral + | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral + | INTEGER_VALUE #integerLiteral + | BIGINT_LITERAL #bigIntLiteral + | SMALLINT_LITERAL #smallIntLiteral + | TINYINT_LITERAL #tinyIntLiteral + | DOUBLE_LITERAL #doubleLiteral + ; + +nonReserved + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + | ADD + | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT + | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER + | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED + | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS + | GROUPING | CUBE | ROLLUP + | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF + | SET + | VIEW | REPLACE + | IF + | NO | DATA + | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL + | SNAPSHOT | READ | WRITE | ONLY + | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST + | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT + | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE + | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION + ; + +SELECT: 'SELECT'; +FROM: 'FROM'; +ADD: 'ADD'; +AS: 'AS'; +ALL: 'ALL'; +DISTINCT: 'DISTINCT'; +WHERE: 'WHERE'; +GROUP: 'GROUP'; +BY: 'BY'; +GROUPING: 'GROUPING'; +SETS: 'SETS'; +CUBE: 'CUBE'; +ROLLUP: 'ROLLUP'; +ORDER: 'ORDER'; +HAVING: 'HAVING'; +LIMIT: 'LIMIT'; +AT: 'AT'; +OR: 'OR'; +AND: 'AND'; +IN: 'IN'; +NOT: 'NOT' | '!'; +NO: 'NO'; +EXISTS: 'EXISTS'; +BETWEEN: 'BETWEEN'; +LIKE: 'LIKE'; +RLIKE: 'RLIKE' | 'REGEXP'; +IS: 'IS'; +NULL: 'NULL'; +TRUE: 'TRUE'; +FALSE: 'FALSE'; +NULLS: 'NULLS'; +ASC: 'ASC'; +DESC: 'DESC'; +FOR: 'FOR'; +INTERVAL: 'INTERVAL'; +CASE: 'CASE'; +WHEN: 'WHEN'; +THEN: 'THEN'; +ELSE: 'ELSE'; +END: 'END'; +JOIN: 'JOIN'; +CROSS: 'CROSS'; +OUTER: 'OUTER'; +INNER: 'INNER'; +LEFT: 'LEFT'; +SEMI: 'SEMI'; +RIGHT: 'RIGHT'; +FULL: 'FULL'; +NATURAL: 'NATURAL'; +ON: 'ON'; +LATERAL: 'LATERAL'; +WINDOW: 'WINDOW'; +OVER: 'OVER'; +PARTITION: 'PARTITION'; +RANGE: 'RANGE'; +ROWS: 'ROWS'; +UNBOUNDED: 'UNBOUNDED'; +PRECEDING: 'PRECEDING'; +FOLLOWING: 'FOLLOWING'; +CURRENT: 'CURRENT'; +ROW: 'ROW'; +WITH: 'WITH'; +VALUES: 'VALUES'; +CREATE: 'CREATE'; +TABLE: 'TABLE'; +VIEW: 'VIEW'; +REPLACE: 'REPLACE'; +INSERT: 'INSERT'; +DELETE: 'DELETE'; +INTO: 'INTO'; +DESCRIBE: 'DESCRIBE'; +EXPLAIN: 'EXPLAIN'; +FORMAT: 'FORMAT'; +LOGICAL: 'LOGICAL'; +CAST: 'CAST'; +SHOW: 'SHOW'; +TABLES: 'TABLES'; +COLUMNS: 'COLUMNS'; +COLUMN: 'COLUMN'; +USE: 'USE'; +PARTITIONS: 'PARTITIONS'; +FUNCTIONS: 'FUNCTIONS'; +DROP: 'DROP'; +UNION: 'UNION'; +EXCEPT: 'EXCEPT'; +INTERSECT: 'INTERSECT'; +TO: 'TO'; +TABLESAMPLE: 'TABLESAMPLE'; +STRATIFY: 'STRATIFY'; +ALTER: 'ALTER'; +RENAME: 'RENAME'; +ARRAY: 'ARRAY'; +MAP: 'MAP'; +STRUCT: 'STRUCT'; +COMMENT: 'COMMENT'; +SET: 'SET'; +DATA: 'DATA'; +START: 'START'; +TRANSACTION: 'TRANSACTION'; +COMMIT: 'COMMIT'; +ROLLBACK: 'ROLLBACK'; +WORK: 'WORK'; +ISOLATION: 'ISOLATION'; +LEVEL: 'LEVEL'; +SNAPSHOT: 'SNAPSHOT'; +READ: 'READ'; +WRITE: 'WRITE'; +ONLY: 'ONLY'; + +IF: 'IF'; + +EQ : '=' | '=='; +NSEQ: '<=>'; +NEQ : '<>'; +NEQJ: '!='; +LT : '<'; +LTE : '<='; +GT : '>'; +GTE : '>='; + +PLUS: '+'; +MINUS: '-'; +ASTERISK: '*'; +SLASH: '/'; +PERCENT: '%'; +DIV: 'DIV'; +TILDE: '~'; +AMPERSAND: '&'; +PIPE: '|'; +HAT: '^'; + +PERCENTLIT: 'PERCENT'; +BUCKET: 'BUCKET'; +OUT: 'OUT'; +OF: 'OF'; + +SORT: 'SORT'; +CLUSTER: 'CLUSTER'; +DISTRIBUTE: 'DISTRIBUTE'; +OVERWRITE: 'OVERWRITE'; +TRANSFORM: 'TRANSFORM'; +REDUCE: 'REDUCE'; +USING: 'USING'; +SERDE: 'SERDE'; +SERDEPROPERTIES: 'SERDEPROPERTIES'; +RECORDREADER: 'RECORDREADER'; +RECORDWRITER: 'RECORDWRITER'; +DELIMITED: 'DELIMITED'; +FIELDS: 'FIELDS'; +TERMINATED: 'TERMINATED'; +COLLECTION: 'COLLECTION'; +ITEMS: 'ITEMS'; +KEYS: 'KEYS'; +ESCAPED: 'ESCAPED'; +LINES: 'LINES'; +SEPARATED: 'SEPARATED'; +FUNCTION: 'FUNCTION'; +EXTENDED: 'EXTENDED'; +REFRESH: 'REFRESH'; +CLEAR: 'CLEAR'; +CACHE: 'CACHE'; +UNCACHE: 'UNCACHE'; +LAZY: 'LAZY'; +FORMATTED: 'FORMATTED'; +TEMPORARY: 'TEMPORARY' | 'TEMP'; +OPTIONS: 'OPTIONS'; +UNSET: 'UNSET'; +TBLPROPERTIES: 'TBLPROPERTIES'; +DBPROPERTIES: 'DBPROPERTIES'; +BUCKETS: 'BUCKETS'; +SKEWED: 'SKEWED'; +STORED: 'STORED'; +DIRECTORIES: 'DIRECTORIES'; +LOCATION: 'LOCATION'; +EXCHANGE: 'EXCHANGE'; +ARCHIVE: 'ARCHIVE'; +UNARCHIVE: 'UNARCHIVE'; +FILEFORMAT: 'FILEFORMAT'; +TOUCH: 'TOUCH'; +COMPACT: 'COMPACT'; +CONCATENATE: 'CONCATENATE'; +CHANGE: 'CHANGE'; +FIRST: 'FIRST'; +AFTER: 'AFTER'; +CASCADE: 'CASCADE'; +RESTRICT: 'RESTRICT'; +CLUSTERED: 'CLUSTERED'; +SORTED: 'SORTED'; +PURGE: 'PURGE'; +INPUTFORMAT: 'INPUTFORMAT'; +OUTPUTFORMAT: 'OUTPUTFORMAT'; +INPUTDRIVER: 'INPUTDRIVER'; +OUTPUTDRIVER: 'OUTPUTDRIVER'; +DATABASE: 'DATABASE' | 'SCHEMA'; +DFS: 'DFS'; +TRUNCATE: 'TRUNCATE'; +METADATA: 'METADATA'; +REPLICATION: 'REPLICATION'; +ANALYZE: 'ANALYZE'; +COMPUTE: 'COMPUTE'; +STATISTICS: 'STATISTICS'; +PARTITIONED: 'PARTITIONED'; +EXTERNAL: 'EXTERNAL'; +DEFINED: 'DEFINED'; +REVOKE: 'REVOKE'; +GRANT: 'GRANT'; +LOCK: 'LOCK'; +UNLOCK: 'UNLOCK'; +MSCK: 'MSCK'; +EXPORT: 'EXPORT'; +IMPORT: 'IMPORT'; +LOAD: 'LOAD'; +ROLE: 'ROLE'; +ROLES: 'ROLES'; +COMPACTIONS: 'COMPACTIONS'; +PRINCIPALS: 'PRINCIPALS'; +TRANSACTIONS: 'TRANSACTIONS'; +INDEXES: 'INDEXES'; +LOCKS: 'LOCKS'; +OPTION: 'OPTION'; + +STRING + : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' + | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' + ; + +BIGINT_LITERAL + : DIGIT+ 'L' + ; + +SMALLINT_LITERAL + : DIGIT+ 'S' + ; + +TINYINT_LITERAL + : DIGIT+ 'Y' + ; + +INTEGER_VALUE + : DIGIT+ + ; + +DECIMAL_VALUE + : DIGIT+ '.' DIGIT* + | '.' DIGIT+ + ; + +SCIENTIFIC_DECIMAL_VALUE + : DIGIT+ ('.' DIGIT*)? EXPONENT + | '.' DIGIT+ EXPONENT + ; + +DOUBLE_LITERAL + : + (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' + ; + +IDENTIFIER + : (LETTER | DIGIT | '_')+ + ; + +BACKQUOTED_IDENTIFIER + : '`' ( ~'`' | '``' )* '`' + ; + +fragment EXPONENT + : 'E' [+-]? DIGIT+ + ; + +fragment DIGIT + : [0-9] + ; + +fragment LETTER + : [A-Z] + ; + +SIMPLE_COMMENT + : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) + ; + +BRACKETED_COMMENT + : '/*' .*? '*/' -> channel(HIDDEN) + ; + +WS + : [ \r\n\t]+ -> channel(HIDDEN) + ; + +// Catch-all for anything we can't recognize. +// We use this to be able to ignore and recover all the text +// when splitting statements with DelimiterLexer +UNRECOGNIZED + : . + ; diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 deleted file mode 100644 index 4e77b6db25..0000000000 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/ng/SqlBase.g4 +++ /dev/null @@ -1,941 +0,0 @@ -/* - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - * This file is an adaptation of Presto's presto-parser/src/main/antlr4/com/facebook/presto/sql/parser/SqlBase.g4 grammar. - */ - -grammar SqlBase; - -tokens { - DELIMITER -} - -singleStatement - : statement EOF - ; - -singleExpression - : namedExpression EOF - ; - -singleTableIdentifier - : tableIdentifier EOF - ; - -singleDataType - : dataType EOF - ; - -statement - : query #statementDefault - | USE db=identifier #use - | CREATE DATABASE (IF NOT EXISTS)? identifier - (COMMENT comment=STRING)? locationSpec? - (WITH DBPROPERTIES tablePropertyList)? #createDatabase - | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties - | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase - | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS tablePropertyList)? #createTableUsing - | createTableHeader tableProvider - (OPTIONS tablePropertyList)? AS? query #createTableUsing - | createTableHeader ('(' colTypeList ')')? (COMMENT STRING)? - (PARTITIONED BY identifierList)? bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - (AS? query)? #createTable - | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS - (identifier | FOR COLUMNS identifierSeq?) #analyze - | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable - | ALTER TABLE tableIdentifier - SET TBLPROPERTIES tablePropertyList #setTableProperties - | ALTER TABLE tableIdentifier - UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties - | ALTER TABLE tableIdentifier (partitionSpec)? - SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe - | ALTER TABLE tableIdentifier (partitionSpec)? - SET SERDEPROPERTIES tablePropertyList #setTableSerDe - | ALTER TABLE tableIdentifier bucketSpec #bucketTable - | ALTER TABLE tableIdentifier NOT CLUSTERED #unclusterTable - | ALTER TABLE tableIdentifier NOT SORTED #unsortTable - | ALTER TABLE tableIdentifier skewSpec #skewTable - | ALTER TABLE tableIdentifier NOT SKEWED #unskewTable - | ALTER TABLE tableIdentifier NOT STORED AS DIRECTORIES #unstoreTable - | ALTER TABLE tableIdentifier - SET SKEWED LOCATION skewedLocationList #setTableSkewLocations - | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? - partitionSpecLocation+ #addTablePartition - | ALTER TABLE tableIdentifier - from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition - | ALTER TABLE from=tableIdentifier - EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition - | ALTER TABLE tableIdentifier - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions - | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition - | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition - | ALTER TABLE tableIdentifier partitionSpec? - SET FILEFORMAT fileFormat #setTableFileFormat - | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation - | ALTER TABLE tableIdentifier TOUCH partitionSpec? #touchTable - | ALTER TABLE tableIdentifier partitionSpec? COMPACT STRING #compactTable - | ALTER TABLE tableIdentifier partitionSpec? CONCATENATE #concatenateTable - | ALTER TABLE tableIdentifier partitionSpec? - CHANGE COLUMN? oldName=identifier colType - (FIRST | AFTER after=identifier)? (CASCADE | RESTRICT)? #changeColumn - | ALTER TABLE tableIdentifier partitionSpec? - ADD COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #addColumns - | ALTER TABLE tableIdentifier partitionSpec? - REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns - | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? - (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable - | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier - identifierCommentList? (COMMENT STRING)? - (PARTITIONED ON identifierList)? - (TBLPROPERTIES tablePropertyList)? AS query #createView - | ALTER VIEW tableIdentifier AS? query #alterViewQuery - | CREATE TEMPORARY? FUNCTION qualifiedName AS className=STRING - (USING resource (',' resource)*)? #createFunction - | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction - | EXPLAIN explainOption* statement #explain - | SHOW TABLES ((FROM | IN) db=identifier)? - (LIKE (qualifiedName | pattern=STRING))? #showTables - | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions - | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction - | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? - tableIdentifier partitionSpec? describeColName? #describeTable - | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase - | REFRESH TABLE tableIdentifier #refreshTable - | CACHE LAZY? TABLE identifier (AS? query)? #cacheTable - | UNCACHE TABLE identifier #uncacheTable - | CLEAR CACHE #clearCache - | ADD identifier .*? #addResource - | SET ROLE .*? #failNativeCommand - | SET .*? #setConfiguration - | kws=unsupportedHiveNativeCommands .*? #failNativeCommand - | hiveNativeCommands #executeNativeCommand - ; - -hiveNativeCommands - : createTableHeader LIKE tableIdentifier - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - | DELETE FROM tableIdentifier (WHERE booleanExpression)? - | TRUNCATE TABLE tableIdentifier partitionSpec? - (COLUMNS identifierList)? - | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier - | ALTER VIEW from=tableIdentifier AS? - SET TBLPROPERTIES tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - ADD (IF NOT EXISTS)? partitionSpecLocation+ - | ALTER VIEW from=tableIdentifier AS? - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? - | DROP VIEW (IF EXISTS)? qualifiedName - | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? - | START TRANSACTION (transactionMode (',' transactionMode)*)? - | COMMIT WORK? - | ROLLBACK WORK? - | SHOW PARTITIONS tableIdentifier partitionSpec? - | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? - ; - -unsupportedHiveNativeCommands - : kw1=CREATE kw2=ROLE - | kw1=DROP kw2=ROLE - | kw1=GRANT kw2=ROLE? - | kw1=REVOKE kw2=ROLE? - | kw1=SHOW kw2=GRANT - | kw1=SHOW kw2=ROLE kw3=GRANT? - | kw1=SHOW kw2=PRINCIPALS - | kw1=SHOW kw2=ROLES - | kw1=SHOW kw2=CURRENT kw3=ROLES - | kw1=EXPORT kw2=TABLE - | kw1=IMPORT kw2=TABLE - | kw1=SHOW kw2=COMPACTIONS - | kw1=SHOW kw2=CREATE kw3=TABLE - | kw1=SHOW kw2=TRANSACTIONS - | kw1=SHOW kw2=INDEXES - | kw1=SHOW kw2=LOCKS - ; - -createTableHeader - : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier - ; - -bucketSpec - : CLUSTERED BY identifierList - (SORTED BY orderedIdentifierList)? - INTO INTEGER_VALUE BUCKETS - ; - -skewSpec - : SKEWED BY identifierList - ON (constantList | nestedConstantList) - (STORED AS DIRECTORIES)? - ; - -locationSpec - : LOCATION STRING - ; - -query - : ctes? queryNoWith - ; - -insertInto - : INSERT OVERWRITE TABLE tableIdentifier partitionSpec? (IF NOT EXISTS)? - | INSERT INTO TABLE? tableIdentifier partitionSpec? - ; - -partitionSpecLocation - : partitionSpec locationSpec? - ; - -partitionSpec - : PARTITION '(' partitionVal (',' partitionVal)* ')' - ; - -partitionVal - : identifier (EQ constant)? - ; - -describeColName - : identifier ('.' (identifier | STRING))* - ; - -ctes - : WITH namedQuery (',' namedQuery)* - ; - -namedQuery - : name=identifier AS? '(' queryNoWith ')' - ; - -tableProvider - : USING qualifiedName - ; - -tablePropertyList - : '(' tableProperty (',' tableProperty)* ')' - ; - -tableProperty - : key=tablePropertyKey (EQ? value=STRING)? - ; - -tablePropertyKey - : looseIdentifier ('.' looseIdentifier)* - | STRING - ; - -constantList - : '(' constant (',' constant)* ')' - ; - -nestedConstantList - : '(' constantList (',' constantList)* ')' - ; - -skewedLocation - : (constant | constantList) EQ STRING - ; - -skewedLocationList - : '(' skewedLocation (',' skewedLocation)* ')' - ; - -createFileFormat - : STORED AS fileFormat - | STORED BY storageHandler - ; - -fileFormat - : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? - (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat - | identifier #genericFileFormat - ; - -storageHandler - : STRING (WITH SERDEPROPERTIES tablePropertyList)? - ; - -resource - : identifier STRING - ; - -queryNoWith - : insertInto? queryTerm queryOrganization #singleInsertQuery - | fromClause multiInsertQueryBody+ #multiInsertQuery - ; - -queryOrganization - : (ORDER BY order+=sortItem (',' order+=sortItem)*)? - (CLUSTER BY clusterBy+=expression (',' clusterBy+=expression)*)? - (DISTRIBUTE BY distributeBy+=expression (',' distributeBy+=expression)*)? - (SORT BY sort+=sortItem (',' sort+=sortItem)*)? - windows? - (LIMIT limit=expression)? - ; - -multiInsertQueryBody - : insertInto? - querySpecification - queryOrganization - ; - -queryTerm - : queryPrimary #queryTermDefault - | left=queryTerm operator=(INTERSECT | UNION | EXCEPT) setQuantifier? right=queryTerm #setOperation - ; - -queryPrimary - : querySpecification #queryPrimaryDefault - | TABLE tableIdentifier #table - | inlineTable #inlineTableDefault1 - | '(' queryNoWith ')' #subquery - ; - -sortItem - : expression ordering=(ASC | DESC)? - ; - -querySpecification - : (((SELECT kind=TRANSFORM '(' namedExpressionSeq ')' - | kind=MAP namedExpressionSeq - | kind=REDUCE namedExpressionSeq)) - inRowFormat=rowFormat? - (RECORDWRITER recordWriter=STRING)? - USING script=STRING - (AS (identifierSeq | colTypeList | ('(' (identifierSeq | colTypeList) ')')))? - outRowFormat=rowFormat? - (RECORDREADER recordReader=STRING)? - fromClause? - (WHERE where=booleanExpression)?) - | ((kind=SELECT setQuantifier? namedExpressionSeq fromClause? - | fromClause (kind=SELECT setQuantifier? namedExpressionSeq)?) - lateralView* - (WHERE where=booleanExpression)? - aggregation? - (HAVING having=booleanExpression)? - windows?) - ; - -fromClause - : FROM relation (',' relation)* lateralView* - ; - -aggregation - : GROUP BY groupingExpressions+=expression (',' groupingExpressions+=expression)* ( - WITH kind=ROLLUP - | WITH kind=CUBE - | kind=GROUPING SETS '(' groupingSet (',' groupingSet)* ')')? - ; - -groupingSet - : '(' (expression (',' expression)*)? ')' - | expression - ; - -lateralView - : LATERAL VIEW (OUTER)? qualifiedName '(' (expression (',' expression)*)? ')' tblName=identifier (AS? colName+=identifier (',' colName+=identifier)*)? - ; - -setQuantifier - : DISTINCT - | ALL - ; - -relation - : left=relation - ((CROSS | joinType) JOIN right=relation joinCriteria? - | NATURAL joinType JOIN right=relation - ) #joinRelation - | relationPrimary #relationDefault - ; - -joinType - : INNER? - | LEFT OUTER? - | LEFT SEMI - | RIGHT OUTER? - | FULL OUTER? - ; - -joinCriteria - : ON booleanExpression - | USING '(' identifier (',' identifier)* ')' - ; - -sample - : TABLESAMPLE '(' - ( (percentage=(INTEGER_VALUE | DECIMAL_VALUE) sampleType=PERCENTLIT) - | (expression sampleType=ROWS) - | (sampleType=BUCKET numerator=INTEGER_VALUE OUT OF denominator=INTEGER_VALUE (ON identifier)?)) - ')' - ; - -identifierList - : '(' identifierSeq ')' - ; - -identifierSeq - : identifier (',' identifier)* - ; - -orderedIdentifierList - : '(' orderedIdentifier (',' orderedIdentifier)* ')' - ; - -orderedIdentifier - : identifier ordering=(ASC | DESC)? - ; - -identifierCommentList - : '(' identifierComment (',' identifierComment)* ')' - ; - -identifierComment - : identifier (COMMENT STRING)? - ; - -relationPrimary - : tableIdentifier sample? (AS? identifier)? #tableName - | '(' queryNoWith ')' sample? (AS? identifier)? #aliasedQuery - | '(' relation ')' sample? (AS? identifier)? #aliasedRelation - | inlineTable #inlineTableDefault2 - ; - -inlineTable - : VALUES expression (',' expression)* (AS? identifier identifierList?)? - ; - -rowFormat - : ROW FORMAT SERDE name=STRING (WITH SERDEPROPERTIES props=tablePropertyList)? #rowFormatSerde - | ROW FORMAT DELIMITED - (FIELDS TERMINATED BY fieldsTerminatedBy=STRING (ESCAPED BY escapedBy=STRING)?)? - (COLLECTION ITEMS TERMINATED BY collectionItemsTerminatedBy=STRING)? - (MAP KEYS TERMINATED BY keysTerminatedBy=STRING)? - (LINES TERMINATED BY linesSeparatedBy=STRING)? - (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited - ; - -tableIdentifier - : (db=identifier '.')? table=identifier - ; - -namedExpression - : expression (AS? (identifier | identifierList))? - ; - -namedExpressionSeq - : namedExpression (',' namedExpression)* - ; - -expression - : booleanExpression - ; - -booleanExpression - : predicated #booleanDefault - | NOT booleanExpression #logicalNot - | left=booleanExpression operator=AND right=booleanExpression #logicalBinary - | left=booleanExpression operator=OR right=booleanExpression #logicalBinary - | EXISTS '(' query ')' #exists - ; - -// workaround for: -// https://github.com/antlr/antlr4/issues/780 -// https://github.com/antlr/antlr4/issues/781 -predicated - : valueExpression predicate[$valueExpression.ctx]? - ; - -predicate[ParserRuleContext value] - : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between - | NOT? IN '(' expression (',' expression)* ')' #inList - | NOT? IN '(' query ')' #inSubquery - | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like - | IS NOT? NULL #nullPredicate - ; - -valueExpression - : primaryExpression #valueExpressionDefault - | operator=(MINUS | PLUS | TILDE) valueExpression #arithmeticUnary - | left=valueExpression operator=(ASTERISK | SLASH | PERCENT | DIV) right=valueExpression #arithmeticBinary - | left=valueExpression operator=(PLUS | MINUS) right=valueExpression #arithmeticBinary - | left=valueExpression operator=AMPERSAND right=valueExpression #arithmeticBinary - | left=valueExpression operator=HAT right=valueExpression #arithmeticBinary - | left=valueExpression operator=PIPE right=valueExpression #arithmeticBinary - | left=valueExpression comparisonOperator right=valueExpression #comparison - ; - -primaryExpression - : constant #constantDefault - | ASTERISK #star - | qualifiedName '.' ASTERISK #star - | '(' expression (',' expression)+ ')' #rowConstructor - | qualifiedName '(' (setQuantifier? expression (',' expression)*)? ')' (OVER windowSpec)? #functionCall - | '(' query ')' #subqueryExpression - | CASE valueExpression whenClause+ (ELSE elseExpression=expression)? END #simpleCase - | CASE whenClause+ (ELSE elseExpression=expression)? END #searchedCase - | CAST '(' expression AS dataType ')' #cast - | value=primaryExpression '[' index=valueExpression ']' #subscript - | identifier #columnReference - | base=primaryExpression '.' fieldName=identifier #dereference - | '(' expression ')' #parenthesizedExpression - ; - -constant - : NULL #nullLiteral - | interval #intervalLiteral - | identifier STRING #typeConstructor - | number #numericLiteral - | booleanValue #booleanLiteral - | STRING+ #stringLiteral - ; - -comparisonOperator - : EQ | NEQ | NEQJ | LT | LTE | GT | GTE | NSEQ - ; - -booleanValue - : TRUE | FALSE - ; - -interval - : INTERVAL intervalField* - ; - -intervalField - : value=intervalValue unit=identifier (TO to=identifier)? - ; - -intervalValue - : (PLUS | MINUS)? (INTEGER_VALUE | DECIMAL_VALUE) - | STRING - ; - -dataType - : complex=ARRAY '<' dataType '>' #complexDataType - | complex=MAP '<' dataType ',' dataType '>' #complexDataType - | complex=STRUCT ('<' colTypeList? '>' | NEQ) #complexDataType - | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType - ; - -colTypeList - : colType (',' colType)* - ; - -colType - : identifier ':'? dataType (COMMENT STRING)? - ; - -whenClause - : WHEN condition=expression THEN result=expression - ; - -windows - : WINDOW namedWindow (',' namedWindow)* - ; - -namedWindow - : identifier AS windowSpec - ; - -windowSpec - : name=identifier #windowRef - | '(' - ( CLUSTER BY partition+=expression (',' partition+=expression)* - | ((PARTITION | DISTRIBUTE) BY partition+=expression (',' partition+=expression)*)? - ((ORDER | SORT) BY sortItem (',' sortItem)*)?) - windowFrame? - ')' #windowDef - ; - -windowFrame - : frameType=RANGE start=frameBound - | frameType=ROWS start=frameBound - | frameType=RANGE BETWEEN start=frameBound AND end=frameBound - | frameType=ROWS BETWEEN start=frameBound AND end=frameBound - ; - -frameBound - : UNBOUNDED boundType=(PRECEDING | FOLLOWING) - | boundType=CURRENT ROW - | expression boundType=(PRECEDING | FOLLOWING) - ; - - -explainOption - : LOGICAL | FORMATTED | EXTENDED - ; - -transactionMode - : ISOLATION LEVEL SNAPSHOT #isolationLevel - | READ accessMode=(ONLY | WRITE) #transactionAccessMode - ; - -qualifiedName - : identifier ('.' identifier)* - ; - -// Identifier that also allows the use of a number of SQL keywords (mainly for backwards compatibility). -looseIdentifier - : identifier - | FROM - | TO - | TABLE - | WITH - ; - -identifier - : IDENTIFIER #unquotedIdentifier - | quotedIdentifier #quotedIdentifierAlternative - | nonReserved #unquotedIdentifier - ; - -quotedIdentifier - : BACKQUOTED_IDENTIFIER - ; - -number - : DECIMAL_VALUE #decimalLiteral - | SCIENTIFIC_DECIMAL_VALUE #scientificDecimalLiteral - | INTEGER_VALUE #integerLiteral - | BIGINT_LITERAL #bigIntLiteral - | SMALLINT_LITERAL #smallIntLiteral - | TINYINT_LITERAL #tinyIntLiteral - | DOUBLE_LITERAL #doubleLiteral - ; - -nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS - | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT - | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER - | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED - | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS - | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED - | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF - | SET - | VIEW | REPLACE - | IF - | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL - | SNAPSHOT | READ | WRITE | ONLY - | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION - | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST - | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT - | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE - | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE - | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION - ; - -SELECT: 'SELECT'; -FROM: 'FROM'; -ADD: 'ADD'; -AS: 'AS'; -ALL: 'ALL'; -DISTINCT: 'DISTINCT'; -WHERE: 'WHERE'; -GROUP: 'GROUP'; -BY: 'BY'; -GROUPING: 'GROUPING'; -SETS: 'SETS'; -CUBE: 'CUBE'; -ROLLUP: 'ROLLUP'; -ORDER: 'ORDER'; -HAVING: 'HAVING'; -LIMIT: 'LIMIT'; -AT: 'AT'; -OR: 'OR'; -AND: 'AND'; -IN: 'IN'; -NOT: 'NOT' | '!'; -NO: 'NO'; -EXISTS: 'EXISTS'; -BETWEEN: 'BETWEEN'; -LIKE: 'LIKE'; -RLIKE: 'RLIKE' | 'REGEXP'; -IS: 'IS'; -NULL: 'NULL'; -TRUE: 'TRUE'; -FALSE: 'FALSE'; -NULLS: 'NULLS'; -ASC: 'ASC'; -DESC: 'DESC'; -FOR: 'FOR'; -INTERVAL: 'INTERVAL'; -CASE: 'CASE'; -WHEN: 'WHEN'; -THEN: 'THEN'; -ELSE: 'ELSE'; -END: 'END'; -JOIN: 'JOIN'; -CROSS: 'CROSS'; -OUTER: 'OUTER'; -INNER: 'INNER'; -LEFT: 'LEFT'; -SEMI: 'SEMI'; -RIGHT: 'RIGHT'; -FULL: 'FULL'; -NATURAL: 'NATURAL'; -ON: 'ON'; -LATERAL: 'LATERAL'; -WINDOW: 'WINDOW'; -OVER: 'OVER'; -PARTITION: 'PARTITION'; -RANGE: 'RANGE'; -ROWS: 'ROWS'; -UNBOUNDED: 'UNBOUNDED'; -PRECEDING: 'PRECEDING'; -FOLLOWING: 'FOLLOWING'; -CURRENT: 'CURRENT'; -ROW: 'ROW'; -WITH: 'WITH'; -VALUES: 'VALUES'; -CREATE: 'CREATE'; -TABLE: 'TABLE'; -VIEW: 'VIEW'; -REPLACE: 'REPLACE'; -INSERT: 'INSERT'; -DELETE: 'DELETE'; -INTO: 'INTO'; -DESCRIBE: 'DESCRIBE'; -EXPLAIN: 'EXPLAIN'; -FORMAT: 'FORMAT'; -LOGICAL: 'LOGICAL'; -CAST: 'CAST'; -SHOW: 'SHOW'; -TABLES: 'TABLES'; -COLUMNS: 'COLUMNS'; -COLUMN: 'COLUMN'; -USE: 'USE'; -PARTITIONS: 'PARTITIONS'; -FUNCTIONS: 'FUNCTIONS'; -DROP: 'DROP'; -UNION: 'UNION'; -EXCEPT: 'EXCEPT'; -INTERSECT: 'INTERSECT'; -TO: 'TO'; -TABLESAMPLE: 'TABLESAMPLE'; -STRATIFY: 'STRATIFY'; -ALTER: 'ALTER'; -RENAME: 'RENAME'; -ARRAY: 'ARRAY'; -MAP: 'MAP'; -STRUCT: 'STRUCT'; -COMMENT: 'COMMENT'; -SET: 'SET'; -DATA: 'DATA'; -START: 'START'; -TRANSACTION: 'TRANSACTION'; -COMMIT: 'COMMIT'; -ROLLBACK: 'ROLLBACK'; -WORK: 'WORK'; -ISOLATION: 'ISOLATION'; -LEVEL: 'LEVEL'; -SNAPSHOT: 'SNAPSHOT'; -READ: 'READ'; -WRITE: 'WRITE'; -ONLY: 'ONLY'; - -IF: 'IF'; - -EQ : '=' | '=='; -NSEQ: '<=>'; -NEQ : '<>'; -NEQJ: '!='; -LT : '<'; -LTE : '<='; -GT : '>'; -GTE : '>='; - -PLUS: '+'; -MINUS: '-'; -ASTERISK: '*'; -SLASH: '/'; -PERCENT: '%'; -DIV: 'DIV'; -TILDE: '~'; -AMPERSAND: '&'; -PIPE: '|'; -HAT: '^'; - -PERCENTLIT: 'PERCENT'; -BUCKET: 'BUCKET'; -OUT: 'OUT'; -OF: 'OF'; - -SORT: 'SORT'; -CLUSTER: 'CLUSTER'; -DISTRIBUTE: 'DISTRIBUTE'; -OVERWRITE: 'OVERWRITE'; -TRANSFORM: 'TRANSFORM'; -REDUCE: 'REDUCE'; -USING: 'USING'; -SERDE: 'SERDE'; -SERDEPROPERTIES: 'SERDEPROPERTIES'; -RECORDREADER: 'RECORDREADER'; -RECORDWRITER: 'RECORDWRITER'; -DELIMITED: 'DELIMITED'; -FIELDS: 'FIELDS'; -TERMINATED: 'TERMINATED'; -COLLECTION: 'COLLECTION'; -ITEMS: 'ITEMS'; -KEYS: 'KEYS'; -ESCAPED: 'ESCAPED'; -LINES: 'LINES'; -SEPARATED: 'SEPARATED'; -FUNCTION: 'FUNCTION'; -EXTENDED: 'EXTENDED'; -REFRESH: 'REFRESH'; -CLEAR: 'CLEAR'; -CACHE: 'CACHE'; -UNCACHE: 'UNCACHE'; -LAZY: 'LAZY'; -FORMATTED: 'FORMATTED'; -TEMPORARY: 'TEMPORARY' | 'TEMP'; -OPTIONS: 'OPTIONS'; -UNSET: 'UNSET'; -TBLPROPERTIES: 'TBLPROPERTIES'; -DBPROPERTIES: 'DBPROPERTIES'; -BUCKETS: 'BUCKETS'; -SKEWED: 'SKEWED'; -STORED: 'STORED'; -DIRECTORIES: 'DIRECTORIES'; -LOCATION: 'LOCATION'; -EXCHANGE: 'EXCHANGE'; -ARCHIVE: 'ARCHIVE'; -UNARCHIVE: 'UNARCHIVE'; -FILEFORMAT: 'FILEFORMAT'; -TOUCH: 'TOUCH'; -COMPACT: 'COMPACT'; -CONCATENATE: 'CONCATENATE'; -CHANGE: 'CHANGE'; -FIRST: 'FIRST'; -AFTER: 'AFTER'; -CASCADE: 'CASCADE'; -RESTRICT: 'RESTRICT'; -CLUSTERED: 'CLUSTERED'; -SORTED: 'SORTED'; -PURGE: 'PURGE'; -INPUTFORMAT: 'INPUTFORMAT'; -OUTPUTFORMAT: 'OUTPUTFORMAT'; -INPUTDRIVER: 'INPUTDRIVER'; -OUTPUTDRIVER: 'OUTPUTDRIVER'; -DATABASE: 'DATABASE' | 'SCHEMA'; -DFS: 'DFS'; -TRUNCATE: 'TRUNCATE'; -METADATA: 'METADATA'; -REPLICATION: 'REPLICATION'; -ANALYZE: 'ANALYZE'; -COMPUTE: 'COMPUTE'; -STATISTICS: 'STATISTICS'; -PARTITIONED: 'PARTITIONED'; -EXTERNAL: 'EXTERNAL'; -DEFINED: 'DEFINED'; -REVOKE: 'REVOKE'; -GRANT: 'GRANT'; -LOCK: 'LOCK'; -UNLOCK: 'UNLOCK'; -MSCK: 'MSCK'; -EXPORT: 'EXPORT'; -IMPORT: 'IMPORT'; -LOAD: 'LOAD'; -ROLE: 'ROLE'; -ROLES: 'ROLES'; -COMPACTIONS: 'COMPACTIONS'; -PRINCIPALS: 'PRINCIPALS'; -TRANSACTIONS: 'TRANSACTIONS'; -INDEXES: 'INDEXES'; -LOCKS: 'LOCKS'; -OPTION: 'OPTION'; - -STRING - : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' - | '\"' ( ~('\"'|'\\') | ('\\' .) )* '\"' - ; - -BIGINT_LITERAL - : DIGIT+ 'L' - ; - -SMALLINT_LITERAL - : DIGIT+ 'S' - ; - -TINYINT_LITERAL - : DIGIT+ 'Y' - ; - -INTEGER_VALUE - : DIGIT+ - ; - -DECIMAL_VALUE - : DIGIT+ '.' DIGIT* - | '.' DIGIT+ - ; - -SCIENTIFIC_DECIMAL_VALUE - : DIGIT+ ('.' DIGIT*)? EXPONENT - | '.' DIGIT+ EXPONENT - ; - -DOUBLE_LITERAL - : - (INTEGER_VALUE | DECIMAL_VALUE | SCIENTIFIC_DECIMAL_VALUE) 'D' - ; - -IDENTIFIER - : (LETTER | DIGIT | '_')+ - ; - -BACKQUOTED_IDENTIFIER - : '`' ( ~'`' | '``' )* '`' - ; - -fragment EXPONENT - : 'E' [+-]? DIGIT+ - ; - -fragment DIGIT - : [0-9] - ; - -fragment LETTER - : [A-Z] - ; - -SIMPLE_COMMENT - : '--' ~[\r\n]* '\r'? '\n'? -> channel(HIDDEN) - ; - -BRACKETED_COMMENT - : '/*' .*? '*/' -> channel(HIDDEN) - ; - -WS - : [ \r\n\t]+ -> channel(HIDDEN) - ; - -// Catch-all for anything we can't recognize. -// We use this to be able to ignore and recover all the text -// when splitting statements with DelimiterLexer -UNRECOGNIZED - : . - ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala deleted file mode 100644 index 28f7b10ed6..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ASTNode.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -import org.antlr.runtime.{Token, TokenRewriteStream} - -import org.apache.spark.sql.catalyst.trees.{Origin, TreeNode} - -case class ASTNode( - token: Token, - startIndex: Int, - stopIndex: Int, - children: List[ASTNode], - stream: TokenRewriteStream) extends TreeNode[ASTNode] { - /** Cache the number of children. */ - val numChildren: Int = children.size - - /** tuple used in pattern matching. */ - val pattern: Some[(String, List[ASTNode])] = Some((token.getText, children)) - - /** Line in which the ASTNode starts. */ - lazy val line: Int = { - val line = token.getLine - if (line == 0) { - if (children.nonEmpty) children.head.line - else 0 - } else { - line - } - } - - /** Position of the Character at which ASTNode starts. */ - lazy val positionInLine: Int = { - val line = token.getCharPositionInLine - if (line == -1) { - if (children.nonEmpty) children.head.positionInLine - else 0 - } else { - line - } - } - - /** Origin of the ASTNode. */ - override val origin: Origin = Origin(Some(line), Some(positionInLine)) - - /** Source text. */ - lazy val source: String = stream.toOriginalString(startIndex, stopIndex) - - /** Get the source text that remains after this token. */ - lazy val remainder: String = { - stream.fill() - stream.toOriginalString(stopIndex + 1, stream.size() - 1).trim() - } - - def text: String = token.getText - - def tokenType: Int = token.getType - - /** - * Checks if this node is equal to another node. - * - * Right now this function only checks the name, type, text and children of the node - * for equality. - */ - def treeEquals(other: ASTNode): Boolean = { - def check(f: ASTNode => Any): Boolean = { - val l = f(this) - val r = f(other) - (l == null && r == null) || l.equals(r) - } - if (other == null) { - false - } else if (!check(_.token.getType) - || !check(_.token.getText) - || !check(_.numChildren)) { - false - } else { - children.zip(other.children).forall { - case (l, r) => l treeEquals r - } - } - } - - override def simpleString: String = s"$text $line, $startIndex, $stopIndex, $positionInLine " -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala deleted file mode 100644 index 7b456a6de3..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AbstractSparkSQLParser.scala +++ /dev/null @@ -1,145 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import scala.language.implicitConversions -import scala.util.parsing.combinator.lexical.StdLexical -import scala.util.parsing.combinator.syntactical.StandardTokenParsers -import scala.util.parsing.combinator.PackratParsers -import scala.util.parsing.input.CharArrayReader.EofCh - -import org.apache.spark.sql.catalyst.plans.logical._ - -private[sql] abstract class AbstractSparkSQLParser - extends StandardTokenParsers with PackratParsers with ParserInterface { - - def parsePlan(input: String): LogicalPlan = synchronized { - // Initialize the Keywords. - initLexical - phrase(start)(new lexical.Scanner(input)) match { - case Success(plan, _) => plan - case failureOrError => sys.error(failureOrError.toString) - } - } - /* One time initialization of lexical.This avoid reinitialization of lexical in parse method */ - protected lazy val initLexical: Unit = lexical.initialize(reservedWords) - - protected case class Keyword(str: String) { - def normalize: String = lexical.normalizeKeyword(str) - def parser: Parser[String] = normalize - } - - protected implicit def asParser(k: Keyword): Parser[String] = k.parser - - // By default, use Reflection to find the reserved words defined in the sub class. - // NOTICE, Since the Keyword properties defined by sub class, we couldn't call this - // method during the parent class instantiation, because the sub class instance - // isn't created yet. - protected lazy val reservedWords: Seq[String] = - this - .getClass - .getMethods - .filter(_.getReturnType == classOf[Keyword]) - .map(_.invoke(this).asInstanceOf[Keyword].normalize) - - // Set the keywords as empty by default, will change that later. - override val lexical = new SqlLexical - - protected def start: Parser[LogicalPlan] - - // Returns the whole input string - protected lazy val wholeInput: Parser[String] = new Parser[String] { - def apply(in: Input): ParseResult[String] = - Success(in.source.toString, in.drop(in.source.length())) - } - - // Returns the rest of the input string that are not parsed yet - protected lazy val restInput: Parser[String] = new Parser[String] { - def apply(in: Input): ParseResult[String] = - Success( - in.source.subSequence(in.offset, in.source.length()).toString, - in.drop(in.source.length())) - } -} - -class SqlLexical extends StdLexical { - case class DecimalLit(chars: String) extends Token { - override def toString: String = chars - } - - /* This is a work around to support the lazy setting */ - def initialize(keywords: Seq[String]): Unit = { - reserved.clear() - reserved ++= keywords - } - - /* Normal the keyword string */ - def normalizeKeyword(str: String): String = str.toLowerCase - - delimiters += ( - "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", - ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" - ) - - protected override def processIdent(name: String) = { - val token = normalizeKeyword(name) - if (reserved contains token) Keyword(token) else Identifier(name) - } - - override lazy val token: Parser[Token] = - ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } - | '.' ~> (rep1(digit) ~ scientificNotation) ^^ - { case i ~ s => DecimalLit("0." + i.mkString + s) } - | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ - { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } - | digit.* ~ identChar ~ (identChar | digit).* ^^ - { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } - | rep1(digit) ~ ('.' ~> digit.*).? ^^ { - case i ~ None => NumericLit(i.mkString) - case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) - } - | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ - { case chars => StringLit(chars mkString "") } - | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ - { case chars => StringLit(chars mkString "") } - | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ - { case chars => Identifier(chars mkString "") } - | EofCh ^^^ EOF - | '\'' ~> failure("unclosed string literal") - | '"' ~> failure("unclosed string literal") - | delim - | failure("illegal character") - ) - - override def identChar: Parser[Elem] = letter | elem('_') - - private lazy val scientificNotation: Parser[String] = - (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { - case s ~ rest => "e" + s.mkString + rest.mkString - } - - override def whitespace: Parser[Any] = - ( whitespaceChar - | '/' ~ '*' ~ comment - | '/' ~ '/' ~ chrExcept(EofCh, '\n').* - | '#' ~ chrExcept(EofCh, '\n').* - | '-' ~ '-' ~ chrExcept(EofCh, '\n').* - | '/' ~ '*' ~ failure("unclosed comment") - ).* -} - diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala new file mode 100644 index 0000000000..c350f3049f --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -0,0 +1,1460 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.antlr.v4.runtime.{ParserRuleContext, Token} +import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval +import org.apache.spark.util.random.RandomSampler + +/** + * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or + * TableIdentifier. + */ +class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { + import ParserUtils._ + + protected def typedVisit[T](ctx: ParseTree): T = { + ctx.accept(this).asInstanceOf[T] + } + + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { + visit(ctx.statement).asInstanceOf[LogicalPlan] + } + + override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { + visitNamedExpression(ctx.namedExpression) + } + + override def visitSingleTableIdentifier( + ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { + visitTableIdentifier(ctx.tableIdentifier) + } + + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { + visit(ctx.dataType).asInstanceOf[DataType] + } + + /* ******************************************************************************************** + * Plan parsing + * ******************************************************************************************** */ + protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) + + /** + * Make sure we do not try to create a plan for a native command. + */ + override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null + + /** + * Create a plan for a SHOW FUNCTIONS command. + */ + override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { + import ctx._ + if (qualifiedName != null) { + val names = qualifiedName().identifier().asScala.map(_.getText).toList + names match { + case db :: name :: Nil => + ShowFunctions(Some(db), Some(name)) + case name :: Nil => + ShowFunctions(None, Some(name)) + case _ => + throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) + } + } else if (pattern != null) { + ShowFunctions(None, Some(string(pattern))) + } else { + ShowFunctions(None, None) + } + } + + /** + * Create a plan for a DESCRIBE FUNCTION command. + */ + override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { + val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") + DescribeFunction(functionName, ctx.EXTENDED != null) + } + + /** + * Create a top-level plan with Common Table Expressions. + */ + override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { + val query = plan(ctx.queryNoWith) + + // Apply CTEs + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { + case nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + + // Check for duplicate names. + ctes.groupBy(_._1).filter(_._2.size > 1).foreach { + case (name, _) => + throw new ParseException( + s"Name '$name' is used for multiple common table expressions", ctx) + } + + With(query, ctes.toMap) + } + } + + /** + * Create a named logical plan. + * + * This is only used for Common Table Expressions. + */ + override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { + SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) + } + + /** + * Create a logical plan which allows for multiple inserts using one 'from' statement. These + * queries have the following SQL form: + * {{{ + * [WITH cte...]? + * FROM src + * [INSERT INTO tbl1 SELECT *]+ + * }}} + * For example: + * {{{ + * FROM db.tbl1 A + * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 + * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 + * }}} + * This (Hive) feature cannot be combined with set-operators. + */ + override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val inserts = ctx.multiInsertQueryBody.asScala.map { + body => + assert(body.querySpecification.fromClause == null, + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", + body) + + withQuerySpecification(body.querySpecification, from). + // Add organization statements. + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) + } + + // If there are multiple INSERTS just UNION them together into one query. + inserts match { + case Seq(query) => query + case queries => Union(queries) + } + } + + /** + * Create a logical plan for a regular (single-insert) query. + */ + override def visitSingleInsertQuery( + ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) + } + + /** + * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. + */ + private def withInsertInto( + ctx: InsertIntoContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + val tableIdent = visitTableIdentifier(ctx.tableIdentifier) + val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) + + InsertIntoTable( + UnresolvedRelation(tableIdent, None), + partitionKeys, + query, + ctx.OVERWRITE != null, + ctx.EXISTS != null) + } + + /** + * Create a partition specification map. + */ + override def visitPartitionSpec( + ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { + ctx.partitionVal.asScala.map { pVal => + val name = pVal.identifier.getText.toLowerCase + val value = Option(pVal.constant).map(visitStringConstant) + name -> value + }.toMap + } + + /** + * Create a partition specification map without optional values. + */ + protected def visitNonOptionalPartitionSpec( + ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { + visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) + } + + /** + * Convert a constant of any type into a string. This is typically used in DDL commands, and its + * main purpose is to prevent slight differences due to back to back conversions i.e.: + * String -> Literal -> String. + */ + protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { + ctx match { + case s: StringLiteralContext => createString(s) + case o => o.getText + } + } + + /** + * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These + * clauses determine the shape (ordering/partitioning/rows) of the query result. + */ + private def withQueryResultClauses( + ctx: QueryOrganizationContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. + val withOrder = if ( + !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // ORDER BY ... + Sort(order.asScala.map(visitSortItem), global = true, query) + } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... + Sort(sort.asScala.map(visitSortItem), global = false, query) + } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // DISTRIBUTE BY ... + RepartitionByExpression(expressionList(distributeBy), query) + } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { + // SORT BY ... DISTRIBUTE BY ... + Sort( + sort.asScala.map(visitSortItem), + global = false, + RepartitionByExpression(expressionList(distributeBy), query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { + // CLUSTER BY ... + val expressions = expressionList(clusterBy) + Sort( + expressions.map(SortOrder(_, Ascending)), + global = false, + RepartitionByExpression(expressions, query)) + } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { + // [EMPTY] + query + } else { + throw new ParseException( + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) + } + + // WINDOWS + val withWindow = withOrder.optionalMap(windows)(withWindows) + + // LIMIT + withWindow.optional(limit) { + Limit(typedVisit(limit), withWindow) + } + } + + /** + * Create a logical plan using a query specification. + */ + override def visitQuerySpecification( + ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { + val from = OneRowRelation.optional(ctx.fromClause) { + visitFromClause(ctx.fromClause) + } + withQuerySpecification(ctx, from) + } + + /** + * Add a query specification to a logical plan. The query specification is the core of the logical + * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), + * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. + * + * Note that query hints are ignored (both by the parser and the builder). + */ + private def withQuerySpecification( + ctx: QuerySpecificationContext, + relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + + // WHERE + def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { + Filter(expression(ctx), plan) + } + + // Expressions. + val expressions = Option(namedExpressionSeq).toSeq + .flatMap(_.namedExpression.asScala) + .map(typedVisit[Expression]) + + // Create either a transform or a regular query. + val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) + specType match { + case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => + // Transform + + // Add where. + val withFilter = relation.optionalMap(where)(filter) + + // Create the attributes. + val (attributes, schemaLess) = if (colTypeList != null) { + // Typed return columns. + (createStructType(colTypeList).toAttributes, false) + } else if (identifierSeq != null) { + // Untyped return columns. + val attrs = visitIdentifierSeq(identifierSeq).map { name => + AttributeReference(name, StringType, nullable = true)() + } + (attrs, false) + } else { + (Seq(AttributeReference("key", StringType)(), + AttributeReference("value", StringType)()), true) + } + + // Create the transform. + ScriptTransformation( + expressions, + string(script), + attributes, + withFilter, + withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + + case SqlBaseParser.SELECT => + // Regular select + + // Add lateral views. + val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) + + // Add where. + val withFilter = withLateralView.optionalMap(where)(filter) + + // Add aggregation or a project. + val namedExpressions = expressions.map { + case e: NamedExpression => e + case e: Expression => UnresolvedAlias(e) + } + val withProject = if (aggregation != null) { + withAggregation(aggregation, namedExpressions, withFilter) + } else if (namedExpressions.nonEmpty) { + Project(namedExpressions, withFilter) + } else { + withFilter + } + + // Having + val withHaving = withProject.optional(having) { + // Note that we added a cast to boolean. If the expression itself is already boolean, + // the optimizer will get rid of the unnecessary cast. + Filter(Cast(expression(having), BooleanType), withProject) + } + + // Distinct + val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { + Distinct(withHaving) + } else { + withHaving + } + + // Window + withDistinct.optionalMap(windows)(withWindows) + } + } + + /** + * Create a (Hive based) [[ScriptInputOutputSchema]]. + */ + protected def withScriptIOSchema( + inRowFormat: RowFormatContext, + recordWriter: Token, + outRowFormat: RowFormatContext, + recordReader: Token, + schemaLess: Boolean): ScriptInputOutputSchema = null + + /** + * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma + * separated) relations here, these get converted into a single plan by condition-less inner join. + */ + override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { + val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) + ctx.lateralView.asScala.foldLeft(from)(withGenerate) + } + + /** + * Connect two queries by a Set operator. + * + * Supported Set operators are: + * - UNION [DISTINCT] + * - UNION ALL + * - EXCEPT [DISTINCT] + * - INTERSECT [DISTINCT] + */ + override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { + val left = plan(ctx.left) + val right = plan(ctx.right) + val all = Option(ctx.setQuantifier()).exists(_.ALL != null) + ctx.operator.getType match { + case SqlBaseParser.UNION if all => + Union(left, right) + case SqlBaseParser.UNION => + Distinct(Union(left, right)) + case SqlBaseParser.INTERSECT if all => + throw new ParseException("INTERSECT ALL is not supported.", ctx) + case SqlBaseParser.INTERSECT => + Intersect(left, right) + case SqlBaseParser.EXCEPT if all => + throw new ParseException("EXCEPT ALL is not supported.", ctx) + case SqlBaseParser.EXCEPT => + Except(left, right) + } + } + + /** + * Add a [[WithWindowDefinition]] operator to a logical plan. + */ + private def withWindows( + ctx: WindowsContext, + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Collect all window specifications defined in the WINDOW clause. + val baseWindowMap = ctx.namedWindow.asScala.map { + wCtx => + (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) + }.toMap + + // Handle cases like + // window w1 as (partition by p_mfgr order by p_name + // range between 2 preceding and 2 following), + // w2 as w1 + val windowMapView = baseWindowMap.mapValues { + case WindowSpecReference(name) => + baseWindowMap.get(name) match { + case Some(spec: WindowSpecDefinition) => + spec + case Some(ref) => + throw new ParseException(s"Window reference '$name' is not a window specification", ctx) + case None => + throw new ParseException(s"Cannot resolve window reference '$name'", ctx) + } + case spec: WindowSpecDefinition => spec + } + + // Note that mapValues creates a view instead of materialized map. We force materialization by + // mapping over identity. + WithWindowDefinition(windowMapView.map(identity), query) + } + + /** + * Add an [[Aggregate]] to a logical plan. + */ + private def withAggregation( + ctx: AggregationContext, + selectExpressions: Seq[NamedExpression], + query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + import ctx._ + val groupByExpressions = expressionList(groupingExpressions) + + if (GROUPING != null) { + // GROUP BY .... GROUPING SETS (...) + val expressionMap = groupByExpressions.zipWithIndex.toMap + val numExpressions = expressionMap.size + val mask = (1 << numExpressions) - 1 + val masks = ctx.groupingSet.asScala.map { + _.expression.asScala.foldLeft(mask) { + case (bitmap, eCtx) => + // Find the index of the expression. + val e = typedVisit[Expression](eCtx) + val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( + throw new ParseException( + s"$e doesn't show up in the GROUP BY list", ctx)) + // 0 means that the column at the given index is a grouping column, 1 means it is not, + // so we unset the bit in bitmap. + bitmap & ~(1 << (numExpressions - 1 - index)) + } + } + GroupingSets(masks, groupByExpressions, query, selectExpressions) + } else { + // GROUP BY .... (WITH CUBE | WITH ROLLUP)? + val mappedGroupByExpressions = if (CUBE != null) { + Seq(Cube(groupByExpressions)) + } else if (ROLLUP != null) { + Seq(Rollup(groupByExpressions)) + } else { + groupByExpressions + } + Aggregate(mappedGroupByExpressions, selectExpressions, query) + } + } + + /** + * Add a [[Generate]] (Lateral View) to a logical plan. + */ + private def withGenerate( + query: LogicalPlan, + ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { + val expressions = expressionList(ctx.expression) + + // Create the generator. + val generator = ctx.qualifiedName.getText.toLowerCase match { + case "explode" if expressions.size == 1 => + Explode(expressions.head) + case "json_tuple" => + JsonTuple(expressions) + case other => + withGenerator(other, expressions, ctx) + } + + Generate( + generator, + join = true, + outer = ctx.OUTER != null, + Some(ctx.tblName.getText.toLowerCase), + ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), + query) + } + + /** + * Create a [[Generator]]. Override this method in order to support custom Generators. + */ + protected def withGenerator( + name: String, + expressions: Seq[Expression], + ctx: LateralViewContext): Generator = { + throw new ParseException(s"Generator function '$name' is not supported", ctx) + } + + /** + * Create a joins between two or more logical plans. + */ + override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { + /** Build a join between two plans. */ + def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { + val baseJoinType = ctx.joinType match { + case null => Inner + case jt if jt.FULL != null => FullOuter + case jt if jt.SEMI != null => LeftSemi + case jt if jt.LEFT != null => LeftOuter + case jt if jt.RIGHT != null => RightOuter + case _ => Inner + } + + // Resolve the join type and join condition + val (joinType, condition) = Option(ctx.joinCriteria) match { + case Some(c) if c.USING != null => + val columns = c.identifier.asScala.map { column => + UnresolvedAttribute.quoted(column.getText) + } + (UsingJoin(baseJoinType, columns), None) + case Some(c) if c.booleanExpression != null => + (baseJoinType, Option(expression(c.booleanExpression))) + case None if ctx.NATURAL != null => + (NaturalJoin(baseJoinType), None) + case None => + (baseJoinType, None) + } + Join(left, right, joinType, condition) + } + + // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the + // first join clause is at the top. However fields of previously referenced tables can be used + // in following join clauses. The tree needs to be reversed in order to make this work. + var result = plan(ctx.left) + var current = ctx + while (current != null) { + current.right match { + case right: JoinRelationContext => + result = join(current, result, plan(right.left)) + current = right + case right => + result = join(current, result, plan(right)) + current = null + } + } + result + } + + /** + * Add a [[Sample]] to a logical plan. + * + * This currently supports the following sampling methods: + * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. + * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages + * are defined as a number between 0 and 100. + * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. + */ + private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { + // Create a sampled plan if we need one. + def sample(fraction: Double): Sample = { + // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling + // function takes X PERCENT as the input and the range of X is [0, 100], we need to + // adjust the fraction. + val eps = RandomSampler.roundingEpsilon + assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, + s"Sampling fraction ($fraction) must be on interval [0, 1]", + ctx) + Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) + } + + ctx.sampleType.getType match { + case SqlBaseParser.ROWS => + Limit(expression(ctx.expression), query) + + case SqlBaseParser.PERCENTLIT => + val fraction = ctx.percentage.getText.toDouble + sample(fraction / 100.0d) + + case SqlBaseParser.BUCKET if ctx.ON != null => + throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) + + case SqlBaseParser.BUCKET => + sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) + } + } + + /** + * Create a logical plan for a sub-query. + */ + override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith) + } + + /** + * Create an un-aliased table reference. This is typically used for top-level table references, + * for example: + * {{{ + * INSERT INTO db.tbl2 + * TABLE db.tbl1 + * }}} + */ + override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { + UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) + } + + /** + * Create an aliased table reference. This is typically used in FROM clauses. + */ + override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { + val table = UnresolvedRelation( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.identifier).map(_.getText)) + table.optionalMap(ctx.sample)(withSample) + } + + /** + * Create an inline table (a virtual table in Hive parlance). + */ + override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { + // Get the backing expressions. + val expressions = ctx.expression.asScala.map { eCtx => + val e = expression(eCtx) + assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) + e + } + + // Validate and evaluate the rows. + val (structType, structConstructor) = expressions.head.dataType match { + case st: StructType => + (st, (e: Expression) => e) + case dt => + val st = CreateStruct(Seq(expressions.head)).dataType + (st, (e: Expression) => CreateStruct(Seq(e))) + } + val rows = expressions.map { + case expression => + val safe = Cast(structConstructor(expression), structType) + safe.eval().asInstanceOf[InternalRow] + } + + // Construct attributes. + val baseAttributes = structType.toAttributes.map(_.withNullability(true)) + val attributes = if (ctx.identifierList != null) { + val aliases = visitIdentifierList(ctx.identifierList) + assert(aliases.size == baseAttributes.size, + "Number of aliases must match the number of fields in an inline table.", ctx) + baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) + } else { + baseAttributes + } + + // Create plan and add an alias if a name has been defined. + LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a join relation. This is practically the same as + * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as + * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different + * hooks. + */ + override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) + } + + /** + * Create an alias (SubqueryAlias) for a LogicalPlan. + */ + private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { + SubqueryAlias(alias.getText, plan) + } + + /** + * Create a Sequence of Strings for a parenthesis enclosed alias list. + */ + override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { + visitIdentifierSeq(ctx.identifierSeq) + } + + /** + * Create a Sequence of Strings for an identifier list. + */ + override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + + /* ******************************************************************************************** + * Table Identifier parsing + * ******************************************************************************************** */ + /** + * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. + */ + override def visitTableIdentifier( + ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { + TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) + } + + /* ******************************************************************************************** + * Expression parsing + * ******************************************************************************************** */ + /** + * Create an expression from the given context. This method just passes the context on to the + * vistor and only takes care of typing (We assume that the visitor returns an Expression here). + */ + protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) + + /** + * Create sequence of expressions from the given sequence of contexts. + */ + private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { + trees.asScala.map(expression) + } + + /** + * Invert a boolean expression if it has a valid NOT clause. + */ + private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { + if (not != null) { + Not(expression) + } else { + expression + } + } + + /** + * Create a star (i.e. all) expression; this selects all elements (in the specified object). + * Both un-targeted (global) and targeted aliases are supported. + */ + override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { + UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) + } + + /** + * Create an aliased expression if an alias is specified. Both single and multi-aliases are + * supported. + */ + override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { + val e = expression(ctx.expression) + if (ctx.identifier != null) { + Alias(e, ctx.identifier.getText)() + } else if (ctx.identifierList != null) { + MultiAlias(e, visitIdentifierList(ctx.identifierList)) + } else { + e + } + } + + /** + * Combine a number of boolean expressions into a balanced expression tree. These expressions are + * either combined by a logical [[And]] or a logical [[Or]]. + * + * A balanced binary tree is created because regular left recursive trees cause considerable + * performance degradations and can cause stack overflows. + */ + override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { + val expressionType = ctx.operator.getType + val expressionCombiner = expressionType match { + case SqlBaseParser.AND => And.apply _ + case SqlBaseParser.OR => Or.apply _ + } + + // Collect all similar left hand contexts. + val contexts = ArrayBuffer(ctx.right) + var current = ctx.left + def collectContexts: Boolean = current match { + case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => + contexts += lbc.right + current = lbc.left + true + case _ => + contexts += current + false + } + while (collectContexts) { + // No body - all updates take place in the collectContexts. + } + + // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them + // into expressions. + val expressions = contexts.reverse.map(expression) + + // Create a balanced tree. + def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { + case 0 => + expressions(low) + case 1 => + expressionCombiner(expressions(low), expressions(high)) + case x => + val mid = low + x / 2 + expressionCombiner( + reduceToExpressionTree(low, mid), + reduceToExpressionTree(mid + 1, high)) + } + reduceToExpressionTree(0, expressions.size - 1) + } + + /** + * Invert a boolean expression. + */ + override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { + Not(expression(ctx.booleanExpression())) + } + + /** + * Create a filtering correlated sub-query. This is not supported yet. + */ + override def visitExists(ctx: ExistsContext): Expression = { + throw new ParseException("EXISTS clauses are not supported.", ctx) + } + + /** + * Create a comparison expression. This compares two expressions. The following comparison + * operators are supported: + * - Equal: '=' or '==' + * - Null-safe Equal: '<=>' + * - Not Equal: '<>' or '!=' + * - Less than: '<' + * - Less then or Equal: '<=' + * - Greater than: '>' + * - Greater then or Equal: '>=' + */ + override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] + operator.getSymbol.getType match { + case SqlBaseParser.EQ => + EqualTo(left, right) + case SqlBaseParser.NSEQ => + EqualNullSafe(left, right) + case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => + Not(EqualTo(left, right)) + case SqlBaseParser.LT => + LessThan(left, right) + case SqlBaseParser.LTE => + LessThanOrEqual(left, right) + case SqlBaseParser.GT => + GreaterThan(left, right) + case SqlBaseParser.GTE => + GreaterThanOrEqual(left, right) + } + } + + /** + * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two + * other expressions. The inverse can also be created. + */ + override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + val between = And( + GreaterThanOrEqual(value, expression(ctx.lower)), + LessThanOrEqual(value, expression(ctx.upper))) + invertIfNotDefined(between, ctx.NOT) + } + + /** + * Create an IN expression. This tests if the value of the left hand side expression is + * contained by the sequence of expressions on the right hand side. + */ + override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { + val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) + invertIfNotDefined(in, ctx.NOT) + } + + /** + * Create an IN expression, where the the right hand side is a query. This is unsupported. + */ + override def visitInSubquery(ctx: InSubqueryContext): Expression = { + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + } + + /** + * Create a (R)LIKE/REGEXP expression. + */ + override def visitLike(ctx: LikeContext): Expression = { + val left = expression(ctx.value) + val right = expression(ctx.pattern) + val like = ctx.like.getType match { + case SqlBaseParser.LIKE => + Like(left, right) + case SqlBaseParser.RLIKE => + RLike(left, right) + } + invertIfNotDefined(like, ctx.NOT) + } + + /** + * Create an IS (NOT) NULL expression. + */ + override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { + val value = expression(ctx.value) + if (ctx.NOT != null) { + IsNotNull(value) + } else { + IsNull(value) + } + } + + /** + * Create a binary arithmetic expression. The following arithmetic operators are supported: + * - Mulitplication: '*' + * - Division: '/' + * - Hive Long Division: 'DIV' + * - Modulo: '%' + * - Addition: '+' + * - Subtraction: '-' + * - Binary AND: '&' + * - Binary XOR + * - Binary OR: '|' + */ + override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { + val left = expression(ctx.left) + val right = expression(ctx.right) + ctx.operator.getType match { + case SqlBaseParser.ASTERISK => + Multiply(left, right) + case SqlBaseParser.SLASH => + Divide(left, right) + case SqlBaseParser.PERCENT => + Remainder(left, right) + case SqlBaseParser.DIV => + Cast(Divide(left, right), LongType) + case SqlBaseParser.PLUS => + Add(left, right) + case SqlBaseParser.MINUS => + Subtract(left, right) + case SqlBaseParser.AMPERSAND => + BitwiseAnd(left, right) + case SqlBaseParser.HAT => + BitwiseXor(left, right) + case SqlBaseParser.PIPE => + BitwiseOr(left, right) + } + } + + /** + * Create a unary arithmetic expression. The following arithmetic operators are supported: + * - Plus: '+' + * - Minus: '-' + * - Bitwise Not: '~' + */ + override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { + val value = expression(ctx.valueExpression) + ctx.operator.getType match { + case SqlBaseParser.PLUS => + value + case SqlBaseParser.MINUS => + UnaryMinus(value) + case SqlBaseParser.TILDE => + BitwiseNot(value) + } + } + + /** + * Create a [[Cast]] expression. + */ + override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { + Cast(expression(ctx.expression), typedVisit(ctx.dataType)) + } + + /** + * Create a (windowed) Function expression. + */ + override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { + // Create the function call. + val name = ctx.qualifiedName.getText + val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) + val arguments = ctx.expression().asScala.map(expression) match { + case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => + // Transform COUNT(*) into COUNT(1). Move this to analysis? + Seq(Literal(1)) + case expressions => + expressions + } + val function = UnresolvedFunction(name, arguments, isDistinct) + + // Check if the function is evaluated in a windowed context. + ctx.windowSpec match { + case spec: WindowRefContext => + UnresolvedWindowExpression(function, visitWindowRef(spec)) + case spec: WindowDefContext => + WindowExpression(function, visitWindowDef(spec)) + case _ => function + } + } + + /** + * Create a reference to a window frame, i.e. [[WindowSpecReference]]. + */ + override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { + WindowSpecReference(ctx.identifier.getText) + } + + /** + * Create a window definition, i.e. [[WindowSpecDefinition]]. + */ + override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { + // CLUSTER BY ... | PARTITION BY ... ORDER BY ... + val partition = ctx.partition.asScala.map(expression) + val order = ctx.sortItem.asScala.map(visitSortItem) + + // RANGE/ROWS BETWEEN ... + val frameSpecOption = Option(ctx.windowFrame).map { frame => + val frameType = frame.frameType.getType match { + case SqlBaseParser.RANGE => RangeFrame + case SqlBaseParser.ROWS => RowFrame + } + + SpecifiedWindowFrame( + frameType, + visitFrameBound(frame.start), + Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) + } + + WindowSpecDefinition( + partition, + order, + frameSpecOption.getOrElse(UnspecifiedFrame)) + } + + /** + * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value + * Preceding/Following boundaries. These expressions must be constant (foldable) and return an + * integer value. + */ + override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { + // We currently only allow foldable integers. + def value: Int = { + val e = expression(ctx.expression) + assert(e.resolved && e.foldable && e.dataType == IntegerType, + "Frame bound value must be a constant integer.", + ctx) + e.eval().asInstanceOf[Int] + } + + // Create the FrameBoundary + ctx.boundType.getType match { + case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => + UnboundedPreceding + case SqlBaseParser.PRECEDING => + ValuePreceding(value) + case SqlBaseParser.CURRENT => + CurrentRow + case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => + UnboundedFollowing + case SqlBaseParser.FOLLOWING => + ValueFollowing(value) + } + } + + /** + * Create a [[CreateStruct]] expression. + */ + override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { + CreateStruct(ctx.expression.asScala.map(expression)) + } + + /** + * Create a [[ScalarSubquery]] expression. + */ + override def visitSubqueryExpression( + ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { + ScalarSubquery(plan(ctx.query)) + } + + /** + * Create a value based [[CaseWhen]] expression. This has the following SQL form: + * {{{ + * CASE [expression] + * WHEN [value] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + */ + override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + val branches = ctx.whenClause.asScala.map { wCtx => + (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: + * {{{ + * CASE + * WHEN [predicate] THEN [expression] + * ... + * ELSE [expression] + * END + * }}} + * + * @param ctx the parse tree + * */ + override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { + val branches = ctx.whenClause.asScala.map { wCtx => + (expression(wCtx.condition), expression(wCtx.result)) + } + CaseWhen(branches, Option(ctx.elseExpression).map(expression)) + } + + /** + * Create a dereference expression. The return type depends on the type of the parent, this can + * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an + * [[UnresolvedExtractValue]] if the parent is some expression. + */ + override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { + val attr = ctx.fieldName.getText + expression(ctx.base) match { + case UnresolvedAttribute(nameParts) => + UnresolvedAttribute(nameParts :+ attr) + case e => + UnresolvedExtractValue(e, Literal(attr)) + } + } + + /** + * Create an [[UnresolvedAttribute]] expression. + */ + override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { + UnresolvedAttribute.quoted(ctx.getText) + } + + /** + * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. + */ + override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { + UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) + } + + /** + * Create an expression for an expression between parentheses. This is need because the ANTLR + * visitor cannot automatically convert the nested context into an expression. + */ + override def visitParenthesizedExpression( + ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { + expression(ctx.expression) + } + + /** + * Create a [[SortOrder]] expression. + */ + override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { + if (ctx.DESC != null) { + SortOrder(expression(ctx.expression), Descending) + } else { + SortOrder(expression(ctx.expression), Ascending) + } + } + + /** + * Create a typed Literal expression. A typed literal has the following SQL syntax: + * {{{ + * [TYPE] '[VALUE]' + * }}} + * Currently Date and Timestamp typed literals are supported. + * + * TODO what the added value of this over casting? + */ + override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { + val value = string(ctx.STRING) + ctx.identifier.getText.toUpperCase match { + case "DATE" => + Literal(Date.valueOf(value)) + case "TIMESTAMP" => + Literal(Timestamp.valueOf(value)) + case other => + throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) + } + } + + /** + * Create a NULL literal expression. + */ + override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { + Literal(null) + } + + /** + * Create a Boolean literal expression. + */ + override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { + if (ctx.getText.toBoolean) { + Literal.TrueLiteral + } else { + Literal.FalseLiteral + } + } + + /** + * Create an integral literal expression. The code selects the most narrow integral type + * possible, either a BigDecimal, a Long or an Integer is returned. + */ + override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { + BigDecimal(ctx.getText) match { + case v if v.isValidInt => + Literal(v.intValue()) + case v if v.isValidLong => + Literal(v.longValue()) + case v => Literal(v.underlying()) + } + } + + /** + * Create a double literal for a number denoted in scientifc notation. + */ + override def visitScientificDecimalLiteral( + ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(ctx.getText.toDouble) + } + + /** + * Create a decimal literal for a regular decimal number. + */ + override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { + Literal(BigDecimal(ctx.getText).underlying()) + } + + /** Create a numeric literal expression. */ + private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { + val raw = ctx.getText + try { + Literal(f(raw.substring(0, raw.length - 1))) + } catch { + case e: NumberFormatException => + throw new ParseException(e.getMessage, ctx) + } + } + + /** + * Create a Byte Literal expression. + */ + override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { + _.toByte + } + + /** + * Create a Short Literal expression. + */ + override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { + _.toShort + } + + /** + * Create a Long Literal expression. + */ + override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { + _.toLong + } + + /** + * Create a Double Literal expression. + */ + override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { + _.toDouble + } + + /** + * Create a String literal expression. + */ + override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { + Literal(createString(ctx)) + } + + /** + * Create a String from a string literal context. This supports multiple consecutive string + * literals, these are concatenated, for example this expression "'hello' 'world'" will be + * converted into "helloworld". + * + * Special characters can be escaped by using Hive/C-style escaping. + */ + private def createString(ctx: StringLiteralContext): String = { + ctx.STRING().asScala.map(string).mkString + } + + /** + * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple + * unit value pairs, for instance: interval 2 months 2 days. + */ + override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { + val intervals = ctx.intervalField.asScala.map(visitIntervalField) + assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) + Literal(intervals.reduce(_.add(_))) + } + + /** + * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are + * supported: + * - Single unit. + * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). + */ + override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { + import ctx._ + val s = value.getText + try { + val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { + case (u, None) if u.endsWith("s") => + // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... + CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) + case (u, None) => + CalendarInterval.fromSingleUnitString(u, s) + case ("year", Some("month")) => + CalendarInterval.fromYearMonthString(s) + case ("day", Some("second")) => + CalendarInterval.fromDayTimeString(s) + case (from, Some(t)) => + throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) + } + assert(interval != null, "No interval can be constructed", ctx) + interval + } catch { + // Handle Exceptions thrown by CalendarInterval + case e: IllegalArgumentException => + val pe = new ParseException(e.getMessage, ctx) + pe.setStackTrace(e.getStackTrace) + throw pe + } + } + + /* ******************************************************************************************** + * DataType parsing + * ******************************************************************************************** */ + /** + * Resolve/create a primitive type. + */ + override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { + (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { + case ("boolean", Nil) => BooleanType + case ("tinyint" | "byte", Nil) => ByteType + case ("smallint" | "short", Nil) => ShortType + case ("int" | "integer", Nil) => IntegerType + case ("bigint" | "long", Nil) => LongType + case ("float", Nil) => FloatType + case ("double", Nil) => DoubleType + case ("date", Nil) => DateType + case ("timestamp", Nil) => TimestampType + case ("char" | "varchar" | "string", Nil) => StringType + case ("char" | "varchar", _ :: Nil) => StringType + case ("binary", Nil) => BinaryType + case ("decimal", Nil) => DecimalType.USER_DEFAULT + case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) + case ("decimal", precision :: scale :: Nil) => + DecimalType(precision.getText.toInt, scale.getText.toInt) + case (dt, params) => + throw new ParseException( + s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) + } + } + + /** + * Create a complex DataType. Arrays, Maps and Structures are supported. + */ + override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { + ctx.complex.getType match { + case SqlBaseParser.ARRAY => + ArrayType(typedVisit(ctx.dataType(0))) + case SqlBaseParser.MAP => + MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) + case SqlBaseParser.STRUCT => + createStructType(ctx.colTypeList()) + } + } + + /** + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ + protected def createStructType(ctx: ColTypeListContext): StructType = { + StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) + } + + /** + * Create a [[StructType]] from a number of column definitions. + */ + override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { + ctx.colType().asScala.map(visitColType) + } + + /** + * Create a [[StructField]] from a column definition. + */ + override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { + import ctx._ + + // Add the comment to the metadata. + val builder = new MetadataBuilder + if (STRING != null) { + builder.putString("comment", string(STRING)) + } + + StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala deleted file mode 100644 index c188c5b108..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/CatalystQl.scala +++ /dev/null @@ -1,933 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import java.sql.Date - -import scala.collection.mutable.ArrayBuffer -import scala.util.matching.Regex - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.Count -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler - - -/** - * This class translates SQL to Catalyst [[LogicalPlan]]s or [[Expression]]s. - */ -private[sql] class CatalystQl(val conf: ParserConf = SimpleParserConf()) extends ParserInterface { - import ParserUtils._ - - /** - * The safeParse method allows a user to focus on the parsing/AST transformation logic. This - * method will take care of possible errors during the parsing process. - */ - protected def safeParse[T](sql: String, ast: ASTNode)(toResult: ASTNode => T): T = { - try { - toResult(ast) - } catch { - case e: MatchError => throw e - case e: AnalysisException => throw e - case e: Exception => - throw new AnalysisException(e.getMessage) - case e: NotImplementedError => - throw new AnalysisException( - s"""Unsupported language features in query - |== SQL == - |$sql - |== AST == - |${ast.treeString} - |== Error == - |$e - |== Stacktrace == - |${e.getStackTrace.head} - """.stripMargin) - } - } - - /** Creates LogicalPlan for a given SQL string. */ - def parsePlan(sql: String): LogicalPlan = - safeParse(sql, ParseDriver.parsePlan(sql, conf))(nodeToPlan) - - /** Creates Expression for a given SQL string. */ - def parseExpression(sql: String): Expression = - safeParse(sql, ParseDriver.parseExpression(sql, conf))(selExprNodeToExpr(_).get) - - /** Creates TableIdentifier for a given SQL string. */ - def parseTableIdentifier(sql: String): TableIdentifier = - safeParse(sql, ParseDriver.parseTableName(sql, conf))(extractTableIdent) - - /** - * SELECT MAX(value) FROM src GROUP BY k1, k2, k3 GROUPING SETS((k1, k2), (k2)) - * is equivalent to - * SELECT MAX(value) FROM src GROUP BY k1, k2 UNION SELECT MAX(value) FROM src GROUP BY k2 - * Check the following link for details. - * -https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C+Grouping+and+Rollup - * - * The bitmask denotes the grouping expressions validity for a grouping set, - * the bitmask also be called as grouping id (`GROUPING__ID`, the virtual column in Hive) - * e.g. In superset (k1, k2, k3), (bit 2: k1, bit 1: k2, and bit 0: k3), the grouping id of - * GROUPING SETS (k1, k2) and (k2) should be 1 and 5 respectively. - */ - protected def extractGroupingSet(children: Seq[ASTNode]): (Seq[Expression], Seq[Int]) = { - val (keyASTs, setASTs) = children.partition { - case Token("TOK_GROUPING_SETS_EXPRESSION", _) => false // grouping sets - case _ => true // grouping keys - } - - val keys = keyASTs.map(nodeToExpr) - val keyMap = keyASTs.zipWithIndex.toMap - - val mask = (1 << keys.length) - 1 - val bitmasks: Seq[Int] = setASTs.map { - case Token("TOK_GROUPING_SETS_EXPRESSION", columns) => - columns.foldLeft(mask)((bitmap, col) => { - val keyIndex = keyMap.find(_._1.treeEquals(col)).map(_._2).getOrElse( - throw new AnalysisException(s"${col.treeString} doesn't show up in the GROUP BY list")) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (keys.length - 1 - keyIndex)) - }) - case _ => sys.error("Expect GROUPING SETS clause") - } - - (keys, bitmasks) - } - - protected def nodeToPlan(node: ASTNode): LogicalPlan = node match { - case Token("TOK_SHOWFUNCTIONS", args) => - // Skip LIKE. - val pattern = args match { - case like :: nodes if like.text.toUpperCase == "LIKE" => nodes - case nodes => nodes - } - - // Extract Database and Function name - pattern match { - case Nil => - ShowFunctions(None, None) - case Token(name, Nil) :: Nil => - ShowFunctions(None, Some(unquoteString(cleanIdentifier(name)))) - case Token(db, Nil) :: Token(name, Nil) :: Nil => - ShowFunctions(Some(unquoteString(cleanIdentifier(db))), - Some(unquoteString(cleanIdentifier(name)))) - case _ => - noParseRule("SHOW FUNCTIONS", node) - } - - case Token("TOK_DESCFUNCTION", Token(functionName, Nil) :: isExtended) => - DescribeFunction(cleanIdentifier(functionName), isExtended.nonEmpty) - - case Token("TOK_QUERY", queryArgs @ Token("TOK_CTE" | "TOK_FROM" | "TOK_INSERT", _) :: _) => - val (fromClause: Option[ASTNode], insertClauses, cteRelations) = - queryArgs match { - case Token("TOK_CTE", ctes) :: Token("TOK_FROM", from) :: inserts => - val cteRelations = ctes.map { node => - val relation = nodeToRelation(node).asInstanceOf[SubqueryAlias] - relation.alias -> relation - } - (Some(from.head), inserts, Some(cteRelations.toMap)) - case Token("TOK_FROM", from) :: inserts => - (Some(from.head), inserts, None) - case Token("TOK_INSERT", _) :: Nil => - (None, queryArgs, None) - } - - // Return one query for each insert clause. - val queries = insertClauses.map { - case Token("TOK_INSERT", singleInsert) => - val ( - intoClause :: - destClause :: - selectClause :: - selectDistinctClause :: - whereClause :: - groupByClause :: - rollupGroupByClause :: - cubeGroupByClause :: - groupingSetsClause :: - orderByClause :: - havingClause :: - sortByClause :: - clusterByClause :: - distributeByClause :: - limitClause :: - lateralViewClause :: - windowClause :: Nil) = { - getClauses( - Seq( - "TOK_INSERT_INTO", - "TOK_DESTINATION", - "TOK_SELECT", - "TOK_SELECTDI", - "TOK_WHERE", - "TOK_GROUPBY", - "TOK_ROLLUP_GROUPBY", - "TOK_CUBE_GROUPBY", - "TOK_GROUPING_SETS", - "TOK_ORDERBY", - "TOK_HAVING", - "TOK_SORTBY", - "TOK_CLUSTERBY", - "TOK_DISTRIBUTEBY", - "TOK_LIMIT", - "TOK_LATERAL_VIEW", - "WINDOW"), - singleInsert) - } - - val relations = fromClause match { - case Some(f) => nodeToRelation(f) - case None => OneRowRelation - } - - val withLateralView = lateralViewClause.map { lv => - nodeToGenerate(lv.children.head, outer = false, relations) - }.getOrElse(relations) - - val withWhere = whereClause.map { whereNode => - val Seq(whereExpr) = whereNode.children - Filter(nodeToExpr(whereExpr), withLateralView) - }.getOrElse(withLateralView) - - val select = (selectClause orElse selectDistinctClause) - .getOrElse(sys.error("No select clause.")) - - val transformation = nodeToTransformation(select.children.head, withWhere) - - // The projection of the query can either be a normal projection, an aggregation - // (if there is a group by) or a script transformation. - val withProject: LogicalPlan = transformation.getOrElse { - val selectExpressions = - select.children.flatMap(selExprNodeToExpr).map(UnresolvedAlias(_)) - Seq( - groupByClause.map(e => e match { - case Token("TOK_GROUPBY", children) => - // Not a transformation so must be either project or aggregation. - Aggregate(children.map(nodeToExpr), selectExpressions, withWhere) - case _ => sys.error("Expect GROUP BY") - }), - groupingSetsClause.map(e => e match { - case Token("TOK_GROUPING_SETS", children) => - val(groupByExprs, masks) = extractGroupingSet(children) - GroupingSets(masks, groupByExprs, withWhere, selectExpressions) - case _ => sys.error("Expect GROUPING SETS") - }), - rollupGroupByClause.map(e => e match { - case Token("TOK_ROLLUP_GROUPBY", children) => - Aggregate( - Seq(Rollup(children.map(nodeToExpr))), - selectExpressions, - withWhere) - case _ => sys.error("Expect WITH ROLLUP") - }), - cubeGroupByClause.map(e => e match { - case Token("TOK_CUBE_GROUPBY", children) => - Aggregate( - Seq(Cube(children.map(nodeToExpr))), - selectExpressions, - withWhere) - case _ => sys.error("Expect WITH CUBE") - }), - Some(Project(selectExpressions, withWhere))).flatten.head - } - - // Handle HAVING clause. - val withHaving = havingClause.map { h => - val havingExpr = h.children match { case Seq(hexpr) => nodeToExpr(hexpr) } - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(havingExpr, BooleanType), withProject) - }.getOrElse(withProject) - - // Handle SELECT DISTINCT - val withDistinct = - if (selectDistinctClause.isDefined) Distinct(withHaving) else withHaving - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withSort = - (orderByClause, sortByClause, distributeByClause, clusterByClause) match { - case (Some(totalOrdering), None, None, None) => - Sort(totalOrdering.children.map(nodeToSortOrder), global = true, withDistinct) - case (None, Some(perPartitionOrdering), None, None) => - Sort( - perPartitionOrdering.children.map(nodeToSortOrder), - global = false, withDistinct) - case (None, None, Some(partitionExprs), None) => - RepartitionByExpression( - partitionExprs.children.map(nodeToExpr), withDistinct) - case (None, Some(perPartitionOrdering), Some(partitionExprs), None) => - Sort( - perPartitionOrdering.children.map(nodeToSortOrder), global = false, - RepartitionByExpression( - partitionExprs.children.map(nodeToExpr), - withDistinct)) - case (None, None, None, Some(clusterExprs)) => - Sort( - clusterExprs.children.map(nodeToExpr).map(SortOrder(_, Ascending)), - global = false, - RepartitionByExpression( - clusterExprs.children.map(nodeToExpr), - withDistinct)) - case (None, None, None, None) => withDistinct - case _ => sys.error("Unsupported set of ordering / distribution clauses.") - } - - val withLimit = - limitClause.map(l => nodeToExpr(l.children.head)) - .map(Limit(_, withSort)) - .getOrElse(withSort) - - // Collect all window specifications defined in the WINDOW clause. - val windowDefinitions = windowClause.map(_.children.collect { - case Token("TOK_WINDOWDEF", - Token(windowName, Nil) :: Token("TOK_WINDOWSPEC", spec) :: Nil) => - windowName -> nodesToWindowSpecification(spec) - }.toMap) - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val resolvedCrossReference = windowDefinitions.map { - windowDefMap => windowDefMap.map { - case (windowName, WindowSpecReference(other)) => - (windowName, windowDefMap(other).asInstanceOf[WindowSpecDefinition]) - case o => o.asInstanceOf[(String, WindowSpecDefinition)] - } - } - - val withWindowDefinitions = - resolvedCrossReference.map(WithWindowDefinition(_, withLimit)).getOrElse(withLimit) - - // TOK_INSERT_INTO means to add files to the table. - // TOK_DESTINATION means to overwrite the table. - val resultDestination = - (intoClause orElse destClause).getOrElse(sys.error("No destination found.")) - val overwrite = intoClause.isEmpty - nodeToDest( - resultDestination, - withWindowDefinitions, - overwrite) - } - - // If there are multiple INSERTS just UNION them together into one query. - val query = if (queries.length == 1) queries.head else Union(queries) - - // return With plan if there is CTE - cteRelations.map(With(query, _)).getOrElse(query) - - case Token("TOK_UNIONALL", left :: right :: Nil) => - Union(nodeToPlan(left), nodeToPlan(right)) - case Token("TOK_UNIONDISTINCT", left :: right :: Nil) => - Distinct(Union(nodeToPlan(left), nodeToPlan(right))) - case Token("TOK_EXCEPT", left :: right :: Nil) => - Except(nodeToPlan(left), nodeToPlan(right)) - case Token("TOK_INTERSECT", left :: right :: Nil) => - Intersect(nodeToPlan(left), nodeToPlan(right)) - - case _ => - noParseRule("Plan", node) - } - - val allJoinTokens = "(TOK_.*JOIN)".r - val laterViewToken = "TOK_LATERAL_VIEW(.*)".r - protected def nodeToRelation(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_SUBQUERY", query :: Token(alias, Nil) :: Nil) => - SubqueryAlias(cleanIdentifier(alias), nodeToPlan(query)) - - case Token(laterViewToken(isOuter), selectClause :: relationClause :: Nil) => - nodeToGenerate( - selectClause, - outer = isOuter.nonEmpty, - nodeToRelation(relationClause)) - - /* All relations, possibly with aliases or sampling clauses. */ - case Token("TOK_TABREF", clauses) => - // If the last clause is not a token then it's the alias of the table. - val (nonAliasClauses, aliasClause) = - if (clauses.last.text.startsWith("TOK")) { - (clauses, None) - } else { - (clauses.dropRight(1), Some(clauses.last)) - } - - val (Some(tableNameParts) :: - splitSampleClause :: - bucketSampleClause :: Nil) = { - getClauses(Seq("TOK_TABNAME", "TOK_TABLESPLITSAMPLE", "TOK_TABLEBUCKETSAMPLE"), - nonAliasClauses) - } - - val tableIdent = extractTableIdent(tableNameParts) - val alias = aliasClause.map { case Token(a, Nil) => cleanIdentifier(a) } - val relation = UnresolvedRelation(tableIdent, alias) - - // Apply sampling if requested. - (bucketSampleClause orElse splitSampleClause).map { - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_ROWCOUNT", Nil) :: Token(count, Nil) :: Nil) => - Limit(Literal(count.toInt), relation) - case Token("TOK_TABLESPLITSAMPLE", - Token("TOK_PERCENT", Nil) :: Token(fraction, Nil) :: Nil) => - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - require( - fraction.toDouble >= (0.0 - RandomSampler.roundingEpsilon) - && fraction.toDouble <= (100.0 + RandomSampler.roundingEpsilon), - s"Sampling fraction ($fraction) must be on interval [0, 100]") - Sample(0.0, fraction.toDouble / 100, withReplacement = false, - (math.random * 1000).toInt, - relation)( - isTableSample = true) - case Token("TOK_TABLEBUCKETSAMPLE", - Token(numerator, Nil) :: - Token(denominator, Nil) :: Nil) => - val fraction = numerator.toDouble / denominator.toDouble - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, relation)( - isTableSample = true) - case a => - noParseRule("Sampling", a) - }.getOrElse(relation) - - case Token(allJoinTokens(joinToken), relation1 :: relation2 :: other) => - if (!(other.size <= 1)) { - sys.error(s"Unsupported join operation: $other") - } - - val (joinType, joinCondition) = getJoinInfo(joinToken, other, node) - - Join(nodeToRelation(relation1), - nodeToRelation(relation2), - joinType, - joinCondition) - case _ => - noParseRule("Relation", node) - } - } - - protected def getJoinInfo( - joinToken: String, - joinConditionToken: Seq[ASTNode], - node: ASTNode): (JoinType, Option[Expression]) = { - val joinType = joinToken match { - case "TOK_JOIN" => Inner - case "TOK_CROSSJOIN" => Inner - case "TOK_RIGHTOUTERJOIN" => RightOuter - case "TOK_LEFTOUTERJOIN" => LeftOuter - case "TOK_FULLOUTERJOIN" => FullOuter - case "TOK_LEFTSEMIJOIN" => LeftSemi - case "TOK_UNIQUEJOIN" => noParseRule("Unique Join", node) - case "TOK_ANTIJOIN" => noParseRule("Anti Join", node) - case "TOK_NATURALJOIN" => NaturalJoin(Inner) - case "TOK_NATURALRIGHTOUTERJOIN" => NaturalJoin(RightOuter) - case "TOK_NATURALLEFTOUTERJOIN" => NaturalJoin(LeftOuter) - case "TOK_NATURALFULLOUTERJOIN" => NaturalJoin(FullOuter) - } - - joinConditionToken match { - case Token("TOK_USING", columnList :: Nil) :: Nil => - val colNames = columnList.children.collect { - case Token(name, Nil) => UnresolvedAttribute(name) - } - (UsingJoin(joinType, colNames), None) - /* Join expression specified using ON clause */ - case _ => (joinType, joinConditionToken.headOption.map(nodeToExpr)) - } - } - - protected def nodeToSortOrder(node: ASTNode): SortOrder = node match { - case Token("TOK_TABSORTCOLNAMEASC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", sortExpr :: Nil) => - SortOrder(nodeToExpr(sortExpr), Descending) - case _ => - noParseRule("SortOrder", node) - } - - val destinationToken = "TOK_DESTINATION|TOK_INSERT_INTO".r - protected def nodeToDest( - node: ASTNode, - query: LogicalPlan, - overwrite: Boolean): LogicalPlan = node match { - case Token(destinationToken(), - Token("TOK_DIR", - Token("TOK_TMP_FILE", Nil) :: Nil) :: Nil) => - query - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.children.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = false) - - case Token(destinationToken(), - Token("TOK_TAB", - tableArgs) :: - Token("TOK_IFNOTEXISTS", - ifNotExists) :: Nil) => - val Some(tableNameParts) :: partitionClause :: Nil = - getClauses(Seq("TOK_TABNAME", "TOK_PARTSPEC"), tableArgs) - - val tableIdent = extractTableIdent(tableNameParts) - - val partitionKeys = partitionClause.map(_.children.map { - // Parse partitions. We also make keys case insensitive. - case Token("TOK_PARTVAL", Token(key, Nil) :: Token(value, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> Some(unquoteString(value)) - case Token("TOK_PARTVAL", Token(key, Nil) :: Nil) => - cleanIdentifier(key.toLowerCase) -> None - }.toMap).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), partitionKeys, query, overwrite, ifNotExists = true) - - case _ => - noParseRule("Destination", node) - } - - protected def selExprNodeToExpr(node: ASTNode): Option[Expression] = node match { - case Token("TOK_SELEXPR", e :: Nil) => - Some(nodeToExpr(e)) - - case Token("TOK_SELEXPR", e :: Token(alias, Nil) :: Nil) => - Some(Alias(nodeToExpr(e), cleanIdentifier(alias))()) - - case Token("TOK_SELEXPR", e :: aliasChildren) => - val aliasNames = aliasChildren.collect { - case Token(name, Nil) => cleanIdentifier(name) - } - Some(MultiAlias(nodeToExpr(e), aliasNames)) - - /* Hints are ignored */ - case Token("TOK_HINTLIST", _) => None - - case _ => - noParseRule("Select", node) - } - - /** - * Flattens the left deep tree with the specified pattern into a list. - */ - private def flattenLeftDeepTree(node: ASTNode, pattern: Regex): Seq[ASTNode] = { - val collected = ArrayBuffer[ASTNode]() - var rest = node - while (rest match { - case Token(pattern(), l :: r :: Nil) => - collected += r - rest = l - true - case _ => false - }) { - // do nothing - } - collected += rest - // keep them in the same order as in SQL - collected.reverse - } - - /** - * Creates a balanced tree that has similar number of nodes on left and right. - * - * This help to reduce the depth of the tree to prevent StackOverflow in analyzer/optimizer. - */ - private def balancedTree( - expr: Seq[Expression], - f: (Expression, Expression) => Expression): Expression = expr.length match { - case 1 => expr.head - case 2 => f(expr.head, expr(1)) - case l => f(balancedTree(expr.slice(0, l / 2), f), balancedTree(expr.slice(l / 2, l), f)) - } - - protected def nodeToExpr(node: ASTNode): Expression = node match { - /* Attribute References */ - case Token("TOK_TABLE_OR_COL", Token(name, Nil) :: Nil) => - UnresolvedAttribute.quoted(cleanIdentifier(name)) - case Token(".", qualifier :: Token(attr, Nil) :: Nil) => - nodeToExpr(qualifier) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ cleanIdentifier(attr)) - case other => UnresolvedExtractValue(other, Literal(cleanIdentifier(attr))) - } - case Token("TOK_SUBQUERY_EXPR", Token("TOK_SUBQUERY_OP", Nil) :: subquery :: Nil) => - ScalarSubquery(nodeToPlan(subquery)) - - /* Stars (*) */ - case Token("TOK_ALLCOLREF", Nil) => UnresolvedStar(None) - // The format of dbName.tableName.* cannot be parsed by HiveParser. TOK_TABNAME will only - // has a single child which is tableName. - case Token("TOK_ALLCOLREF", Token("TOK_TABNAME", target) :: Nil) if target.nonEmpty => - UnresolvedStar(Some(target.map(x => cleanIdentifier(x.text)))) - - /* Aggregate Functions */ - case Token("TOK_FUNCTIONDI", Token(COUNT(), Nil) :: args) => - Count(args.map(nodeToExpr)).toAggregateExpression(isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(COUNT(), Nil) :: Nil) => - Count(Literal(1)).toAggregateExpression() - - /* Casts */ - case Token("TOK_FUNCTION", Token("TOK_STRING", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_VARCHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_CHAR", _) :: arg :: Nil) => - Cast(nodeToExpr(arg), StringType) - case Token("TOK_FUNCTION", Token("TOK_INT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), IntegerType) - case Token("TOK_FUNCTION", Token("TOK_BIGINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), LongType) - case Token("TOK_FUNCTION", Token("TOK_FLOAT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), FloatType) - case Token("TOK_FUNCTION", Token("TOK_DOUBLE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DoubleType) - case Token("TOK_FUNCTION", Token("TOK_SMALLINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ShortType) - case Token("TOK_FUNCTION", Token("TOK_TINYINT", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), ByteType) - case Token("TOK_FUNCTION", Token("TOK_BINARY", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BinaryType) - case Token("TOK_FUNCTION", Token("TOK_BOOLEAN", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), BooleanType) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: scale :: nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, scale.text.toInt)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", precision :: Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType(precision.text.toInt, 0)) - case Token("TOK_FUNCTION", Token("TOK_DECIMAL", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DecimalType.USER_DEFAULT) - case Token("TOK_FUNCTION", Token("TOK_TIMESTAMP", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), TimestampType) - case Token("TOK_FUNCTION", Token("TOK_DATE", Nil) :: arg :: Nil) => - Cast(nodeToExpr(arg), DateType) - - /* Arithmetic */ - case Token("+", child :: Nil) => nodeToExpr(child) - case Token("-", child :: Nil) => UnaryMinus(nodeToExpr(child)) - case Token("~", child :: Nil) => BitwiseNot(nodeToExpr(child)) - case Token("+", left :: right:: Nil) => Add(nodeToExpr(left), nodeToExpr(right)) - case Token("-", left :: right:: Nil) => Subtract(nodeToExpr(left), nodeToExpr(right)) - case Token("*", left :: right:: Nil) => Multiply(nodeToExpr(left), nodeToExpr(right)) - case Token("/", left :: right:: Nil) => Divide(nodeToExpr(left), nodeToExpr(right)) - case Token(DIV(), left :: right:: Nil) => - Cast(Divide(nodeToExpr(left), nodeToExpr(right)), LongType) - case Token("%", left :: right:: Nil) => Remainder(nodeToExpr(left), nodeToExpr(right)) - case Token("&", left :: right:: Nil) => BitwiseAnd(nodeToExpr(left), nodeToExpr(right)) - case Token("|", left :: right:: Nil) => BitwiseOr(nodeToExpr(left), nodeToExpr(right)) - case Token("^", left :: right:: Nil) => BitwiseXor(nodeToExpr(left), nodeToExpr(right)) - - /* Comparisons */ - case Token("=", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("==", left :: right:: Nil) => EqualTo(nodeToExpr(left), nodeToExpr(right)) - case Token("<=>", left :: right:: Nil) => EqualNullSafe(nodeToExpr(left), nodeToExpr(right)) - case Token("!=", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token("<>", left :: right:: Nil) => Not(EqualTo(nodeToExpr(left), nodeToExpr(right))) - case Token(">", left :: right:: Nil) => GreaterThan(nodeToExpr(left), nodeToExpr(right)) - case Token(">=", left :: right:: Nil) => GreaterThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token("<", left :: right:: Nil) => LessThan(nodeToExpr(left), nodeToExpr(right)) - case Token("<=", left :: right:: Nil) => LessThanOrEqual(nodeToExpr(left), nodeToExpr(right)) - case Token(LIKE(), left :: right:: Nil) => Like(nodeToExpr(left), nodeToExpr(right)) - case Token(RLIKE(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token(REGEXP(), left :: right:: Nil) => RLike(nodeToExpr(left), nodeToExpr(right)) - case Token("TOK_FUNCTION", Token("TOK_ISNOTNULL", Nil) :: child :: Nil) => - IsNotNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token("TOK_ISNULL", Nil) :: child :: Nil) => - IsNull(nodeToExpr(child)) - case Token("TOK_FUNCTION", Token(IN(), Nil) :: value :: list) => - In(nodeToExpr(value), list.map(nodeToExpr)) - case Token("TOK_FUNCTION", - Token(BETWEEN(), Nil) :: - kw :: - target :: - minValue :: - maxValue :: Nil) => - - val targetExpression = nodeToExpr(target) - val betweenExpr = - And( - GreaterThanOrEqual(targetExpression, nodeToExpr(minValue)), - LessThanOrEqual(targetExpression, nodeToExpr(maxValue))) - kw match { - case Token("KW_FALSE", Nil) => betweenExpr - case Token("KW_TRUE", Nil) => Not(betweenExpr) - } - - /* Boolean Logic */ - case Token(AND(), left :: right:: Nil) => - balancedTree(flattenLeftDeepTree(node, AND).map(nodeToExpr), And) - case Token(OR(), left :: right:: Nil) => - balancedTree(flattenLeftDeepTree(node, OR).map(nodeToExpr), Or) - case Token(NOT(), child :: Nil) => Not(nodeToExpr(child)) - case Token("!", child :: Nil) => Not(nodeToExpr(child)) - - /* Case statements */ - case Token("TOK_FUNCTION", Token(WHEN(), Nil) :: branches) => - CaseWhen.createFromParser(branches.map(nodeToExpr)) - case Token("TOK_FUNCTION", Token(CASE(), Nil) :: branches) => - val keyExpr = nodeToExpr(branches.head) - CaseKeyWhen(keyExpr, branches.drop(1).map(nodeToExpr)) - - /* Complex datatype manipulation */ - case Token("[", child :: ordinal :: Nil) => - UnresolvedExtractValue(nodeToExpr(child), nodeToExpr(ordinal)) - - /* Window Functions */ - case Token(text, args :+ Token("TOK_WINDOWSPEC", spec)) => - val function = nodeToExpr(node.copy(children = node.children.init)) - nodesToWindowSpecification(spec) match { - case reference: WindowSpecReference => - UnresolvedWindowExpression(function, reference) - case definition: WindowSpecDefinition => - WindowExpression(function, definition) - } - - /* UDFs - Must be last otherwise will preempt built in functions */ - case Token("TOK_FUNCTION", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = false) - // Aggregate function with DISTINCT keyword. - case Token("TOK_FUNCTIONDI", Token(name, Nil) :: args) => - UnresolvedFunction(name, args.map(nodeToExpr), isDistinct = true) - case Token("TOK_FUNCTIONSTAR", Token(name, Nil) :: args) => - UnresolvedFunction(name, UnresolvedStar(None) :: Nil, isDistinct = false) - - /* Literals */ - case Token("TOK_NULL", Nil) => Literal.create(null, NullType) - case Token(TRUE(), Nil) => Literal.create(true, BooleanType) - case Token(FALSE(), Nil) => Literal.create(false, BooleanType) - case Token("TOK_STRINGLITERALSEQUENCE", strings) => - Literal(strings.map(s => ParseUtils.unescapeSQLString(s.text)).mkString) - - case ast if ast.tokenType == SparkSqlParser.TinyintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toByte, ByteType) - - case ast if ast.tokenType == SparkSqlParser.SmallintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toShort, ShortType) - - case ast if ast.tokenType == SparkSqlParser.BigintLiteral => - Literal.create(ast.text.substring(0, ast.text.length() - 1).toLong, LongType) - - case ast if ast.tokenType == SparkSqlParser.DoubleLiteral => - Literal(ast.text.toDouble) - - case ast if ast.tokenType == SparkSqlParser.Number => - val text = ast.text - text match { - case INTEGRAL() => - BigDecimal(text) match { - case v if v.isValidInt => - Literal(v.intValue()) - case v if v.isValidLong => - Literal(v.longValue()) - case v => Literal(v.underlying()) - } - case DECIMAL(_*) => - Literal(BigDecimal(text).underlying()) - case _ => - // Convert a scientifically notated decimal into a double. - Literal(text.toDouble) - } - case ast if ast.tokenType == SparkSqlParser.StringLiteral => - Literal(ParseUtils.unescapeSQLString(ast.text)) - - case ast if ast.tokenType == SparkSqlParser.TOK_DATELITERAL => - Literal(Date.valueOf(ast.text.substring(1, ast.text.length - 1))) - - case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_YEAR_MONTH_LITERAL => - Literal(CalendarInterval.fromYearMonthString(ast.children.head.text)) - - case ast if ast.tokenType == SparkSqlParser.TOK_INTERVAL_DAY_TIME_LITERAL => - Literal(CalendarInterval.fromDayTimeString(ast.children.head.text)) - - case Token("TOK_INTERVAL", elements) => - var interval = new CalendarInterval(0, 0) - var updated = false - elements.foreach { - // The interval node will always contain children for all possible time units. A child node - // is only useful when it contains exactly one (numeric) child. - case e @ Token(name, Token(value, Nil) :: Nil) => - val unit = name match { - case "TOK_INTERVAL_YEAR_LITERAL" => "year" - case "TOK_INTERVAL_MONTH_LITERAL" => "month" - case "TOK_INTERVAL_WEEK_LITERAL" => "week" - case "TOK_INTERVAL_DAY_LITERAL" => "day" - case "TOK_INTERVAL_HOUR_LITERAL" => "hour" - case "TOK_INTERVAL_MINUTE_LITERAL" => "minute" - case "TOK_INTERVAL_SECOND_LITERAL" => "second" - case "TOK_INTERVAL_MILLISECOND_LITERAL" => "millisecond" - case "TOK_INTERVAL_MICROSECOND_LITERAL" => "microsecond" - case _ => noParseRule(s"Interval($name)", e) - } - interval = interval.add(CalendarInterval.fromSingleUnitString(unit, value)) - updated = true - case _ => - } - if (!updated) { - throw new AnalysisException("at least one time unit should be given for interval literal") - } - Literal(interval) - - case _ => - noParseRule("Expression", node) - } - - /* Case insensitive matches for Window Specification */ - val PRECEDING = "(?i)preceding".r - val FOLLOWING = "(?i)following".r - val CURRENT = "(?i)current".r - protected def nodesToWindowSpecification(nodes: Seq[ASTNode]): WindowSpec = nodes match { - case Token(windowName, Nil) :: Nil => - // Refer to a window spec defined in the window clause. - WindowSpecReference(windowName) - case Nil => - // OVER() - WindowSpecDefinition( - partitionSpec = Nil, - orderSpec = Nil, - frameSpecification = UnspecifiedFrame) - case spec => - val (partitionClause :: rowFrame :: rangeFrame :: Nil) = - getClauses( - Seq( - "TOK_PARTITIONINGSPEC", - "TOK_WINDOWRANGE", - "TOK_WINDOWVALUES"), - spec) - - // Handle Partition By and Order By. - val (partitionSpec, orderSpec) = partitionClause.map { partitionAndOrdering => - val (partitionByClause :: orderByClause :: sortByClause :: clusterByClause :: Nil) = - getClauses( - Seq("TOK_DISTRIBUTEBY", "TOK_ORDERBY", "TOK_SORTBY", "TOK_CLUSTERBY"), - partitionAndOrdering.children) - - (partitionByClause, orderByClause.orElse(sortByClause), clusterByClause) match { - case (Some(partitionByExpr), Some(orderByExpr), None) => - (partitionByExpr.children.map(nodeToExpr), - orderByExpr.children.map(nodeToSortOrder)) - case (Some(partitionByExpr), None, None) => - (partitionByExpr.children.map(nodeToExpr), Nil) - case (None, Some(orderByExpr), None) => - (Nil, orderByExpr.children.map(nodeToSortOrder)) - case (None, None, Some(clusterByExpr)) => - val expressions = clusterByExpr.children.map(nodeToExpr) - (expressions, expressions.map(SortOrder(_, Ascending))) - case _ => - noParseRule("Partition & Ordering", partitionAndOrdering) - } - }.getOrElse { - (Nil, Nil) - } - - // Handle Window Frame - val windowFrame = - if (rowFrame.isEmpty && rangeFrame.isEmpty) { - UnspecifiedFrame - } else { - val frameType = rowFrame.map(_ => RowFrame).getOrElse(RangeFrame) - def nodeToBoundary(node: ASTNode): FrameBoundary = node match { - case Token(PRECEDING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedPreceding - } else { - ValuePreceding(count.toInt) - } - case Token(FOLLOWING(), Token(count, Nil) :: Nil) => - if (count.toLowerCase() == "unbounded") { - UnboundedFollowing - } else { - ValueFollowing(count.toInt) - } - case Token(CURRENT(), Nil) => CurrentRow - case _ => - noParseRule("Window Frame Boundary", node) - } - - rowFrame.orElse(rangeFrame).map { frame => - frame.children match { - case precedingNode :: followingNode :: Nil => - SpecifiedWindowFrame( - frameType, - nodeToBoundary(precedingNode), - nodeToBoundary(followingNode)) - case precedingNode :: Nil => - SpecifiedWindowFrame(frameType, nodeToBoundary(precedingNode), CurrentRow) - case _ => - noParseRule("Window Frame", frame) - } - }.getOrElse(sys.error(s"If you see this, please file a bug report with your query.")) - } - - WindowSpecDefinition(partitionSpec, orderSpec, windowFrame) - } - - protected def nodeToTransformation( - node: ASTNode, - child: LogicalPlan): Option[ScriptTransformation] = None - - val explode = "(?i)explode".r - val jsonTuple = "(?i)json_tuple".r - protected def nodeToGenerate(node: ASTNode, outer: Boolean, child: LogicalPlan): Generate = { - val Token("TOK_SELECT", Token("TOK_SELEXPR", clauses) :: Nil) = node - - val alias = cleanIdentifier(getClause("TOK_TABALIAS", clauses).children.head.text) - - val generator = clauses.head match { - case Token("TOK_FUNCTION", Token(explode(), Nil) :: childNode :: Nil) => - Explode(nodeToExpr(childNode)) - case Token("TOK_FUNCTION", Token(jsonTuple(), Nil) :: children) => - JsonTuple(children.map(nodeToExpr)) - case other => - nodeToGenerator(other) - } - - val attributes = clauses.collect { - case Token(a, Nil) => UnresolvedAttribute(cleanIdentifier(a.toLowerCase)) - } - - Generate( - generator, - join = true, - outer = outer, - Some(cleanIdentifier(alias.toLowerCase)), - attributes, - child) - } - - protected def nodeToGenerator(node: ASTNode): Generator = noParseRule("Generator", node) - -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala index 21deb82107..0b570c9e42 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeParser.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.catalyst.parser import scala.language.implicitConversions import scala.util.matching.Regex import scala.util.parsing.combinator.syntactical.StandardTokenParsers +import scala.util.parsing.input.CharArrayReader._ import org.apache.spark.sql.types._ @@ -117,3 +118,69 @@ private[sql] object DataTypeParser { /** The exception thrown from the [[DataTypeParser]]. */ private[sql] class DataTypeException(message: String) extends Exception(message) + +class SqlLexical extends scala.util.parsing.combinator.lexical.StdLexical { + case class DecimalLit(chars: String) extends Token { + override def toString: String = chars + } + + /* This is a work around to support the lazy setting */ + def initialize(keywords: Seq[String]): Unit = { + reserved.clear() + reserved ++= keywords + } + + /* Normal the keyword string */ + def normalizeKeyword(str: String): String = str.toLowerCase + + delimiters += ( + "@", "*", "+", "-", "<", "=", "<>", "!=", "<=", ">=", ">", "/", "(", ")", + ",", ";", "%", "{", "}", ":", "[", "]", ".", "&", "|", "^", "~", "<=>" + ) + + protected override def processIdent(name: String) = { + val token = normalizeKeyword(name) + if (reserved contains token) Keyword(token) else Identifier(name) + } + + override lazy val token: Parser[Token] = + ( rep1(digit) ~ scientificNotation ^^ { case i ~ s => DecimalLit(i.mkString + s) } + | '.' ~> (rep1(digit) ~ scientificNotation) ^^ + { case i ~ s => DecimalLit("0." + i.mkString + s) } + | rep1(digit) ~ ('.' ~> digit.*) ~ scientificNotation ^^ + { case i1 ~ i2 ~ s => DecimalLit(i1.mkString + "." + i2.mkString + s) } + | digit.* ~ identChar ~ (identChar | digit).* ^^ + { case first ~ middle ~ rest => processIdent((first ++ (middle :: rest)).mkString) } + | rep1(digit) ~ ('.' ~> digit.*).? ^^ { + case i ~ None => NumericLit(i.mkString) + case i ~ Some(d) => DecimalLit(i.mkString + "." + d.mkString) + } + | '\'' ~> chrExcept('\'', '\n', EofCh).* <~ '\'' ^^ + { case chars => StringLit(chars mkString "") } + | '"' ~> chrExcept('"', '\n', EofCh).* <~ '"' ^^ + { case chars => StringLit(chars mkString "") } + | '`' ~> chrExcept('`', '\n', EofCh).* <~ '`' ^^ + { case chars => Identifier(chars mkString "") } + | EofCh ^^^ EOF + | '\'' ~> failure("unclosed string literal") + | '"' ~> failure("unclosed string literal") + | delim + | failure("illegal character") + ) + + override def identChar: Parser[Elem] = letter | elem('_') + + private lazy val scientificNotation: Parser[String] = + (elem('e') | elem('E')) ~> (elem('+') | elem('-')).? ~ rep1(digit) ^^ { + case s ~ rest => "e" + s.mkString + rest.mkString + } + + override def whitespace: Parser[Any] = + ( whitespaceChar + | '/' ~ '*' ~ comment + | '/' ~ '/' ~ chrExcept(EofCh, '\n').* + | '#' ~ chrExcept(EofCh, '\n').* + | '-' ~ '-' ~ chrExcept(EofCh, '\n').* + | '/' ~ '*' ~ failure("unclosed comment") + ).* +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 51cfc50130..d0132529f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -16,90 +16,105 @@ */ package org.apache.spark.sql.catalyst.parser -import scala.annotation.tailrec - -import org.antlr.runtime._ -import org.antlr.runtime.tree.CommonTree +import org.antlr.v4.runtime._ +import org.antlr.v4.runtime.atn.PredictionMode +import org.antlr.v4.runtime.misc.ParseCancellationException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.Origin +import org.apache.spark.sql.types.DataType /** - * The ParseDriver takes a SQL command and turns this into an AST. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver + * Base SQL parsing infrastructure. */ -object ParseDriver extends Logging { - /** Create an LogicalPlan ASTNode from a SQL command. */ - def parsePlan(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.statement().getTree - } +abstract class AbstractSqlParser extends ParserInterface with Logging { - /** Create an Expression ASTNode from a SQL command. */ - def parseExpression(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleNamedExpression().getTree + /** Creates/Resolves DataType for a given SQL string. */ + def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => + // TODO add this to the parser interface. + astBuilder.visitSingleDataType(parser.singleDataType()) } - /** Create an TableIdentifier ASTNode from a SQL command. */ - def parseTableName(command: String, conf: ParserConf): ASTNode = parse(command, conf) { parser => - parser.singleTableName().getTree + /** Creates Expression for a given SQL string. */ + override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => + astBuilder.visitSingleExpression(parser.singleExpression()) } - private def parse( - command: String, - conf: ParserConf)( - toTree: SparkSqlParser => CommonTree): ASTNode = { - logInfo(s"Parsing command: $command") + /** Creates TableIdentifier for a given SQL string. */ + override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => + astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) + } - // Setup error collection. - val reporter = new ParseErrorReporter() + /** Creates LogicalPlan for a given SQL string. */ + override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => + astBuilder.visitSingleStatement(parser.singleStatement()) match { + case plan: LogicalPlan => plan + case _ => nativeCommand(sqlText) + } + } - // Create lexer. - val lexer = new SparkSqlLexer(new ANTLRNoCaseStringStream(command)) - val tokens = new TokenRewriteStream(lexer) - lexer.configure(conf, reporter) + /** Get the builder (visitor) which converts a ParseTree into a AST. */ + protected def astBuilder: AstBuilder - // Create the parser. - val parser = new SparkSqlParser(tokens) - parser.configure(conf, reporter) + /** Create a native command, or fail when this is not supported. */ + protected def nativeCommand(sqlText: String): LogicalPlan = { + val position = Origin(None, None) + throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) + } - try { - val result = toTree(parser) + protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { + logInfo(s"Parsing command: $command") - // Check errors. - reporter.checkForErrors() + val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) + lexer.removeErrorListeners() + lexer.addErrorListener(ParseErrorListener) - // Return the AST node from the result. - logInfo(s"Parse completed.") + val tokenStream = new CommonTokenStream(lexer) + val parser = new SqlBaseParser(tokenStream) + parser.addParseListener(PostProcessor) + parser.removeErrorListeners() + parser.addErrorListener(ParseErrorListener) - // Find the non null token tree in the result. - @tailrec - def nonNullToken(tree: CommonTree): CommonTree = { - if (tree.token != null || tree.getChildCount == 0) tree - else nonNullToken(tree.getChild(0).asInstanceOf[CommonTree]) + try { + try { + // first, try parsing with potentially faster SLL mode + parser.getInterpreter.setPredictionMode(PredictionMode.SLL) + toResult(parser) } - val tree = nonNullToken(result) - - // Make sure all boundaries are set. - tree.setUnknownTokenBoundaries() - - // Construct the immutable AST. - def createASTNode(tree: CommonTree): ASTNode = { - val children = (0 until tree.getChildCount).map { i => - createASTNode(tree.getChild(i).asInstanceOf[CommonTree]) - }.toList - ASTNode(tree.token, tree.getTokenStartIndex, tree.getTokenStopIndex, children, tokens) + catch { + case e: ParseCancellationException => + // if we fail, parse with LL mode + tokenStream.reset() // rewind input stream + parser.reset() + + // Try Again. + parser.getInterpreter.setPredictionMode(PredictionMode.LL) + toResult(parser) } - createASTNode(tree) } catch { - case e: RecognitionException => - logInfo(s"Parse failed.") - reporter.throwError(e) + case e: ParseException if e.command.isDefined => + throw e + case e: ParseException => + throw e.withCommand(command) + case e: AnalysisException => + val position = Origin(e.line, e.startPosition) + throw new ParseException(Option(command), e.message, position, position) } } } +/** + * Concrete SQL parser for Catalyst-only SQL statements. + */ +object CatalystSqlParser extends AbstractSqlParser { + val astBuilder = new AstBuilder +} + /** * This string stream provides the lexer with upper case characters only. This greatly simplifies * lexing the stream, while we can maintain the original command. @@ -120,58 +135,104 @@ object ParseDriver extends Logging { * have the ANTLRNoCaseStringStream implementation. */ -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRStringStream(input) { +private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { override def LA(i: Int): Int = { val la = super.LA(i) - if (la == 0 || la == CharStream.EOF) la + if (la == 0 || la == IntStream.EOF) la else Character.toUpperCase(la) } } /** - * Utility used by the Parser and the Lexer for error collection and reporting. + * The ParseErrorListener converts parse errors into AnalysisExceptions. */ -private[parser] class ParseErrorReporter { - val errors = scala.collection.mutable.Buffer.empty[ParseError] - - def report(br: BaseRecognizer, re: RecognitionException, tokenNames: Array[String]): Unit = { - errors += ParseError(br, re, tokenNames) +case object ParseErrorListener extends BaseErrorListener { + override def syntaxError( + recognizer: Recognizer[_, _], + offendingSymbol: scala.Any, + line: Int, + charPositionInLine: Int, + msg: String, + e: RecognitionException): Unit = { + val position = Origin(Some(line), Some(charPositionInLine)) + throw new ParseException(None, msg, position, position) } +} - def checkForErrors(): Unit = { - if (errors.nonEmpty) { - val first = errors.head - val e = first.re - throwError(e.line, e.charPositionInLine, first.buildMessage().toString, errors.tail) - } +/** + * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It + * contains fields and an extended error message that make reporting and diagnosing errors easier. + */ +class ParseException( + val command: Option[String], + message: String, + val start: Origin, + val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { + + def this(message: String, ctx: ParserRuleContext) = { + this(Option(ParserUtils.command(ctx)), + message, + ParserUtils.position(ctx.getStart), + ParserUtils.position(ctx.getStop)) } - def throwError(e: RecognitionException): Nothing = { - throwError(e.line, e.charPositionInLine, e.toString, errors) + override def getMessage: String = { + val builder = new StringBuilder + builder ++= "\n" ++= message + start match { + case Origin(Some(l), Some(p)) => + builder ++= s"(line $l, pos $p)\n" + command.foreach { cmd => + val (above, below) = cmd.split("\n").splitAt(l) + builder ++= "\n== SQL ==\n" + above.foreach(builder ++= _ += '\n') + builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" + below.foreach(builder ++= _ += '\n') + } + case _ => + command.foreach { cmd => + builder ++= "\n== SQL ==\n" ++= cmd + } + } + builder.toString } - private def throwError( - line: Int, - startPosition: Int, - msg: String, - errors: Seq[ParseError]): Nothing = { - val b = new StringBuilder - b.append(msg).append("\n") - errors.foreach(error => error.buildMessage(b).append("\n")) - throw new AnalysisException(b.toString, Option(line), Option(startPosition)) + def withCommand(cmd: String): ParseException = { + new ParseException(Option(cmd), message, start, stop) } } /** - * Error collected during the parsing process. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseError + * The post-processor validates & cleans-up the parse tree during the parse process. */ -private[parser] case class ParseError( - br: BaseRecognizer, - re: RecognitionException, - tokenNames: Array[String]) { - def buildMessage(s: StringBuilder = new StringBuilder): StringBuilder = { - s.append(br.getErrorHeader(re)).append(" ").append(br.getErrorMessage(re, tokenNames)) +case object PostProcessor extends SqlBaseBaseListener { + + /** Remove the back ticks from an Identifier. */ + override def exitQuotedIdentifier(ctx: SqlBaseParser.QuotedIdentifierContext): Unit = { + replaceTokenByIdentifier(ctx, 1) { token => + // Remove the double back ticks in the string. + token.setText(token.getText.replace("``", "`")) + token + } + } + + /** Treat non-reserved keywords as Identifiers. */ + override def exitNonReserved(ctx: SqlBaseParser.NonReservedContext): Unit = { + replaceTokenByIdentifier(ctx, 0)(identity) + } + + private def replaceTokenByIdentifier( + ctx: ParserRuleContext, + stripMargins: Int)( + f: CommonToken => CommonToken = identity): Unit = { + val parent = ctx.getParent + parent.removeLastChild() + val token = ctx.getChild(0).getPayload.asInstanceOf[Token] + parent.addChild(f(new CommonToken( + new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), + SqlBaseParser.IDENTIFIER, + token.getChannel, + token.getStartIndex + stripMargins, + token.getStopIndex - stripMargins))) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala deleted file mode 100644 index ce449b1143..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserConf.scala +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -trait ParserConf { - def supportQuotedId: Boolean - def supportSQL11ReservedKeywords: Boolean -} - -case class SimpleParserConf( - supportQuotedId: Boolean = true, - supportSQL11ReservedKeywords: Boolean = false) extends ParserConf diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 0c2e481954..90b76dc314 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -14,166 +14,105 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.sql.catalyst.parser -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.trees.CurrentOrigin -import org.apache.spark.sql.types._ +import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} +import org.antlr.v4.runtime.misc.Interval +import org.antlr.v4.runtime.tree.TerminalNode +import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} /** - * A collection of utility methods and patterns for parsing query texts. + * A collection of utility methods for use during the parsing process. */ -// TODO: merge with ParseUtils object ParserUtils { - - object Token { - // Match on (text, children) - def unapply(node: ASTNode): Some[(String, List[ASTNode])] = { - CurrentOrigin.setPosition(node.line, node.positionInLine) - node.pattern - } + /** Get the command which created the token. */ + def command(ctx: ParserRuleContext): String = { + command(ctx.getStart.getInputStream) } - private val escapedIdentifier = "`(.+)`".r - private val doubleQuotedString = "\"([^\"]+)\"".r - private val singleQuotedString = "'([^']+)'".r - - // Token patterns - val COUNT = "(?i)COUNT".r - val SUM = "(?i)SUM".r - val AND = "(?i)AND".r - val OR = "(?i)OR".r - val NOT = "(?i)NOT".r - val TRUE = "(?i)TRUE".r - val FALSE = "(?i)FALSE".r - val LIKE = "(?i)LIKE".r - val RLIKE = "(?i)RLIKE".r - val REGEXP = "(?i)REGEXP".r - val IN = "(?i)IN".r - val DIV = "(?i)DIV".r - val BETWEEN = "(?i)BETWEEN".r - val WHEN = "(?i)WHEN".r - val CASE = "(?i)CASE".r - val INTEGRAL = "[+-]?\\d+".r - val DECIMAL = "[+-]?((\\d+(\\.\\d*)?)|(\\.\\d+))".r - - /** - * Strip quotes, if any, from the string. - */ - def unquoteString(str: String): String = { - str match { - case singleQuotedString(s) => s - case doubleQuotedString(s) => s - case other => other - } + /** Get the command which created the token. */ + def command(stream: CharStream): String = { + stream.getText(Interval.of(0, stream.size())) } - /** - * Strip backticks, if any, from the string. - */ - def cleanIdentifier(ident: String): String = { - ident match { - case escapedIdentifier(i) => i - case plainIdent => plainIdent - } + /** Get the code that creates the given node. */ + def source(ctx: ParserRuleContext): String = { + val stream = ctx.getStart.getInputStream + stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) } - def getClauses( - clauseNames: Seq[String], - nodeList: Seq[ASTNode]): Seq[Option[ASTNode]] = { - var remainingNodes = nodeList - val clauses = clauseNames.map { clauseName => - val (matches, nonMatches) = remainingNodes.partition(_.text.toUpperCase == clauseName) - remainingNodes = nonMatches ++ (if (matches.nonEmpty) matches.tail else Nil) - matches.headOption - } + /** Get all the text which comes after the given rule. */ + def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) - if (remainingNodes.nonEmpty) { - sys.error( - s"""Unhandled clauses: ${remainingNodes.map(_.treeString).mkString("\n")}. - |You are likely trying to use an unsupported Hive feature."""".stripMargin) - } - clauses + /** Get all the text which comes after the given token. */ + def remainder(token: Token): String = { + val stream = token.getInputStream + val interval = Interval.of(token.getStopIndex + 1, stream.size()) + stream.getText(interval) } - def getClause(clauseName: String, nodeList: Seq[ASTNode]): ASTNode = { - getClauseOption(clauseName, nodeList).getOrElse(sys.error( - s"Expected clause $clauseName missing from ${nodeList.map(_.treeString).mkString("\n")}")) - } + /** Convert a string token into a string. */ + def string(token: Token): String = unescapeSQLString(token.getText) - def getClauseOption(clauseName: String, nodeList: Seq[ASTNode]): Option[ASTNode] = { - nodeList.filter { case ast: ASTNode => ast.text == clauseName } match { - case Seq(oneMatch) => Some(oneMatch) - case Seq() => None - case _ => sys.error(s"Found multiple instances of clause $clauseName") - } - } + /** Convert a string node into a string. */ + def string(node: TerminalNode): String = unescapeSQLString(node.getText) - def extractTableIdent(tableNameParts: ASTNode): TableIdentifier = { - tableNameParts.children.map { - case Token(part, Nil) => cleanIdentifier(part) - } match { - case Seq(tableOnly) => TableIdentifier(tableOnly) - case Seq(databaseName, table) => TableIdentifier(table, Some(databaseName)) - case other => sys.error("Hive only supports tables names like 'tableName' " + - s"or 'databaseName.tableName', found '$other'") - } + /** Get the origin (line and position) of the token. */ + def position(token: Token): Origin = { + Origin(Option(token.getLine), Option(token.getCharPositionInLine)) } - def nodeToDataType(node: ASTNode): DataType = node match { - case Token("TOK_DECIMAL", precision :: scale :: Nil) => - DecimalType(precision.text.toInt, scale.text.toInt) - case Token("TOK_DECIMAL", precision :: Nil) => - DecimalType(precision.text.toInt, 0) - case Token("TOK_DECIMAL", Nil) => DecimalType.USER_DEFAULT - case Token("TOK_BIGINT", Nil) => LongType - case Token("TOK_INT", Nil) => IntegerType - case Token("TOK_TINYINT", Nil) => ByteType - case Token("TOK_SMALLINT", Nil) => ShortType - case Token("TOK_BOOLEAN", Nil) => BooleanType - case Token("TOK_STRING", Nil) => StringType - case Token("TOK_VARCHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_CHAR", Token(_, Nil) :: Nil) => StringType - case Token("TOK_FLOAT", Nil) => FloatType - case Token("TOK_DOUBLE", Nil) => DoubleType - case Token("TOK_DATE", Nil) => DateType - case Token("TOK_TIMESTAMP", Nil) => TimestampType - case Token("TOK_BINARY", Nil) => BinaryType - case Token("TOK_LIST", elementType :: Nil) => ArrayType(nodeToDataType(elementType)) - case Token("TOK_STRUCT", Token("TOK_TABCOLLIST", fields) :: Nil) => - StructType(fields.map(nodeToStructField)) - case Token("TOK_MAP", keyType :: valueType :: Nil) => - MapType(nodeToDataType(keyType), nodeToDataType(valueType)) - case _ => - noParseRule("DataType", node) - } - - def nodeToStructField(node: ASTNode): StructField = node match { - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: Nil) => - StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true) - case Token("TOK_TABCOL", Token(fieldName, Nil) :: dataType :: comment :: Nil) => - val meta = new MetadataBuilder().putString("comment", unquoteString(comment.text)).build() - StructField(cleanIdentifier(fieldName), nodeToDataType(dataType), nullable = true, meta) - case _ => - noParseRule("StructField", node) + /** Assert if a condition holds. If it doesn't throw a parse exception. */ + def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { + if (!f) { + throw new ParseException(message, ctx) + } } /** - * Throw an exception because we cannot parse the given node for some unexpected reason. + * Register the origin of the context. Any TreeNode created in the closure will be assigned the + * registered origin. This method restores the previously set origin after completion of the + * closure. */ - def parseFailed(msg: String, node: ASTNode): Nothing = { - throw new AnalysisException(s"$msg: '${node.source}") + def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { + val current = CurrentOrigin.get + CurrentOrigin.set(position(ctx.getStart)) + try { + f + } finally { + CurrentOrigin.set(current) + } } - /** - * Throw an exception because there are no rules to parse the node. - */ - def noParseRule(msg: String, node: ASTNode): Nothing = { - throw new NotImplementedError( - s"[$msg]: No parse rules for ASTNode type: ${node.tokenType}, tree:\n${node.treeString}") - } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ + implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { + /** + * Create a plan using the block of code when the given context exists. Otherwise return the + * original plan. + */ + def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f + } else { + plan + } + } + /** + * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the + * passed function. The original plan is returned when the context does not exist. + */ + def optionalMap[C <: ParserRuleContext]( + ctx: C)( + f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { + if (ctx != null) { + f(ctx, plan) + } else { + plan + } + } + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala deleted file mode 100644 index 5a64c414fb..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/AstBuilder.scala +++ /dev/null @@ -1,1452 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import java.sql.{Date, Timestamp} - -import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuffer - -import org.antlr.v4.runtime.{ParserRuleContext, Token} -import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval -import org.apache.spark.util.random.RandomSampler - -/** - * The AstBuilder converts an ANTLR4 ParseTree into a catalyst Expression, LogicalPlan or - * TableIdentifier. - */ -class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { - import ParserUtils._ - - protected def typedVisit[T](ctx: ParseTree): T = { - ctx.accept(this).asInstanceOf[T] - } - - override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { - visit(ctx.statement).asInstanceOf[LogicalPlan] - } - - override def visitSingleExpression(ctx: SingleExpressionContext): Expression = withOrigin(ctx) { - visitNamedExpression(ctx.namedExpression) - } - - override def visitSingleTableIdentifier( - ctx: SingleTableIdentifierContext): TableIdentifier = withOrigin(ctx) { - visitTableIdentifier(ctx.tableIdentifier) - } - - override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { - visit(ctx.dataType).asInstanceOf[DataType] - } - - /* ******************************************************************************************** - * Plan parsing - * ******************************************************************************************** */ - protected def plan(tree: ParserRuleContext): LogicalPlan = typedVisit(tree) - - /** - * Make sure we do not try to create a plan for a native command. - */ - override def visitExecuteNativeCommand(ctx: ExecuteNativeCommandContext): LogicalPlan = null - - /** - * Create a plan for a SHOW FUNCTIONS command. - */ - override def visitShowFunctions(ctx: ShowFunctionsContext): LogicalPlan = withOrigin(ctx) { - import ctx._ - if (qualifiedName != null) { - val names = qualifiedName().identifier().asScala.map(_.getText).toList - names match { - case db :: name :: Nil => - ShowFunctions(Some(db), Some(name)) - case name :: Nil => - ShowFunctions(None, Some(name)) - case _ => - throw new ParseException("SHOW FUNCTIONS unsupported name", ctx) - } - } else if (pattern != null) { - ShowFunctions(None, Some(string(pattern))) - } else { - ShowFunctions(None, None) - } - } - - /** - * Create a plan for a DESCRIBE FUNCTION command. - */ - override def visitDescribeFunction(ctx: DescribeFunctionContext): LogicalPlan = withOrigin(ctx) { - val functionName = ctx.qualifiedName().identifier().asScala.map(_.getText).mkString(".") - DescribeFunction(functionName, ctx.EXTENDED != null) - } - - /** - * Create a top-level plan with Common Table Expressions. - */ - override def visitQuery(ctx: QueryContext): LogicalPlan = withOrigin(ctx) { - val query = plan(ctx.queryNoWith) - - // Apply CTEs - query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { - case nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) - } - - // Check for duplicate names. - ctes.groupBy(_._1).filter(_._2.size > 1).foreach { - case (name, _) => - throw new ParseException( - s"Name '$name' is used for multiple common table expressions", ctx) - } - - With(query, ctes.toMap) - } - } - - /** - * Create a named logical plan. - * - * This is only used for Common Table Expressions. - */ - override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.queryNoWith)) - } - - /** - * Create a logical plan which allows for multiple inserts using one 'from' statement. These - * queries have the following SQL form: - * {{{ - * [WITH cte...]? - * FROM src - * [INSERT INTO tbl1 SELECT *]+ - * }}} - * For example: - * {{{ - * FROM db.tbl1 A - * INSERT INTO dbo.tbl1 SELECT * WHERE A.value = 10 LIMIT 5 - * INSERT INTO dbo.tbl2 SELECT * WHERE A.value = 12 - * }}} - * This (Hive) feature cannot be combined with set-operators. - */ - override def visitMultiInsertQuery(ctx: MultiInsertQueryContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - - // Build the insert clauses. - val inserts = ctx.multiInsertQueryBody.asScala.map { - body => - assert(body.querySpecification.fromClause == null, - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", - body) - - withQuerySpecification(body.querySpecification, from). - // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(body.insertInto())(withInsertInto) - } - - // If there are multiple INSERTS just UNION them together into one query. - inserts match { - case Seq(query) => query - case queries => Union(queries) - } - } - - /** - * Create a logical plan for a regular (single-insert) query. - */ - override def visitSingleInsertQuery( - ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm). - // Add organization statements. - optionalMap(ctx.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(ctx.insertInto())(withInsertInto) - } - - /** - * Add an INSERT INTO [TABLE]/INSERT OVERWRITE TABLE operation to the logical plan. - */ - private def withInsertInto( - ctx: InsertIntoContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - val tableIdent = visitTableIdentifier(ctx.tableIdentifier) - val partitionKeys = Option(ctx.partitionSpec).map(visitPartitionSpec).getOrElse(Map.empty) - - InsertIntoTable( - UnresolvedRelation(tableIdent, None), - partitionKeys, - query, - ctx.OVERWRITE != null, - ctx.EXISTS != null) - } - - /** - * Create a partition specification map. - */ - override def visitPartitionSpec( - ctx: PartitionSpecContext): Map[String, Option[String]] = withOrigin(ctx) { - ctx.partitionVal.asScala.map { pVal => - val name = pVal.identifier.getText.toLowerCase - val value = Option(pVal.constant).map(visitStringConstant) - name -> value - }.toMap - } - - /** - * Create a partition specification map without optional values. - */ - protected def visitNonOptionalPartitionSpec( - ctx: PartitionSpecContext): Map[String, String] = withOrigin(ctx) { - visitPartitionSpec(ctx).mapValues(_.orNull).map(identity) - } - - /** - * Convert a constant of any type into a string. This is typically used in DDL commands, and its - * main purpose is to prevent slight differences due to back to back conversions i.e.: - * String -> Literal -> String. - */ - protected def visitStringConstant(ctx: ConstantContext): String = withOrigin(ctx) { - ctx match { - case s: StringLiteralContext => createString(s) - case o => o.getText - } - } - - /** - * Add ORDER BY/SORT BY/CLUSTER BY/DISTRIBUTE BY/LIMIT/WINDOWS clauses to the logical plan. These - * clauses determine the shape (ordering/partitioning/rows) of the query result. - */ - private def withQueryResultClauses( - ctx: QueryOrganizationContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - - // Handle ORDER BY, SORT BY, DISTRIBUTE BY, and CLUSTER BY clause. - val withOrder = if ( - !order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // ORDER BY ... - Sort(order.asScala.map(visitSortItem), global = true, query) - } else if (order.isEmpty && !sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // SORT BY ... - Sort(sort.asScala.map(visitSortItem), global = false, query) - } else if (order.isEmpty && sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { - // DISTRIBUTE BY ... - RepartitionByExpression(expressionList(distributeBy), query) - } else if (order.isEmpty && !sort.isEmpty && !distributeBy.isEmpty && clusterBy.isEmpty) { - // SORT BY ... DISTRIBUTE BY ... - Sort( - sort.asScala.map(visitSortItem), - global = false, - RepartitionByExpression(expressionList(distributeBy), query)) - } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && !clusterBy.isEmpty) { - // CLUSTER BY ... - val expressions = expressionList(clusterBy) - Sort( - expressions.map(SortOrder(_, Ascending)), - global = false, - RepartitionByExpression(expressions, query)) - } else if (order.isEmpty && sort.isEmpty && distributeBy.isEmpty && clusterBy.isEmpty) { - // [EMPTY] - query - } else { - throw new ParseException( - "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", ctx) - } - - // WINDOWS - val withWindow = withOrder.optionalMap(windows)(withWindows) - - // LIMIT - withWindow.optional(limit) { - Limit(typedVisit(limit), withWindow) - } - } - - /** - * Create a logical plan using a query specification. - */ - override def visitQuerySpecification( - ctx: QuerySpecificationContext): LogicalPlan = withOrigin(ctx) { - val from = OneRowRelation.optional(ctx.fromClause) { - visitFromClause(ctx.fromClause) - } - withQuerySpecification(ctx, from) - } - - /** - * Add a query specification to a logical plan. The query specification is the core of the logical - * plan, this is where sourcing (FROM clause), transforming (SELECT TRANSFORM/MAP/REDUCE), - * projection (SELECT), aggregation (GROUP BY ... HAVING ...) and filtering (WHERE) takes place. - * - * Note that query hints are ignored (both by the parser and the builder). - */ - private def withQuerySpecification( - ctx: QuerySpecificationContext, - relation: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - - // WHERE - def filter(ctx: BooleanExpressionContext, plan: LogicalPlan): LogicalPlan = { - Filter(expression(ctx), plan) - } - - // Expressions. - val expressions = Option(namedExpressionSeq).toSeq - .flatMap(_.namedExpression.asScala) - .map(typedVisit[Expression]) - - // Create either a transform or a regular query. - val specType = Option(kind).map(_.getType).getOrElse(SqlBaseParser.SELECT) - specType match { - case SqlBaseParser.MAP | SqlBaseParser.REDUCE | SqlBaseParser.TRANSFORM => - // Transform - - // Add where. - val withFilter = relation.optionalMap(where)(filter) - - // Create the attributes. - val (attributes, schemaLess) = if (colTypeList != null) { - // Typed return columns. - (createStructType(colTypeList).toAttributes, false) - } else if (identifierSeq != null) { - // Untyped return columns. - val attrs = visitIdentifierSeq(identifierSeq).map { name => - AttributeReference(name, StringType, nullable = true)() - } - (attrs, false) - } else { - (Seq(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) - } - - // Create the transform. - ScriptTransformation( - expressions, - string(script), - attributes, - withFilter, - withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) - - case SqlBaseParser.SELECT => - // Regular select - - // Add lateral views. - val withLateralView = ctx.lateralView.asScala.foldLeft(relation)(withGenerate) - - // Add where. - val withFilter = withLateralView.optionalMap(where)(filter) - - // Add aggregation or a project. - val namedExpressions = expressions.map { - case e: NamedExpression => e - case e: Expression => UnresolvedAlias(e) - } - val withProject = if (aggregation != null) { - withAggregation(aggregation, namedExpressions, withFilter) - } else if (namedExpressions.nonEmpty) { - Project(namedExpressions, withFilter) - } else { - withFilter - } - - // Having - val withHaving = withProject.optional(having) { - // Note that we added a cast to boolean. If the expression itself is already boolean, - // the optimizer will get rid of the unnecessary cast. - Filter(Cast(expression(having), BooleanType), withProject) - } - - // Distinct - val withDistinct = if (setQuantifier() != null && setQuantifier().DISTINCT() != null) { - Distinct(withHaving) - } else { - withHaving - } - - // Window - withDistinct.optionalMap(windows)(withWindows) - } - } - - /** - * Create a (Hive based) [[ScriptInputOutputSchema]]. - */ - protected def withScriptIOSchema( - inRowFormat: RowFormatContext, - recordWriter: Token, - outRowFormat: RowFormatContext, - recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = null - - /** - * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma - * separated) relations here, these get converted into a single plan by condition-less inner join. - */ - override def visitFromClause(ctx: FromClauseContext): LogicalPlan = withOrigin(ctx) { - val from = ctx.relation.asScala.map(plan).reduceLeft(Join(_, _, Inner, None)) - ctx.lateralView.asScala.foldLeft(from)(withGenerate) - } - - /** - * Connect two queries by a Set operator. - * - * Supported Set operators are: - * - UNION [DISTINCT] - * - UNION ALL - * - EXCEPT [DISTINCT] - * - INTERSECT [DISTINCT] - */ - override def visitSetOperation(ctx: SetOperationContext): LogicalPlan = withOrigin(ctx) { - val left = plan(ctx.left) - val right = plan(ctx.right) - val all = Option(ctx.setQuantifier()).exists(_.ALL != null) - ctx.operator.getType match { - case SqlBaseParser.UNION if all => - Union(left, right) - case SqlBaseParser.UNION => - Distinct(Union(left, right)) - case SqlBaseParser.INTERSECT if all => - throw new ParseException("INTERSECT ALL is not supported.", ctx) - case SqlBaseParser.INTERSECT => - Intersect(left, right) - case SqlBaseParser.EXCEPT if all => - throw new ParseException("EXCEPT ALL is not supported.", ctx) - case SqlBaseParser.EXCEPT => - Except(left, right) - } - } - - /** - * Add a [[WithWindowDefinition]] operator to a logical plan. - */ - private def withWindows( - ctx: WindowsContext, - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - // Collect all window specifications defined in the WINDOW clause. - val baseWindowMap = ctx.namedWindow.asScala.map { - wCtx => - (wCtx.identifier.getText, typedVisit[WindowSpec](wCtx.windowSpec)) - }.toMap - - // Handle cases like - // window w1 as (partition by p_mfgr order by p_name - // range between 2 preceding and 2 following), - // w2 as w1 - val windowMapView = baseWindowMap.mapValues { - case WindowSpecReference(name) => - baseWindowMap.get(name) match { - case Some(spec: WindowSpecDefinition) => - spec - case Some(ref) => - throw new ParseException(s"Window reference '$name' is not a window specification", ctx) - case None => - throw new ParseException(s"Cannot resolve window reference '$name'", ctx) - } - case spec: WindowSpecDefinition => spec - } - - // Note that mapValues creates a view instead of materialized map. We force materialization by - // mapping over identity. - WithWindowDefinition(windowMapView.map(identity), query) - } - - /** - * Add an [[Aggregate]] to a logical plan. - */ - private def withAggregation( - ctx: AggregationContext, - selectExpressions: Seq[NamedExpression], - query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - import ctx._ - val groupByExpressions = expressionList(groupingExpressions) - - if (GROUPING != null) { - // GROUP BY .... GROUPING SETS (...) - val expressionMap = groupByExpressions.zipWithIndex.toMap - val numExpressions = expressionMap.size - val mask = (1 << numExpressions) - 1 - val masks = ctx.groupingSet.asScala.map { - _.expression.asScala.foldLeft(mask) { - case (bitmap, eCtx) => - // Find the index of the expression. - val e = typedVisit[Expression](eCtx) - val index = expressionMap.find(_._1.semanticEquals(e)).map(_._2).getOrElse( - throw new ParseException( - s"$e doesn't show up in the GROUP BY list", ctx)) - // 0 means that the column at the given index is a grouping column, 1 means it is not, - // so we unset the bit in bitmap. - bitmap & ~(1 << (numExpressions - 1 - index)) - } - } - GroupingSets(masks, groupByExpressions, query, selectExpressions) - } else { - // GROUP BY .... (WITH CUBE | WITH ROLLUP)? - val mappedGroupByExpressions = if (CUBE != null) { - Seq(Cube(groupByExpressions)) - } else if (ROLLUP != null) { - Seq(Rollup(groupByExpressions)) - } else { - groupByExpressions - } - Aggregate(mappedGroupByExpressions, selectExpressions, query) - } - } - - /** - * Add a [[Generate]] (Lateral View) to a logical plan. - */ - private def withGenerate( - query: LogicalPlan, - ctx: LateralViewContext): LogicalPlan = withOrigin(ctx) { - val expressions = expressionList(ctx.expression) - - // Create the generator. - val generator = ctx.qualifiedName.getText.toLowerCase match { - case "explode" if expressions.size == 1 => - Explode(expressions.head) - case "json_tuple" => - JsonTuple(expressions) - case other => - withGenerator(other, expressions, ctx) - } - - Generate( - generator, - join = true, - outer = ctx.OUTER != null, - Some(ctx.tblName.getText.toLowerCase), - ctx.colName.asScala.map(_.getText).map(UnresolvedAttribute.apply), - query) - } - - /** - * Create a [[Generator]]. Override this method in order to support custom Generators. - */ - protected def withGenerator( - name: String, - expressions: Seq[Expression], - ctx: LateralViewContext): Generator = { - throw new ParseException(s"Generator function '$name' is not supported", ctx) - } - - /** - * Create a joins between two or more logical plans. - */ - override def visitJoinRelation(ctx: JoinRelationContext): LogicalPlan = withOrigin(ctx) { - /** Build a join between two plans. */ - def join(ctx: JoinRelationContext, left: LogicalPlan, right: LogicalPlan): Join = { - val baseJoinType = ctx.joinType match { - case null => Inner - case jt if jt.FULL != null => FullOuter - case jt if jt.SEMI != null => LeftSemi - case jt if jt.LEFT != null => LeftOuter - case jt if jt.RIGHT != null => RightOuter - case _ => Inner - } - - // Resolve the join type and join condition - val (joinType, condition) = Option(ctx.joinCriteria) match { - case Some(c) if c.USING != null => - val columns = c.identifier.asScala.map { column => - UnresolvedAttribute.quoted(column.getText) - } - (UsingJoin(baseJoinType, columns), None) - case Some(c) if c.booleanExpression != null => - (baseJoinType, Option(expression(c.booleanExpression))) - case None if ctx.NATURAL != null => - (NaturalJoin(baseJoinType), None) - case None => - (baseJoinType, None) - } - Join(left, right, joinType, condition) - } - - // Handle all consecutive join clauses. ANTLR produces a right nested tree in which the the - // first join clause is at the top. However fields of previously referenced tables can be used - // in following join clauses. The tree needs to be reversed in order to make this work. - var result = plan(ctx.left) - var current = ctx - while (current != null) { - current.right match { - case right: JoinRelationContext => - result = join(current, result, plan(right.left)) - current = right - case right => - result = join(current, result, plan(right)) - current = null - } - } - result - } - - /** - * Add a [[Sample]] to a logical plan. - * - * This currently supports the following sampling methods: - * - TABLESAMPLE(x ROWS): Sample the table down to the given number of rows. - * - TABLESAMPLE(x PERCENT): Sample the table down to the given percentage. Note that percentages - * are defined as a number between 0 and 100. - * - TABLESAMPLE(BUCKET x OUT OF y): Sample the table down to a 'x' divided by 'y' fraction. - */ - private def withSample(ctx: SampleContext, query: LogicalPlan): LogicalPlan = withOrigin(ctx) { - // Create a sampled plan if we need one. - def sample(fraction: Double): Sample = { - // The range of fraction accepted by Sample is [0, 1]. Because Hive's block sampling - // function takes X PERCENT as the input and the range of X is [0, 100], we need to - // adjust the fraction. - val eps = RandomSampler.roundingEpsilon - assert(fraction >= 0.0 - eps && fraction <= 1.0 + eps, - s"Sampling fraction ($fraction) must be on interval [0, 1]", - ctx) - Sample(0.0, fraction, withReplacement = false, (math.random * 1000).toInt, query)(true) - } - - ctx.sampleType.getType match { - case SqlBaseParser.ROWS => - Limit(expression(ctx.expression), query) - - case SqlBaseParser.PERCENTLIT => - val fraction = ctx.percentage.getText.toDouble - sample(fraction / 100.0d) - - case SqlBaseParser.BUCKET if ctx.ON != null => - throw new ParseException("TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported", ctx) - - case SqlBaseParser.BUCKET => - sample(ctx.numerator.getText.toDouble / ctx.denominator.getText.toDouble) - } - } - - /** - * Create a logical plan for a sub-query. - */ - override def visitSubquery(ctx: SubqueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryNoWith) - } - - /** - * Create an un-aliased table reference. This is typically used for top-level table references, - * for example: - * {{{ - * INSERT INTO db.tbl2 - * TABLE db.tbl1 - * }}} - */ - override def visitTable(ctx: TableContext): LogicalPlan = withOrigin(ctx) { - UnresolvedRelation(visitTableIdentifier(ctx.tableIdentifier), None) - } - - /** - * Create an aliased table reference. This is typically used in FROM clauses. - */ - override def visitTableName(ctx: TableNameContext): LogicalPlan = withOrigin(ctx) { - val table = UnresolvedRelation( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.identifier).map(_.getText)) - table.optionalMap(ctx.sample)(withSample) - } - - /** - * Create an inline table (a virtual table in Hive parlance). - */ - override def visitInlineTable(ctx: InlineTableContext): LogicalPlan = withOrigin(ctx) { - // Get the backing expressions. - val expressions = ctx.expression.asScala.map { eCtx => - val e = expression(eCtx) - assert(e.foldable, "All expressions in an inline table must be constants.", eCtx) - e - } - - // Validate and evaluate the rows. - val (structType, structConstructor) = expressions.head.dataType match { - case st: StructType => - (st, (e: Expression) => e) - case dt => - val st = CreateStruct(Seq(expressions.head)).dataType - (st, (e: Expression) => CreateStruct(Seq(e))) - } - val rows = expressions.map { - case expression => - val safe = Cast(structConstructor(expression), structType) - safe.eval().asInstanceOf[InternalRow] - } - - // Construct attributes. - val baseAttributes = structType.toAttributes.map(_.withNullability(true)) - val attributes = if (ctx.identifierList != null) { - val aliases = visitIdentifierList(ctx.identifierList) - assert(aliases.size == baseAttributes.size, - "Number of aliases must match the number of fields in an inline table.", ctx) - baseAttributes.zip(aliases).map(p => p._1.withName(p._2)) - } else { - baseAttributes - } - - // Create plan and add an alias if a name has been defined. - LocalRelation(attributes, rows).optionalMap(ctx.identifier)(aliasPlan) - } - - /** - * Create an alias (SubqueryAlias) for a join relation. This is practically the same as - * visitAliasedQuery and visitNamedExpression, ANTLR4 however requires us to use 3 different - * hooks. - */ - override def visitAliasedRelation(ctx: AliasedRelationContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.relation).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) - } - - /** - * Create an alias (SubqueryAlias) for a sub-query. This is practically the same as - * visitAliasedRelation and visitNamedExpression, ANTLR4 however requires us to use 3 different - * hooks. - */ - override def visitAliasedQuery(ctx: AliasedQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryNoWith).optionalMap(ctx.sample)(withSample).optionalMap(ctx.identifier)(aliasPlan) - } - - /** - * Create an alias (SubqueryAlias) for a LogicalPlan. - */ - private def aliasPlan(alias: IdentifierContext, plan: LogicalPlan): LogicalPlan = { - SubqueryAlias(alias.getText, plan) - } - - /** - * Create a Sequence of Strings for a parenthesis enclosed alias list. - */ - override def visitIdentifierList(ctx: IdentifierListContext): Seq[String] = withOrigin(ctx) { - visitIdentifierSeq(ctx.identifierSeq) - } - - /** - * Create a Sequence of Strings for an identifier list. - */ - override def visitIdentifierSeq(ctx: IdentifierSeqContext): Seq[String] = withOrigin(ctx) { - ctx.identifier.asScala.map(_.getText) - } - - /* ******************************************************************************************** - * Table Identifier parsing - * ******************************************************************************************** */ - /** - * Create a [[TableIdentifier]] from a 'tableName' or 'databaseName'.'tableName' pattern. - */ - override def visitTableIdentifier( - ctx: TableIdentifierContext): TableIdentifier = withOrigin(ctx) { - TableIdentifier(ctx.table.getText, Option(ctx.db).map(_.getText)) - } - - /* ******************************************************************************************** - * Expression parsing - * ******************************************************************************************** */ - /** - * Create an expression from the given context. This method just passes the context on to the - * vistor and only takes care of typing (We assume that the visitor returns an Expression here). - */ - protected def expression(ctx: ParserRuleContext): Expression = typedVisit(ctx) - - /** - * Create sequence of expressions from the given sequence of contexts. - */ - private def expressionList(trees: java.util.List[ExpressionContext]): Seq[Expression] = { - trees.asScala.map(expression) - } - - /** - * Invert a boolean expression if it has a valid NOT clause. - */ - private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { - if (not != null) { - Not(expression) - } else { - expression - } - } - - /** - * Create a star (i.e. all) expression; this selects all elements (in the specified object). - * Both un-targeted (global) and targeted aliases are supported. - */ - override def visitStar(ctx: StarContext): Expression = withOrigin(ctx) { - UnresolvedStar(Option(ctx.qualifiedName()).map(_.identifier.asScala.map(_.getText))) - } - - /** - * Create an aliased expression if an alias is specified. Both single and multi-aliases are - * supported. - */ - override def visitNamedExpression(ctx: NamedExpressionContext): Expression = withOrigin(ctx) { - val e = expression(ctx.expression) - if (ctx.identifier != null) { - Alias(e, ctx.identifier.getText)() - } else if (ctx.identifierList != null) { - MultiAlias(e, visitIdentifierList(ctx.identifierList)) - } else { - e - } - } - - /** - * Combine a number of boolean expressions into a balanced expression tree. These expressions are - * either combined by a logical [[And]] or a logical [[Or]]. - * - * A balanced binary tree is created because regular left recursive trees cause considerable - * performance degradations and can cause stack overflows. - */ - override def visitLogicalBinary(ctx: LogicalBinaryContext): Expression = withOrigin(ctx) { - val expressionType = ctx.operator.getType - val expressionCombiner = expressionType match { - case SqlBaseParser.AND => And.apply _ - case SqlBaseParser.OR => Or.apply _ - } - - // Collect all similar left hand contexts. - val contexts = ArrayBuffer(ctx.right) - var current = ctx.left - def collectContexts: Boolean = current match { - case lbc: LogicalBinaryContext if lbc.operator.getType == expressionType => - contexts += lbc.right - current = lbc.left - true - case _ => - contexts += current - false - } - while (collectContexts) { - // No body - all updates take place in the collectContexts. - } - - // Reverse the contexts to have them in the same sequence as in the SQL statement & turn them - // into expressions. - val expressions = contexts.reverse.map(expression) - - // Create a balanced tree. - def reduceToExpressionTree(low: Int, high: Int): Expression = high - low match { - case 0 => - expressions(low) - case 1 => - expressionCombiner(expressions(low), expressions(high)) - case x => - val mid = low + x / 2 - expressionCombiner( - reduceToExpressionTree(low, mid), - reduceToExpressionTree(mid + 1, high)) - } - reduceToExpressionTree(0, expressions.size - 1) - } - - /** - * Invert a boolean expression. - */ - override def visitLogicalNot(ctx: LogicalNotContext): Expression = withOrigin(ctx) { - Not(expression(ctx.booleanExpression())) - } - - /** - * Create a filtering correlated sub-query. This is not supported yet. - */ - override def visitExists(ctx: ExistsContext): Expression = { - throw new ParseException("EXISTS clauses are not supported.", ctx) - } - - /** - * Create a comparison expression. This compares two expressions. The following comparison - * operators are supported: - * - Equal: '=' or '==' - * - Null-safe Equal: '<=>' - * - Not Equal: '<>' or '!=' - * - Less than: '<' - * - Less then or Equal: '<=' - * - Greater than: '>' - * - Greater then or Equal: '>=' - */ - override def visitComparison(ctx: ComparisonContext): Expression = withOrigin(ctx) { - val left = expression(ctx.left) - val right = expression(ctx.right) - val operator = ctx.comparisonOperator().getChild(0).asInstanceOf[TerminalNode] - operator.getSymbol.getType match { - case SqlBaseParser.EQ => - EqualTo(left, right) - case SqlBaseParser.NSEQ => - EqualNullSafe(left, right) - case SqlBaseParser.NEQ | SqlBaseParser.NEQJ => - Not(EqualTo(left, right)) - case SqlBaseParser.LT => - LessThan(left, right) - case SqlBaseParser.LTE => - LessThanOrEqual(left, right) - case SqlBaseParser.GT => - GreaterThan(left, right) - case SqlBaseParser.GTE => - GreaterThanOrEqual(left, right) - } - } - - /** - * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two - * other expressions. The inverse can also be created. - */ - override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - val between = And( - GreaterThanOrEqual(value, expression(ctx.lower)), - LessThanOrEqual(value, expression(ctx.upper))) - invertIfNotDefined(between, ctx.NOT) - } - - /** - * Create an IN expression. This tests if the value of the left hand side expression is - * contained by the sequence of expressions on the right hand side. - */ - override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { - val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) - invertIfNotDefined(in, ctx.NOT) - } - - /** - * Create an IN expression, where the the right hand side is a query. This is unsupported. - */ - override def visitInSubquery(ctx: InSubqueryContext): Expression = { - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) - } - - /** - * Create a (R)LIKE/REGEXP expression. - */ - override def visitLike(ctx: LikeContext): Expression = { - val left = expression(ctx.value) - val right = expression(ctx.pattern) - val like = ctx.like.getType match { - case SqlBaseParser.LIKE => - Like(left, right) - case SqlBaseParser.RLIKE => - RLike(left, right) - } - invertIfNotDefined(like, ctx.NOT) - } - - /** - * Create an IS (NOT) NULL expression. - */ - override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - if (ctx.NOT != null) { - IsNotNull(value) - } else { - IsNull(value) - } - } - - /** - * Create a binary arithmetic expression. The following arithmetic operators are supported: - * - Mulitplication: '*' - * - Division: '/' - * - Hive Long Division: 'DIV' - * - Modulo: '%' - * - Addition: '+' - * - Subtraction: '-' - * - Binary AND: '&' - * - Binary XOR - * - Binary OR: '|' - */ - override def visitArithmeticBinary(ctx: ArithmeticBinaryContext): Expression = withOrigin(ctx) { - val left = expression(ctx.left) - val right = expression(ctx.right) - ctx.operator.getType match { - case SqlBaseParser.ASTERISK => - Multiply(left, right) - case SqlBaseParser.SLASH => - Divide(left, right) - case SqlBaseParser.PERCENT => - Remainder(left, right) - case SqlBaseParser.DIV => - Cast(Divide(left, right), LongType) - case SqlBaseParser.PLUS => - Add(left, right) - case SqlBaseParser.MINUS => - Subtract(left, right) - case SqlBaseParser.AMPERSAND => - BitwiseAnd(left, right) - case SqlBaseParser.HAT => - BitwiseXor(left, right) - case SqlBaseParser.PIPE => - BitwiseOr(left, right) - } - } - - /** - * Create a unary arithmetic expression. The following arithmetic operators are supported: - * - Plus: '+' - * - Minus: '-' - * - Bitwise Not: '~' - */ - override def visitArithmeticUnary(ctx: ArithmeticUnaryContext): Expression = withOrigin(ctx) { - val value = expression(ctx.valueExpression) - ctx.operator.getType match { - case SqlBaseParser.PLUS => - value - case SqlBaseParser.MINUS => - UnaryMinus(value) - case SqlBaseParser.TILDE => - BitwiseNot(value) - } - } - - /** - * Create a [[Cast]] expression. - */ - override def visitCast(ctx: CastContext): Expression = withOrigin(ctx) { - Cast(expression(ctx.expression), typedVisit(ctx.dataType)) - } - - /** - * Create a (windowed) Function expression. - */ - override def visitFunctionCall(ctx: FunctionCallContext): Expression = withOrigin(ctx) { - // Create the function call. - val name = ctx.qualifiedName.getText - val isDistinct = Option(ctx.setQuantifier()).exists(_.DISTINCT != null) - val arguments = ctx.expression().asScala.map(expression) match { - case Seq(UnresolvedStar(None)) if name.toLowerCase == "count" && !isDistinct => - // Transform COUNT(*) into COUNT(1). Move this to analysis? - Seq(Literal(1)) - case expressions => - expressions - } - val function = UnresolvedFunction(name, arguments, isDistinct) - - // Check if the function is evaluated in a windowed context. - ctx.windowSpec match { - case spec: WindowRefContext => - UnresolvedWindowExpression(function, visitWindowRef(spec)) - case spec: WindowDefContext => - WindowExpression(function, visitWindowDef(spec)) - case _ => function - } - } - - /** - * Create a reference to a window frame, i.e. [[WindowSpecReference]]. - */ - override def visitWindowRef(ctx: WindowRefContext): WindowSpecReference = withOrigin(ctx) { - WindowSpecReference(ctx.identifier.getText) - } - - /** - * Create a window definition, i.e. [[WindowSpecDefinition]]. - */ - override def visitWindowDef(ctx: WindowDefContext): WindowSpecDefinition = withOrigin(ctx) { - // CLUSTER BY ... | PARTITION BY ... ORDER BY ... - val partition = ctx.partition.asScala.map(expression) - val order = ctx.sortItem.asScala.map(visitSortItem) - - // RANGE/ROWS BETWEEN ... - val frameSpecOption = Option(ctx.windowFrame).map { frame => - val frameType = frame.frameType.getType match { - case SqlBaseParser.RANGE => RangeFrame - case SqlBaseParser.ROWS => RowFrame - } - - SpecifiedWindowFrame( - frameType, - visitFrameBound(frame.start), - Option(frame.end).map(visitFrameBound).getOrElse(CurrentRow)) - } - - WindowSpecDefinition( - partition, - order, - frameSpecOption.getOrElse(UnspecifiedFrame)) - } - - /** - * Create or resolve a [[FrameBoundary]]. Simple math expressions are allowed for Value - * Preceding/Following boundaries. These expressions must be constant (foldable) and return an - * integer value. - */ - override def visitFrameBound(ctx: FrameBoundContext): FrameBoundary = withOrigin(ctx) { - // We currently only allow foldable integers. - def value: Int = { - val e = expression(ctx.expression) - assert(e.resolved && e.foldable && e.dataType == IntegerType, - "Frame bound value must be a constant integer.", - ctx) - e.eval().asInstanceOf[Int] - } - - // Create the FrameBoundary - ctx.boundType.getType match { - case SqlBaseParser.PRECEDING if ctx.UNBOUNDED != null => - UnboundedPreceding - case SqlBaseParser.PRECEDING => - ValuePreceding(value) - case SqlBaseParser.CURRENT => - CurrentRow - case SqlBaseParser.FOLLOWING if ctx.UNBOUNDED != null => - UnboundedFollowing - case SqlBaseParser.FOLLOWING => - ValueFollowing(value) - } - } - - /** - * Create a [[CreateStruct]] expression. - */ - override def visitRowConstructor(ctx: RowConstructorContext): Expression = withOrigin(ctx) { - CreateStruct(ctx.expression.asScala.map(expression)) - } - - /** - * Create a [[ScalarSubquery]] expression. - */ - override def visitSubqueryExpression( - ctx: SubqueryExpressionContext): Expression = withOrigin(ctx) { - ScalarSubquery(plan(ctx.query)) - } - - /** - * Create a value based [[CaseWhen]] expression. This has the following SQL form: - * {{{ - * CASE [expression] - * WHEN [value] THEN [expression] - * ... - * ELSE [expression] - * END - * }}} - */ - override def visitSimpleCase(ctx: SimpleCaseContext): Expression = withOrigin(ctx) { - val e = expression(ctx.valueExpression) - val branches = ctx.whenClause.asScala.map { wCtx => - (EqualTo(e, expression(wCtx.condition)), expression(wCtx.result)) - } - CaseWhen(branches, Option(ctx.elseExpression).map(expression)) - } - - /** - * Create a condition based [[CaseWhen]] expression. This has the following SQL syntax: - * {{{ - * CASE - * WHEN [predicate] THEN [expression] - * ... - * ELSE [expression] - * END - * }}} - * - * @param ctx the parse tree - * */ - override def visitSearchedCase(ctx: SearchedCaseContext): Expression = withOrigin(ctx) { - val branches = ctx.whenClause.asScala.map { wCtx => - (expression(wCtx.condition), expression(wCtx.result)) - } - CaseWhen(branches, Option(ctx.elseExpression).map(expression)) - } - - /** - * Create a dereference expression. The return type depends on the type of the parent, this can - * either be a [[UnresolvedAttribute]] (if the parent is an [[UnresolvedAttribute]]), or an - * [[UnresolvedExtractValue]] if the parent is some expression. - */ - override def visitDereference(ctx: DereferenceContext): Expression = withOrigin(ctx) { - val attr = ctx.fieldName.getText - expression(ctx.base) match { - case UnresolvedAttribute(nameParts) => - UnresolvedAttribute(nameParts :+ attr) - case e => - UnresolvedExtractValue(e, Literal(attr)) - } - } - - /** - * Create an [[UnresolvedAttribute]] expression. - */ - override def visitColumnReference(ctx: ColumnReferenceContext): Expression = withOrigin(ctx) { - UnresolvedAttribute.quoted(ctx.getText) - } - - /** - * Create an [[UnresolvedExtractValue]] expression, this is used for subscript access to an array. - */ - override def visitSubscript(ctx: SubscriptContext): Expression = withOrigin(ctx) { - UnresolvedExtractValue(expression(ctx.value), expression(ctx.index)) - } - - /** - * Create an expression for an expression between parentheses. This is need because the ANTLR - * visitor cannot automatically convert the nested context into an expression. - */ - override def visitParenthesizedExpression( - ctx: ParenthesizedExpressionContext): Expression = withOrigin(ctx) { - expression(ctx.expression) - } - - /** - * Create a [[SortOrder]] expression. - */ - override def visitSortItem(ctx: SortItemContext): SortOrder = withOrigin(ctx) { - if (ctx.DESC != null) { - SortOrder(expression(ctx.expression), Descending) - } else { - SortOrder(expression(ctx.expression), Ascending) - } - } - - /** - * Create a typed Literal expression. A typed literal has the following SQL syntax: - * {{{ - * [TYPE] '[VALUE]' - * }}} - * Currently Date and Timestamp typed literals are supported. - * - * TODO what the added value of this over casting? - */ - override def visitTypeConstructor(ctx: TypeConstructorContext): Literal = withOrigin(ctx) { - val value = string(ctx.STRING) - ctx.identifier.getText.toUpperCase match { - case "DATE" => - Literal(Date.valueOf(value)) - case "TIMESTAMP" => - Literal(Timestamp.valueOf(value)) - case other => - throw new ParseException(s"Literals of type '$other' are currently not supported.", ctx) - } - } - - /** - * Create a NULL literal expression. - */ - override def visitNullLiteral(ctx: NullLiteralContext): Literal = withOrigin(ctx) { - Literal(null) - } - - /** - * Create a Boolean literal expression. - */ - override def visitBooleanLiteral(ctx: BooleanLiteralContext): Literal = withOrigin(ctx) { - if (ctx.getText.toBoolean) { - Literal.TrueLiteral - } else { - Literal.FalseLiteral - } - } - - /** - * Create an integral literal expression. The code selects the most narrow integral type - * possible, either a BigDecimal, a Long or an Integer is returned. - */ - override def visitIntegerLiteral(ctx: IntegerLiteralContext): Literal = withOrigin(ctx) { - BigDecimal(ctx.getText) match { - case v if v.isValidInt => - Literal(v.intValue()) - case v if v.isValidLong => - Literal(v.longValue()) - case v => Literal(v.underlying()) - } - } - - /** - * Create a double literal for a number denoted in scientifc notation. - */ - override def visitScientificDecimalLiteral( - ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { - Literal(ctx.getText.toDouble) - } - - /** - * Create a decimal literal for a regular decimal number. - */ - override def visitDecimalLiteral(ctx: DecimalLiteralContext): Literal = withOrigin(ctx) { - Literal(BigDecimal(ctx.getText).underlying()) - } - - /** Create a numeric literal expression. */ - private def numericLiteral(ctx: NumberContext)(f: String => Any): Literal = withOrigin(ctx) { - val raw = ctx.getText - try { - Literal(f(raw.substring(0, raw.length - 1))) - } catch { - case e: NumberFormatException => - throw new ParseException(e.getMessage, ctx) - } - } - - /** - * Create a Byte Literal expression. - */ - override def visitTinyIntLiteral(ctx: TinyIntLiteralContext): Literal = numericLiteral(ctx) { - _.toByte - } - - /** - * Create a Short Literal expression. - */ - override def visitSmallIntLiteral(ctx: SmallIntLiteralContext): Literal = numericLiteral(ctx) { - _.toShort - } - - /** - * Create a Long Literal expression. - */ - override def visitBigIntLiteral(ctx: BigIntLiteralContext): Literal = numericLiteral(ctx) { - _.toLong - } - - /** - * Create a Double Literal expression. - */ - override def visitDoubleLiteral(ctx: DoubleLiteralContext): Literal = numericLiteral(ctx) { - _.toDouble - } - - /** - * Create a String literal expression. - */ - override def visitStringLiteral(ctx: StringLiteralContext): Literal = withOrigin(ctx) { - Literal(createString(ctx)) - } - - /** - * Create a String from a string literal context. This supports multiple consecutive string - * literals, these are concatenated, for example this expression "'hello' 'world'" will be - * converted into "helloworld". - * - * Special characters can be escaped by using Hive/C-style escaping. - */ - private def createString(ctx: StringLiteralContext): String = { - ctx.STRING().asScala.map(string).mkString - } - - /** - * Create a [[CalendarInterval]] literal expression. An interval expression can contain multiple - * unit value pairs, for instance: interval 2 months 2 days. - */ - override def visitInterval(ctx: IntervalContext): Literal = withOrigin(ctx) { - val intervals = ctx.intervalField.asScala.map(visitIntervalField) - assert(intervals.nonEmpty, "at least one time unit should be given for interval literal", ctx) - Literal(intervals.reduce(_.add(_))) - } - - /** - * Create a [[CalendarInterval]] for a unit value pair. Two unit configuration types are - * supported: - * - Single unit. - * - From-To unit (only 'YEAR TO MONTH' and 'DAY TO SECOND' are supported). - */ - override def visitIntervalField(ctx: IntervalFieldContext): CalendarInterval = withOrigin(ctx) { - import ctx._ - val s = value.getText - val interval = (unit.getText.toLowerCase, Option(to).map(_.getText.toLowerCase)) match { - case (u, None) if u.endsWith("s") => - // Handle plural forms, e.g: yearS/monthS/weekS/dayS/hourS/minuteS/hourS/... - CalendarInterval.fromSingleUnitString(u.substring(0, u.length - 1), s) - case (u, None) => - CalendarInterval.fromSingleUnitString(u, s) - case ("year", Some("month")) => - CalendarInterval.fromYearMonthString(s) - case ("day", Some("second")) => - CalendarInterval.fromDayTimeString(s) - case (from, Some(t)) => - throw new ParseException(s"Intervals FROM $from TO $t are not supported.", ctx) - } - assert(interval != null, "No interval can be constructed", ctx) - interval - } - - /* ******************************************************************************************** - * DataType parsing - * ******************************************************************************************** */ - /** - * Resolve/create a primitive type. - */ - override def visitPrimitiveDataType(ctx: PrimitiveDataTypeContext): DataType = withOrigin(ctx) { - (ctx.identifier.getText.toLowerCase, ctx.INTEGER_VALUE().asScala.toList) match { - case ("boolean", Nil) => BooleanType - case ("tinyint" | "byte", Nil) => ByteType - case ("smallint" | "short", Nil) => ShortType - case ("int" | "integer", Nil) => IntegerType - case ("bigint" | "long", Nil) => LongType - case ("float", Nil) => FloatType - case ("double", Nil) => DoubleType - case ("date", Nil) => DateType - case ("timestamp", Nil) => TimestampType - case ("char" | "varchar" | "string", Nil) => StringType - case ("char" | "varchar", _ :: Nil) => StringType - case ("binary", Nil) => BinaryType - case ("decimal", Nil) => DecimalType.USER_DEFAULT - case ("decimal", precision :: Nil) => DecimalType(precision.getText.toInt, 0) - case ("decimal", precision :: scale :: Nil) => - DecimalType(precision.getText.toInt, scale.getText.toInt) - case (dt, params) => - throw new ParseException( - s"DataType $dt${params.mkString("(", ",", ")")} is not supported.", ctx) - } - } - - /** - * Create a complex DataType. Arrays, Maps and Structures are supported. - */ - override def visitComplexDataType(ctx: ComplexDataTypeContext): DataType = withOrigin(ctx) { - ctx.complex.getType match { - case SqlBaseParser.ARRAY => - ArrayType(typedVisit(ctx.dataType(0))) - case SqlBaseParser.MAP => - MapType(typedVisit(ctx.dataType(0)), typedVisit(ctx.dataType(1))) - case SqlBaseParser.STRUCT => - createStructType(ctx.colTypeList()) - } - } - - /** - * Create a [[StructType]] from a sequence of [[StructField]]s. - */ - protected def createStructType(ctx: ColTypeListContext): StructType = { - StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) - } - - /** - * Create a [[StructType]] from a number of column definitions. - */ - override def visitColTypeList(ctx: ColTypeListContext): Seq[StructField] = withOrigin(ctx) { - ctx.colType().asScala.map(visitColType) - } - - /** - * Create a [[StructField]] from a column definition. - */ - override def visitColType(ctx: ColTypeContext): StructField = withOrigin(ctx) { - import ctx._ - - // Add the comment to the metadata. - val builder = new MetadataBuilder - if (STRING != null) { - builder.putString("comment", string(STRING)) - } - - StructField(identifier.getText, typedVisit(dataType), nullable = true, builder.build()) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala deleted file mode 100644 index c9a286374c..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParseDriver.scala +++ /dev/null @@ -1,240 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.antlr.v4.runtime._ -import org.antlr.v4.runtime.atn.PredictionMode -import org.antlr.v4.runtime.misc.ParseCancellationException - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.parser.ParserInterface -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.Origin -import org.apache.spark.sql.types.DataType - -/** - * Base SQL parsing infrastructure. - */ -abstract class AbstractSqlParser extends ParserInterface with Logging { - - /** Creates/Resolves DataType for a given SQL string. */ - def parseDataType(sqlText: String): DataType = parse(sqlText) { parser => - // TODO add this to the parser interface. - astBuilder.visitSingleDataType(parser.singleDataType()) - } - - /** Creates Expression for a given SQL string. */ - override def parseExpression(sqlText: String): Expression = parse(sqlText) { parser => - astBuilder.visitSingleExpression(parser.singleExpression()) - } - - /** Creates TableIdentifier for a given SQL string. */ - override def parseTableIdentifier(sqlText: String): TableIdentifier = parse(sqlText) { parser => - astBuilder.visitSingleTableIdentifier(parser.singleTableIdentifier()) - } - - /** Creates LogicalPlan for a given SQL string. */ - override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { parser => - astBuilder.visitSingleStatement(parser.singleStatement()) match { - case plan: LogicalPlan => plan - case _ => nativeCommand(sqlText) - } - } - - /** Get the builder (visitor) which converts a ParseTree into a AST. */ - protected def astBuilder: AstBuilder - - /** Create a native command, or fail when this is not supported. */ - protected def nativeCommand(sqlText: String): LogicalPlan = { - val position = Origin(None, None) - throw new ParseException(Option(sqlText), "Unsupported SQL statement", position, position) - } - - protected def parse[T](command: String)(toResult: SqlBaseParser => T): T = { - logInfo(s"Parsing command: $command") - - val lexer = new SqlBaseLexer(new ANTLRNoCaseStringStream(command)) - lexer.removeErrorListeners() - lexer.addErrorListener(ParseErrorListener) - - val tokenStream = new CommonTokenStream(lexer) - val parser = new SqlBaseParser(tokenStream) - parser.addParseListener(PostProcessor) - parser.removeErrorListeners() - parser.addErrorListener(ParseErrorListener) - - try { - try { - // first, try parsing with potentially faster SLL mode - parser.getInterpreter.setPredictionMode(PredictionMode.SLL) - toResult(parser) - } - catch { - case e: ParseCancellationException => - // if we fail, parse with LL mode - tokenStream.reset() // rewind input stream - parser.reset() - - // Try Again. - parser.getInterpreter.setPredictionMode(PredictionMode.LL) - toResult(parser) - } - } - catch { - case e: ParseException if e.command.isDefined => - throw e - case e: ParseException => - throw e.withCommand(command) - case e: AnalysisException => - val position = Origin(e.line, e.startPosition) - throw new ParseException(Option(command), e.message, position, position) - } - } -} - -/** - * Concrete SQL parser for Catalyst-only SQL statements. - */ -object CatalystSqlParser extends AbstractSqlParser { - val astBuilder = new AstBuilder -} - -/** - * This string stream provides the lexer with upper case characters only. This greatly simplifies - * lexing the stream, while we can maintain the original command. - * - * This is based on Hive's org.apache.hadoop.hive.ql.parse.ParseDriver.ANTLRNoCaseStringStream - * - * The comment below (taken from the original class) describes the rationale for doing this: - * - * This class provides and implementation for a case insensitive token checker for the lexical - * analysis part of antlr. By converting the token stream into upper case at the time when lexical - * rules are checked, this class ensures that the lexical rules need to just match the token with - * upper case letters as opposed to combination of upper case and lower case characters. This is - * purely used for matching lexical rules. The actual token text is stored in the same way as the - * user input without actually converting it into an upper case. The token values are generated by - * the consume() function of the super class ANTLRStringStream. The LA() function is the lookahead - * function and is purely used for matching lexical rules. This also means that the grammar will - * only accept capitalized tokens in case it is run from other tools like antlrworks which do not - * have the ANTLRNoCaseStringStream implementation. - */ - -private[parser] class ANTLRNoCaseStringStream(input: String) extends ANTLRInputStream(input) { - override def LA(i: Int): Int = { - val la = super.LA(i) - if (la == 0 || la == IntStream.EOF) la - else Character.toUpperCase(la) - } -} - -/** - * The ParseErrorListener converts parse errors into AnalysisExceptions. - */ -case object ParseErrorListener extends BaseErrorListener { - override def syntaxError( - recognizer: Recognizer[_, _], - offendingSymbol: scala.Any, - line: Int, - charPositionInLine: Int, - msg: String, - e: RecognitionException): Unit = { - val position = Origin(Some(line), Some(charPositionInLine)) - throw new ParseException(None, msg, position, position) - } -} - -/** - * A [[ParseException]] is an [[AnalysisException]] that is thrown during the parse process. It - * contains fields and an extended error message that make reporting and diagnosing errors easier. - */ -class ParseException( - val command: Option[String], - message: String, - val start: Origin, - val stop: Origin) extends AnalysisException(message, start.line, start.startPosition) { - - def this(message: String, ctx: ParserRuleContext) = { - this(Option(ParserUtils.command(ctx)), - message, - ParserUtils.position(ctx.getStart), - ParserUtils.position(ctx.getStop)) - } - - override def getMessage: String = { - val builder = new StringBuilder - builder ++= "\n" ++= message - start match { - case Origin(Some(l), Some(p)) => - builder ++= s"(line $l, pos $p)\n" - command.foreach { cmd => - val (above, below) = cmd.split("\n").splitAt(l) - builder ++= "\n== SQL ==\n" - above.foreach(builder ++= _ += '\n') - builder ++= (0 until p).map(_ => "-").mkString("") ++= "^^^\n" - below.foreach(builder ++= _ += '\n') - } - case _ => - command.foreach { cmd => - builder ++= "\n== SQL ==\n" ++= cmd - } - } - builder.toString - } - - def withCommand(cmd: String): ParseException = { - new ParseException(Option(cmd), message, start, stop) - } -} - -/** - * The post-processor validates & cleans-up the parse tree during the parse process. - */ -case object PostProcessor extends SqlBaseBaseListener { - - /** Remove the back ticks from an Identifier. */ - override def exitQuotedIdentifier(ctx: QuotedIdentifierContext): Unit = { - replaceTokenByIdentifier(ctx, 1) { token => - // Remove the double back ticks in the string. - token.setText(token.getText.replace("``", "`")) - token - } - } - - /** Treat non-reserved keywords as Identifiers. */ - override def exitNonReserved(ctx: NonReservedContext): Unit = { - replaceTokenByIdentifier(ctx, 0)(identity) - } - - private def replaceTokenByIdentifier( - ctx: ParserRuleContext, - stripMargins: Int)( - f: CommonToken => CommonToken = identity): Unit = { - val parent = ctx.getParent - parent.removeLastChild() - val token = ctx.getChild(0).getPayload.asInstanceOf[Token] - parent.addChild(f(new CommonToken( - new org.antlr.v4.runtime.misc.Pair(token.getTokenSource, token.getInputStream), - SqlBaseParser.IDENTIFIER, - token.getChannel, - token.getStartIndex + stripMargins, - token.getStopIndex - stripMargins))) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala deleted file mode 100644 index 1fbfa763b4..0000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ng/ParserUtils.scala +++ /dev/null @@ -1,118 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} -import org.antlr.v4.runtime.misc.Interval -import org.antlr.v4.runtime.tree.TerminalNode - -import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} - -/** - * A collection of utility methods for use during the parsing process. - */ -object ParserUtils { - /** Get the command which created the token. */ - def command(ctx: ParserRuleContext): String = { - command(ctx.getStart.getInputStream) - } - - /** Get the command which created the token. */ - def command(stream: CharStream): String = { - stream.getText(Interval.of(0, stream.size())) - } - - /** Get the code that creates the given node. */ - def source(ctx: ParserRuleContext): String = { - val stream = ctx.getStart.getInputStream - stream.getText(Interval.of(ctx.getStart.getStartIndex, ctx.getStop.getStopIndex)) - } - - /** Get all the text which comes after the given rule. */ - def remainder(ctx: ParserRuleContext): String = remainder(ctx.getStop) - - /** Get all the text which comes after the given token. */ - def remainder(token: Token): String = { - val stream = token.getInputStream - val interval = Interval.of(token.getStopIndex + 1, stream.size()) - stream.getText(interval) - } - - /** Convert a string token into a string. */ - def string(token: Token): String = unescapeSQLString(token.getText) - - /** Convert a string node into a string. */ - def string(node: TerminalNode): String = unescapeSQLString(node.getText) - - /** Get the origin (line and position) of the token. */ - def position(token: Token): Origin = { - Origin(Option(token.getLine), Option(token.getCharPositionInLine)) - } - - /** Assert if a condition holds. If it doesn't throw a parse exception. */ - def assert(f: => Boolean, message: String, ctx: ParserRuleContext): Unit = { - if (!f) { - throw new ParseException(message, ctx) - } - } - - /** - * Register the origin of the context. Any TreeNode created in the closure will be assigned the - * registered origin. This method restores the previously set origin after completion of the - * closure. - */ - def withOrigin[T](ctx: ParserRuleContext)(f: => T): T = { - val current = CurrentOrigin.get - CurrentOrigin.set(position(ctx.getStart)) - try { - f - } finally { - CurrentOrigin.set(current) - } - } - - /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ - implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { - /** - * Create a plan using the block of code when the given context exists. Otherwise return the - * original plan. - */ - def optional(ctx: AnyRef)(f: => LogicalPlan): LogicalPlan = { - if (ctx != null) { - f - } else { - plan - } - } - - /** - * Map a [[LogicalPlan]] to another [[LogicalPlan]] if the passed context exists using the - * passed function. The original plan is returned when the context does not exist. - */ - def optionalMap[C <: ParserRuleContext]( - ctx: C)( - f: (C, LogicalPlan) => LogicalPlan): LogicalPlan = { - if (ctx != null) { - f(ctx, plan) - } else { - plan - } - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala deleted file mode 100644 index 8b05f9e33d..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ASTNodeSuite.scala +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser - -import org.apache.spark.SparkFunSuite - -class ASTNodeSuite extends SparkFunSuite { - test("SPARK-13157 - remainder must return all input chars") { - val inputs = Seq( - ("add jar", "file:///tmp/ab/TestUDTF.jar"), - ("add jar", "file:///tmp/a@b/TestUDTF.jar"), - ("add jar", "c:\\windows32\\TestUDTF.jar"), - ("add jar", "some \nbad\t\tfile\r\n.\njar"), - ("ADD JAR", "@*#&@(!#@$^*!@^@#(*!@#"), - ("SET", "foo=bar"), - ("SET", "foo*)(@#^*@&!#^=bar") - ) - inputs.foreach { - case (command, arguments) => - val node = ParseDriver.parsePlan(s"$command $arguments", null) - assert(node.remainder === arguments) - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala deleted file mode 100644 index 223485e292..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/CatalystQlSuite.scala +++ /dev/null @@ -1,223 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser - -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.ng.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.unsafe.types.CalendarInterval - -class CatalystQlSuite extends PlanTest { - val parser = new CatalystQl() - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - val star = UnresolvedAlias(UnresolvedStar(None)) - - test("test case insensitive") { - val result = OneRowRelation.select(1) - assert(result === parser.parsePlan("seLect 1")) - assert(result === parser.parsePlan("select 1")) - assert(result === parser.parsePlan("SELECT 1")) - } - - test("test NOT operator with comparison operations") { - val parsed = parser.parsePlan("SELECT NOT TRUE > TRUE") - val expected = OneRowRelation.select(Not(GreaterThan(true, true))) - comparePlans(parsed, expected) - } - - test("test Union Distinct operator") { - val parsed1 = parser.parsePlan( - "SELECT * FROM t0 UNION SELECT * FROM t1") - val parsed2 = parser.parsePlan( - "SELECT * FROM t0 UNION DISTINCT SELECT * FROM t1") - val expected = Distinct(Union(table("t0").select(star), table("t1").select(star))) - .as("u_1").select(star) - comparePlans(parsed1, expected) - comparePlans(parsed2, expected) - } - - test("test Union All operator") { - val parsed = parser.parsePlan("SELECT * FROM t0 UNION ALL SELECT * FROM t1") - val expected = Union(table("t0").select(star), table("t1").select(star)).as("u_1").select(star) - comparePlans(parsed, expected) - } - - test("support hive interval literal") { - def checkInterval(sql: String, result: CalendarInterval): Unit = { - val parsed = parser.parsePlan(sql) - val expected = OneRowRelation.select(Literal(result)) - comparePlans(parsed, expected) - } - - def checkYearMonth(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' YEAR TO MONTH", - CalendarInterval.fromYearMonthString(lit)) - } - - def checkDayTime(lit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' DAY TO SECOND", - CalendarInterval.fromDayTimeString(lit)) - } - - def checkSingleUnit(lit: String, unit: String): Unit = { - checkInterval( - s"SELECT INTERVAL '$lit' $unit", - CalendarInterval.fromSingleUnitString(unit, lit)) - } - - checkYearMonth("123-10") - checkYearMonth("496-0") - checkYearMonth("-2-3") - checkYearMonth("-123-0") - - checkDayTime("99 11:22:33.123456789") - checkDayTime("-99 11:22:33.123456789") - checkDayTime("10 9:8:7.123456789") - checkDayTime("1 0:0:0") - checkDayTime("-1 0:0:0") - checkDayTime("1 0:0:1") - - for (unit <- Seq("year", "month", "day", "hour", "minute", "second")) { - checkSingleUnit("7", unit) - checkSingleUnit("-7", unit) - checkSingleUnit("0", unit) - } - - checkSingleUnit("13.123456789", "second") - checkSingleUnit("-13.123456789", "second") - } - - test("support scientific notation") { - def assertRight(input: String, output: Double): Unit = { - val parsed = parser.parsePlan("SELECT " + input) - val expected = OneRowRelation.select(Literal(output)) - comparePlans(parsed, expected) - } - - assertRight("9.0e1", 90) - assertRight(".9e+2", 90) - assertRight("0.9e+2", 90) - assertRight("900e-1", 90) - assertRight("900.0E-1", 90) - assertRight("9.e+1", 90) - - intercept[AnalysisException](parser.parsePlan("SELECT .e3")) - } - - test("parse expressions") { - compareExpressions( - parser.parseExpression("prinln('hello', 'world')"), - UnresolvedFunction( - "prinln", Literal("hello") :: Literal("world") :: Nil, false)) - - compareExpressions( - parser.parseExpression("1 + r.r As q"), - Alias(Add(Literal(1), UnresolvedAttribute("r.r")), "q")()) - - compareExpressions( - parser.parseExpression("1 - f('o', o(bar))"), - Subtract(Literal(1), - UnresolvedFunction("f", - Literal("o") :: - UnresolvedFunction("o", UnresolvedAttribute("bar") :: Nil, false) :: - Nil, false))) - - intercept[AnalysisException](parser.parseExpression("1 - f('o', o(bar)) hello * world")) - } - - test("table identifier") { - assert(TableIdentifier("q") === parser.parseTableIdentifier("q")) - assert(TableIdentifier("q", Some("d")) === parser.parseTableIdentifier("d.q")) - intercept[AnalysisException](parser.parseTableIdentifier("")) - intercept[AnalysisException](parser.parseTableIdentifier("d.q.g")) - } - - test("parse union/except/intersect") { - parser.parsePlan("select * from t1 union all select * from t2") - parser.parsePlan("select * from t1 union distinct select * from t2") - parser.parsePlan("select * from t1 union select * from t2") - parser.parsePlan("select * from t1 except select * from t2") - parser.parsePlan("select * from t1 intersect select * from t2") - parser.parsePlan("(select * from t1) union all (select * from t2)") - parser.parsePlan("(select * from t1) union distinct (select * from t2)") - parser.parsePlan("(select * from t1) union (select * from t2)") - parser.parsePlan("select * from ((select * from t1) union (select * from t2)) t") - } - - test("window function: better support of parentheses") { - parser.parsePlan("select sum(product + 1) over (partition by ((1) + (product / 2)) " + - "order by 2) from windowData") - parser.parsePlan("select sum(product + 1) over (partition by (1 + (product / 2)) " + - "order by 2) from windowData") - parser.parsePlan("select sum(product + 1) over (partition by ((product / 2) + 1) " + - "order by 2) from windowData") - - parser.parsePlan("select sum(product + 1) over (partition by ((product) + (1)) order by 2) " + - "from windowData") - parser.parsePlan("select sum(product + 1) over (partition by ((product) + 1) order by 2) " + - "from windowData") - parser.parsePlan("select sum(product + 1) over (partition by (product + (1)) order by 2) " + - "from windowData") - } - - test("very long AND/OR expression") { - val equals = (1 to 1000).map(x => s"$x == $x") - val expr = parser.parseExpression(equals.mkString(" AND ")) - assert(expr.isInstanceOf[And]) - assert(expr.collect( { case EqualTo(_, _) => true } ).size == 1000) - - val expr2 = parser.parseExpression(equals.mkString(" OR ")) - assert(expr2.isInstanceOf[Or]) - assert(expr2.collect( { case EqualTo(_, _) => true } ).size == 1000) - } - - test("subquery") { - parser.parsePlan("select (select max(b) from s) ss from t") - parser.parsePlan("select * from t where a = (select b from s)") - parser.parsePlan("select * from t group by g having a > (select b from s)") - } - - test("using clause in JOIN") { - // Tests parsing of using clause for different join types. - parser.parsePlan("select * from t1 join t2 using (c1)") - parser.parsePlan("select * from t1 join t2 using (c1, c2)") - parser.parsePlan("select * from t1 left join t2 using (c1, c2)") - parser.parsePlan("select * from t1 right join t2 using (c1, c2)") - parser.parsePlan("select * from t1 full outer join t2 using (c1, c2)") - parser.parsePlan("select * from t1 join t2 using (c1) join t3 using (c2)") - // Tests errors - // (1) Empty using clause - // (2) Qualified columns in using - // (3) Both on and using clause - var error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using ()")) - assert(error.message.contains("cannot recognize input near ')'")) - error = intercept[AnalysisException](parser.parsePlan("select * from t1 join t2 using (t1.c1)")) - assert(error.message.contains("mismatched input '.'")) - error = intercept[AnalysisException](parser.parsePlan("select * from t1" + - " join t2 using (c1) on t1.c1 = t2.c1")) - assert(error.message.contains("missing EOF at 'on' near ')'")) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala index d9bd33c50a..07b89cb61f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DataTypeParserSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.parser.ng.{CatalystSqlParser, ParseException} import org.apache.spark.sql.types._ abstract class AbstractDataTypeParserSuite extends SparkFunSuite { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala new file mode 100644 index 0000000000..db96bfb652 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ErrorParserSuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +/** + * Test various parser errors. + */ +class ErrorParserSuite extends SparkFunSuite { + def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { + val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) + + // Check position. + assert(e.line.isDefined) + assert(e.line.get === line) + assert(e.startPosition.isDefined) + assert(e.startPosition.get === startPosition) + + // Check messages. + val error = e.getMessage + messages.foreach { message => + assert(error.contains(message)) + } + } + + test("no viable input") { + intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") + intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") + intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") + } + + test("extraneous input") { + intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") + intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") + } + + test("mismatched input") { + intercept("select * from r order by q from t", 1, 27, + "mismatched input", + "---------------------------^^^") + intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") + } + + test("semantic errors") { + intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, + "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", + "^^^") + intercept("select * from r where a in (select * from t)", 1, 24, + "IN with a Sub-query is currently not supported", + "------------------------^^^") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala new file mode 100644 index 0000000000..a80d29ce5d --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -0,0 +1,497 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import java.sql.{Date, Timestamp} + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +/** + * Test basic expression parsing. If a type of expression is supported it should be tested here. + * + * Please note that some of the expressions test don't have to be sound expressions, only their + * structure needs to be valid. Unsound expressions should be caught by the Analyzer or + * CheckAnalysis classes. + */ +class ExpressionParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, e: Expression): Unit = { + compareExpressions(parseExpression(sqlCommand), e) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parseExpression(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("star expressions") { + // Global Star + assertEqual("*", UnresolvedStar(None)) + + // Targeted Star + assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) + } + + // NamedExpression (Alias/Multialias) + test("named expressions") { + // No Alias + val r0 = 'a + assertEqual("a", r0) + + // Single Alias. + val r1 = 'a as "b" + assertEqual("a as b", r1) + assertEqual("a b", r1) + + // Multi-Alias + assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) + assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) + + // Numeric literals without a space between the literal qualifier and the alias, should not be + // interpreted as such. An unresolved reference should be returned instead. + // TODO add the JIRA-ticket number. + assertEqual("1SL", Symbol("1SL")) + + // Aliased star is allowed. + assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) + } + + test("binary logical expressions") { + // And + assertEqual("a and b", 'a && 'b) + + // Or + assertEqual("a or b", 'a || 'b) + + // Combination And/Or check precedence + assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) + assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) + + // Multiple AND/OR get converted into a balanced tree + assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) + assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) + } + + test("long binary logical expressions") { + def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { + val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) + val e = parseExpression(sql) + assert(e.collect { case _: EqualTo => true }.size === 1000) + assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) + } + testVeryBinaryExpression(" AND ", classOf[And]) + testVeryBinaryExpression(" OR ", classOf[Or]) + } + + test("not expressions") { + assertEqual("not a", !'a) + assertEqual("!a", !'a) + assertEqual("not true > true", Not(GreaterThan(true, true))) + } + + test("exists expression") { + intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") + } + + test("comparison expressions") { + assertEqual("a = b", 'a === 'b) + assertEqual("a == b", 'a === 'b) + assertEqual("a <=> b", 'a <=> 'b) + assertEqual("a <> b", 'a =!= 'b) + assertEqual("a != b", 'a =!= 'b) + assertEqual("a < b", 'a < 'b) + assertEqual("a <= b", 'a <= 'b) + assertEqual("a > b", 'a > 'b) + assertEqual("a >= b", 'a >= 'b) + } + + test("between expressions") { + assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) + assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) + } + + test("in expressions") { + assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) + assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) + } + + test("in sub-query") { + intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") + } + + test("like expressions") { + assertEqual("a like 'pattern%'", 'a like "pattern%") + assertEqual("a not like 'pattern%'", !('a like "pattern%")) + assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") + assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) + assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") + assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) + } + + test("is null expressions") { + assertEqual("a is null", 'a.isNull) + assertEqual("a is not null", 'a.isNotNull) + assertEqual("a = b is null", ('a === 'b).isNull) + assertEqual("a = b is not null", ('a === 'b).isNotNull) + } + + test("binary arithmetic expressions") { + // Simple operations + assertEqual("a * b", 'a * 'b) + assertEqual("a / b", 'a / 'b) + assertEqual("a DIV b", ('a / 'b).cast(LongType)) + assertEqual("a % b", 'a % 'b) + assertEqual("a + b", 'a + 'b) + assertEqual("a - b", 'a - 'b) + assertEqual("a & b", 'a & 'b) + assertEqual("a ^ b", 'a ^ 'b) + assertEqual("a | b", 'a | 'b) + + // Check precedences + assertEqual( + "a * t | b ^ c & d - e + f % g DIV h / i * k", + 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) + } + + test("unary arithmetic expressions") { + assertEqual("+a", 'a) + assertEqual("-a", -'a) + assertEqual("~a", ~'a) + assertEqual("-+~~a", -(~(~'a))) + } + + test("cast expressions") { + // Note that DataType parsing is tested elsewhere. + assertEqual("cast(a as int)", 'a.cast(IntegerType)) + assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) + assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) + assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) + } + + test("function expressions") { + assertEqual("foo()", 'foo.function()) + assertEqual("foo.bar()", Symbol("foo.bar").function()) + assertEqual("foo(*)", 'foo.function(star())) + assertEqual("count(*)", 'count.function(1)) + assertEqual("foo(a, b)", 'foo.function('a, 'b)) + assertEqual("foo(all a, b)", 'foo.function('a, 'b)) + assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) + assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) + assertEqual("`select`(all a, b)", 'select.function('a, 'b)) + } + + test("window function expressions") { + val func = 'foo.function(star()) + def windowed( + partitioning: Seq[Expression] = Seq.empty, + ordering: Seq[SortOrder] = Seq.empty, + frame: WindowFrame = UnspecifiedFrame): Expression = { + WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) + } + + // Basic window testing. + assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) + assertEqual("foo(*) over ()", windowed()) + assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) + assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) + assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) + assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) + + // Test use of expressions in window functions. + assertEqual( + "sum(product + 1) over (partition by ((product) + (1)) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + assertEqual( + "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", + WindowExpression('sum.function('product + 1), + WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) + + // Range/Row + val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) + val boundaries = Seq( + ("10 preceding", ValuePreceding(10), CurrentRow), + ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis + ("unbounded preceding", UnboundedPreceding, CurrentRow), + ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis + ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), + ("between unbounded preceding and unbounded following", + UnboundedPreceding, UnboundedFollowing), + ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), + ("between current row and 5 following", CurrentRow, ValueFollowing(5)), + ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) + ) + frameTypes.foreach { + case (frameTypeSql, frameType) => + boundaries.foreach { + case (boundarySql, begin, end) => + val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" + val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) + assertEqual(query, expr) + } + } + + // We cannot use non integer constants. + intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", + "Frame bound value must be a constant integer.") + + // We cannot use an arbitrary expression. + intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", + "Frame bound value must be a constant integer.") + } + + test("row constructor") { + // Note that '(a)' will be interpreted as a nested expression. + assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) + assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) + } + + test("scalar sub-query") { + assertEqual( + "(select max(val) from tbl) > current", + ScalarSubquery(table("tbl").select('max.function('val))) > 'current) + assertEqual( + "a = (select b from s)", + 'a === ScalarSubquery(table("s").select('b))) + } + + test("case when") { + assertEqual("case a when 1 then b when 2 then c else d end", + CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) + assertEqual("case when a = 1 then b when a = 2 then c else d end", + CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) + } + + test("dereference") { + assertEqual("a.b", UnresolvedAttribute("a.b")) + assertEqual("`select`.b", UnresolvedAttribute("select.b")) + assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. + assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) + } + + test("reference") { + // Regular + assertEqual("a", 'a) + + // Starting with a digit. + assertEqual("1a", Symbol("1a")) + + // Quoted using a keyword. + assertEqual("`select`", 'select) + + // Unquoted using an unreserved keyword. + assertEqual("columns", 'columns) + } + + test("subscript") { + assertEqual("a[b]", 'a.getItem('b)) + assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) + assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) + } + + test("parenthesis") { + assertEqual("(a)", 'a) + assertEqual("r * (a + b)", 'r * ('a + 'b)) + } + + test("type constructors") { + // Dates. + assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) + intercept[IllegalArgumentException] { + parseExpression("DAtE 'mar 11 2016'") + } + + // Timestamps. + assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", + Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) + intercept[IllegalArgumentException] { + parseExpression("timestamP '2016-33-11 20:54:00.000'") + } + + // Unsupported datatype. + intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") + } + + test("literals") { + // NULL + assertEqual("null", Literal(null)) + + // Boolean + assertEqual("trUe", Literal(true)) + assertEqual("False", Literal(false)) + + // Integral should have the narrowest possible type + assertEqual("787324", Literal(787324)) + assertEqual("7873247234798249234", Literal(7873247234798249234L)) + assertEqual("78732472347982492793712334", + Literal(BigDecimal("78732472347982492793712334").underlying())) + + // Decimal + assertEqual("7873247234798249279371.2334", + Literal(BigDecimal("7873247234798249279371.2334").underlying())) + + // Scientific Decimal + assertEqual("9.0e1", 90d) + assertEqual(".9e+2", 90d) + assertEqual("0.9e+2", 90d) + assertEqual("900e-1", 90d) + assertEqual("900.0E-1", 90d) + assertEqual("9.e+1", 90d) + intercept(".e3") + + // Tiny Int Literal + assertEqual("10Y", Literal(10.toByte)) + intercept("-1000Y") + + // Small Int Literal + assertEqual("10S", Literal(10.toShort)) + intercept("40000S") + + // Long Int Literal + assertEqual("10L", Literal(10L)) + intercept("78732472347982492793712334L") + + // Double Literal + assertEqual("10.0D", Literal(10.0D)) + // TODO we need to figure out if we should throw an exception here! + assertEqual("1E309", Literal(Double.PositiveInfinity)) + } + + test("strings") { + // Single Strings. + assertEqual("\"hello\"", "hello") + assertEqual("'hello'", "hello") + + // Multi-Strings. + assertEqual("\"hello\" 'world'", "helloworld") + assertEqual("'hello' \" \" 'world'", "hello world") + + // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a + // regular '%'; to get the correct result you need to add another escaped '\'. + // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? + assertEqual("'pattern%'", "pattern%") + assertEqual("'no-pattern\\%'", "no-pattern\\%") + assertEqual("'pattern\\\\%'", "pattern\\%") + assertEqual("'pattern\\\\\\%'", "pattern\\\\%") + + // Escaped characters. + // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html + assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') + assertEqual("'\\''", "\'") // Single quote + assertEqual("'\\\"'", "\"") // Double quote + assertEqual("'\\b'", "\b") // Backspace + assertEqual("'\\n'", "\n") // Newline + assertEqual("'\\r'", "\r") // Carriage return + assertEqual("'\\t'", "\t") // Tab character + assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) + + // Octals + assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") + + // Unicode + assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + } + + test("intervals") { + def intervalLiteral(u: String, s: String): Literal = { + Literal(CalendarInterval.fromSingleUnitString(u, s)) + } + + // Empty interval statement + intercept("interval", "at least one time unit should be given for interval literal") + + // Single Intervals. + val units = Seq( + "year", + "month", + "week", + "day", + "hour", + "minute", + "second", + "millisecond", + "microsecond") + val forms = Seq("", "s") + val values = Seq("0", "10", "-7", "21") + units.foreach { unit => + forms.foreach { form => + values.foreach { value => + val expected = intervalLiteral(unit, value) + assertEqual(s"interval $value $unit$form", expected) + assertEqual(s"interval '$value' $unit$form", expected) + } + } + } + + // Hive nanosecond notation. + assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) + assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) + + // Non Existing unit + intercept("interval 10 nanoseconds", "No interval can be constructed") + + // Year-Month intervals. + val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") + yearMonthValues.foreach { value => + val result = Literal(CalendarInterval.fromYearMonthString(value)) + assertEqual(s"interval '$value' year to month", result) + } + + // Day-Time intervals. + val datTimeValues = Seq( + "99 11:22:33.123456789", + "-99 11:22:33.123456789", + "10 9:8:7.123456789", + "1 0:0:0", + "-1 0:0:0", + "1 0:0:1") + datTimeValues.foreach { value => + val result = Literal(CalendarInterval.fromDayTimeString(value)) + assertEqual(s"interval '$value' day to second", result) + } + + // Unknown FROM TO intervals + intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") + + // Composed intervals. + assertEqual( + "interval 3 months 22 seconds 1 millisecond", + Literal(new CalendarInterval(3, 22001000L))) + assertEqual( + "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", + Literal(new CalendarInterval(14, + 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) + } + + test("composed expressions") { + assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) + assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) + intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala new file mode 100644 index 0000000000..23f05ce846 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -0,0 +1,429 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.{BooleanType, IntegerType} + +class PlanParserSuite extends PlanTest { + import CatalystSqlParser._ + import org.apache.spark.sql.catalyst.dsl.expressions._ + import org.apache.spark.sql.catalyst.dsl.plans._ + + def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { + comparePlans(parsePlan(sqlCommand), plan) + } + + def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("case insensitive") { + val plan = table("a").select(star()) + assertEqual("sELEct * FroM a", plan) + assertEqual("select * fRoM a", plan) + assertEqual("SELECT * FROM a", plan) + } + + test("show functions") { + assertEqual("show functions", ShowFunctions(None, None)) + assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) + assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) + assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) + intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") + } + + test("describe function") { + assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) + assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) + assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) + assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) + } + + test("set operations") { + val a = table("a").select(star()) + val b = table("b").select(star()) + + assertEqual("select * from a union select * from b", Distinct(a.union(b))) + assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) + assertEqual("select * from a union all select * from b", a.union(b)) + assertEqual("select * from a except select * from b", a.except(b)) + intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") + assertEqual("select * from a except distinct select * from b", a.except(b)) + assertEqual("select * from a intersect select * from b", a.intersect(b)) + intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") + assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) + } + + test("common table expressions") { + def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { + val ctes = namedPlans.map { + case (name, cte) => + name -> SubqueryAlias(name, cte) + }.toMap + With(plan, ctes) + } + assertEqual( + "with cte1 as (select * from a) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) + assertEqual( + "with cte1 (select 1) select * from cte1", + cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) + assertEqual( + "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", + cte(table("cte2").select(star()), + "cte1" -> OneRowRelation.select(1), + "cte2" -> table("cte1").select(star()))) + intercept( + "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", + "Name 'cte1' is used for multiple common table expressions") + } + + test("simple select query") { + assertEqual("select 1", OneRowRelation.select(1)) + assertEqual("select a, b", OneRowRelation.select('a, 'b)) + assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) + assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) + assertEqual( + "select a, b from db.c having x < 1", + table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) + assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) + assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) + } + + test("reverse select query") { + assertEqual("from a", table("a")) + assertEqual("from a select b, c", table("a").select('b, 'c)) + assertEqual( + "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) + assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) + assertEqual( + "from (from a union all from b) c select *", + table("a").union(table("b")).as("c").select(star())) + } + + test("transform query spec") { + val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) + assertEqual("select transform(a, b) using 'func' from e where f < 10", + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + assertEqual("map a, b using 'func' as c, d from e", + p.copy(output = Seq('c.string, 'd.string))) + assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + + test("multi select query") { + assertEqual( + "from a select * select * where s < 10", + table("a").select(star()).union(table("a").where('s < 10).select(star()))) + intercept( + "from a select * select * from x where a.s < 10", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + assertEqual( + "from a insert into tbl1 select * insert into tbl2 select * where s < 10", + table("a").select(star()).insertInto("tbl1").union( + table("a").where('s < 10).select(star()).insertInto("tbl2"))) + } + + test("query organization") { + // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows + val baseSql = "select * from t" + val basePlan = table("t").select(star()) + + val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) + val limitWindowClauses = Seq( + ("", (p: LogicalPlan) => p), + (" limit 10", (p: LogicalPlan) => p.limit(10)), + (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), + (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) + ) + + val orderSortDistrClusterClauses = Seq( + ("", basePlan), + (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), + (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), + (" distribute by a, b", basePlan.distribute('a, 'b)), + (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), + (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) + ) + + orderSortDistrClusterClauses.foreach { + case (s1, p1) => + limitWindowClauses.foreach { + case (s2, pf2) => + assertEqual(baseSql + s1 + s2, pf2(p1)) + } + } + + val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" + intercept(s"$baseSql order by a sort by a", msg) + intercept(s"$baseSql cluster by a distribute by a", msg) + intercept(s"$baseSql order by a cluster by a", msg) + intercept(s"$baseSql order by a distribute by a", msg) + } + + test("insert into") { + val sql = "select * from t" + val plan = table("t").select(star()) + def insert( + partition: Map[String, Option[String]], + overwrite: Boolean = false, + ifNotExists: Boolean = false): LogicalPlan = + InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) + + // Single inserts + assertEqual(s"insert overwrite table s $sql", + insert(Map.empty, overwrite = true)) + assertEqual(s"insert overwrite table s if not exists $sql", + insert(Map.empty, overwrite = true, ifNotExists = true)) + assertEqual(s"insert into s $sql", + insert(Map.empty)) + assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", + insert(Map("c" -> Option("d"), "e" -> Option("1")))) + assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", + insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) + + // Multi insert + val plan2 = table("t").where('x > 5).select(star()) + assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", + InsertIntoTable( + table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( + InsertIntoTable( + table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) + } + + test("aggregation") { + val sql = "select a, b, sum(c) as c from d group by a, b" + + // Normal + assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) + + // Cube + assertEqual(s"$sql with cube", + table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Rollup + assertEqual(s"$sql with rollup", + table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) + + // Grouping Sets + assertEqual(s"$sql grouping sets((a, b), (a), ())", + GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) + intercept(s"$sql grouping sets((a, b), (c), ())", + "c doesn't show up in the GROUP BY list") + } + + test("limit") { + val sql = "select * from t" + val plan = table("t").select(star()) + assertEqual(s"$sql limit 10", plan.limit(10)) + assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) + } + + test("window spec") { + // Note that WindowSpecs are testing in the ExpressionParserSuite + val sql = "select * from t" + val plan = table("t").select(star()) + val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), + SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) + + // Test window resolution. + val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) + assertEqual( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w1""".stripMargin, + WithWindowDefinition(ws1, plan)) + + // Fail with no reference. + intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") + + // Fail when resolved reference is not a window spec. + intercept( + s"""$sql + |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), + | w2 as w1, + | w3 as w2""".stripMargin, + "Window reference 'w2' is not a window specification" + ) + } + + test("lateral view") { + // Single lateral view + assertEqual( + "select * from t lateral view explode(x) expl as x", + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + .select(star())) + + // Multiple lateral views + assertEqual( + """select * + |from t + |lateral view explode(x) expl + |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, + table("t") + .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) + .select(star())) + + // Multi-Insert lateral views. + val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) + assertEqual( + """from t1 + |lateral view explode(x) expl as x + |insert into t2 + |select * + |lateral view json_tuple(x, y) jtup q, z + |insert into t3 + |select * + |where s < 10 + """.stripMargin, + Union(from + .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) + .select(star()) + .insertInto("t2"), + from.where('s < 10).select(star()).insertInto("t3"))) + + // Unsupported generator. + intercept( + "select * from t lateral view posexplode(x) posexpl as x, y", + "Generator function 'posexplode' is not supported") + } + + test("joins") { + // Test single joins. + val testUnconditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t as tt $sql u", + table("t").as("tt").join(table("u"), jt, None).select(star())) + } + val testConditionalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u as uu on a = b", + table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) + } + val testNaturalJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t tt natural $sql u as uu", + table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) + } + val testUsingJoin = (sql: String, jt: JoinType) => { + assertEqual( + s"select * from t $sql u using(a, b)", + table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) + } + val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) + + def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { + tests.foreach(_(sql, jt)) + } + test("cross join", Inner, Seq(testUnconditionalJoin)) + test(",", Inner, Seq(testUnconditionalJoin)) + test("join", Inner, testAll) + test("inner join", Inner, testAll) + test("left join", LeftOuter, testAll) + test("left outer join", LeftOuter, testAll) + test("right join", RightOuter, testAll) + test("right outer join", RightOuter, testAll) + test("full join", FullOuter, testAll) + test("full outer join", FullOuter, testAll) + + // Test multiple consecutive joins + assertEqual( + "select * from a join b join c right join d", + table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) + } + + test("sampled relations") { + val sql = "select * from t" + assertEqual(s"$sql tablesample(100 rows)", + table("t").limit(100).select(star())) + assertEqual(s"$sql tablesample(43 percent) as x", + Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", + Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) + intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", + "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") + intercept(s"$sql tablesample(bucket 11 out of 10) as x", + s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") + } + + test("sub-query") { + val plan = table("t0").select('id) + assertEqual("select id from (t0)", plan) + assertEqual("select id from ((((((t0))))))", plan) + assertEqual( + "(select * from t1) union distinct (select * from t2)", + Distinct(table("t1").select(star()).union(table("t2").select(star())))) + assertEqual( + "select * from ((select * from t1) union (select * from t2)) t", + Distinct( + table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) + assertEqual( + """select id + |from (((select id from t0) + | union all + | (select id from t0)) + | union all + | (select id from t0)) as u_1 + """.stripMargin, + plan.union(plan).union(plan).as("u_1").select('id)) + } + + test("scalar sub-query") { + assertEqual( + "select (select max(b) from s) ss from t", + table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) + assertEqual( + "select * from t where a = (select b from s)", + table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) + assertEqual( + "select g from t group by g having a > (select b from s)", + table("t") + .groupBy('g)('g) + .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) + } + + test("table reference") { + assertEqual("table t", table("t")) + assertEqual("table d.t", table("d", "t")) + } + + test("inline table") { + assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( + Seq('col1.int), + Seq(1, 2, 3, 4).map(x => Row(x)))) + assertEqual( + "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", + LocalRelation.fromExternalRows( + Seq('a.int, 'b.string), + Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) + intercept("values (a, 'a'), (b, 'b')", + "All expressions in an inline table must be constants.") + intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", + "Number of aliases must match the number of fields in an inline table.") + intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala new file mode 100644 index 0000000000..297b1931a9 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.TableIdentifier + +class TableIdentifierParserSuite extends SparkFunSuite { + import CatalystSqlParser._ + + test("table identifier") { + // Regular names. + assert(TableIdentifier("q") === parseTableIdentifier("q")) + assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) + + // Illegal names. + intercept[ParseException](parseTableIdentifier("")) + intercept[ParseException](parseTableIdentifier("d.q.g")) + + // SQL Keywords. + val keywords = Seq("select", "from", "where", "left", "right") + keywords.foreach { keyword => + intercept[ParseException](parseTableIdentifier(keyword)) + assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) + assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala deleted file mode 100644 index 1963fc368f..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ErrorParserSuite.scala +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.SparkFunSuite - -/** - * Test various parser errors. - */ -class ErrorParserSuite extends SparkFunSuite { - def intercept(sql: String, line: Int, startPosition: Int, messages: String*): Unit = { - val e = intercept[ParseException](CatalystSqlParser.parsePlan(sql)) - - // Check position. - assert(e.line.isDefined) - assert(e.line.get === line) - assert(e.startPosition.isDefined) - assert(e.startPosition.get === startPosition) - - // Check messages. - val error = e.getMessage - messages.foreach { message => - assert(error.contains(message)) - } - } - - test("no viable input") { - intercept("select from tbl", 1, 7, "no viable alternative at input", "-------^^^") - intercept("select\nfrom tbl", 2, 0, "no viable alternative at input", "^^^") - intercept("select ((r + 1) ", 1, 16, "no viable alternative at input", "----------------^^^") - } - - test("extraneous input") { - intercept("select 1 1", 1, 9, "extraneous input '1' expecting", "---------^^^") - intercept("select *\nfrom r as q t", 2, 12, "extraneous input", "------------^^^") - } - - test("mismatched input") { - intercept("select * from r order by q from t", 1, 27, - "mismatched input", - "---------------------------^^^") - intercept("select *\nfrom r\norder by q\nfrom t", 4, 0, "mismatched input", "^^^") - } - - test("semantic errors") { - intercept("select *\nfrom r\norder by q\ncluster by q", 3, 0, - "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported", - "^^^") - intercept("select * from r where a in (select * from t)", 1, 24, - "IN with a Sub-query is currently not supported", - "------------------------^^^") - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala deleted file mode 100644 index 32311a5a66..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/ExpressionParserSuite.scala +++ /dev/null @@ -1,497 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import java.sql.{Date, Timestamp} - -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, _} -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.CalendarInterval - -/** - * Test basic expression parsing. If a type of expression is supported it should be tested here. - * - * Please note that some of the expressions test don't have to be sound expressions, only their - * structure needs to be valid. Unsound expressions should be caught by the Analyzer or - * CheckAnalysis classes. - */ -class ExpressionParserSuite extends PlanTest { - import CatalystSqlParser._ - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - def assertEqual(sqlCommand: String, e: Expression): Unit = { - compareExpressions(parseExpression(sqlCommand), e) - } - - def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parseExpression(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } - - test("star expressions") { - // Global Star - assertEqual("*", UnresolvedStar(None)) - - // Targeted Star - assertEqual("a.b.*", UnresolvedStar(Option(Seq("a", "b")))) - } - - // NamedExpression (Alias/Multialias) - test("named expressions") { - // No Alias - val r0 = 'a - assertEqual("a", r0) - - // Single Alias. - val r1 = 'a as "b" - assertEqual("a as b", r1) - assertEqual("a b", r1) - - // Multi-Alias - assertEqual("a as (b, c)", MultiAlias('a, Seq("b", "c"))) - assertEqual("a() (b, c)", MultiAlias('a.function(), Seq("b", "c"))) - - // Numeric literals without a space between the literal qualifier and the alias, should not be - // interpreted as such. An unresolved reference should be returned instead. - // TODO add the JIRA-ticket number. - assertEqual("1SL", Symbol("1SL")) - - // Aliased star is allowed. - assertEqual("a.* b", UnresolvedStar(Option(Seq("a"))) as 'b) - } - - test("binary logical expressions") { - // And - assertEqual("a and b", 'a && 'b) - - // Or - assertEqual("a or b", 'a || 'b) - - // Combination And/Or check precedence - assertEqual("a and b or c and d", ('a && 'b) || ('c && 'd)) - assertEqual("a or b or c and d", 'a || 'b || ('c && 'd)) - - // Multiple AND/OR get converted into a balanced tree - assertEqual("a or b or c or d or e or f", (('a || 'b) || 'c) || (('d || 'e) || 'f)) - assertEqual("a and b and c and d and e and f", (('a && 'b) && 'c) && (('d && 'e) && 'f)) - } - - test("long binary logical expressions") { - def testVeryBinaryExpression(op: String, clazz: Class[_]): Unit = { - val sql = (1 to 1000).map(x => s"$x == $x").mkString(op) - val e = parseExpression(sql) - assert(e.collect { case _: EqualTo => true }.size === 1000) - assert(e.collect { case x if clazz.isInstance(x) => true }.size === 999) - } - testVeryBinaryExpression(" AND ", classOf[And]) - testVeryBinaryExpression(" OR ", classOf[Or]) - } - - test("not expressions") { - assertEqual("not a", !'a) - assertEqual("!a", !'a) - assertEqual("not true > true", Not(GreaterThan(true, true))) - } - - test("exists expression") { - intercept("exists (select 1 from b where b.x = a.x)", "EXISTS clauses are not supported") - } - - test("comparison expressions") { - assertEqual("a = b", 'a === 'b) - assertEqual("a == b", 'a === 'b) - assertEqual("a <=> b", 'a <=> 'b) - assertEqual("a <> b", 'a =!= 'b) - assertEqual("a != b", 'a =!= 'b) - assertEqual("a < b", 'a < 'b) - assertEqual("a <= b", 'a <= 'b) - assertEqual("a > b", 'a > 'b) - assertEqual("a >= b", 'a >= 'b) - } - - test("between expressions") { - assertEqual("a between b and c", 'a >= 'b && 'a <= 'c) - assertEqual("a not between b and c", !('a >= 'b && 'a <= 'c)) - } - - test("in expressions") { - assertEqual("a in (b, c, d)", 'a in ('b, 'c, 'd)) - assertEqual("a not in (b, c, d)", !('a in ('b, 'c, 'd))) - } - - test("in sub-query") { - intercept("a in (select b from c)", "IN with a Sub-query is currently not supported") - } - - test("like expressions") { - assertEqual("a like 'pattern%'", 'a like "pattern%") - assertEqual("a not like 'pattern%'", !('a like "pattern%")) - assertEqual("a rlike 'pattern%'", 'a rlike "pattern%") - assertEqual("a not rlike 'pattern%'", !('a rlike "pattern%")) - assertEqual("a regexp 'pattern%'", 'a rlike "pattern%") - assertEqual("a not regexp 'pattern%'", !('a rlike "pattern%")) - } - - test("is null expressions") { - assertEqual("a is null", 'a.isNull) - assertEqual("a is not null", 'a.isNotNull) - assertEqual("a = b is null", ('a === 'b).isNull) - assertEqual("a = b is not null", ('a === 'b).isNotNull) - } - - test("binary arithmetic expressions") { - // Simple operations - assertEqual("a * b", 'a * 'b) - assertEqual("a / b", 'a / 'b) - assertEqual("a DIV b", ('a / 'b).cast(LongType)) - assertEqual("a % b", 'a % 'b) - assertEqual("a + b", 'a + 'b) - assertEqual("a - b", 'a - 'b) - assertEqual("a & b", 'a & 'b) - assertEqual("a ^ b", 'a ^ 'b) - assertEqual("a | b", 'a | 'b) - - // Check precedences - assertEqual( - "a * t | b ^ c & d - e + f % g DIV h / i * k", - 'a * 't | ('b ^ ('c & ('d - 'e + (('f % 'g / 'h).cast(LongType) / 'i * 'k))))) - } - - test("unary arithmetic expressions") { - assertEqual("+a", 'a) - assertEqual("-a", -'a) - assertEqual("~a", ~'a) - assertEqual("-+~~a", -(~(~'a))) - } - - test("cast expressions") { - // Note that DataType parsing is tested elsewhere. - assertEqual("cast(a as int)", 'a.cast(IntegerType)) - assertEqual("cast(a as timestamp)", 'a.cast(TimestampType)) - assertEqual("cast(a as array)", 'a.cast(ArrayType(IntegerType))) - assertEqual("cast(cast(a as int) as long)", 'a.cast(IntegerType).cast(LongType)) - } - - test("function expressions") { - assertEqual("foo()", 'foo.function()) - assertEqual("foo.bar()", Symbol("foo.bar").function()) - assertEqual("foo(*)", 'foo.function(star())) - assertEqual("count(*)", 'count.function(1)) - assertEqual("foo(a, b)", 'foo.function('a, 'b)) - assertEqual("foo(all a, b)", 'foo.function('a, 'b)) - assertEqual("foo(distinct a, b)", 'foo.distinctFunction('a, 'b)) - assertEqual("grouping(distinct a, b)", 'grouping.distinctFunction('a, 'b)) - assertEqual("`select`(all a, b)", 'select.function('a, 'b)) - } - - test("window function expressions") { - val func = 'foo.function(star()) - def windowed( - partitioning: Seq[Expression] = Seq.empty, - ordering: Seq[SortOrder] = Seq.empty, - frame: WindowFrame = UnspecifiedFrame): Expression = { - WindowExpression(func, WindowSpecDefinition(partitioning, ordering, frame)) - } - - // Basic window testing. - assertEqual("foo(*) over w1", UnresolvedWindowExpression(func, WindowSpecReference("w1"))) - assertEqual("foo(*) over ()", windowed()) - assertEqual("foo(*) over (partition by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (distribute by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (cluster by a, b)", windowed(Seq('a, 'b))) - assertEqual("foo(*) over (order by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (sort by a desc, b asc)", windowed(Seq.empty, Seq('a.desc, 'b.asc ))) - assertEqual("foo(*) over (partition by a, b order by c)", windowed(Seq('a, 'b), Seq('c.asc))) - assertEqual("foo(*) over (distribute by a, b sort by c)", windowed(Seq('a, 'b), Seq('c.asc))) - - // Test use of expressions in window functions. - assertEqual( - "sum(product + 1) over (partition by ((product) + (1)) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product + 1), Seq(Literal(2).asc), UnspecifiedFrame))) - assertEqual( - "sum(product + 1) over (partition by ((product / 2) + 1) order by 2)", - WindowExpression('sum.function('product + 1), - WindowSpecDefinition(Seq('product / 2 + 1), Seq(Literal(2).asc), UnspecifiedFrame))) - - // Range/Row - val frameTypes = Seq(("rows", RowFrame), ("range", RangeFrame)) - val boundaries = Seq( - ("10 preceding", ValuePreceding(10), CurrentRow), - ("3 + 1 following", ValueFollowing(4), CurrentRow), // Will fail during analysis - ("unbounded preceding", UnboundedPreceding, CurrentRow), - ("unbounded following", UnboundedFollowing, CurrentRow), // Will fail during analysis - ("between unbounded preceding and current row", UnboundedPreceding, CurrentRow), - ("between unbounded preceding and unbounded following", - UnboundedPreceding, UnboundedFollowing), - ("between 10 preceding and current row", ValuePreceding(10), CurrentRow), - ("between current row and 5 following", CurrentRow, ValueFollowing(5)), - ("between 10 preceding and 5 following", ValuePreceding(10), ValueFollowing(5)) - ) - frameTypes.foreach { - case (frameTypeSql, frameType) => - boundaries.foreach { - case (boundarySql, begin, end) => - val query = s"foo(*) over (partition by a order by b $frameTypeSql $boundarySql)" - val expr = windowed(Seq('a), Seq('b.asc), SpecifiedWindowFrame(frameType, begin, end)) - assertEqual(query, expr) - } - } - - // We cannot use non integer constants. - intercept("foo(*) over (partition by a order by b rows 10.0 preceding)", - "Frame bound value must be a constant integer.") - - // We cannot use an arbitrary expression. - intercept("foo(*) over (partition by a order by b rows exp(b) preceding)", - "Frame bound value must be a constant integer.") - } - - test("row constructor") { - // Note that '(a)' will be interpreted as a nested expression. - assertEqual("(a, b)", CreateStruct(Seq('a, 'b))) - assertEqual("(a, b, c)", CreateStruct(Seq('a, 'b, 'c))) - } - - test("scalar sub-query") { - assertEqual( - "(select max(val) from tbl) > current", - ScalarSubquery(table("tbl").select('max.function('val))) > 'current) - assertEqual( - "a = (select b from s)", - 'a === ScalarSubquery(table("s").select('b))) - } - - test("case when") { - assertEqual("case a when 1 then b when 2 then c else d end", - CaseKeyWhen('a, Seq(1, 'b, 2, 'c, 'd))) - assertEqual("case when a = 1 then b when a = 2 then c else d end", - CaseWhen(Seq(('a === 1, 'b.expr), ('a === 2, 'c.expr)), 'd)) - } - - test("dereference") { - assertEqual("a.b", UnresolvedAttribute("a.b")) - assertEqual("`select`.b", UnresolvedAttribute("select.b")) - assertEqual("(a + b).b", ('a + 'b).getField("b")) // This will fail analysis. - assertEqual("struct(a, b).b", 'struct.function('a, 'b).getField("b")) - } - - test("reference") { - // Regular - assertEqual("a", 'a) - - // Starting with a digit. - assertEqual("1a", Symbol("1a")) - - // Quoted using a keyword. - assertEqual("`select`", 'select) - - // Unquoted using an unreserved keyword. - assertEqual("columns", 'columns) - } - - test("subscript") { - assertEqual("a[b]", 'a.getItem('b)) - assertEqual("a[1 + 1]", 'a.getItem(Literal(1) + 1)) - assertEqual("`c`.a[b]", UnresolvedAttribute("c.a").getItem('b)) - } - - test("parenthesis") { - assertEqual("(a)", 'a) - assertEqual("r * (a + b)", 'r * ('a + 'b)) - } - - test("type constructors") { - // Dates. - assertEqual("dAte '2016-03-11'", Literal(Date.valueOf("2016-03-11"))) - intercept[IllegalArgumentException] { - parseExpression("DAtE 'mar 11 2016'") - } - - // Timestamps. - assertEqual("tImEstAmp '2016-03-11 20:54:00.000'", - Literal(Timestamp.valueOf("2016-03-11 20:54:00.000"))) - intercept[IllegalArgumentException] { - parseExpression("timestamP '2016-33-11 20:54:00.000'") - } - - // Unsupported datatype. - intercept("GEO '(10,-6)'", "Literals of type 'GEO' are currently not supported.") - } - - test("literals") { - // NULL - assertEqual("null", Literal(null)) - - // Boolean - assertEqual("trUe", Literal(true)) - assertEqual("False", Literal(false)) - - // Integral should have the narrowest possible type - assertEqual("787324", Literal(787324)) - assertEqual("7873247234798249234", Literal(7873247234798249234L)) - assertEqual("78732472347982492793712334", - Literal(BigDecimal("78732472347982492793712334").underlying())) - - // Decimal - assertEqual("7873247234798249279371.2334", - Literal(BigDecimal("7873247234798249279371.2334").underlying())) - - // Scientific Decimal - assertEqual("9.0e1", 90d) - assertEqual(".9e+2", 90d) - assertEqual("0.9e+2", 90d) - assertEqual("900e-1", 90d) - assertEqual("900.0E-1", 90d) - assertEqual("9.e+1", 90d) - intercept(".e3") - - // Tiny Int Literal - assertEqual("10Y", Literal(10.toByte)) - intercept("-1000Y") - - // Small Int Literal - assertEqual("10S", Literal(10.toShort)) - intercept("40000S") - - // Long Int Literal - assertEqual("10L", Literal(10L)) - intercept("78732472347982492793712334L") - - // Double Literal - assertEqual("10.0D", Literal(10.0D)) - // TODO we need to figure out if we should throw an exception here! - assertEqual("1E309", Literal(Double.PositiveInfinity)) - } - - test("strings") { - // Single Strings. - assertEqual("\"hello\"", "hello") - assertEqual("'hello'", "hello") - - // Multi-Strings. - assertEqual("\"hello\" 'world'", "helloworld") - assertEqual("'hello' \" \" 'world'", "hello world") - - // 'LIKE' string literals. Notice that an escaped '%' is the same as an escaped '\' and a - // regular '%'; to get the correct result you need to add another escaped '\'. - // TODO figure out if we shouldn't change the ParseUtils.unescapeSQLString method? - assertEqual("'pattern%'", "pattern%") - assertEqual("'no-pattern\\%'", "no-pattern\\%") - assertEqual("'pattern\\\\%'", "pattern\\%") - assertEqual("'pattern\\\\\\%'", "pattern\\\\%") - - // Escaped characters. - // See: http://dev.mysql.com/doc/refman/5.7/en/string-literals.html - assertEqual("'\\0'", "\u0000") // ASCII NUL (X'00') - assertEqual("'\\''", "\'") // Single quote - assertEqual("'\\\"'", "\"") // Double quote - assertEqual("'\\b'", "\b") // Backspace - assertEqual("'\\n'", "\n") // Newline - assertEqual("'\\r'", "\r") // Carriage return - assertEqual("'\\t'", "\t") // Tab character - assertEqual("'\\Z'", "\u001A") // ASCII 26 - CTRL + Z (EOF on windows) - - // Octals - assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") - - // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") - } - - test("intervals") { - def intervalLiteral(u: String, s: String): Literal = { - Literal(CalendarInterval.fromSingleUnitString(u, s)) - } - - // Empty interval statement - intercept("interval", "at least one time unit should be given for interval literal") - - // Single Intervals. - val units = Seq( - "year", - "month", - "week", - "day", - "hour", - "minute", - "second", - "millisecond", - "microsecond") - val forms = Seq("", "s") - val values = Seq("0", "10", "-7", "21") - units.foreach { unit => - forms.foreach { form => - values.foreach { value => - val expected = intervalLiteral(unit, value) - assertEqual(s"interval $value $unit$form", expected) - assertEqual(s"interval '$value' $unit$form", expected) - } - } - } - - // Hive nanosecond notation. - assertEqual("interval 13.123456789 seconds", intervalLiteral("second", "13.123456789")) - assertEqual("interval -13.123456789 second", intervalLiteral("second", "-13.123456789")) - - // Non Existing unit - intercept("interval 10 nanoseconds", "No interval can be constructed") - - // Year-Month intervals. - val yearMonthValues = Seq("123-10", "496-0", "-2-3", "-123-0") - yearMonthValues.foreach { value => - val result = Literal(CalendarInterval.fromYearMonthString(value)) - assertEqual(s"interval '$value' year to month", result) - } - - // Day-Time intervals. - val datTimeValues = Seq( - "99 11:22:33.123456789", - "-99 11:22:33.123456789", - "10 9:8:7.123456789", - "1 0:0:0", - "-1 0:0:0", - "1 0:0:1") - datTimeValues.foreach { value => - val result = Literal(CalendarInterval.fromDayTimeString(value)) - assertEqual(s"interval '$value' day to second", result) - } - - // Unknown FROM TO intervals - intercept("interval 10 month to second", "Intervals FROM month TO second are not supported.") - - // Composed intervals. - assertEqual( - "interval 3 months 22 seconds 1 millisecond", - Literal(new CalendarInterval(3, 22001000L))) - assertEqual( - "interval 3 years '-1-10' year to month 3 weeks '1 0:0:2' day to second", - Literal(new CalendarInterval(14, - 22 * CalendarInterval.MICROS_PER_DAY + 2 * CalendarInterval.MICROS_PER_SECOND))) - } - - test("composed expressions") { - assertEqual("1 + r.r As q", (Literal(1) + UnresolvedAttribute("r.r")).as("q")) - assertEqual("1 - f('o', o(bar))", Literal(1) - 'f.function("o", 'o.function('bar))) - intercept("1 - f('o', o(bar)) hello * world", "mismatched input '*'") - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala deleted file mode 100644 index 4206d22ca7..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/PlanParserSuite.scala +++ /dev/null @@ -1,429 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{BooleanType, IntegerType} - -class PlanParserSuite extends PlanTest { - import CatalystSqlParser._ - import org.apache.spark.sql.catalyst.dsl.expressions._ - import org.apache.spark.sql.catalyst.dsl.plans._ - - def assertEqual(sqlCommand: String, plan: LogicalPlan): Unit = { - comparePlans(parsePlan(sqlCommand), plan) - } - - def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parsePlan(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } - - test("case insensitive") { - val plan = table("a").select(star()) - assertEqual("sELEct * FroM a", plan) - assertEqual("select * fRoM a", plan) - assertEqual("SELECT * FROM a", plan) - } - - test("show functions") { - assertEqual("show functions", ShowFunctions(None, None)) - assertEqual("show functions foo", ShowFunctions(None, Some("foo"))) - assertEqual("show functions foo.bar", ShowFunctions(Some("foo"), Some("bar"))) - assertEqual("show functions 'foo\\\\.*'", ShowFunctions(None, Some("foo\\.*"))) - intercept("show functions foo.bar.baz", "SHOW FUNCTIONS unsupported name") - } - - test("describe function") { - assertEqual("describe function bar", DescribeFunction("bar", isExtended = false)) - assertEqual("describe function extended bar", DescribeFunction("bar", isExtended = true)) - assertEqual("describe function foo.bar", DescribeFunction("foo.bar", isExtended = false)) - assertEqual("describe function extended f.bar", DescribeFunction("f.bar", isExtended = true)) - } - - test("set operations") { - val a = table("a").select(star()) - val b = table("b").select(star()) - - assertEqual("select * from a union select * from b", Distinct(a.union(b))) - assertEqual("select * from a union distinct select * from b", Distinct(a.union(b))) - assertEqual("select * from a union all select * from b", a.union(b)) - assertEqual("select * from a except select * from b", a.except(b)) - intercept("select * from a except all select * from b", "EXCEPT ALL is not supported.") - assertEqual("select * from a except distinct select * from b", a.except(b)) - assertEqual("select * from a intersect select * from b", a.intersect(b)) - intercept("select * from a intersect all select * from b", "INTERSECT ALL is not supported.") - assertEqual("select * from a intersect distinct select * from b", a.intersect(b)) - } - - test("common table expressions") { - def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { - val ctes = namedPlans.map { - case (name, cte) => - name -> SubqueryAlias(name, cte) - }.toMap - With(plan, ctes) - } - assertEqual( - "with cte1 as (select * from a) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> table("a").select(star()))) - assertEqual( - "with cte1 (select 1) select * from cte1", - cte(table("cte1").select(star()), "cte1" -> OneRowRelation.select(1))) - assertEqual( - "with cte1 (select 1), cte2 as (select * from cte1) select * from cte2", - cte(table("cte2").select(star()), - "cte1" -> OneRowRelation.select(1), - "cte2" -> table("cte1").select(star()))) - intercept( - "with cte1 (select 1), cte1 as (select 1 from cte1) select * from cte1", - "Name 'cte1' is used for multiple common table expressions") - } - - test("simple select query") { - assertEqual("select 1", OneRowRelation.select(1)) - assertEqual("select a, b", OneRowRelation.select('a, 'b)) - assertEqual("select a, b from db.c", table("db", "c").select('a, 'b)) - assertEqual("select a, b from db.c where x < 1", table("db", "c").where('x < 1).select('a, 'b)) - assertEqual( - "select a, b from db.c having x < 1", - table("db", "c").select('a, 'b).where(('x < 1).cast(BooleanType))) - assertEqual("select distinct a, b from db.c", Distinct(table("db", "c").select('a, 'b))) - assertEqual("select all a, b from db.c", table("db", "c").select('a, 'b)) - } - - test("reverse select query") { - assertEqual("from a", table("a")) - assertEqual("from a select b, c", table("a").select('b, 'c)) - assertEqual( - "from db.a select b, c where d < 1", table("db", "a").where('d < 1).select('b, 'c)) - assertEqual("from a select distinct b, c", Distinct(table("a").select('b, 'c))) - assertEqual( - "from (from a union all from b) c select *", - table("a").union(table("b")).as("c").select(star())) - } - - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - - test("multi select query") { - assertEqual( - "from a select * select * where s < 10", - table("a").select(star()).union(table("a").where('s < 10).select(star()))) - intercept( - "from a select * select * from x where a.s < 10", - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") - assertEqual( - "from a insert into tbl1 select * insert into tbl2 select * where s < 10", - table("a").select(star()).insertInto("tbl1").union( - table("a").where('s < 10).select(star()).insertInto("tbl2"))) - } - - test("query organization") { - // Test all valid combinations of order by/sort by/distribute by/cluster by/limit/windows - val baseSql = "select * from t" - val basePlan = table("t").select(star()) - - val ws = Map("w1" -> WindowSpecDefinition(Seq.empty, Seq.empty, UnspecifiedFrame)) - val limitWindowClauses = Seq( - ("", (p: LogicalPlan) => p), - (" limit 10", (p: LogicalPlan) => p.limit(10)), - (" window w1 as ()", (p: LogicalPlan) => WithWindowDefinition(ws, p)), - (" window w1 as () limit 10", (p: LogicalPlan) => WithWindowDefinition(ws, p).limit(10)) - ) - - val orderSortDistrClusterClauses = Seq( - ("", basePlan), - (" order by a, b desc", basePlan.orderBy('a.asc, 'b.desc)), - (" sort by a, b desc", basePlan.sortBy('a.asc, 'b.desc)), - (" distribute by a, b", basePlan.distribute('a, 'b)), - (" distribute by a sort by b", basePlan.distribute('a).sortBy('b.asc)), - (" cluster by a, b", basePlan.distribute('a, 'b).sortBy('a.asc, 'b.asc)) - ) - - orderSortDistrClusterClauses.foreach { - case (s1, p1) => - limitWindowClauses.foreach { - case (s2, pf2) => - assertEqual(baseSql + s1 + s2, pf2(p1)) - } - } - - val msg = "Combination of ORDER BY/SORT BY/DISTRIBUTE BY/CLUSTER BY is not supported" - intercept(s"$baseSql order by a sort by a", msg) - intercept(s"$baseSql cluster by a distribute by a", msg) - intercept(s"$baseSql order by a cluster by a", msg) - intercept(s"$baseSql order by a distribute by a", msg) - } - - test("insert into") { - val sql = "select * from t" - val plan = table("t").select(star()) - def insert( - partition: Map[String, Option[String]], - overwrite: Boolean = false, - ifNotExists: Boolean = false): LogicalPlan = - InsertIntoTable(table("s"), partition, plan, overwrite, ifNotExists) - - // Single inserts - assertEqual(s"insert overwrite table s $sql", - insert(Map.empty, overwrite = true)) - assertEqual(s"insert overwrite table s if not exists $sql", - insert(Map.empty, overwrite = true, ifNotExists = true)) - assertEqual(s"insert into s $sql", - insert(Map.empty)) - assertEqual(s"insert into table s partition (c = 'd', e = 1) $sql", - insert(Map("c" -> Option("d"), "e" -> Option("1")))) - assertEqual(s"insert overwrite table s partition (c = 'd', x) if not exists $sql", - insert(Map("c" -> Option("d"), "x" -> None), overwrite = true, ifNotExists = true)) - - // Multi insert - val plan2 = table("t").where('x > 5).select(star()) - assertEqual("from t insert into s select * limit 1 insert into u select * where x > 5", - InsertIntoTable( - table("s"), Map.empty, plan.limit(1), overwrite = false, ifNotExists = false).union( - InsertIntoTable( - table("u"), Map.empty, plan2, overwrite = false, ifNotExists = false))) - } - - test("aggregation") { - val sql = "select a, b, sum(c) as c from d group by a, b" - - // Normal - assertEqual(sql, table("d").groupBy('a, 'b)('a, 'b, 'sum.function('c).as("c"))) - - // Cube - assertEqual(s"$sql with cube", - table("d").groupBy(Cube(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) - - // Rollup - assertEqual(s"$sql with rollup", - table("d").groupBy(Rollup(Seq('a, 'b)))('a, 'b, 'sum.function('c).as("c"))) - - // Grouping Sets - assertEqual(s"$sql grouping sets((a, b), (a), ())", - GroupingSets(Seq(0, 1, 3), Seq('a, 'b), table("d"), Seq('a, 'b, 'sum.function('c).as("c")))) - intercept(s"$sql grouping sets((a, b), (c), ())", - "c doesn't show up in the GROUP BY list") - } - - test("limit") { - val sql = "select * from t" - val plan = table("t").select(star()) - assertEqual(s"$sql limit 10", plan.limit(10)) - assertEqual(s"$sql limit cast(9 / 4 as int)", plan.limit(Cast(Literal(9) / 4, IntegerType))) - } - - test("window spec") { - // Note that WindowSpecs are testing in the ExpressionParserSuite - val sql = "select * from t" - val plan = table("t").select(star()) - val spec = WindowSpecDefinition(Seq('a, 'b), Seq('c.asc), - SpecifiedWindowFrame(RowFrame, ValuePreceding(1), ValueFollowing(1))) - - // Test window resolution. - val ws1 = Map("w1" -> spec, "w2" -> spec, "w3" -> spec) - assertEqual( - s"""$sql - |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), - | w2 as w1, - | w3 as w1""".stripMargin, - WithWindowDefinition(ws1, plan)) - - // Fail with no reference. - intercept(s"$sql window w2 as w1", "Cannot resolve window reference 'w1'") - - // Fail when resolved reference is not a window spec. - intercept( - s"""$sql - |window w1 as (partition by a, b order by c rows between 1 preceding and 1 following), - | w2 as w1, - | w3 as w2""".stripMargin, - "Window reference 'w2' is not a window specification" - ) - } - - test("lateral view") { - // Single lateral view - assertEqual( - "select * from t lateral view explode(x) expl as x", - table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) - .select(star())) - - // Multiple lateral views - assertEqual( - """select * - |from t - |lateral view explode(x) expl - |lateral view outer json_tuple(x, y) jtup q, z""".stripMargin, - table("t") - .generate(Explode('x), join = true, outer = false, Some("expl"), Seq.empty) - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = true, Some("jtup"), Seq("q", "z")) - .select(star())) - - // Multi-Insert lateral views. - val from = table("t1").generate(Explode('x), join = true, outer = false, Some("expl"), Seq("x")) - assertEqual( - """from t1 - |lateral view explode(x) expl as x - |insert into t2 - |select * - |lateral view json_tuple(x, y) jtup q, z - |insert into t3 - |select * - |where s < 10 - """.stripMargin, - Union(from - .generate(JsonTuple(Seq('x, 'y)), join = true, outer = false, Some("jtup"), Seq("q", "z")) - .select(star()) - .insertInto("t2"), - from.where('s < 10).select(star()).insertInto("t3"))) - - // Unsupported generator. - intercept( - "select * from t lateral view posexplode(x) posexpl as x, y", - "Generator function 'posexplode' is not supported") - } - - test("joins") { - // Test single joins. - val testUnconditionalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t as tt $sql u", - table("t").as("tt").join(table("u"), jt, None).select(star())) - } - val testConditionalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t $sql u as uu on a = b", - table("t").join(table("u").as("uu"), jt, Option('a === 'b)).select(star())) - } - val testNaturalJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t tt natural $sql u as uu", - table("t").as("tt").join(table("u").as("uu"), NaturalJoin(jt), None).select(star())) - } - val testUsingJoin = (sql: String, jt: JoinType) => { - assertEqual( - s"select * from t $sql u using(a, b)", - table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) - } - val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) - - def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { - tests.foreach(_(sql, jt)) - } - test("cross join", Inner, Seq(testUnconditionalJoin)) - test(",", Inner, Seq(testUnconditionalJoin)) - test("join", Inner, testAll) - test("inner join", Inner, testAll) - test("left join", LeftOuter, testAll) - test("left outer join", LeftOuter, testAll) - test("right join", RightOuter, testAll) - test("right outer join", RightOuter, testAll) - test("full join", FullOuter, testAll) - test("full outer join", FullOuter, testAll) - - // Test multiple consecutive joins - assertEqual( - "select * from a join b join c right join d", - table("a").join(table("b")).join(table("c")).join(table("d"), RightOuter).select(star())) - } - - test("sampled relations") { - val sql = "select * from t" - assertEqual(s"$sql tablesample(100 rows)", - table("t").limit(100).select(star())) - assertEqual(s"$sql tablesample(43 percent) as x", - Sample(0, .43d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) - assertEqual(s"$sql tablesample(bucket 4 out of 10) as x", - Sample(0, .4d, withReplacement = false, 10L, table("t").as("x"))(true).select(star())) - intercept(s"$sql tablesample(bucket 4 out of 10 on x) as x", - "TABLESAMPLE(BUCKET x OUT OF y ON id) is not supported") - intercept(s"$sql tablesample(bucket 11 out of 10) as x", - s"Sampling fraction (${11.0/10.0}) must be on interval [0, 1]") - } - - test("sub-query") { - val plan = table("t0").select('id) - assertEqual("select id from (t0)", plan) - assertEqual("select id from ((((((t0))))))", plan) - assertEqual( - "(select * from t1) union distinct (select * from t2)", - Distinct(table("t1").select(star()).union(table("t2").select(star())))) - assertEqual( - "select * from ((select * from t1) union (select * from t2)) t", - Distinct( - table("t1").select(star()).union(table("t2").select(star()))).as("t").select(star())) - assertEqual( - """select id - |from (((select id from t0) - | union all - | (select id from t0)) - | union all - | (select id from t0)) as u_1 - """.stripMargin, - plan.union(plan).union(plan).as("u_1").select('id)) - } - - test("scalar sub-query") { - assertEqual( - "select (select max(b) from s) ss from t", - table("t").select(ScalarSubquery(table("s").select('max.function('b))).as("ss"))) - assertEqual( - "select * from t where a = (select b from s)", - table("t").where('a === ScalarSubquery(table("s").select('b))).select(star())) - assertEqual( - "select g from t group by g having a > (select b from s)", - table("t") - .groupBy('g)('g) - .where(('a > ScalarSubquery(table("s").select('b))).cast(BooleanType))) - } - - test("table reference") { - assertEqual("table t", table("t")) - assertEqual("table d.t", table("d", "t")) - } - - test("inline table") { - assertEqual("values 1, 2, 3, 4", LocalRelation.fromExternalRows( - Seq('col1.int), - Seq(1, 2, 3, 4).map(x => Row(x)))) - assertEqual( - "values (1, 'a'), (2, 'b'), (3, 'c') as tbl(a, b)", - LocalRelation.fromExternalRows( - Seq('a.int, 'b.string), - Seq((1, "a"), (2, "b"), (3, "c")).map(x => Row(x._1, x._2))).as("tbl")) - intercept("values (a, 'a'), (b, 'b')", - "All expressions in an inline table must be constants.") - intercept("values (1, 'a'), (2, 'b') as tbl(a, b, c)", - "Number of aliases must match the number of fields in an inline table.") - intercept[ArrayIndexOutOfBoundsException](parsePlan("values (1, 'a'), (2, 'b', 5Y)")) - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala deleted file mode 100644 index 0874322187..0000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ng/TableIdentifierParserSuite.scala +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.parser.ng - -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.TableIdentifier - -class TableIdentifierParserSuite extends SparkFunSuite { - import CatalystSqlParser._ - - test("table identifier") { - // Regular names. - assert(TableIdentifier("q") === parseTableIdentifier("q")) - assert(TableIdentifier("q", Option("d")) === parseTableIdentifier("d.q")) - - // Illegal names. - intercept[ParseException](parseTableIdentifier("")) - intercept[ParseException](parseTableIdentifier("d.q.g")) - - // SQL Keywords. - val keywords = Seq("select", "from", "where", "left", "right") - keywords.foreach { keyword => - intercept[ParseException](parseTableIdentifier(keyword)) - assert(TableIdentifier(keyword) === parseTableIdentifier(s"`$keyword`")) - assert(TableIdentifier(keyword, Option("db")) === parseTableIdentifier(s"db.`$keyword`")) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala deleted file mode 100644 index 6fe04757ba..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkQl.scala +++ /dev/null @@ -1,387 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution - -import org.apache.spark.sql.{AnalysisException, SaveMode} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} -import org.apache.spark.sql.execution.command._ -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.types.StructType - -private[sql] class SparkQl(conf: ParserConf = SimpleParserConf()) extends CatalystQl(conf) { - import ParserUtils._ - - /** Check if a command should not be explained. */ - protected def isNoExplainCommand(command: String): Boolean = { - "TOK_DESCTABLE" == command || "TOK_ALTERTABLE" == command - } - - /** - * For each node, extract properties in the form of a list - * ['key_part1', 'key_part2', 'key_part3', 'value'] - * into a pair (key_part1.key_part2.key_part3, value). - * - * Example format: - * - * TOK_TABLEPROPERTY - * :- 'k1' - * +- 'v1' - * TOK_TABLEPROPERTY - * :- 'k2' - * +- 'v2' - * TOK_TABLEPROPERTY - * :- 'k3' - * +- 'v3' - */ - private def extractProps( - props: Seq[ASTNode], - expectedNodeText: String): Seq[(String, String)] = { - props.map { - case Token(x, keysAndValue) if x == expectedNodeText => - val key = keysAndValue.init.map { x => unquoteString(x.text) }.mkString(".") - val value = unquoteString(keysAndValue.last.text) - (key, value) - case p => - parseFailed(s"Expected property '$expectedNodeText' in command", p) - } - } - - protected override def nodeToPlan(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_SETCONFIG", Nil) => - val keyValueSeparatorIndex = node.remainder.indexOf('=') - if (keyValueSeparatorIndex >= 0) { - val key = node.remainder.substring(0, keyValueSeparatorIndex).trim - val value = node.remainder.substring(keyValueSeparatorIndex + 1).trim - SetCommand(Some(key -> Option(value))) - } else if (node.remainder.nonEmpty) { - SetCommand(Some(node.remainder -> None)) - } else { - SetCommand(None) - } - - // Just fake explain for any of the native commands. - case Token("TOK_EXPLAIN", explainArgs) if isNoExplainCommand(explainArgs.head.text) => - ExplainCommand(OneRowRelation) - - case Token("TOK_EXPLAIN", explainArgs) if "TOK_CREATETABLE" == explainArgs.head.text => - val Some(crtTbl) :: _ :: extended :: Nil = - getClauses(Seq("TOK_CREATETABLE", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand(nodeToPlan(crtTbl), extended = extended.isDefined) - - case Token("TOK_EXPLAIN", explainArgs) => - // Ignore FORMATTED if present. - val Some(query) :: _ :: extended :: Nil = - getClauses(Seq("TOK_QUERY", "FORMATTED", "EXTENDED"), explainArgs) - ExplainCommand(nodeToPlan(query), extended = extended.isDefined) - - case Token("TOK_REFRESHTABLE", nameParts :: Nil) => - val tableIdent = extractTableIdent(nameParts) - RefreshTable(tableIdent) - - // CREATE DATABASE [IF NOT EXISTS] database_name [COMMENT database_comment] - // [LOCATION path] [WITH DBPROPERTIES (key1=val1, key2=val2, ...)]; - case Token("TOK_CREATEDATABASE", Token(dbName, Nil) :: args) => - val databaseName = cleanIdentifier(dbName) - val Seq(ifNotExists, dbLocation, databaseComment, dbprops) = getClauses(Seq( - "TOK_IFNOTEXISTS", - "TOK_DATABASELOCATION", - "TOK_DATABASECOMMENT", - "TOK_DATABASEPROPERTIES"), args) - val location = dbLocation.map { - case Token("TOK_DATABASELOCATION", Token(loc, Nil) :: Nil) => unquoteString(loc) - case _ => parseFailed("Invalid CREATE DATABASE command", node) - } - val comment = databaseComment.map { - case Token("TOK_DATABASECOMMENT", Token(com, Nil) :: Nil) => unquoteString(com) - case _ => parseFailed("Invalid CREATE DATABASE command", node) - } - val props = dbprops.toSeq.flatMap { - case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => - // Example format: - // - // TOK_DATABASEPROPERTIES - // +- TOK_DBPROPLIST - // :- TOK_TABLEPROPERTY - // : :- 'k1' - // : +- 'v1' - // :- TOK_TABLEPROPERTY - // :- 'k2' - // +- 'v2' - extractProps(propList, "TOK_TABLEPROPERTY") - case _ => parseFailed("Invalid CREATE DATABASE command", node) - }.toMap - CreateDatabase(databaseName, ifNotExists.isDefined, location, comment, props) - - // DROP DATABASE [IF EXISTS] database_name [RESTRICT|CASCADE]; - case Token("TOK_DROPDATABASE", Token(dbName, Nil) :: otherArgs) => - // Example format: - // - // TOK_DROPDATABASE - // :- database_name - // :- TOK_IFEXISTS - // +- TOK_RESTRICT/TOK_CASCADE - val databaseName = cleanIdentifier(dbName) - // The default is RESTRICT - val Seq(ifExists, _, cascade) = getClauses(Seq( - "TOK_IFEXISTS", "TOK_RESTRICT", "TOK_CASCADE"), otherArgs) - DropDatabase(databaseName, ifExists.isDefined, cascade.isDefined) - - // ALTER (DATABASE|SCHEMA) database_name SET DBPROPERTIES (property_name=property_value, ...) - case Token("TOK_ALTERDATABASE_PROPERTIES", Token(dbName, Nil) :: args) => - val databaseName = cleanIdentifier(dbName) - val dbprops = getClause("TOK_DATABASEPROPERTIES", args) - val props = dbprops match { - case Token("TOK_DATABASEPROPERTIES", Token("TOK_DBPROPLIST", propList) :: Nil) => - // Example format: - // - // TOK_DATABASEPROPERTIES - // +- TOK_DBPROPLIST - // :- TOK_TABLEPROPERTY - // : :- 'k1' - // : +- 'v1' - // :- TOK_TABLEPROPERTY - // :- 'k2' - // +- 'v2' - extractProps(propList, "TOK_TABLEPROPERTY") - case _ => parseFailed("Invalid ALTER DATABASE command", node) - } - AlterDatabaseProperties(databaseName, props.toMap) - - // DESCRIBE DATABASE [EXTENDED] db_name - case Token("TOK_DESCDATABASE", Token(dbName, Nil) :: describeArgs) => - val databaseName = cleanIdentifier(dbName) - val extended = getClauseOption("EXTENDED", describeArgs) - DescribeDatabase(databaseName, extended.isDefined) - - // CREATE [TEMPORARY] FUNCTION [db_name.]function_name AS class_name - // [USING JAR|FILE|ARCHIVE 'file_uri' [, JAR|FILE|ARCHIVE 'file_uri'] ]; - case Token("TOK_CREATEFUNCTION", args) => - // Example format: - // - // TOK_CREATEFUNCTION - // :- db_name - // :- func_name - // :- alias - // +- TOK_RESOURCE_LIST - // :- TOK_RESOURCE_URI - // : :- TOK_JAR - // : +- '/path/to/jar' - // +- TOK_RESOURCE_URI - // :- TOK_FILE - // +- 'path/to/file' - val (funcNameArgs, otherArgs) = args.partition { - case Token("TOK_RESOURCE_LIST", _) => false - case Token("TOK_TEMPORARY", _) => false - case Token(_, Nil) => true - case _ => parseFailed("Invalid CREATE FUNCTION command", node) - } - // If database name is specified, there are 3 tokens, otherwise 2. - val (dbName, funcName, alias) = funcNameArgs match { - case Token(dbName, Nil) :: Token(fname, Nil) :: Token(aname, Nil) :: Nil => - (Some(unquoteString(dbName)), unquoteString(fname), unquoteString(aname)) - case Token(fname, Nil) :: Token(aname, Nil) :: Nil => - (None, unquoteString(fname), unquoteString(aname)) - case _ => - parseFailed("Invalid CREATE FUNCTION command", node) - } - // Extract other keywords, if they exist - val Seq(rList, temp) = getClauses(Seq("TOK_RESOURCE_LIST", "TOK_TEMPORARY"), otherArgs) - val resources: Seq[(String, String)] = rList.toSeq.flatMap { - case Token("TOK_RESOURCE_LIST", resList) => - resList.map { - case Token("TOK_RESOURCE_URI", rType :: Token(rPath, Nil) :: Nil) => - val resourceType = rType match { - case Token("TOK_JAR", Nil) => "jar" - case Token("TOK_FILE", Nil) => "file" - case Token("TOK_ARCHIVE", Nil) => "archive" - case Token(f, _) => parseFailed(s"Unexpected resource format '$f'", node) - } - (resourceType, unquoteString(rPath)) - case _ => parseFailed("Invalid CREATE FUNCTION command", node) - } - case _ => parseFailed("Invalid CREATE FUNCTION command", node) - } - CreateFunction(dbName, funcName, alias, resources, temp.isDefined)(node.source) - - // DROP [TEMPORARY] FUNCTION [IF EXISTS] function_name; - case Token("TOK_DROPFUNCTION", args) => - // Example format: - // - // TOK_DROPFUNCTION - // :- db_name - // :- func_name - // :- TOK_IFEXISTS - // +- TOK_TEMPORARY - val (funcNameArgs, otherArgs) = args.partition { - case Token("TOK_IFEXISTS", _) => false - case Token("TOK_TEMPORARY", _) => false - case Token(_, Nil) => true - case _ => parseFailed("Invalid DROP FUNCTION command", node) - } - // If database name is specified, there are 2 tokens, otherwise 1. - val (dbName, funcName) = funcNameArgs match { - case Token(dbName, Nil) :: Token(fname, Nil) :: Nil => - (Some(unquoteString(dbName)), unquoteString(fname)) - case Token(fname, Nil) :: Nil => - (None, unquoteString(fname)) - case _ => - parseFailed("Invalid DROP FUNCTION command", node) - } - - val Seq(ifExists, temp) = getClauses(Seq( - "TOK_IFEXISTS", "TOK_TEMPORARY"), otherArgs) - - DropFunction(dbName, funcName, ifExists.isDefined, temp.isDefined)(node.source) - - case Token("TOK_ALTERTABLE", alterTableArgs) => - AlterTableCommandParser.parse(node) - - case Token("TOK_CREATETABLEUSING", createTableArgs) => - val Seq( - temp, - ifNotExists, - Some(tabName), - tableCols, - Some(Token("TOK_TABLEPROVIDER", providerNameParts)), - tableOpts, - tableAs) = getClauses(Seq( - "TEMPORARY", - "TOK_IFNOTEXISTS", - "TOK_TABNAME", "TOK_TABCOLLIST", - "TOK_TABLEPROVIDER", - "TOK_TABLEOPTIONS", - "TOK_QUERY"), createTableArgs) - val tableIdent: TableIdentifier = extractTableIdent(tabName) - val columns = tableCols.map { - case Token("TOK_TABCOLLIST", fields) => StructType(fields.map(nodeToStructField)) - case _ => parseFailed("Invalid CREATE TABLE command", node) - } - val provider = providerNameParts.map { - case Token(name, Nil) => name - case _ => parseFailed("Invalid CREATE TABLE command", node) - }.mkString(".") - val options = tableOpts.toSeq.flatMap { - case Token("TOK_TABLEOPTIONS", opts) => extractProps(opts, "TOK_TABLEOPTION") - case _ => parseFailed("Invalid CREATE TABLE command", node) - }.toMap - val asClause = tableAs.map(nodeToPlan) - - if (temp.isDefined && ifNotExists.isDefined) { - throw new AnalysisException( - "a CREATE TEMPORARY TABLE statement does not allow IF NOT EXISTS clause.") - } - - if (asClause.isDefined) { - if (columns.isDefined) { - throw new AnalysisException( - "a CREATE TABLE AS SELECT statement does not allow column definitions.") - } - - val mode = if (ifNotExists.isDefined) { - SaveMode.Ignore - } else if (temp.isDefined) { - SaveMode.Overwrite - } else { - SaveMode.ErrorIfExists - } - - CreateTableUsingAsSelect(tableIdent, - provider, - temp.isDefined, - Array.empty[String], - bucketSpec = None, - mode, - options, - asClause.get) - } else { - CreateTableUsing( - tableIdent, - columns, - provider, - temp.isDefined, - options, - ifNotExists.isDefined, - managedIfNoPath = false) - } - - case Token("TOK_SWITCHDATABASE", Token(database, Nil) :: Nil) => - SetDatabaseCommand(cleanIdentifier(database)) - - case Token("TOK_DESCTABLE", describeArgs) => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val Some(tableType) :: formatted :: extended :: pretty :: Nil = - getClauses(Seq("TOK_TABTYPE", "FORMATTED", "EXTENDED", "PRETTY"), describeArgs) - if (formatted.isDefined || pretty.isDefined) { - // FORMATTED and PRETTY are not supported and this statement will be treated as - // a Hive native command. - nodeToDescribeFallback(node) - } else { - tableType match { - case Token("TOK_TABTYPE", Token("TOK_TABNAME", nameParts) :: Nil) => - nameParts match { - case Token(dbName, Nil) :: Token(tableName, Nil) :: Nil => - // It is describing a table with the format like "describe db.table". - // TODO: Actually, a user may mean tableName.columnName. Need to resolve this - // issue. - val tableIdent = TableIdentifier( - cleanIdentifier(tableName), Some(cleanIdentifier(dbName))) - datasources.DescribeCommand(tableIdent, isExtended = extended.isDefined) - case Token(dbName, Nil) :: Token(tableName, Nil) :: Token(colName, Nil) :: Nil => - // It is describing a column with the format like "describe db.table column". - nodeToDescribeFallback(node) - case tableName :: Nil => - // It is describing a table with the format like "describe table". - datasources.DescribeCommand( - TableIdentifier(cleanIdentifier(tableName.text)), - isExtended = extended.isDefined) - case _ => - nodeToDescribeFallback(node) - } - // All other cases. - case _ => - nodeToDescribeFallback(node) - } - } - - case Token("TOK_CACHETABLE", Token(tableName, Nil) :: args) => - val Seq(lzy, selectAst) = getClauses(Seq("LAZY", "TOK_QUERY"), args) - CacheTableCommand(tableName, selectAst.map(nodeToPlan), lzy.isDefined) - - case Token("TOK_UNCACHETABLE", Token(tableName, Nil) :: Nil) => - UncacheTableCommand(tableName) - - case Token("TOK_CLEARCACHE", Nil) => - ClearCacheCommand - - case Token("TOK_SHOWTABLES", args) => - val databaseName = args match { - case Nil => None - case Token("TOK_FROM", Token(dbName, Nil) :: Nil) :: Nil => Option(dbName) - case _ => noParseRule("SHOW TABLES", node) - } - ShowTablesCommand(databaseName) - - case _ => - super.nodeToPlan(node) - } - } - - protected def nodeToDescribeFallback(node: ASTNode): LogicalPlan = noParseRule("Describe", node) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8333074eca..b4687c985d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -20,8 +20,8 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser.ng.{AbstractSqlParser, AstBuilder, ParseException} -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.parser.{AbstractSqlParser, AstBuilder, ParseException} +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, OneRowRelation} import org.apache.spark.sql.execution.command.{DescribeCommand => _, _} import org.apache.spark.sql.execution.datasources._ @@ -37,7 +37,7 @@ object SparkSqlParser extends AbstractSqlParser{ * Builder that converts an ANTLR ParseTree into a LogicalPlan/Expression/TableIdentifier. */ class SparkSqlAstBuilder extends AstBuilder { - import org.apache.spark.sql.catalyst.parser.ng.ParserUtils._ + import org.apache.spark.sql.catalyst.parser.ParserUtils._ /** * Create a [[SetCommand]] logical plan. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala deleted file mode 100644 index 9fbe6db467..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/AlterTableCommandParser.scala +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.command - -import scala.collection.mutable.ArrayBuffer - -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, SortDirection} -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.types.StructType - - -/** - * Helper object to parse alter table commands. - */ -object AlterTableCommandParser { - import ParserUtils._ - - /** - * Parse the given node assuming it is an alter table command. - */ - def parse(node: ASTNode): LogicalPlan = { - node.children match { - case (tabName @ Token("TOK_TABNAME", _)) :: otherNodes => - val tableIdent = extractTableIdent(tabName) - val partSpec = getClauseOption("TOK_PARTSPEC", node.children).map(parsePartitionSpec) - matchAlterTableCommands(node, otherNodes, tableIdent, partSpec) - case _ => - parseFailed("Could not parse ALTER TABLE command", node) - } - } - - private def cleanAndUnquoteString(s: String): String = { - cleanIdentifier(unquoteString(s)) - } - - /** - * Extract partition spec from the given [[ASTNode]] as a map, assuming it exists. - * - * Example format: - * - * TOK_PARTSPEC - * :- TOK_PARTVAL - * : :- dt - * : +- '2008-08-08' - * +- TOK_PARTVAL - * :- country - * +- 'us' - */ - private def parsePartitionSpec(node: ASTNode): Map[String, String] = { - node match { - case Token("TOK_PARTSPEC", partitions) => - partitions.map { - // Note: sometimes there's a "=", "<" or ">" between the key and the value - // (e.g. when dropping all partitions with value > than a certain constant) - case Token("TOK_PARTVAL", ident :: conj :: constant :: Nil) => - (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text)) - case Token("TOK_PARTVAL", ident :: constant :: Nil) => - (cleanAndUnquoteString(ident.text), cleanAndUnquoteString(constant.text)) - case Token("TOK_PARTVAL", ident :: Nil) => - (cleanAndUnquoteString(ident.text), null) - case _ => - parseFailed("Invalid ALTER TABLE command", node) - }.toMap - case _ => - parseFailed("Expected partition spec in ALTER TABLE command", node) - } - } - - /** - * Extract table properties from the given [[ASTNode]] as a map, assuming it exists. - * - * Example format: - * - * TOK_TABLEPROPERTIES - * +- TOK_TABLEPROPLIST - * :- TOK_TABLEPROPERTY - * : :- 'test' - * : +- 'value' - * +- TOK_TABLEPROPERTY - * :- 'comment' - * +- 'new_comment' - */ - private def extractTableProps(node: ASTNode): Map[String, String] = { - node match { - case Token("TOK_TABLEPROPERTIES", propsList) => - propsList.flatMap { - case Token("TOK_TABLEPROPLIST", props) => - props.map { case Token("TOK_TABLEPROPERTY", key :: value :: Nil) => - val k = cleanAndUnquoteString(key.text) - val v = value match { - case Token("TOK_NULL", Nil) => null - case _ => cleanAndUnquoteString(value.text) - } - (k, v) - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - }.toMap - case _ => - parseFailed("Expected table properties in ALTER TABLE command", node) - } - } - - /** - * Parse an alter table command from a [[ASTNode]] into a [[LogicalPlan]]. - * This follows https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL. - * - * @param node the original [[ASTNode]] to parse. - * @param otherNodes the other [[ASTNode]]s after the first one containing the table name. - * @param tableIdent identifier of the table, parsed from the first [[ASTNode]]. - * @param partition spec identifying the partition this command is concerned with, if any. - */ - // TODO: This method is massive. Break it down. - private def matchAlterTableCommands( - node: ASTNode, - otherNodes: Seq[ASTNode], - tableIdent: TableIdentifier, - partition: Option[TablePartitionSpec]): LogicalPlan = { - otherNodes match { - // ALTER TABLE table_name RENAME TO new_table_name; - case Token("TOK_ALTERTABLE_RENAME", renameArgs) :: _ => - val tableNameClause = getClause("TOK_TABNAME", renameArgs) - val newTableIdent = extractTableIdent(tableNameClause) - AlterTableRename(tableIdent, newTableIdent)(node.source) - - // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); - case Token("TOK_ALTERTABLE_PROPERTIES", args) :: _ => - val properties = extractTableProps(args.head) - AlterTableSetProperties(tableIdent, properties)(node.source) - - // ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); - case Token("TOK_ALTERTABLE_DROPPROPERTIES", args) :: _ => - val properties = extractTableProps(args.head) - val ifExists = getClauseOption("TOK_IFEXISTS", args).isDefined - AlterTableUnsetProperties(tableIdent, properties, ifExists)(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; - case Token("TOK_ALTERTABLE_SERIALIZER", Token(serdeClassName, Nil) :: serdeArgs) :: _ => - AlterTableSerDeProperties( - tableIdent, - Some(cleanAndUnquoteString(serdeClassName)), - serdeArgs.headOption.map(extractTableProps), - partition)(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET SERDEPROPERTIES serde_properties; - case Token("TOK_ALTERTABLE_SERDEPROPERTIES", args) :: _ => - AlterTableSerDeProperties( - tableIdent, - None, - Some(extractTableProps(args.head)), - partition)(node.source) - - // ALTER TABLE table_name CLUSTERED BY (col, ...) [SORTED BY (col, ...)] INTO n BUCKETS; - case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_ALTERTABLE_BUCKETS", b) :: Nil) :: _ => - val clusterCols: Seq[String] = b.head match { - case Token("TOK_TABCOLNAME", children) => children.map(_.text) - case _ => parseFailed("Invalid ALTER TABLE command", node) - } - // If sort columns are specified, num buckets should be the third arg. - // If sort columns are not specified, num buckets should be the second arg. - // TODO: actually use `sortDirections` once we actually store that in the metastore - val (sortCols: Seq[String], sortDirections: Seq[SortDirection], numBuckets: Int) = { - b.tail match { - case Token("TOK_TABCOLNAME", children) :: numBucketsNode :: Nil => - val (cols, directions) = children.map { - case Token("TOK_TABSORTCOLNAMEASC", Token(col, Nil) :: Nil) => (col, Ascending) - case Token("TOK_TABSORTCOLNAMEDESC", Token(col, Nil) :: Nil) => (col, Descending) - }.unzip - (cols, directions, numBucketsNode.text.toInt) - case numBucketsNode :: Nil => - (Nil, Nil, numBucketsNode.text.toInt) - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - } - AlterTableStorageProperties( - tableIdent, - BucketSpec(numBuckets, clusterCols, sortCols))(node.source) - - // ALTER TABLE table_name NOT CLUSTERED - case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_NOT_CLUSTERED", Nil) :: Nil) :: _ => - AlterTableNotClustered(tableIdent)(node.source) - - // ALTER TABLE table_name NOT SORTED - case Token("TOK_ALTERTABLE_CLUSTER_SORT", Token("TOK_NOT_SORTED", Nil) :: Nil) :: _ => - AlterTableNotSorted(tableIdent)(node.source) - - // ALTER TABLE table_name SKEWED BY (col1, col2) - // ON ((col1_value, col2_value) [, (col1_value, col2_value), ...]) - // [STORED AS DIRECTORIES]; - case Token("TOK_ALTERTABLE_SKEWED", - Token("TOK_TABLESKEWED", - Token("TOK_TABCOLNAME", colNames) :: colValues :: rest) :: Nil) :: _ => - // Example format: - // - // TOK_ALTERTABLE_SKEWED - // :- TOK_TABLESKEWED - // : :- TOK_TABCOLNAME - // : : :- dt - // : : +- country - // :- TOK_TABCOLVALUE_PAIR - // : :- TOK_TABCOLVALUES - // : : :- TOK_TABCOLVALUE - // : : : :- '2008-08-08' - // : : : +- 'us' - // : :- TOK_TABCOLVALUES - // : : :- TOK_TABCOLVALUE - // : : : :- '2009-09-09' - // : : : +- 'uk' - // +- TOK_STOREASDIR - val names = colNames.map { n => cleanAndUnquoteString(n.text) } - val values = colValues match { - case Token("TOK_TABCOLVALUE", vals) => - Seq(vals.map { n => cleanAndUnquoteString(n.text) }) - case Token("TOK_TABCOLVALUE_PAIR", pairs) => - pairs.map { - case Token("TOK_TABCOLVALUES", Token("TOK_TABCOLVALUE", vals) :: Nil) => - vals.map { n => cleanAndUnquoteString(n.text) } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - val storedAsDirs = rest match { - case Token("TOK_STOREDASDIRS", Nil) :: Nil => true - case _ => false - } - AlterTableSkewed( - tableIdent, - names, - values, - storedAsDirs)(node.source) - - // ALTER TABLE table_name NOT SKEWED - case Token("TOK_ALTERTABLE_SKEWED", Nil) :: _ => - AlterTableNotSkewed(tableIdent)(node.source) - - // ALTER TABLE table_name NOT STORED AS DIRECTORIES - case Token("TOK_ALTERTABLE_SKEWED", Token("TOK_STOREDASDIRS", Nil) :: Nil) :: _ => - AlterTableNotStoredAsDirs(tableIdent)(node.source) - - // ALTER TABLE table_name SET SKEWED LOCATION (col1="loc1" [, (col2, col3)="loc2", ...] ); - case Token("TOK_ALTERTABLE_SKEWED_LOCATION", - Token("TOK_SKEWED_LOCATIONS", - Token("TOK_SKEWED_LOCATION_LIST", locationMaps) :: Nil) :: Nil) :: _ => - // Example format: - // - // TOK_ALTERTABLE_SKEWED_LOCATION - // +- TOK_SKEWED_LOCATIONS - // +- TOK_SKEWED_LOCATION_LIST - // :- TOK_SKEWED_LOCATION_MAP - // : :- 'col1' - // : +- 'loc1' - // +- TOK_SKEWED_LOCATION_MAP - // :- TOK_TABCOLVALUES - // : +- TOK_TABCOLVALUE - // : :- 'col2' - // : +- 'col3' - // +- 'loc2' - val skewedMaps = locationMaps.flatMap { - case Token("TOK_SKEWED_LOCATION_MAP", col :: loc :: Nil) => - col match { - case Token(const, Nil) => - Seq((cleanAndUnquoteString(const), cleanAndUnquoteString(loc.text))) - case Token("TOK_TABCOLVALUES", Token("TOK_TABCOLVALUE", keys) :: Nil) => - keys.map { k => (cleanAndUnquoteString(k.text), cleanAndUnquoteString(loc.text)) } - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - }.toMap - AlterTableSkewedLocation(tableIdent, skewedMaps)(node.source) - - // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] - // spec [LOCATION 'loc2'] ...; - case Token("TOK_ALTERTABLE_ADDPARTS", args) :: _ => - val (ifNotExists, parts) = args.head match { - case Token("TOK_IFNOTEXISTS", Nil) => (true, args.tail) - case _ => (false, args) - } - // List of (spec, location) to describe partitions to add - // Each partition spec may or may not be followed by a location - val parsedParts = new ArrayBuffer[(TablePartitionSpec, Option[String])] - parts.foreach { - case t @ Token("TOK_PARTSPEC", _) => - parsedParts += ((parsePartitionSpec(t), None)) - case Token("TOK_PARTITIONLOCATION", loc :: Nil) => - // Update the location of the last partition we just added - if (parsedParts.nonEmpty) { - val (spec, _) = parsedParts.remove(parsedParts.length - 1) - parsedParts += ((spec, Some(unquoteString(loc.text)))) - } - case _ => - parseFailed("Invalid ALTER TABLE command", node) - } - AlterTableAddPartition(tableIdent, parsedParts, ifNotExists)(node.source) - - // ALTER TABLE table_name PARTITION spec1 RENAME TO PARTITION spec2; - case Token("TOK_ALTERTABLE_RENAMEPART", spec :: Nil) :: _ => - val newPartition = parsePartitionSpec(spec) - val oldPartition = partition.getOrElse { - parseFailed("Expected old partition spec in ALTER TABLE rename partition command", node) - } - AlterTableRenamePartition(tableIdent, oldPartition, newPartition)(node.source) - - // ALTER TABLE table_name_1 EXCHANGE PARTITION spec WITH TABLE table_name_2; - case Token("TOK_ALTERTABLE_EXCHANGEPARTITION", spec :: newTable :: Nil) :: _ => - val parsedSpec = parsePartitionSpec(spec) - val newTableIdent = extractTableIdent(newTable) - AlterTableExchangePartition(tableIdent, newTableIdent, parsedSpec)(node.source) - - // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; - case Token("TOK_ALTERTABLE_DROPPARTS", args) :: _ => - val parts = args.collect { case p @ Token("TOK_PARTSPEC", _) => parsePartitionSpec(p) } - val ifExists = getClauseOption("TOK_IFEXISTS", args).isDefined - val purge = getClauseOption("PURGE", args).isDefined - AlterTableDropPartition(tableIdent, parts, ifExists, purge)(node.source) - - // ALTER TABLE table_name ARCHIVE PARTITION spec; - case Token("TOK_ALTERTABLE_ARCHIVE", spec :: Nil) :: _ => - AlterTableArchivePartition(tableIdent, parsePartitionSpec(spec))(node.source) - - // ALTER TABLE table_name UNARCHIVE PARTITION spec; - case Token("TOK_ALTERTABLE_UNARCHIVE", spec :: Nil) :: _ => - AlterTableUnarchivePartition(tableIdent, parsePartitionSpec(spec))(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET FILEFORMAT file_format; - case Token("TOK_ALTERTABLE_FILEFORMAT", args) :: _ => - val Seq(fileFormat, genericFormat) = - getClauses(Seq("TOK_TABLEFILEFORMAT", "TOK_FILEFORMAT_GENERIC"), args) - // Note: the AST doesn't contain information about which file format is being set here. - // E.g. we can't differentiate between INPUTFORMAT and OUTPUTFORMAT if either is set. - // Right now this just stores the values, but we should figure out how to get the keys. - val fFormat = fileFormat - .map { _.children.map { n => cleanAndUnquoteString(n.text) }} - .getOrElse(Seq()) - val gFormat = genericFormat.map { f => cleanAndUnquoteString(f.children(0).text) } - AlterTableSetFileFormat(tableIdent, partition, fFormat, gFormat)(node.source) - - // ALTER TABLE table_name [PARTITION spec] SET LOCATION "loc"; - case Token("TOK_ALTERTABLE_LOCATION", Token(loc, Nil) :: Nil) :: _ => - AlterTableSetLocation(tableIdent, partition, cleanAndUnquoteString(loc))(node.source) - - // ALTER TABLE table_name TOUCH [PARTITION spec]; - case Token("TOK_ALTERTABLE_TOUCH", args) :: _ => - // Note: the partition spec, if it exists, comes after TOUCH, so `partition` should - // always be None here. Instead, we need to parse it from the TOUCH node's children. - val part = getClauseOption("TOK_PARTSPEC", args).map(parsePartitionSpec) - AlterTableTouch(tableIdent, part)(node.source) - - // ALTER TABLE table_name [PARTITION spec] COMPACT 'compaction_type'; - case Token("TOK_ALTERTABLE_COMPACT", Token(compactType, Nil) :: Nil) :: _ => - AlterTableCompact(tableIdent, partition, cleanAndUnquoteString(compactType))(node.source) - - // ALTER TABLE table_name [PARTITION spec] CONCATENATE; - case Token("TOK_ALTERTABLE_MERGEFILES", _) :: _ => - AlterTableMerge(tableIdent, partition)(node.source) - - // ALTER TABLE table_name [PARTITION spec] CHANGE [COLUMN] col_old_name col_new_name - // column_type [COMMENT col_comment] [FIRST|AFTER column_name] [CASCADE|RESTRICT]; - case Token("TOK_ALTERTABLE_RENAMECOL", oldName :: newName :: dataType :: args) :: _ => - val afterColName: Option[String] = - getClauseOption("TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION", args).map { ap => - ap.children match { - case Token(col, Nil) :: Nil => col - case _ => parseFailed("Invalid ALTER TABLE command", node) - } - } - val restrict = getClauseOption("TOK_RESTRICT", args).isDefined - val cascade = getClauseOption("TOK_CASCADE", args).isDefined - val comment = args.headOption.map { - case Token("TOK_ALTERTABLE_CHANGECOL_AFTER_POSITION", _) => null - case Token("TOK_RESTRICT", _) => null - case Token("TOK_CASCADE", _) => null - case Token(commentStr, Nil) => cleanAndUnquoteString(commentStr) - case _ => parseFailed("Invalid ALTER TABLE command", node) - } - AlterTableChangeCol( - tableIdent, - partition, - oldName.text, - newName.text, - nodeToDataType(dataType), - comment, - afterColName, - restrict, - cascade)(node.source) - - // ALTER TABLE table_name [PARTITION spec] ADD COLUMNS (name type [COMMENT comment], ...) - // [CASCADE|RESTRICT] - case Token("TOK_ALTERTABLE_ADDCOLS", args) :: _ => - val columnNodes = getClause("TOK_TABCOLLIST", args).children - val columns = StructType(columnNodes.map(nodeToStructField)) - val restrict = getClauseOption("TOK_RESTRICT", args).isDefined - val cascade = getClauseOption("TOK_CASCADE", args).isDefined - AlterTableAddCol(tableIdent, partition, columns, restrict, cascade)(node.source) - - // ALTER TABLE table_name [PARTITION spec] REPLACE COLUMNS (name type [COMMENT comment], ...) - // [CASCADE|RESTRICT] - case Token("TOK_ALTERTABLE_REPLACECOLS", args) :: _ => - val columnNodes = getClause("TOK_TABCOLLIST", args).children - val columns = StructType(columnNodes.map(nodeToStructField)) - val restrict = getClauseOption("TOK_RESTRICT", args).isDefined - val cascade = getClauseOption("TOK_CASCADE", args).isDefined - AlterTableReplaceCol(tableIdent, partition, columns, restrict, cascade)(node.source) - - case _ => - parseFailed("Unsupported ALTER TABLE command", node) - } - } - -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index d06e9086a3..6cc72fba48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -26,7 +26,6 @@ import org.apache.parquet.hadoop.ParquetOutputCommitter import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.parser.ParserConf import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -500,19 +499,6 @@ object SQLConf { doc = "When true, we could use `datasource`.`path` as table in SQL query." ) - val PARSER_SUPPORT_QUOTEDID = booleanConf("spark.sql.parser.supportQuotedIdentifiers", - defaultValue = Some(true), - isPublic = false, - doc = "Whether to use quoted identifier.\n false: default(past) behavior. Implies only" + - "alphaNumeric and underscore are valid characters in identifiers.\n" + - " true: implies column names can contain any character.") - - val PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS = booleanConf( - "spark.sql.parser.supportSQL11ReservedKeywords", - defaultValue = Some(false), - isPublic = false, - doc = "This flag should be set to true to enable support for SQL2011 reserved keywords.") - val WHOLESTAGE_CODEGEN_ENABLED = booleanConf("spark.sql.codegen.wholeStage", defaultValue = Some(true), doc = "When true, the whole stage (of multiple operators) will be compiled into single java" + @@ -573,7 +559,7 @@ object SQLConf { * * SQLConf is thread-safe (internally synchronized, so safe to be used in multiple threads). */ -class SQLConf extends Serializable with CatalystConf with ParserConf with Logging { +class SQLConf extends Serializable with CatalystConf with Logging { import SQLConf._ /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @@ -674,10 +660,6 @@ class SQLConf extends Serializable with CatalystConf with ParserConf with Loggin def runSQLOnFile: Boolean = getConf(RUN_SQL_ON_FILES) - def supportQuotedId: Boolean = getConf(PARSER_SUPPORT_QUOTEDID) - - def supportSQL11ReservedKeywords: Boolean = getConf(PARSER_SUPPORT_SQL11_RESERVED_KEYWORDS) - override def orderByOrdinal: Boolean = getConf(ORDER_BY_ORDINAL) override def groupByOrdinal: Boolean = getConf(GROUP_BY_ORDINAL) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 47c9a22acd..f148f2d4ea 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -21,11 +21,22 @@ import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.catalog.CatalogDatabase -import org.apache.spark.sql.catalyst.parser.ParserUtils._ import org.apache.spark.sql.test.SharedSQLContext class DDLSuite extends QueryTest with SharedSQLContext { + private val escapedIdentifier = "`(.+)`".r + + /** + * Strip backticks, if any, from the string. + */ + def cleanIdentifier(ident: String): String = { + ident match { + case escapedIdentifier(i) => i + case plainIdent => plainIdent + } + } + /** * Drops database `databaseName` after calling `f`. */ diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index 22bad93e6d..58efd80512 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -225,25 +225,6 @@ -da -Xmx3g -XX:MaxPermSize=${MaxPermGen} -XX:ReservedCodeCacheSize=512m - - org.codehaus.mojo - build-helper-maven-plugin - - - add-default-sources - generate-sources - - add-source - - - - v${hive.version.short}/src/main/scala - ${project.build.directory/generated-sources/antlr - - - - - diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 9a5ec9880e..2cdc931c3f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -25,7 +25,7 @@ import com.google.common.cache.{CacheBuilder, CacheLoader, LoadingCache} import org.apache.hadoop.fs.{FileStatus, Path} import org.apache.hadoop.hive.common.StatsSetupConst import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{TableType => HiveTableType, Warehouse} +import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.FieldSchema import org.apache.hadoop.hive.ql.metadata.{Table => HiveTable, _} import org.apache.hadoop.hive.ql.plan.TableDesc @@ -988,3 +988,28 @@ private[hive] object HiveMetastoreTypes { case udt: UserDefinedType[_] => toMetastoreType(udt.sqlType) } } + +private[hive] case class CreateTableAsSelect( + tableDesc: CatalogTable, + child: LogicalPlan, + allowExisting: Boolean) extends UnaryNode with Command { + + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = + tableDesc.identifier.database.isDefined && + tableDesc.schema.nonEmpty && + tableDesc.storage.serde.isDefined && + tableDesc.storage.inputFormat.isDefined && + tableDesc.storage.outputFormat.isDefined && + childrenResolved +} + +private[hive] case class CreateViewAsSelect( + tableDesc: CatalogTable, + child: LogicalPlan, + allowExisting: Boolean, + replace: Boolean, + sql: String) extends UnaryNode with Command { + override def output: Seq[Attribute] = Seq.empty[Attribute] + override lazy val resolved: Boolean = false +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala deleted file mode 100644 index 052c43a3ce..0000000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ /dev/null @@ -1,749 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import java.util.Locale - -import scala.collection.JavaConverters._ - -import org.apache.hadoop.hive.common.`type`.HiveDecimal -import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.{FunctionInfo, FunctionRegistry} -import org.apache.hadoop.hive.ql.parse.EximUtil -import org.apache.hadoop.hive.ql.session.SessionState -import org.apache.hadoop.hive.serde.serdeConstants -import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe - -import org.apache.spark.internal.Logging -import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser._ -import org.apache.spark.sql.catalyst.plans._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.SparkQl -import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper -import org.apache.spark.sql.hive.execution._ -import org.apache.spark.sql.types._ -import org.apache.spark.sql.AnalysisException - -/** - * Used when we need to start parsing the AST before deciding that we are going to pass the command - * back for Hive to execute natively. Will be replaced with a native command that contains the - * cmd string. - */ -private[hive] case object NativePlaceholder extends LogicalPlan { - override def children: Seq[LogicalPlan] = Seq.empty - override def output: Seq[Attribute] = Seq.empty -} - -private[hive] case class CreateTableAsSelect( - tableDesc: CatalogTable, - child: LogicalPlan, - allowExisting: Boolean) extends UnaryNode with Command { - - override def output: Seq[Attribute] = Seq.empty[Attribute] - override lazy val resolved: Boolean = - tableDesc.identifier.database.isDefined && - tableDesc.schema.nonEmpty && - tableDesc.storage.serde.isDefined && - tableDesc.storage.inputFormat.isDefined && - tableDesc.storage.outputFormat.isDefined && - childrenResolved -} - -private[hive] case class CreateViewAsSelect( - tableDesc: CatalogTable, - child: LogicalPlan, - allowExisting: Boolean, - replace: Boolean, - sql: String) extends UnaryNode with Command { - override def output: Seq[Attribute] = Seq.empty[Attribute] - override lazy val resolved: Boolean = false -} - -/** Provides a mapping from HiveQL statements to catalyst logical plans and expression trees. */ -private[hive] class HiveQl(conf: ParserConf) extends SparkQl(conf) with Logging { - import ParseUtils._ - import ParserUtils._ - - // Token text -> human readable text - private val hiveUnsupportedCommands = Map( - "TOK_CREATEROLE" -> "CREATE ROLE", - "TOK_DROPROLE" -> "DROP ROLE", - "TOK_EXPORT" -> "EXPORT TABLE", - "TOK_GRANT" -> "GRANT", - "TOK_GRANT_ROLE" -> "GRANT", - "TOK_IMPORT" -> "IMPORT TABLE", - "TOK_REVOKE" -> "REVOKE", - "TOK_REVOKE_ROLE" -> "REVOKE", - "TOK_SHOW_COMPACTIONS" -> "SHOW COMPACTIONS", - "TOK_SHOW_CREATETABLE" -> "SHOW CREATE TABLE", - "TOK_SHOW_GRANT" -> "SHOW GRANT", - "TOK_SHOW_ROLE_GRANT" -> "SHOW ROLE GRANT", - "TOK_SHOW_ROLE_PRINCIPALS" -> "SHOW PRINCIPALS", - "TOK_SHOW_ROLES" -> "SHOW ROLES", - "TOK_SHOW_SET_ROLE" -> "SHOW CURRENT ROLES / SET ROLE", - "TOK_SHOW_TRANSACTIONS" -> "SHOW TRANSACTIONS", - "TOK_SHOWINDEXES" -> "SHOW INDEXES", - "TOK_SHOWLOCKS" -> "SHOW LOCKS") - - private val nativeCommands = Set( - "TOK_ALTERDATABASE_OWNER", - "TOK_ALTERINDEX_PROPERTIES", - "TOK_ALTERINDEX_REBUILD", - "TOK_ALTERTABLE_ALTERPARTS", - "TOK_ALTERTABLE_PARTITION", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_AS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME", - - "TOK_CREATEINDEX", - "TOK_CREATEMACRO", - - "TOK_DROPINDEX", - "TOK_DROPMACRO", - "TOK_DROPTABLE_PROPERTIES", - "TOK_DROPVIEW", - "TOK_DROPVIEW_PROPERTIES", - - "TOK_LOAD", - - "TOK_LOCKTABLE", - - "TOK_MSCK", - - "TOK_SHOW_TABLESTATUS", - "TOK_SHOW_TBLPROPERTIES", - "TOK_SHOWCOLUMNS", - "TOK_SHOWDATABASES", - "TOK_SHOWPARTITIONS", - - "TOK_UNLOCKTABLE" - ) - - // Commands that we do not need to explain. - private val noExplainCommands = Set( - "TOK_DESCTABLE", - "TOK_SHOWTABLES", - "TOK_TRUNCATETABLE", // truncate table" is a NativeCommand, does not need to explain. - "TOK_ALTERTABLE" - ) ++ nativeCommands - - /** - * Returns the HiveConf - */ - private[this] def hiveConf: HiveConf = { - var ss = SessionState.get() - // SessionState is lazy initialization, it can be null here - if (ss == null) { - val original = Thread.currentThread().getContextClassLoader - val conf = new HiveConf(classOf[SessionState]) - conf.setClassLoader(original) - ss = new SessionState(conf) - SessionState.start(ss) - } - ss.getConf - } - - protected def getProperties(node: ASTNode): Seq[(String, String)] = node match { - case Token("TOK_TABLEPROPLIST", list) => - list.map { - case Token("TOK_TABLEPROPERTY", Token(key, Nil) :: Token(value, Nil) :: Nil) => - unquoteString(key) -> unquoteString(value) - } - } - - private def createView( - view: ASTNode, - viewNameParts: ASTNode, - query: ASTNode, - schema: Seq[CatalogColumn], - properties: Map[String, String], - allowExist: Boolean, - replace: Boolean): CreateViewAsSelect = { - val tableIdentifier = extractTableIdent(viewNameParts) - val originalText = query.source - val tableDesc = CatalogTable( - identifier = tableIdentifier, - tableType = CatalogTableType.VIRTUAL_VIEW, - schema = schema, - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = None, - outputFormat = None, - serde = None, - serdeProperties = Map.empty[String, String] - ), - properties = properties, - viewOriginalText = Some(originalText), - viewText = Some(originalText)) - - // We need to keep the original SQL string so that if `spark.sql.nativeView` is - // false, we can fall back to use hive native command later. - // We can remove this when parser is configurable(can access SQLConf) in the future. - val sql = view.source - CreateViewAsSelect(tableDesc, nodeToPlan(query), allowExist, replace, sql) - } - - /** Creates LogicalPlan for a given SQL string. */ - override def parsePlan(sql: String): LogicalPlan = { - safeParse(sql, ParseDriver.parsePlan(sql, conf)) { ast => - if (nativeCommands.contains(ast.text)) { - HiveNativeCommand(sql) - } else if (hiveUnsupportedCommands.contains(ast.text)) { - val humanReadableText = hiveUnsupportedCommands(ast.text) - throw new AnalysisException("Unsupported operation: " + humanReadableText) - } else { - nodeToPlan(ast) match { - case NativePlaceholder => HiveNativeCommand(sql) - case plan => plan - } - } - } - } - - protected override def isNoExplainCommand(command: String): Boolean = - noExplainCommands.contains(command) - - protected override def nodeToPlan(node: ASTNode): LogicalPlan = { - node match { - case Token("TOK_DFS", Nil) => - HiveNativeCommand(node.source + " " + node.remainder) - - case Token("TOK_ADDFILE", Nil) => - AddFile(node.remainder) - - case Token("TOK_ADDJAR", Nil) => - AddJar(node.remainder) - - // Special drop table that also uncaches. - case Token("TOK_DROPTABLE", Token("TOK_TABNAME", tableNameParts) :: ifExists) => - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - DropTable(tableName, ifExists.nonEmpty) - - // Support "ANALYZE TABLE tableName COMPUTE STATISTICS noscan" - case Token("TOK_ANALYZE", - Token("TOK_TAB", Token("TOK_TABNAME", tableNameParts) :: partitionSpec) :: isNoscan) => - // Reference: - // https://cwiki.apache.org/confluence/display/Hive/StatsDev#StatsDev-ExistingTables - if (partitionSpec.nonEmpty) { - // Analyze partitions will be treated as a Hive native command. - NativePlaceholder - } else if (isNoscan.isEmpty) { - // If users do not specify "noscan", it will be treated as a Hive native command. - NativePlaceholder - } else { - val tableName = tableNameParts.map { case Token(p, Nil) => p }.mkString(".") - AnalyzeTable(tableName) - } - - case view @ Token("TOK_ALTERVIEW", children) => - val Some(nameParts) :: maybeQuery :: _ = - getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_ALTERVIEW_ADDPARTS", - "TOK_ALTERVIEW_DROPPARTS", - "TOK_ALTERVIEW_PROPERTIES", - "TOK_ALTERVIEW_RENAME"), children) - - // if ALTER VIEW doesn't have query part, let hive to handle it. - maybeQuery.map { query => - createView(view, nameParts, query, Nil, Map(), allowExist = false, replace = true) - }.getOrElse(NativePlaceholder) - - case view @ Token("TOK_CREATEVIEW", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - val Seq( - Some(viewNameParts), - Some(query), - maybeComment, - replace, - allowExisting, - maybeProperties, - maybeColumns, - maybePartCols - ) = getClauses(Seq( - "TOK_TABNAME", - "TOK_QUERY", - "TOK_TABLECOMMENT", - "TOK_ORREPLACE", - "TOK_IFNOTEXISTS", - "TOK_TABLEPROPERTIES", - "TOK_TABCOLNAME", - "TOK_VIEWPARTCOLS"), children) - - // If the view is partitioned, we let hive handle it. - if (maybePartCols.isDefined) { - NativePlaceholder - } else { - val schema = maybeColumns.map { cols => - // We can't specify column types when create view, so fill it with null first, and - // update it after the schema has been resolved later. - nodeToColumns(cols, lowerCase = true).map(_.copy(dataType = null)) - }.getOrElse(Seq.empty[CatalogColumn]) - - val properties = scala.collection.mutable.Map.empty[String, String] - - maybeProperties.foreach { - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - properties ++= getProperties(list) - } - - maybeComment.foreach { - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = unescapeSQLString(child.text) - if (comment ne null) { - properties += ("comment" -> comment) - } - } - - createView(view, viewNameParts, query, schema, properties.toMap, - allowExisting.isDefined, replace.isDefined) - } - - case Token("TOK_CREATETABLE", children) - if children.collect { case t @ Token("TOK_QUERY", _) => t }.nonEmpty => - // Reference: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL - val ( - Some(tableNameParts) :: - _ /* likeTable */ :: - externalTable :: - Some(query) :: - allowExisting +: - _) = - getClauses( - Seq( - "TOK_TABNAME", - "TOK_LIKETABLE", - "EXTERNAL", - "TOK_QUERY", - "TOK_IFNOTEXISTS", - "TOK_TABLECOMMENT", - "TOK_TABCOLLIST", - "TOK_TABLEPARTCOLS", // Partitioned by - "TOK_TABLEBUCKETS", // Clustered by - "TOK_TABLESKEWED", // Skewed by - "TOK_TABLEROWFORMAT", - "TOK_TABLESERIALIZER", - "TOK_FILEFORMAT_GENERIC", - "TOK_TABLEFILEFORMAT", // User-provided InputFormat and OutputFormat - "TOK_STORAGEHANDLER", // Storage handler - "TOK_TABLELOCATION", - "TOK_TABLEPROPERTIES"), - children) - val tableIdentifier = extractTableIdent(tableNameParts) - - // TODO add bucket support - var tableDesc: CatalogTable = CatalogTable( - identifier = tableIdentifier, - tableType = - if (externalTable.isDefined) { - CatalogTableType.EXTERNAL_TABLE - } else { - CatalogTableType.MANAGED_TABLE - }, - storage = CatalogStorageFormat( - locationUri = None, - inputFormat = None, - outputFormat = None, - serde = None, - serdeProperties = Map.empty[String, String] - ), - schema = Seq.empty[CatalogColumn]) - - // default storage type abbreviation (e.g. RCFile, ORC, PARQUET etc.) - val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - // handle the default format for the storage type abbreviation - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - - tableDesc = tableDesc.withNewStorage( - inputFormat = hiveSerDe.inputFormat.orElse(tableDesc.storage.inputFormat), - outputFormat = hiveSerDe.outputFormat.orElse(tableDesc.storage.outputFormat), - serde = hiveSerDe.serde.orElse(tableDesc.storage.serde)) - - children.collect { - case list @ Token("TOK_TABCOLLIST", _) => - val cols = nodeToColumns(list, lowerCase = true) - if (cols != null) { - tableDesc = tableDesc.copy(schema = cols) - } - case Token("TOK_TABLECOMMENT", child :: Nil) => - val comment = unescapeSQLString(child.text) - // TODO support the sql text - tableDesc = tableDesc.copy(viewText = Option(comment)) - case Token("TOK_TABLEPARTCOLS", list @ Token("TOK_TABCOLLIST", _) :: Nil) => - val cols = nodeToColumns(list.head, lowerCase = false) - if (cols != null) { - tableDesc = tableDesc.copy(partitionColumns = cols) - } - case Token("TOK_TABLEROWFORMAT", Token("TOK_SERDEPROPS", child :: Nil) :: Nil) => - val serdeParams = new java.util.HashMap[String, String]() - child match { - case Token("TOK_TABLEROWFORMATFIELD", rowChild1 :: rowChild2) => - val fieldDelim = unescapeSQLString (rowChild1.text) - serdeParams.put(serdeConstants.FIELD_DELIM, fieldDelim) - serdeParams.put(serdeConstants.SERIALIZATION_FORMAT, fieldDelim) - if (rowChild2.length > 1) { - val fieldEscape = unescapeSQLString (rowChild2.head.text) - serdeParams.put(serdeConstants.ESCAPE_CHAR, fieldEscape) - } - case Token("TOK_TABLEROWFORMATCOLLITEMS", rowChild :: Nil) => - val collItemDelim = unescapeSQLString(rowChild.text) - serdeParams.put(serdeConstants.COLLECTION_DELIM, collItemDelim) - case Token("TOK_TABLEROWFORMATMAPKEYS", rowChild :: Nil) => - val mapKeyDelim = unescapeSQLString(rowChild.text) - serdeParams.put(serdeConstants.MAPKEY_DELIM, mapKeyDelim) - case Token("TOK_TABLEROWFORMATLINES", rowChild :: Nil) => - val lineDelim = unescapeSQLString(rowChild.text) - if (!(lineDelim == "\n") && !(lineDelim == "10")) { - throw new AnalysisException( - s"LINES TERMINATED BY only supports newline '\\n' right now: $rowChild") - } - serdeParams.put(serdeConstants.LINE_DELIM, lineDelim) - case Token("TOK_TABLEROWFORMATNULL", rowChild :: Nil) => - val nullFormat = unescapeSQLString(rowChild.text) - // TODO support the nullFormat - case _ => assert(false) - } - tableDesc = tableDesc.withNewStorage( - serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams.asScala) - case Token("TOK_TABLELOCATION", child :: Nil) => - val location = EximUtil.relativeToAbsolutePath(hiveConf, unescapeSQLString(child.text)) - tableDesc = tableDesc.withNewStorage(locationUri = Option(location)) - case Token("TOK_TABLESERIALIZER", child :: Nil) => - tableDesc = tableDesc.withNewStorage( - serde = Option(unescapeSQLString(child.children.head.text))) - if (child.numChildren == 2) { - // This is based on the readProps(..) method in - // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: - val serdeParams = child.children(1).children.head.children.map { - case Token(_, Token(prop, Nil) :: valueNode) => - val value = valueNode.headOption - .map(_.text) - .map(unescapeSQLString) - .orNull - (unescapeSQLString(prop), value) - }.toMap - tableDesc = tableDesc.withNewStorage( - serdeProperties = tableDesc.storage.serdeProperties ++ serdeParams) - } - case Token("TOK_FILEFORMAT_GENERIC", child :: Nil) => - child.text.toLowerCase(Locale.ENGLISH) match { - case "orc" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = Option("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) - } - - case "parquet" => - tableDesc = tableDesc.withNewStorage( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = Option("org.apache.hadoop.hive.ql.io.parquet.serde.ParquetHiveSerDe")) - } - - case "rcfile" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = - Option("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - } - - case "textfile" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - - case "sequencefile" => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option("org.apache.hadoop.mapred.SequenceFileInputFormat"), - outputFormat = Option("org.apache.hadoop.mapred.SequenceFileOutputFormat")) - - case "avro" => - tableDesc = tableDesc.withNewStorage( - inputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerInputFormat"), - outputFormat = - Option("org.apache.hadoop.hive.ql.io.avro.AvroContainerOutputFormat")) - if (tableDesc.storage.serde.isEmpty) { - tableDesc = tableDesc.withNewStorage( - serde = Option("org.apache.hadoop.hive.serde2.avro.AvroSerDe")) - } - - case _ => - throw new AnalysisException( - s"Unrecognized file format in STORED AS clause: ${child.text}") - } - - case Token("TOK_TABLESERIALIZER", - Token("TOK_SERDENAME", Token(serdeName, Nil) :: otherProps) :: Nil) => - tableDesc = tableDesc.withNewStorage(serde = Option(unquoteString(serdeName))) - - otherProps match { - case Token("TOK_TABLEPROPERTIES", list :: Nil) :: Nil => - tableDesc = tableDesc.withNewStorage( - serdeProperties = tableDesc.storage.serdeProperties ++ getProperties(list)) - case _ => - } - - case Token("TOK_TABLEPROPERTIES", list :: Nil) => - tableDesc = tableDesc.copy(properties = tableDesc.properties ++ getProperties(list)) - case list @ Token("TOK_TABLEFILEFORMAT", _) => - tableDesc = tableDesc.withNewStorage( - inputFormat = Option(unescapeSQLString(list.children.head.text)), - outputFormat = Option(unescapeSQLString(list.children(1).text))) - case Token("TOK_STORAGEHANDLER", _) => - throw new AnalysisException( - "CREATE TABLE AS SELECT cannot be used for a non-native table") - case _ => // Unsupported features - } - - CreateTableAsSelect(tableDesc, nodeToPlan(query), allowExisting.isDefined) - - // If its not a "CTAS" like above then take it as a native command - case Token("TOK_CREATETABLE", _) => - NativePlaceholder - - // Support "TRUNCATE TABLE table_name [PARTITION partition_spec]" - case Token("TOK_TRUNCATETABLE", Token("TOK_TABLE_PARTITION", table) :: Nil) => - NativePlaceholder - - case _ => - super.nodeToPlan(node) - } - } - - protected override def nodeToDescribeFallback(node: ASTNode): LogicalPlan = NativePlaceholder - - protected override def nodeToTransformation( - node: ASTNode, - child: LogicalPlan): Option[logical.ScriptTransformation] = node match { - case Token("TOK_SELEXPR", - Token("TOK_TRANSFORM", - Token("TOK_EXPLIST", inputExprs) :: - Token("TOK_SERDE", inputSerdeClause) :: - Token("TOK_RECORDWRITER", writerClause) :: - // TODO: Need to support other types of (in/out)put - Token(script, Nil) :: - Token("TOK_SERDE", outputSerdeClause) :: - Token("TOK_RECORDREADER", readerClause) :: - outputClause) :: Nil) => - - val (output, schemaLess) = outputClause match { - case Token("TOK_ALIASLIST", aliases) :: Nil => - (aliases.map { case Token(name, Nil) => - AttributeReference(cleanIdentifier(name), StringType)() }, false) - case Token("TOK_TABCOLLIST", attributes) :: Nil => - (attributes.map { case Token("TOK_TABCOL", Token(name, Nil) :: dataType :: Nil) => - AttributeReference(cleanIdentifier(name), nodeToDataType(dataType))() }, false) - case Nil => - (List(AttributeReference("key", StringType)(), - AttributeReference("value", StringType)()), true) - case _ => - noParseRule("Transform", node) - } - - type SerDeInfo = ( - Seq[(String, String)], // Input row format information - Option[String], // Optional input SerDe class - Seq[(String, String)], // Input SerDe properties - Boolean // Whether to use default record reader/writer - ) - - def matchSerDe(clause: Seq[ASTNode]): SerDeInfo = clause match { - case Token("TOK_SERDEPROPS", propsClause) :: Nil => - val rowFormat = propsClause.map { - case Token(name, Token(value, Nil) :: Nil) => (name, value) - } - (rowFormat, None, Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: Nil) :: Nil => - (Nil, Some(unescapeSQLString(serdeClass)), Nil, false) - - case Token("TOK_SERDENAME", Token(serdeClass, Nil) :: - Token("TOK_TABLEPROPERTIES", - Token("TOK_TABLEPROPLIST", propsClause) :: Nil) :: Nil) :: Nil => - val serdeProps = propsClause.map { - case Token("TOK_TABLEPROPERTY", Token(name, Nil) :: Token(value, Nil) :: Nil) => - (unescapeSQLString(name), unescapeSQLString(value)) - } - - // SPARK-10310: Special cases LazySimpleSerDe - // TODO Fully supports user-defined record reader/writer classes - val unescapedSerDeClass = unescapeSQLString(serdeClass) - val useDefaultRecordReaderWriter = - unescapedSerDeClass == classOf[LazySimpleSerDe].getCanonicalName - (Nil, Some(unescapedSerDeClass), serdeProps, useDefaultRecordReaderWriter) - - case Nil => - // Uses default TextRecordReader/TextRecordWriter, sets field delimiter here - val serdeProps = Seq(serdeConstants.FIELD_DELIM -> "\t") - (Nil, Option(hiveConf.getVar(ConfVars.HIVESCRIPTSERDE)), serdeProps, true) - } - - val (inRowFormat, inSerdeClass, inSerdeProps, useDefaultRecordReader) = - matchSerDe(inputSerdeClause) - - val (outRowFormat, outSerdeClass, outSerdeProps, useDefaultRecordWriter) = - matchSerDe(outputSerdeClause) - - val unescapedScript = unescapeSQLString(script) - - // TODO Adds support for user-defined record reader/writer classes - val recordReaderClass = if (useDefaultRecordReader) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDREADER)) - } else { - None - } - - val recordWriterClass = if (useDefaultRecordWriter) { - Option(hiveConf.getVar(ConfVars.HIVESCRIPTRECORDWRITER)) - } else { - None - } - - val schema = HiveScriptIOSchema( - inRowFormat, outRowFormat, - inSerdeClass, outSerdeClass, - inSerdeProps, outSerdeProps, - recordReaderClass, recordWriterClass, - schemaLess) - - Some( - logical.ScriptTransformation( - inputExprs.map(nodeToExpr), - unescapedScript, - output, - child, schema)) - case _ => None - } - - protected override def nodeToGenerator(node: ASTNode): Generator = node match { - case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) => - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(functionName.toLowerCase)).getOrElse( - sys.error(s"Couldn't find function $functionName")) - val functionClassName = functionInfo.getFunctionClass.getName - HiveGenericUDTF( - functionName, new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)) - case other => super.nodeToGenerator(node) - } - - // This is based the getColumns methods in - // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - protected def nodeToColumns(node: ASTNode, lowerCase: Boolean): Seq[CatalogColumn] = { - node.children.map(_.children).collect { - case Token(rawColName, Nil) :: colTypeNode :: comment => - val colName = if (!lowerCase) rawColName else rawColName.toLowerCase - CatalogColumn( - name = cleanIdentifier(colName), - dataType = nodeToTypeString(colTypeNode), - nullable = true, - comment.headOption.map(n => unescapeSQLString(n.text))) - } - } - - // This is based on the following methods in - // ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java: - // getTypeStringFromAST - // getStructTypeStringFromAST - // getUnionTypeStringFromAST - protected def nodeToTypeString(node: ASTNode): String = node.tokenType match { - case SparkSqlParser.TOK_LIST => - val listType :: Nil = node.children - val listTypeString = nodeToTypeString(listType) - s"${serdeConstants.LIST_TYPE_NAME}<$listTypeString>" - - case SparkSqlParser.TOK_MAP => - val keyType :: valueType :: Nil = node.children - val keyTypeString = nodeToTypeString(keyType) - val valueTypeString = nodeToTypeString(valueType) - s"${serdeConstants.MAP_TYPE_NAME}<$keyTypeString,$valueTypeString>" - - case SparkSqlParser.TOK_STRUCT => - val typeNode = node.children.head - require(typeNode.children.nonEmpty, "Struct must have one or more columns.") - val structColStrings = typeNode.children.map { columnNode => - val Token(colName, Nil) :: colTypeNode :: Nil = columnNode.children - cleanIdentifier(colName) + ":" + nodeToTypeString(colTypeNode) - } - s"${serdeConstants.STRUCT_TYPE_NAME}<${structColStrings.mkString(",")}>" - - case SparkSqlParser.TOK_UNIONTYPE => - val typeNode = node.children.head - val unionTypesString = typeNode.children.map(nodeToTypeString).mkString(",") - s"${serdeConstants.UNION_TYPE_NAME}<$unionTypesString>" - - case SparkSqlParser.TOK_CHAR => - val Token(size, Nil) :: Nil = node.children - s"${serdeConstants.CHAR_TYPE_NAME}($size)" - - case SparkSqlParser.TOK_VARCHAR => - val Token(size, Nil) :: Nil = node.children - s"${serdeConstants.VARCHAR_TYPE_NAME}($size)" - - case SparkSqlParser.TOK_DECIMAL => - val precisionAndScale = node.children match { - case Token(precision, Nil) :: Token(scale, Nil) :: Nil => - precision + "," + scale - case Token(precision, Nil) :: Nil => - precision + "," + HiveDecimal.USER_DEFAULT_SCALE - case Nil => - HiveDecimal.USER_DEFAULT_PRECISION + "," + HiveDecimal.USER_DEFAULT_SCALE - case _ => - noParseRule("Decimal", node) - } - s"${serdeConstants.DECIMAL_TYPE_NAME}($precisionAndScale)" - - // Simple data types. - case SparkSqlParser.TOK_BOOLEAN => serdeConstants.BOOLEAN_TYPE_NAME - case SparkSqlParser.TOK_TINYINT => serdeConstants.TINYINT_TYPE_NAME - case SparkSqlParser.TOK_SMALLINT => serdeConstants.SMALLINT_TYPE_NAME - case SparkSqlParser.TOK_INT => serdeConstants.INT_TYPE_NAME - case SparkSqlParser.TOK_BIGINT => serdeConstants.BIGINT_TYPE_NAME - case SparkSqlParser.TOK_FLOAT => serdeConstants.FLOAT_TYPE_NAME - case SparkSqlParser.TOK_DOUBLE => serdeConstants.DOUBLE_TYPE_NAME - case SparkSqlParser.TOK_STRING => serdeConstants.STRING_TYPE_NAME - case SparkSqlParser.TOK_BINARY => serdeConstants.BINARY_TYPE_NAME - case SparkSqlParser.TOK_DATE => serdeConstants.DATE_TYPE_NAME - case SparkSqlParser.TOK_TIMESTAMP => serdeConstants.TIMESTAMP_TYPE_NAME - case SparkSqlParser.TOK_INTERVAL_YEAR_MONTH => serdeConstants.INTERVAL_YEAR_MONTH_TYPE_NAME - case SparkSqlParser.TOK_INTERVAL_DAY_TIME => serdeConstants.INTERVAL_DAY_TIME_TYPE_NAME - case SparkSqlParser.TOK_DATETIME => serdeConstants.DATETIME_TYPE_NAME - case _ => null - } - -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index d6a08fcc57..12e4f49756 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -29,8 +29,8 @@ import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.parser.ng._ -import org.apache.spark.sql.catalyst.parser.ng.SqlBaseParser._ +import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} @@ -161,18 +161,10 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } // Create the schema. - val schema = Option(ctx.colTypeList).toSeq.flatMap(_.colType.asScala).map { col => - CatalogColumn( - col.identifier.getText, - col.dataType.getText.toLowerCase, // TODO validate this? - nullable = true, - Option(col.STRING).map(string)) - } + val schema = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns(_, _.toLowerCase)) // Get the column by which the table is partitioned. - val partitionCols = Option(ctx.identifierList).toSeq.flatMap(visitIdentifierList).map { - CatalogColumn(_, null, nullable = true, None) - } + val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns(_)) // Create the storage. def format(fmt: ParserRuleContext): CatalogStorageFormat = { @@ -439,4 +431,19 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } EmptyStorageFormat.copy(serdeProperties = entries.toMap) } + + /** + * Create a sequence of [[CatalogColumn]]s from a column list + */ + private def visitCatalogColumns( + ctx: ColTypeListContext, + formatter: String => String = identity): Seq[CatalogColumn] = withOrigin(ctx) { + ctx.colType.asScala.map { col => + CatalogColumn( + formatter(col.identifier.getText), + col.dataType.getText.toLowerCase, // TODO validate this? + nullable = true, + Option(col.STRING).map(string)) + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala index 4b6da7cd33..d9664680f4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ErrorPositionSuite.scala @@ -22,8 +22,8 @@ import scala.util.Try import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.parser.ParseDriver import org.apache.spark.sql.catalyst.util.quietly +import org.apache.spark.sql.hive.execution.HiveSqlParser import org.apache.spark.sql.hive.test.TestHiveSingleton class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAndAfterEach { @@ -131,7 +131,7 @@ class ErrorPositionSuite extends QueryTest with TestHiveSingleton with BeforeAnd * @param token a unique token in the string that should be indicated by the exception */ def positionTest(name: String, query: String, token: String): Unit = { - def ast = ParseDriver.parsePlan(query, hiveContext.conf) + def ast = HiveSqlParser.parsePlan(query) def parseTree = Try(quietly(ast.treeString)).getOrElse("") test(name) { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 0aaf57649c..75108c6d47 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -24,11 +24,11 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.apache.spark.sql.hive.execution.HiveSqlParser class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { - val parser = new HiveQl(SimpleParserConf()) + val parser = HiveSqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala index ae026ed496..05318f51af 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/StatisticsSuite.scala @@ -21,7 +21,6 @@ import scala.reflect.ClassTag import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.parser.SimpleParserConf import org.apache.spark.sql.execution.joins._ import org.apache.spark.sql.hive.execution._ import org.apache.spark.sql.hive.test.TestHiveSingleton @@ -30,11 +29,9 @@ import org.apache.spark.sql.internal.SQLConf class StatisticsSuite extends QueryTest with TestHiveSingleton { import hiveContext.sql - val parser = new HiveQl(SimpleParserConf()) - test("parse analyze commands") { def assertAnalyzeCommand(analyzeCommand: String, c: Class[_]) { - val parsed = parser.parsePlan(analyzeCommand) + val parsed = HiveSqlParser.parsePlan(analyzeCommand) val operators = parsed.collect { case a: AnalyzeTable => a case o => o -- cgit v1.2.3 From 446c45bd87035e20653394fcaf9dc8caa4299038 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 31 Mar 2016 12:03:55 -0700 Subject: [SPARK-14182][SQL] Parse DDL Command: Alter View This PR is to provide native parsing support for DDL commands: `Alter View`. Since its AST trees are highly similar to `Alter Table`. Thus, both implementation are integrated into the same one. Based on the Hive DDL document: https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL and https://cwiki.apache.org/confluence/display/Hive/PartitionedViews **Syntax:** ```SQL ALTER VIEW view_name RENAME TO new_view_name ``` - to change the name of a view to a different name **Syntax:** ```SQL ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); ``` - to add metadata to a view **Syntax:** ```SQL ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key') ``` - to remove metadata from a view **Syntax:** ```SQL ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION spec1[, PARTITION spec2, ...] ``` - to add the partitioning metadata for a view. - the syntax of partition spec in `ALTER VIEW` is identical to `ALTER TABLE`, **EXCEPT** that it is **ILLEGAL** to specify a `LOCATION` clause. **Syntax:** ```SQL ALTER VIEW view_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] ``` - to drop the related partition metadata for a view. Added the related test cases to `DDLCommandSuite` Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #11987 from gatorsmile/parseAlterView. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 20 +-- .../spark/sql/execution/SparkSqlParser.scala | 20 ++- .../apache/spark/sql/execution/command/ddl.scala | 17 ++ .../sql/execution/command/DDLCommandSuite.scala | 175 ++++++++++++++++----- 4 files changed, 177 insertions(+), 55 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 3b9f82a80f..a857e670da 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -57,10 +57,11 @@ statement (AS? query)? #createTable | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq?)? #analyze - | ALTER TABLE from=tableIdentifier RENAME TO to=tableIdentifier #renameTable - | ALTER TABLE tableIdentifier + | ALTER (TABLE | VIEW) from=tableIdentifier + RENAME TO to=tableIdentifier #renameTable + | ALTER (TABLE | VIEW) tableIdentifier SET TBLPROPERTIES tablePropertyList #setTableProperties - | ALTER TABLE tableIdentifier + | ALTER (TABLE | VIEW) tableIdentifier UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties | ALTER TABLE tableIdentifier (partitionSpec)? SET SERDE STRING (WITH SERDEPROPERTIES tablePropertyList)? #setTableSerDe @@ -76,12 +77,16 @@ statement SET SKEWED LOCATION skewedLocationList #setTableSkewLocations | ALTER TABLE tableIdentifier ADD (IF NOT EXISTS)? partitionSpecLocation+ #addTablePartition + | ALTER VIEW tableIdentifier ADD (IF NOT EXISTS)? + partitionSpec+ #addTablePartition | ALTER TABLE tableIdentifier from=partitionSpec RENAME TO to=partitionSpec #renameTablePartition | ALTER TABLE from=tableIdentifier EXCHANGE partitionSpec WITH TABLE to=tableIdentifier #exchangeTablePartition | ALTER TABLE tableIdentifier DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions + | ALTER VIEW tableIdentifier + DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions | ALTER TABLE tableIdentifier ARCHIVE partitionSpec #archiveTablePartition | ALTER TABLE tableIdentifier UNARCHIVE partitionSpec #unarchiveTablePartition | ALTER TABLE tableIdentifier partitionSpec? @@ -133,15 +138,6 @@ hiveNativeCommands | DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? - | ALTER VIEW from=tableIdentifier AS? RENAME TO to=tableIdentifier - | ALTER VIEW from=tableIdentifier AS? - SET TBLPROPERTIES tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList - | ALTER VIEW from=tableIdentifier AS? - ADD (IF NOT EXISTS)? partitionSpecLocation+ - | ALTER VIEW from=tableIdentifier AS? - DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? | DROP VIEW (IF EXISTS)? qualifiedName | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? | START TRANSACTION (transactionMode (',' transactionMode)*)? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b4687c985d..16a899e01f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -335,6 +335,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; * }}} */ override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { @@ -350,6 +351,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); + * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); * }}} */ override def visitSetTableProperties( @@ -366,6 +368,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); * }}} */ override def visitUnsetTableProperties( @@ -510,16 +513,22 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec * }}} */ override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { // Create partition spec to location mapping. - val specsAndLocs = ctx.partitionSpecLocation.asScala.map { - splCtx => - val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) - val location = Option(splCtx.locationSpec).map(visitLocationSpec) - spec -> location + val specsAndLocs = if (ctx.partitionSpec.isEmpty) { + ctx.partitionSpecLocation.asScala.map { + splCtx => + val spec = visitNonOptionalPartitionSpec(splCtx.partitionSpec) + val location = Option(splCtx.locationSpec).map(visitLocationSpec) + spec -> location + } + } else { + // Alter View: the location clauses are not allowed. + ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec(_) -> None) } AlterTableAddPartition( visitTableIdentifier(ctx.tableIdentifier), @@ -568,6 +577,7 @@ class SparkSqlAstBuilder extends AstBuilder { * For example: * {{{ * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * ALTER VIEW view DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...]; * }}} */ override def visitDropTablePartitions( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6c2a67f81c..cd7e0eed65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -195,16 +195,19 @@ case class DropFunction( isTemp: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging +/** Rename in ALTER TABLE/VIEW: change the name of a table/view to a different name. */ case class AlterTableRename( oldName: TableIdentifier, newName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging +/** Set Properties in ALTER TABLE/VIEW: add metadata to a table/view. */ case class AlterTableSetProperties( tableName: TableIdentifier, properties: Map[String, String])(sql: String) extends NativeDDLCommand(sql) with Logging +/** Unset Properties in ALTER TABLE/VIEW: remove metadata from a table/view. */ case class AlterTableUnsetProperties( tableName: TableIdentifier, properties: Map[String, String], @@ -253,6 +256,12 @@ case class AlterTableSkewedLocation( skewedMap: Map[String, String])(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. + * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, + * EXCEPT that it is ILLEGAL to specify a LOCATION clause. + * An error message will be issued if the partition exists, unless 'ifNotExists' is true. + */ case class AlterTableAddPartition( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], @@ -271,6 +280,14 @@ case class AlterTableExchangePartition( spec: TablePartitionSpec)(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * Drop Partition in ALTER TABLE/VIEW: to drop a particular partition for a table/view. + * This removes the data and metadata for this partition. + * The data is actually moved to the .Trash/Current directory if Trash is configured, + * unless 'purge' is true, but the metadata is completely lost. + * An error message will be issued if the partition does not exist, unless 'ifExists' is true. + * Note: purge is always false when the target is a view. + */ case class AlterTableDropPartition( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index ccbfd41cca..cebf9c856d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -195,33 +195,60 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed4, expected4) } - test("alter table: rename table") { - val sql = "ALTER TABLE table_name RENAME TO new_table_name" - val parsed = parser.parsePlan(sql) - val expected = AlterTableRename( + // ALTER TABLE table_name RENAME TO new_table_name; + // ALTER VIEW view_name RENAME TO new_view_name; + test("alter table/view: rename table/view") { + val sql_table = "ALTER TABLE table_name RENAME TO new_table_name" + val sql_view = sql_table.replace("TABLE", "VIEW") + val parsed_table = parser.parsePlan(sql_table) + val parsed_view = parser.parsePlan(sql_view) + val expected_table = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None))(sql) - comparePlans(parsed, expected) + TableIdentifier("new_table_name", None))(sql_table) + val expected_view = AlterTableRename( + TableIdentifier("table_name", None), + TableIdentifier("new_table_name", None))(sql_view) + comparePlans(parsed_table, expected_table) + comparePlans(parsed_view, expected_view) } - test("alter table: alter table properties") { - val sql1 = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + + // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter table/view: alter table/view properties") { + val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + "'comment' = 'new_comment')" - val sql2 = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3 = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) + val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + val sql1_view = sql1_table.replace("TABLE", "VIEW") + val sql2_view = sql2_table.replace("TABLE", "VIEW") + val sql3_view = sql3_table.replace("TABLE", "VIEW") + + val parsed1_table = parser.parsePlan(sql1_table) + val parsed2_table = parser.parsePlan(sql2_table) + val parsed3_table = parser.parsePlan(sql3_table) + val parsed1_view = parser.parsePlan(sql1_view) + val parsed2_view = parser.parsePlan(sql2_view) + val parsed3_view = parser.parsePlan(sql3_view) + val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSetProperties( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1) - val expected2 = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2) - val expected3 = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) + val expected1_table = AlterTableSetProperties( + tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1_table) + val expected2_table = AlterTableUnsetProperties( + tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2_table) + val expected3_table = AlterTableUnsetProperties( + tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3_table) + val expected1_view = expected1_table.copy()(sql = sql1_view) + val expected2_view = expected2_table.copy()(sql = sql2_view) + val expected3_view = expected3_table.copy()(sql = sql3_view) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed3_table, expected3_table) + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) + comparePlans(parsed3_view, expected3_view) } test("alter table: SerDe properties") { @@ -376,21 +403,66 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec + // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...; test("alter table: add partition") { - val sql = + val sql1 = """ |ALTER TABLE table_name ADD IF NOT EXISTS PARTITION |(dt='2008-08-08', country='us') LOCATION 'location1' PARTITION |(dt='2009-09-09', country='uk') """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = AlterTableAddPartition( + val sql2 = "ALTER TABLE table_name ADD PARTITION (dt='2008-08-08') LOCATION 'loc'" + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterTableAddPartition( TableIdentifier("table_name", None), Seq( (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), - ifNotExists = true)(sql) - comparePlans(parsed, expected) + ifNotExists = true)(sql1) + val expected2 = AlterTableAddPartition( + TableIdentifier("table_name", None), + Seq((Map("dt" -> "2008-08-08"), Some("loc"))), + ifNotExists = false)(sql2) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + + // ALTER VIEW view_name ADD [IF NOT EXISTS] PARTITION partition_spec PARTITION partition_spec ...; + test("alter view: add partition") { + val sql1 = + """ + |ALTER VIEW view_name ADD IF NOT EXISTS PARTITION + |(dt='2008-08-08', country='us') PARTITION + |(dt='2009-09-09', country='uk') + """.stripMargin + // different constant types in partitioning spec + val sql2 = + """ + |ALTER VIEW view_name ADD PARTITION + |(col1=NULL, cOL2='f', col3=5, COL4=true) + """.stripMargin + + val parsed1 = parser.parsePlan(sql1) + val parsed2 = parser.parsePlan(sql2) + + val expected1 = AlterTableAddPartition( + TableIdentifier("view_name", None), + Seq( + (Map("dt" -> "2008-08-08", "country" -> "us"), None), + (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), + ifNotExists = true)(sql1) + val expected2 = AlterTableAddPartition( + TableIdentifier("view_name", None), + Seq((Map("col1" -> "NULL", "col2" -> "f", "col3" -> "5", "col4" -> "true"), None)), + ifNotExists = false)(sql2) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) } test("alter table: rename partition") { @@ -421,36 +493,63 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed, expected) } - test("alter table: drop partitions") { - val sql1 = + // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] + // ALTER VIEW table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] + test("alter table/view: drop partitions") { + val sql1_table = """ |ALTER TABLE table_name DROP IF EXISTS PARTITION |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') """.stripMargin - val sql2 = + val sql2_table = """ |ALTER TABLE table_name DROP PARTITION |(dt='2008-08-08', country='us'), PARTITION (dt='2009-09-09', country='uk') PURGE """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) + val sql1_view = sql1_table.replace("TABLE", "VIEW") + // Note: ALTER VIEW DROP PARTITION does not support PURGE + val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") + + val parsed1_table = parser.parsePlan(sql1_table) + val parsed2_table = parser.parsePlan(sql2_table) + val parsed1_view = parser.parsePlan(sql1_view) + val parsed2_view = parser.parsePlan(sql2_view) + val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableDropPartition( + val expected1_table = AlterTableDropPartition( tableIdent, Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), ifExists = true, - purge = false)(sql1) - val expected2 = AlterTableDropPartition( + purge = false)(sql1_table) + val expected2_table = AlterTableDropPartition( tableIdent, Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), ifExists = false, - purge = true)(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + purge = true)(sql2_table) + + val expected1_view = AlterTableDropPartition( + tableIdent, + Seq( + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2009-09-09", "country" -> "uk")), + ifExists = true, + purge = false)(sql1_view) + val expected2_view = AlterTableDropPartition( + tableIdent, + Seq( + Map("dt" -> "2008-08-08", "country" -> "us"), + Map("dt" -> "2009-09-09", "country" -> "uk")), + ifExists = false, + purge = false)(sql2_table) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) } test("alter table: archive partition") { -- cgit v1.2.3 From 0b04f8fdf1614308cb3e7e0c7282f7365cc3d1bb Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 1 Apr 2016 18:27:11 +0200 Subject: [SPARK-14184][SQL] Support native execution of SHOW DATABASE command and fix SHOW TABLE to use table identifier pattern ## What changes were proposed in this pull request? This PR addresses the following 1. Supports native execution of SHOW DATABASES command 2. Fixes SHOW TABLES to apply the identifier_with_wildcards pattern if supplied. SHOW TABLE syntax ``` SHOW TABLES [IN database_name] ['identifier_with_wildcards']; ``` SHOW DATABASES syntax ``` SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; ``` ## How was this patch tested? Tests added in SQLQuerySuite (both hive and sql contexts) and DDLCommandSuite Note: Since the table name pattern was not working , tests are added in both SQLQuerySuite to verify the application of the table pattern. Author: Dilip Biswal Closes #11991 from dilipbiswal/dkb_show_database. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../scala/org/apache/spark/sql/SQLContext.scala | 4 +- .../spark/sql/execution/SparkSqlParser.scala | 22 +++++-- .../spark/sql/execution/command/commands.scala | 42 ++++++++++--- .../sql/execution/command/DDLCommandSuite.scala | 11 ++++ .../spark/sql/execution/command/DDLSuite.scala | 72 ++++++++++++++++++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 22 +++++++ 8 files changed, 163 insertions(+), 18 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a857e670da..5513bbdc7f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -114,7 +114,8 @@ statement | DROP TEMPORARY? FUNCTION (IF EXISTS)? qualifiedName #dropFunction | EXPLAIN explainOption* statement #explain | SHOW TABLES ((FROM | IN) db=identifier)? - (LIKE (qualifiedName | pattern=STRING))? #showTables + (LIKE? pattern=STRING)? #showTables + | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? @@ -618,7 +619,7 @@ number ; nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS + : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES | ADD | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | MAP | ARRAY | STRUCT | LATERAL | WINDOW | REDUCE | TRANSFORM | USING | SERDE | SERDEPROPERTIES | RECORDREADER @@ -836,6 +837,7 @@ OUTPUTFORMAT: 'OUTPUTFORMAT'; INPUTDRIVER: 'INPUTDRIVER'; OUTPUTDRIVER: 'OUTPUTDRIVER'; DATABASE: 'DATABASE' | 'SCHEMA'; +DATABASES: 'DATABASES' | 'SCHEMAS'; DFS: 'DFS'; TRUNCATE: 'TRUNCATE'; METADATA: 'METADATA'; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 0576a1a178..221782ee8f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -781,7 +781,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(): DataFrame = { - Dataset.ofRows(this, ShowTablesCommand(None)) + Dataset.ofRows(this, ShowTablesCommand(None, None)) } /** @@ -793,7 +793,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(databaseName: String): DataFrame = { - Dataset.ofRows(this, ShowTablesCommand(Some(databaseName))) + Dataset.ofRows(this, ShowTablesCommand(Some(databaseName), None)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 16a899e01f..7efe98dd18 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -70,12 +70,26 @@ class SparkSqlAstBuilder extends AstBuilder { /** * Create a [[ShowTablesCommand]] logical plan. + * Example SQL : + * {{{ + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; + * }}} */ override def visitShowTables(ctx: ShowTablesContext): LogicalPlan = withOrigin(ctx) { - if (ctx.LIKE != null) { - logWarning("SHOW TABLES LIKE option is ignored.") - } - ShowTablesCommand(Option(ctx.db).map(_.getText)) + ShowTablesCommand( + Option(ctx.db).map(_.getText), + Option(ctx.pattern).map(string)) + } + + /** + * Create a [[ShowDatabasesCommand]] logical plan. + * Example SQL: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ + override def visitShowDatabases(ctx: ShowDatabasesContext): LogicalPlan = withOrigin(ctx) { + ShowDatabasesCommand(Option(ctx.pattern).map(string)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 964f0a7a7b..f90d8717ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -322,18 +322,17 @@ case class DescribeCommand( * If a databaseName is not given, the current database will be used. * The syntax of using this command in SQL is: * {{{ - * SHOW TABLES [IN databaseName] + * SHOW TABLES [(IN|FROM) database_name] [[LIKE] 'identifier_with_wildcards']; * }}} */ -case class ShowTablesCommand(databaseName: Option[String]) extends RunnableCommand { +case class ShowTablesCommand( + databaseName: Option[String], + tableIdentifierPattern: Option[String]) extends RunnableCommand { // The result of SHOW TABLES has two columns, tableName and isTemporary. override val output: Seq[Attribute] = { - val schema = StructType( - StructField("tableName", StringType, false) :: - StructField("isTemporary", BooleanType, false) :: Nil) - - schema.toAttributes + AttributeReference("tableName", StringType, nullable = false)() :: + AttributeReference("isTemporary", BooleanType, nullable = false)() :: Nil } override def run(sqlContext: SQLContext): Seq[Row] = { @@ -341,11 +340,36 @@ case class ShowTablesCommand(databaseName: Option[String]) extends RunnableComma // instead of calling tables in sqlContext. val catalog = sqlContext.sessionState.catalog val db = databaseName.getOrElse(catalog.getCurrentDatabase) - val rows = catalog.listTables(db).map { t => + val tables = + tableIdentifierPattern.map(catalog.listTables(db, _)).getOrElse(catalog.listTables(db)) + tables.map { t => val isTemp = t.database.isEmpty Row(t.table, isTemp) } - rows + } +} + +/** + * A command for users to list the databases/schemas. + * If a databasePattern is supplied then the databases that only matches the + * pattern would be listed. + * The syntax of using this command in SQL is: + * {{{ + * SHOW (DATABASES|SCHEMAS) [LIKE 'identifier_with_wildcards']; + * }}} + */ +case class ShowDatabasesCommand(databasePattern: Option[String]) extends RunnableCommand { + + // The result of SHOW DATABASES has one column called 'result' + override val output: Seq[Attribute] = { + AttributeReference("result", StringType, nullable = false)() :: Nil + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val databases = + databasePattern.map(catalog.listDatabases(_)).getOrElse(catalog.listDatabases()) + databases.map { d => Row(d) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index cebf9c856d..458f36e832 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -762,4 +762,15 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("show databases") { + val sql1 = "SHOW DATABASES" + val sql2 = "SHOW DATABASES LIKE 'defau*'" + val parsed1 = parser.parsePlan(sql1) + val expected1 = ShowDatabasesCommand(None) + val parsed2 = parser.parsePlan(sql2) + val expected2 = ShowDatabasesCommand(Some("defau*")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } + } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index f148f2d4ea..885a04af59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -45,6 +45,7 @@ class DDLSuite extends QueryTest with SharedSQLContext { dbNames.foreach { name => sqlContext.sql(s"DROP DATABASE IF EXISTS $name CASCADE") } + sqlContext.sessionState.catalog.setCurrentDatabase("default") } } @@ -159,4 +160,75 @@ class DDLSuite extends QueryTest with SharedSQLContext { } // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext + + test("show tables") { + withTempTable("show1a", "show2b") { + sql( + """ + |CREATE TEMPORARY TABLE show1a + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + | + |) + """.stripMargin) + sql( + """ + |CREATE TEMPORARY TABLE show2b + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", true) :: Nil) + + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", true) :: + Row("show2b", true) :: Nil) + + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", true) :: + Row("show2b", true) :: Nil) + + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } + + test("show databases") { + withDatabase("showdb1A", "showdb2B") { + sql("CREATE DATABASE showdb1A") + sql("CREATE DATABASE showdb2B") + + assert( + sql("SHOW DATABASES").count() >= 2) + + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A'"), + Row("showdb1A") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE 'showdb1A'"), + Row("showdb1A") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A|*db2B'"), + Row("showdb1A") :: + Row("showdb2B") :: Nil) + + checkAnswer( + sql("SHOW DATABASES LIKE 'non-existentdb'"), + Nil) + } + } } diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 7ad7f92bd2..e93b0c145f 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -177,7 +177,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { } test("Single command with -e") { - runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "OK") + runCliWithin(2.minute, Seq("-e", "SHOW DATABASES;"))("" -> "") } test("Single command with --database") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6199253d34..c203518fdd 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1811,4 +1811,26 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } + + test("show tables") { + withTable("show1a", "show2b") { + sql("CREATE TABLE show1a(c1 int)") + sql("CREATE TABLE show2b(c2 int)") + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", false) :: Nil) + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } } -- cgit v1.2.3 From a471c7f9eaa59d55dfff5b9d1a858f304a6b3a84 Mon Sep 17 00:00:00 2001 From: sureshthalamati Date: Fri, 1 Apr 2016 18:33:31 +0200 Subject: [SPARK-14133][SQL] Throws exception for unsupported create/drop/alter index , and lock/unlock operations. ## What changes were proposed in this pull request? This PR throws Unsupported Operation exception for create index, drop index, alter index , lock table , lock database, unlock table, and unlock database operations that are not supported in Spark SQL. Currently these operations are executed executed by Hive. Error: spark-sql> drop index my_index on my_table; Error in query: Unsupported operation: drop index(line 1, pos 0) ## How was this patch tested? Added test cases to HiveQuerySuite yhuai hvanhovell andrewor14 Author: sureshthalamati Closes #12069 from sureshthalamati/unsupported_ddl_spark-14133. --- .../org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 12 ++++++++++-- .../spark/sql/hive/execution/HiveCompatibilitySuite.scala | 10 ++++++---- .../apache/spark/sql/hive/execution/HiveQuerySuite.scala | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 6 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 5513bbdc7f..d1747b9915 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -146,7 +146,7 @@ hiveNativeCommands | ROLLBACK WORK? | SHOW PARTITIONS tableIdentifier partitionSpec? | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOCK | UNLOCK | MSCK | LOAD) .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | MSCK | LOAD) .*? ; unsupportedHiveNativeCommands @@ -166,6 +166,13 @@ unsupportedHiveNativeCommands | kw1=SHOW kw2=TRANSACTIONS | kw1=SHOW kw2=INDEXES | kw1=SHOW kw2=LOCKS + | kw1=CREATE kw2=INDEX + | kw1=DROP kw2=INDEX + | kw1=ALTER kw2=INDEX + | kw1=LOCK kw2=TABLE + | kw1=LOCK kw2=DATABASE + | kw1=UNLOCK kw2=TABLE + | kw1=UNLOCK kw2=DATABASE ; createTableHeader @@ -640,7 +647,7 @@ nonReserved | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE - | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEXES | LOCKS | OPTION + | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION ; SELECT: 'SELECT'; @@ -861,6 +868,7 @@ ROLES: 'ROLES'; COMPACTIONS: 'COMPACTIONS'; PRINCIPALS: 'PRINCIPALS'; TRANSACTIONS: 'TRANSACTIONS'; +INDEX: 'INDEX'; INDEXES: 'INDEXES'; LOCKS: 'LOCKS'; OPTION: 'OPTION'; diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index bedbf9ae17..695b5ef733 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -352,7 +352,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_create_table_index", "show_create_table_partitioned", "show_create_table_serde", - "show_create_table_view" + "show_create_table_view", + + // Index commands are not supported + "drop_index", + "drop_index_removes_partition_dirs", + "alter_index" ) /** @@ -369,7 +374,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter3", "alter4", "alter5", - "alter_index", "alter_merge_2", "alter_partition_format_loc", "alter_partition_with_whitelist", @@ -496,8 +500,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "distinct_stats", "drop_database_removes_partition_dirs", "drop_function", - "drop_index", - "drop_index_removes_partition_dirs", "drop_multi_partitions", "drop_partitions_filter", "drop_partitions_filter2", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 79774f5913..58259060bf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1280,6 +1280,21 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertUnsupportedFeature { sql("SHOW LOCKS my_table") } } + test("lock/unlock table and database commands are not supported") { + assertUnsupportedFeature { sql("LOCK TABLE my_table SHARED") } + assertUnsupportedFeature { sql("UNLOCK TABLE my_table") } + assertUnsupportedFeature { sql("LOCK DATABASE my_db SHARED") } + assertUnsupportedFeature { sql("UNLOCK DATABASE my_db") } + } + + test("create/drop/alter index commands are not supported") { + assertUnsupportedFeature { + sql("CREATE INDEX my_index ON TABLE my_table(a) as 'COMPACT' WITH DEFERRED REBUILD")} + assertUnsupportedFeature { sql("DROP INDEX my_index ON my_table") } + assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table REBUILD")} + assertUnsupportedFeature { + sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} + } } // for SPARK-2180 test -- cgit v1.2.3 From df68beb85de59bb6d35b2a8a3b85dbc447798bf5 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Apr 2016 13:00:55 -0700 Subject: [SPARK-13995][SQL] Extract correct IsNotNull constraints for Expression ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-13995 We infer relative `IsNotNull` constraints from logical plan's expressions in `constructIsNotNullConstraints` now. However, we don't consider the case of (nested) `Cast`. For example: val tr = LocalRelation('a.int, 'b.long) val plan = tr.where('a.attr === 'b.attr).analyze Then, the plan's constraints will have `IsNotNull(Cast(resolveColumn(tr, "a"), LongType))`, instead of `IsNotNull(resolveColumn(tr, "a"))`. This PR fixes it. Besides, as `IsNotNull` constraints are most useful for `Attribute`, we should do recursing through any `Expression` that is null intolerant and construct `IsNotNull` constraints for all `Attribute`s under these Expressions. For example, consider the following constraints: val df = Seq((1,2,3)).toDF("a", "b", "c") df.where("a + b = c").queryExecution.analyzed.constraints The inferred isnotnull constraints should be isnotnull(a), isnotnull(b), isnotnull(c), instead of isnotnull(a + c) and isnotnull(c). ## How was this patch tested? Test is added into `ConstraintPropagationSuite`. Author: Liang-Chi Hsieh Closes #11809 from viirya/constraint-cast. --- .../spark/sql/catalyst/expressions/Cast.scala | 2 +- .../sql/catalyst/expressions/arithmetic.scala | 25 ++++--- .../catalyst/expressions/namedExpressions.scala | 2 +- .../spark/sql/catalyst/expressions/package.scala | 7 ++ .../sql/catalyst/expressions/predicates.scala | 17 +++-- .../spark/sql/catalyst/plans/QueryPlan.scala | 33 ++++----- .../plans/ConstraintPropagationSuite.scala | 85 +++++++++++++++++++++- 7 files changed, 134 insertions(+), 37 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a965cc8d53..d842ffdc66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -112,7 +112,7 @@ object Cast { } /** Cast the child expression to the target data type. */ -case class Cast(child: Expression, dataType: DataType) extends UnaryExpression { +case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant { override def toString: String = s"cast($child as ${dataType.simpleString})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 1e9c971800..b388091538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -24,7 +24,8 @@ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval -case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryMinus(child: Expression) extends UnaryExpression + with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -58,7 +59,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInp override def sql: String = s"(-${child.sql})" } -case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class UnaryPositive(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval) @@ -79,7 +81,8 @@ case class UnaryPositive(child: Expression) extends UnaryExpression with Expects @ExpressionDescription( usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", extended = "> SELECT _FUNC_('-1');\n1") -case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class Abs(child: Expression) + extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) @@ -123,7 +126,7 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } -case class Add(left: Expression, right: Expression) extends BinaryArithmetic { +case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -152,7 +155,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic { } } -case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic { +case class Subtract(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -181,7 +185,8 @@ case class Subtract(left: Expression, right: Expression) extends BinaryArithmeti } } -case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic { +case class Multiply(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -193,7 +198,8 @@ case class Multiply(left: Expression, right: Expression) extends BinaryArithmeti protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } -case class Divide(left: Expression, right: Expression) extends BinaryArithmetic { +case class Divide(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -269,7 +275,8 @@ case class Divide(left: Expression, right: Expression) extends BinaryArithmetic } } -case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic { +case class Remainder(left: Expression, right: Expression) + extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = NumericType @@ -457,7 +464,7 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } -case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic { +case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index a5b5758167..262582ca5d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -97,7 +97,7 @@ trait NamedExpression extends Expression { } } -abstract class Attribute extends LeafExpression with NamedExpression { +abstract class Attribute extends LeafExpression with NamedExpression with NullIntolerant { override def references: AttributeSet = AttributeSet(this) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index f1fa13daa7..23baa6f783 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -92,4 +92,11 @@ package object expressions { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) } } + + /** + * When an expression inherits this, meaning the expression is null intolerant (i.e. any null + * input will result in null output). We will use this information during constructing IsNotNull + * constraints. + */ + trait NullIntolerant } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index e23ad5596b..4eb33258ac 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -90,7 +90,7 @@ trait PredicateHelper { case class Not(child: Expression) - extends UnaryExpression with Predicate with ImplicitCastInputTypes { + extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { override def toString: String = s"NOT $child" @@ -402,7 +402,8 @@ private[sql] object Equality { } -case class EqualTo(left: Expression, right: Expression) extends BinaryComparison { +case class EqualTo(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = AnyDataType @@ -467,7 +468,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } -case class LessThan(left: Expression, right: Expression) extends BinaryComparison { +case class LessThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -479,7 +481,8 @@ case class LessThan(left: Expression, right: Expression) extends BinaryCompariso } -case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class LessThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -491,7 +494,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryCo } -case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThan(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered @@ -503,7 +507,8 @@ case class GreaterThan(left: Expression, right: Expression) extends BinaryCompar } -case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison { +case class GreaterThanOrEqual(left: Expression, right: Expression) + extends BinaryComparison with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.Ordered diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index d31164fe94..22a4461e66 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -44,25 +44,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * returns a constraint of the form `isNotNull(a)` */ private def constructIsNotNullConstraints(constraints: Set[Expression]): Set[Expression] = { - var isNotNullConstraints = Set.empty[Expression] - - // First, we propagate constraints if the condition consists of equality and ranges. For all - // other cases, we return an empty set of constraints - constraints.foreach { - case EqualTo(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case GreaterThan(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case GreaterThanOrEqual(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case LessThan(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case LessThanOrEqual(l, r) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case Not(EqualTo(l, r)) => - isNotNullConstraints ++= Set(IsNotNull(l), IsNotNull(r)) - case _ => // No inference - } + // First, we propagate constraints from the null intolerant expressions. + var isNotNullConstraints: Set[Expression] = + constraints.flatMap(scanNullIntolerantExpr).map(IsNotNull(_)) // Second, we infer additional constraints from non-nullable attributes that are part of the // operator's output @@ -72,6 +56,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT isNotNullConstraints -- constraints } + /** + * Recursively explores the expressions which are null intolerant and returns all attributes + * in these expressions. + */ + private def scanNullIntolerantExpr(expr: Expression): Seq[Attribute] = expr match { + case a: Attribute => Seq(a) + case _: NullIntolerant | IsNotNull(_: NullIntolerant) => + expr.children.flatMap(scanNullIntolerantExpr) + case _ => Seq.empty[Attribute] + } + /** * Infers an additional set of constraints from a given set of equality constraints. * For e.g., if an operator has constraints of the form (`a = 5`, `a = b`), this returns an diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index e5063599a3..5cbb889f8e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.types.{IntegerType, StringType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} class ConstraintPropagationSuite extends SparkFunSuite { @@ -219,6 +219,89 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(tr, "b"))))) } + test("infer constraints on cast") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr === 'b.attr && + 'c.attr + 100 > 'd.attr && + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq(Cast(resolveColumn(tr, "a"), LongType) === resolveColumn(tr, "b"), + Cast(resolveColumn(tr, "c") + 100, LongType) > resolveColumn(tr, "d"), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType))))) + } + + test("infer isnotnull constraints from compound expressions") { + val tr = LocalRelation('a.int, 'b.long, 'c.int, 'd.long, 'e.int) + verifyConstraints( + tr.where('a.attr + 'b.attr === 'c.attr && + IsNotNull( + Cast( + Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b") === + Cast(resolveColumn(tr, "c"), LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "e")), + IsNotNull(Cast(Cast(Cast(resolveColumn(tr, "e"), LongType), LongType), LongType))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr + 100) === 'c.attr && 'd / 10 === 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") + Cast(100, LongType) === + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) === + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where(('a.attr * 'b.attr - 10) >= 'c.attr && 'd / 10 < 'e).analyze.constraints, + ExpressionSet(Seq( + Cast(resolveColumn(tr, "a"), LongType) * resolveColumn(tr, "b") - Cast(10, LongType) >= + Cast(resolveColumn(tr, "c"), LongType), + Cast(resolveColumn(tr, "d"), DoubleType) / + Cast(Cast(10, LongType), DoubleType) < + Cast(resolveColumn(tr, "e"), DoubleType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + verifyConstraints( + tr.where('a.attr + 'b.attr - 'c.attr * 'd.attr > 'e.attr * 1000).analyze.constraints, + ExpressionSet(Seq( + (Cast(resolveColumn(tr, "a"), LongType) + resolveColumn(tr, "b")) - + (Cast(resolveColumn(tr, "c"), LongType) * resolveColumn(tr, "d")) > + Cast(resolveColumn(tr, "e") * 1000, LongType), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "b")), + IsNotNull(resolveColumn(tr, "c")), + IsNotNull(resolveColumn(tr, "d")), + IsNotNull(resolveColumn(tr, "e"))))) + + // The constraint IsNotNull(IsNotNull(expr)) doesn't guarantee expr is not null. + verifyConstraints( + tr.where('a.attr === 'c.attr && + IsNotNull(IsNotNull(resolveColumn(tr, "b")))).analyze.constraints, + ExpressionSet(Seq( + resolveColumn(tr, "a") === resolveColumn(tr, "c"), + IsNotNull(IsNotNull(resolveColumn(tr, "b"))), + IsNotNull(resolveColumn(tr, "a")), + IsNotNull(resolveColumn(tr, "c"))))) + } + test("infer IsNotNull constraints from non-nullable attributes") { val tr = LocalRelation('a.int, AttributeReference("b", IntegerType, nullable = false)(), AttributeReference("c", StringType, nullable = false)()) -- cgit v1.2.3 From a884daad805a701494e87393dc307937472a985d Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Fri, 1 Apr 2016 13:03:27 -0700 Subject: [SPARK-14191][SQL] Remove invalid Expand operator constraints `Expand` operator now uses its child plan's constraints as its valid constraints (i.e., the base of constraints). This is not correct because `Expand` will set its group by attributes to null values. So the nullability of these attributes should be true. E.g., for an `Expand` operator like: val input = LocalRelation('a.int, 'b.int, 'c.int).where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) Expand( Seq( Seq('c, Literal.create(null, StringType), 1), Seq('c, 'a, 2)), Seq('c, 'a, 'gid.int), Project(Seq('a, 'c), input)) The `Project` operator has the constraints `IsNotNull('a)`, `IsNotNull('b)` and `IsNotNull('c)`. But the `Expand` should not have `IsNotNull('a)` in its constraints. This PR is the first step for this issue and remove invalid constraints of `Expand` operator. A test is added to `ConstraintPropagationSuite`. Author: Liang-Chi Hsieh Author: Michael Armbrust Closes #11995 from viirya/fix-expand-constraints. --- .../catalyst/plans/logical/basicOperators.scala | 5 +++- .../plans/ConstraintPropagationSuite.scala | 27 ++++++++++++++++++++++ 2 files changed, 31 insertions(+), 1 deletion(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index 09c200fa83..a18efc90ab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -519,7 +519,6 @@ case class Expand( projections: Seq[Seq[Expression]], output: Seq[Attribute], child: LogicalPlan) extends UnaryNode { - override def references: AttributeSet = AttributeSet(projections.flatten.flatMap(_.references)) @@ -527,6 +526,10 @@ case class Expand( val sizeInBytes = super.statistics.sizeInBytes * projections.length Statistics(sizeInBytes = sizeInBytes) } + + // This operator can reuse attributes (for example making them null when doing a roll up) so + // the contraints of the child may no longer be valid. + override protected def validConstraints: Set[Expression] = Set.empty[Expression] } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 5cbb889f8e..49c1353efb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -88,6 +88,33 @@ class ConstraintPropagationSuite extends SparkFunSuite { IsNotNull(resolveColumn(aliasedRelation.analyze, "a"))))) } + test("propagating constraints in expand") { + val tr = LocalRelation('a.int, 'b.int, 'c.int) + + assert(tr.analyze.constraints.isEmpty) + + // We add IsNotNull constraints for 'a, 'b and 'c into LocalRelation + // by creating notNullRelation. + val notNullRelation = tr.where('c.attr > 10 && 'a.attr < 5 && 'b.attr > 2) + verifyConstraints(notNullRelation.analyze.constraints, + ExpressionSet(Seq(resolveColumn(notNullRelation.analyze, "c") > 10, + IsNotNull(resolveColumn(notNullRelation.analyze, "c")), + resolveColumn(notNullRelation.analyze, "a") < 5, + IsNotNull(resolveColumn(notNullRelation.analyze, "a")), + resolveColumn(notNullRelation.analyze, "b") > 2, + IsNotNull(resolveColumn(notNullRelation.analyze, "b"))))) + + val expand = Expand( + Seq( + Seq('c, Literal.create(null, StringType), 1), + Seq('c, 'a, 2)), + Seq('c, 'a, 'gid.int), + Project(Seq('a, 'c), + notNullRelation)) + verifyConstraints(expand.analyze.constraints, + ExpressionSet(Seq.empty[Expression])) + } + test("propagating constraints in aliases") { val tr = LocalRelation('a.int, 'b.string, 'c.int) -- cgit v1.2.3 From 1b829ce13990b40fd8d7c9efcc2ae55c4dbc861c Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Fri, 1 Apr 2016 13:19:24 -0700 Subject: [SPARK-14160] Time Windowing functions for Datasets MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR adds the function `window` as a column expression. `window` can be used to bucket rows into time windows given a time column. With this expression, performing time series analysis on batch data, as well as streaming data should become much more simpler. ### Usage Assume the following schema: `sensor_id, measurement, timestamp` To average 5 minute data every 1 minute (window length of 5 minutes, slide duration of 1 minute), we will use: ```scala df.groupBy(window("timestamp", “5 minutes”, “1 minute”), "sensor_id") .agg(mean("measurement").as("avg_meas")) ``` This will generate windows such as: ``` 09:00:00-09:05:00 09:01:00-09:06:00 09:02:00-09:07:00 ... ``` Intervals will start at every `slideDuration` starting at the unix epoch (1970-01-01 00:00:00 UTC). To start intervals at a different point of time, e.g. 30 seconds after a minute, the `startTime` parameter can be used. ```scala df.groupBy(window("timestamp", “5 minutes”, “1 minute”, "30 second"), "sensor_id") .agg(mean("measurement").as("avg_meas")) ``` This will generate windows such as: ``` 09:00:30-09:05:30 09:01:30-09:06:30 09:02:30-09:07:30 ... ``` Support for Python will be made in a follow up PR after this. ## How was this patch tested? This patch has some basic unit tests for the `TimeWindow` expression testing that the parameters pass validation, and it also has some unit/integration tests testing the correctness of the windowing and usability in complex operations (multi-column grouping, multi-column projections, joins). Author: Burak Yavuz Author: Michael Armbrust Closes #12008 from brkyvz/df-time-window. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 90 ++++++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 1 + .../sql/catalyst/expressions/TimeWindow.scala | 133 +++++++++++ .../sql/catalyst/analysis/AnalysisErrorSuite.scala | 56 +++++ .../sql/catalyst/expressions/TimeWindowSuite.scala | 76 +++++++ .../scala/org/apache/spark/sql/functions.scala | 137 ++++++++++++ .../spark/sql/DataFrameTimeWindowingSuite.scala | 242 +++++++++++++++++++++ 7 files changed, 735 insertions(+) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 8dc0532b3f..d82ee3a205 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -102,6 +102,7 @@ class Analyzer( ExtractWindowExpressions :: GlobalAggregates :: ResolveAggregateFunctions :: + TimeWindowing :: HiveTypeCoercion.typeCoercionRules ++ extendedResolutionRules : _*), Batch("Nondeterministic", Once, @@ -1591,3 +1592,92 @@ object ResolveUpCast extends Rule[LogicalPlan] { } } } + +/** + * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to + * figure out how many windows a time column can map to, we over-estimate the number of windows and + * filter out the rows where the time column is not inside the time window. + */ +object TimeWindowing extends Rule[LogicalPlan] { + import org.apache.spark.sql.catalyst.dsl.expressions._ + + private final val WINDOW_START = "start" + private final val WINDOW_END = "end" + + /** + * Generates the logical plan for generating window ranges on a timestamp column. Without + * knowing what the timestamp value is, it's non-trivial to figure out deterministically how many + * window ranges a timestamp will map to given all possible combinations of a window duration, + * slide duration and start time (offset). Therefore, we express and over-estimate the number of + * windows there may be, and filter the valid windows. We use last Project operator to group + * the window columns into a struct so they can be accessed as `window.start` and `window.end`. + * + * The windows are calculated as below: + * maxNumOverlapping <- ceil(windowDuration / slideDuration) + * for (i <- 0 until maxNumOverlapping) + * windowId <- ceil((timestamp - startTime) / slideDuration) + * windowStart <- windowId * slideDuration + (i - maxNumOverlapping) * slideDuration + startTime + * windowEnd <- windowStart + windowDuration + * return windowStart, windowEnd + * + * This behaves as follows for the given parameters for the time: 12:05. The valid windows are + * marked with a +, and invalid ones are marked with a x. The invalid ones are filtered using the + * Filter operator. + * window: 12m, slide: 5m, start: 0m :: window: 12m, slide: 5m, start: 2m + * 11:55 - 12:07 + 11:52 - 12:04 x + * 12:00 - 12:12 + 11:57 - 12:09 + + * 12:05 - 12:17 + 12:02 - 12:14 + + * + * @param plan The logical plan + * @return the logical plan that will generate the time windows using the Expand operator, with + * the Filter operator for correctness and Project for usability. + */ + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p: LogicalPlan if p.children.size == 1 => + val child = p.children.head + val windowExpressions = + p.expressions.flatMap(_.collect { case t: TimeWindow => t }).distinct.toList // Not correct. + + // Only support a single window expression for now + if (windowExpressions.size == 1 && + windowExpressions.head.timeColumn.resolved && + windowExpressions.head.checkInputDataTypes().isSuccess) { + val window = windowExpressions.head + val windowAttr = AttributeReference("window", window.dataType)() + + val maxNumOverlapping = math.ceil(window.windowDuration * 1.0 / window.slideDuration).toInt + val windows = Seq.tabulate(maxNumOverlapping + 1) { i => + val windowId = Ceil((PreciseTimestamp(window.timeColumn) - window.startTime) / + window.slideDuration) + val windowStart = (windowId + i - maxNumOverlapping) * + window.slideDuration + window.startTime + val windowEnd = windowStart + window.windowDuration + + CreateNamedStruct( + Literal(WINDOW_START) :: windowStart :: + Literal(WINDOW_END) :: windowEnd :: Nil) + } + + val projections = windows.map(_ +: p.children.head.output) + + val filterExpr = + window.timeColumn >= windowAttr.getField(WINDOW_START) && + window.timeColumn < windowAttr.getField(WINDOW_END) + + val expandedPlan = + Filter(filterExpr, + Expand(projections, windowAttr +: child.output, child)) + + val substitutedPlan = p transformExpressions { + case t: TimeWindow => windowAttr + } + + substitutedPlan.withNewChildren(expandedPlan :: Nil) + } else if (windowExpressions.size > 1) { + p.failAnalysis("Multiple time window expressions would result in a cartesian product " + + "of rows, therefore they are not currently not supported.") + } else { + p // Return unchanged. Analyzer will throw exception later + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index e9788b7e4d..ca8db3cbc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -297,6 +297,7 @@ object FunctionRegistry { expression[UnixTimestamp]("unix_timestamp"), expression[WeekOfYear]("weekofyear"), expression[Year]("year"), + expression[TimeWindow]("window"), // collection functions expression[ArrayContains]("array_contains"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala new file mode 100644 index 0000000000..8e13833486 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.lang.StringUtils + +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult +import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.CalendarInterval + +case class TimeWindow( + timeColumn: Expression, + windowDuration: Long, + slideDuration: Long, + startTime: Long) extends UnaryExpression + with ImplicitCastInputTypes + with Unevaluable + with NonSQLExpression { + + override def child: Expression = timeColumn + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = new StructType() + .add(StructField("start", TimestampType)) + .add(StructField("end", TimestampType)) + + // This expression is replaced in the analyzer. + override lazy val resolved = false + + /** + * Validate the inputs for the window duration, slide duration, and start time in addition to + * the input data type. + */ + override def checkInputDataTypes(): TypeCheckResult = { + val dataTypeCheck = super.checkInputDataTypes() + if (dataTypeCheck.isSuccess) { + if (windowDuration <= 0) { + return TypeCheckFailure(s"The window duration ($windowDuration) must be greater than 0.") + } + if (slideDuration <= 0) { + return TypeCheckFailure(s"The slide duration ($slideDuration) must be greater than 0.") + } + if (startTime < 0) { + return TypeCheckFailure(s"The start time ($startTime) must be greater than or equal to 0.") + } + if (slideDuration > windowDuration) { + return TypeCheckFailure(s"The slide duration ($slideDuration) must be less than or equal" + + s" to the windowDuration ($windowDuration).") + } + if (startTime >= slideDuration) { + return TypeCheckFailure(s"The start time ($startTime) must be less than the " + + s"slideDuration ($slideDuration).") + } + } + dataTypeCheck + } +} + +object TimeWindow { + /** + * Parses the interval string for a valid time duration. CalendarInterval expects interval + * strings to start with the string `interval`. For usability, we prepend `interval` to the string + * if the user omitted it. + * + * @param interval The interval string + * @return The interval duration in microseconds. SparkSQL casts TimestampType has microsecond + * precision. + */ + private def getIntervalInMicroSeconds(interval: String): Long = { + if (StringUtils.isBlank(interval)) { + throw new IllegalArgumentException( + "The window duration, slide duration and start time cannot be null or blank.") + } + val intervalString = if (interval.startsWith("interval")) { + interval + } else { + "interval " + interval + } + val cal = CalendarInterval.fromString(intervalString) + if (cal == null) { + throw new IllegalArgumentException( + s"The provided interval ($interval) did not correspond to a valid interval string.") + } + if (cal.months > 0) { + throw new IllegalArgumentException( + s"Intervals greater than a month is not supported ($interval).") + } + cal.microseconds + } + + def apply( + timeColumn: Expression, + windowDuration: String, + slideDuration: String, + startTime: String): TimeWindow = { + TimeWindow(timeColumn, + getIntervalInMicroSeconds(windowDuration), + getIntervalInMicroSeconds(slideDuration), + getIntervalInMicroSeconds(startTime)) + } +} + +/** + * Expression used internally to convert the TimestampType to Long without losing + * precision, i.e. in microseconds. Used in time windowing. + */ +case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { + override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) + override def dataType: DataType = LongType + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval = child.gen(ctx) + eval.code + + s"""boolean ${ev.isNull} = ${eval.isNull}; + |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; + """.stripMargin + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index a90dfc5039..ad101d1c40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -272,6 +272,62 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation2.where('bad_column > 1).groupBy('a)(UnresolvedAlias(max('b))), "cannot resolve '`bad_column`'" :: Nil) + errorTest( + "slide duration greater than window in time window", + testRelation2.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "2 second", "0 second").as("window")), + s"The slide duration " :: " must be less than or equal to the windowDuration " :: Nil + ) + + errorTest( + "start time greater than slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 minute").as("window")), + "The start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "start time equal to slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "1 second").as("window")), + "The start time " :: " must be less than the slideDuration " :: Nil + ) + + errorTest( + "negative window duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "-1 second", "1 second", "0 second").as("window")), + "The window duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "zero window duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "0 second", "1 second", "0 second").as("window")), + "The window duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "negative slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "-1 second", "0 second").as("window")), + "The slide duration " :: " must be greater than 0." :: Nil + ) + + errorTest( + "zero slide duration in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "0 second", "0 second").as("window")), + "The slide duration" :: " must be greater than 0." :: Nil + ) + + errorTest( + "negative start time in time window", + testRelation.select( + TimeWindow(Literal("2016-01-01 01:01:01"), "1 second", "1 second", "-5 second").as("window")), + "The start time" :: "must be greater than or equal to 0." :: Nil + ) + test("SPARK-6452 regression test") { // CheckAnalysis should throw AnalysisException when Aggregate contains missing attribute(s) // Since we manually construct the logical plan at here and Sum only accept diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala new file mode 100644 index 0000000000..71f969aee2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.AnalysisException + +class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("time window is unevaluable") { + intercept[UnsupportedOperationException] { + evaluate(TimeWindow(Literal(10L), "1 second", "1 second", "0 second")) + } + } + + private def checkErrorMessage(msg: String, value: String): Unit = { + val validDuration = "10 second" + val validTime = "5 second" + val e1 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), value, validDuration, validTime).windowDuration + } + val e2 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), validDuration, value, validTime).slideDuration + } + val e3 = intercept[IllegalArgumentException] { + TimeWindow(Literal(10L), validDuration, validDuration, value).startTime + } + Seq(e1, e2, e3).foreach { e => + e.getMessage.contains(msg) + } + } + + test("blank intervals throw exception") { + for (blank <- Seq(null, " ", "\n", "\t")) { + checkErrorMessage( + "The window duration, slide duration and start time cannot be null or blank.", blank) + } + } + + test("invalid intervals throw exception") { + checkErrorMessage( + "did not correspond to a valid interval string.", "2 apples") + } + + test("intervals greater than a month throws exception") { + checkErrorMessage( + "Intervals greater than or equal to a month is not supported (1 month).", "1 month") + } + + test("interval strings work with and without 'interval' prefix and return microseconds") { + val validDuration = "10 second" + for ((text, seconds) <- Seq( + ("1 second", 1000000), // 1e6 + ("1 minute", 60000000), // 6e7 + ("2 hours", 7200000000L))) { // 72e9 + assert(TimeWindow(Literal(10L), text, validDuration, "0 seconds").windowDuration === seconds) + assert(TimeWindow(Literal(10L), "interval " + text, validDuration, "0 seconds").windowDuration + === seconds) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 7ce15e3f35..74906050ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2550,6 +2550,143 @@ object functions { ToUTCTimestamp(ts.expr, Literal(tz)) } + /** + * Bucketize rows into one or more time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The following example takes the average stock price for + * a one minute window every 10 seconds starting 5 seconds after the hour: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute", "10 seconds", "5 seconds"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:05-09:01:05 + * 09:00:15-09:01:15 + * 09:00:25-09:01:25 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time can be as TimestampType or LongType, however when using LongType, + * the time must be given in seconds. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. + * A new window will be generated every `slideDuration`. Must be less than + * or equal to the `windowDuration`. Check + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration + * identifiers. + * @param startTime The offset with respect to 1970-01-01 00:00:00 UTC with which to start + * window intervals. For example, in order to have hourly tumbling windows that + * start 15 minutes past the hour, e.g. 12:15-13:15, 13:15-14:15... provide + * `startTime` as `15 minutes`. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window( + timeColumn: Column, + windowDuration: String, + slideDuration: String, + startTime: String): Column = { + withExpr { + TimeWindow(timeColumn.expr, windowDuration, slideDuration, startTime) + }.as("window") + } + + + /** + * Bucketize rows into one or more time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. + * The following example takes the average stock price for a one minute window every 10 seconds: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute", "10 seconds"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:00-09:01:00 + * 09:00:10-09:01:10 + * 09:00:20-09:01:20 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time can be as TimestampType or LongType, however when using LongType, + * the time must be given in seconds. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * @param slideDuration A string specifying the sliding interval of the window, e.g. `1 minute`. + * A new window will be generated every `slideDuration`. Must be less than + * or equal to the `windowDuration`. Check + * [[org.apache.spark.unsafe.types.CalendarInterval]] for valid duration. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window(timeColumn: Column, windowDuration: String, slideDuration: String): Column = { + window(timeColumn, windowDuration, slideDuration, "0 second") + } + + /** + * Generates tumbling time windows given a timestamp specifying column. Window + * starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + * [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + * the order of months are not supported. The windows start beginning at 1970-01-01 00:00:00 UTC. + * The following example takes the average stock price for a one minute tumbling window: + * + * {{{ + * val df = ... // schema => timestamp: TimestampType, stockId: StringType, price: DoubleType + * df.groupBy(window($"time", "1 minute"), $"stockId") + * .agg(mean("price")) + * }}} + * + * The windows will look like: + * + * {{{ + * 09:00:00-09:01:00 + * 09:01:00-09:02:00 + * 09:02:00-09:03:00 ... + * }}} + * + * For a continuous query, you may use the function `current_timestamp` to generate windows on + * processing time. + * + * @param timeColumn The column or the expression to use as the timestamp for windowing by time. + * The time can be as TimestampType or LongType, however when using LongType, + * the time must be given in seconds. + * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, + * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for + * valid duration identifiers. + * + * @group datetime_funcs + * @since 2.0.0 + */ + @Experimental + def window(timeColumn: Column, windowDuration: String): Column = { + window(timeColumn, windowDuration, windowDuration, "0 second") + } + ////////////////////////////////////////////////////////////////////////////////////////////// // Collection functions ////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala new file mode 100644 index 0000000000..e8103a31d5 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -0,0 +1,242 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.util.TimeZone + +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.StringType + +class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { + + import testImplicits._ + + override def beforeEach(): Unit = { + super.beforeEach() + TimeZone.setDefault(TimeZone.getTimeZone("UTC")) + } + + override def afterEach(): Unit = { + super.beforeEach() + TimeZone.setDefault(null) + } + + test("tumbling window groupBy statement") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + checkAnswer( + df.groupBy(window($"time", "10 seconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1)) + ) + } + + test("tumbling window groupBy statement with startTime") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"id") + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select("counts"), + Seq(Row(1), Row(1), Row(1))) + } + + test("tumbling window with multi-column projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + + test("sliding window grouping") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.groupBy(window($"time", "10 seconds", "3 seconds", "0 second")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast("string"), $"window.end".cast("string"), $"counts"), + // 2016-03-27 19:39:27 UTC -> 4 bins + // 2016-03-27 19:39:34 UTC -> 3 bins + // 2016-03-27 19:39:56 UTC -> 3 bins + Seq( + Row("2016-03-27 19:39:18", "2016-03-27 19:39:28", 1), + Row("2016-03-27 19:39:21", "2016-03-27 19:39:31", 1), + Row("2016-03-27 19:39:24", "2016-03-27 19:39:34", 1), + Row("2016-03-27 19:39:27", "2016-03-27 19:39:37", 2), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:33", "2016-03-27 19:39:43", 1), + Row("2016-03-27 19:39:48", "2016-03-27 19:39:58", 1), + Row("2016-03-27 19:39:51", "2016-03-27 19:40:01", 1), + Row("2016-03-27 19:39:54", "2016-03-27 19:40:04", 1)) + ) + } + + test("sliding window projection") { + val df = Seq( + ("2016-03-27 19:39:34", 1, "a"), + ("2016-03-27 19:39:56", 2, "a"), + ("2016-03-27 19:39:27", 4, "b")).toDF("time", "value", "id") + + checkAnswer( + df.select(window($"time", "10 seconds", "3 seconds", "0 second"), $"value") + .orderBy($"window.start".asc, $"value".desc).select("value"), + // 2016-03-27 19:39:27 UTC -> 4 bins + // 2016-03-27 19:39:34 UTC -> 3 bins + // 2016-03-27 19:39:56 UTC -> 3 bins + Seq(Row(4), Row(4), Row(4), Row(4), Row(1), Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("windowing combined with explode expression") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value", explode($"ids")) + .orderBy($"window.start".asc).select("value"), + // first window exploded to two rows for "a", and "b", second window exploded to 3 rows + Seq(Row(1), Row(1), Row(2), Row(2), Row(2)) + ) + } + + test("null timestamps") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + checkDataset( + df.select(window($"time", "10 seconds"), $"value") + .orderBy($"window.start".asc) + .select("value") + .as[Int], + 1, 2) // null columns are dropped + } + + test("time window joins") { + val df = Seq( + ("2016-03-27 09:00:05", 1), + ("2016-03-27 09:00:32", 2), + (null, 3), + (null, 4)).toDF("time", "value") + + val df2 = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "othervalue") + + checkAnswer( + df.select(window($"time", "10 seconds"), $"value").join( + df2.select(window($"time", "10 seconds"), $"othervalue"), Seq("window")) + .groupBy("window") + .agg((sum("value") + sum("othervalue")).as("total")) + .orderBy($"window.start".asc).select("total"), + Seq(Row(4), Row(8))) + } + + test("negative timestamps") { + val df4 = Seq( + ("1970-01-01 00:00:02", 1), + ("1970-01-01 00:00:12", 2)).toDF("time", "value") + checkAnswer( + df4.select(window($"time", "10 seconds", "10 seconds", "5 seconds"), $"value") + .orderBy($"window.start".asc) + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("1969-12-31 23:59:55", "1970-01-01 00:00:05", 1), + Row("1970-01-01 00:00:05", "1970-01-01 00:00:15", 2)) + ) + } + + test("multiple time windows in a single operator throws nice exception") { + val df = Seq( + ("2016-03-27 09:00:02", 3), + ("2016-03-27 09:00:35", 6)).toDF("time", "value") + val e = intercept[AnalysisException] { + df.select(window($"time", "10 second"), window($"time", "15 second")).collect() + } + assert(e.getMessage.contains( + "Multiple time window expressions would result in a cartesian product")) + } + + test("aliased windows") { + val df = Seq( + ("2016-03-27 19:39:34", 1, Seq("a", "b")), + ("2016-03-27 19:39:56", 2, Seq("a", "c", "d"))).toDF("time", "value", "ids") + + checkAnswer( + df.select(window($"time", "10 seconds").as("time_window"), $"value") + .orderBy($"time_window.start".asc) + .select("value"), + Seq(Row(1), Row(2)) + ) + } + + test("millisecond precision sliding windows") { + val df = Seq( + ("2016-03-27 09:00:00.41", 3), + ("2016-03-27 09:00:00.62", 6), + ("2016-03-27 09:00:00.715", 8)).toDF("time", "value") + checkAnswer( + df.groupBy(window($"time", "200 milliseconds", "40 milliseconds", "0 milliseconds")) + .agg(count("*").as("counts")) + .orderBy($"window.start".asc) + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"counts"), + Seq( + Row("2016-03-27 09:00:00.24", "2016-03-27 09:00:00.44", 1), + Row("2016-03-27 09:00:00.28", "2016-03-27 09:00:00.48", 1), + Row("2016-03-27 09:00:00.32", "2016-03-27 09:00:00.52", 1), + Row("2016-03-27 09:00:00.36", "2016-03-27 09:00:00.56", 1), + Row("2016-03-27 09:00:00.4", "2016-03-27 09:00:00.6", 1), + Row("2016-03-27 09:00:00.44", "2016-03-27 09:00:00.64", 1), + Row("2016-03-27 09:00:00.48", "2016-03-27 09:00:00.68", 1), + Row("2016-03-27 09:00:00.52", "2016-03-27 09:00:00.72", 2), + Row("2016-03-27 09:00:00.56", "2016-03-27 09:00:00.76", 2), + Row("2016-03-27 09:00:00.6", "2016-03-27 09:00:00.8", 2), + Row("2016-03-27 09:00:00.64", "2016-03-27 09:00:00.84", 1), + Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1)) + ) + } +} -- cgit v1.2.3 From 0fc4aaa71c5b4531b3a7c8ac71d62ea8e66b6f0c Mon Sep 17 00:00:00 2001 From: Michael Armbrust Date: Fri, 1 Apr 2016 15:15:16 -0700 Subject: [SPARK-14255][SQL] Streaming Aggregation This PR adds the ability to perform aggregations inside of a `ContinuousQuery`. In order to implement this feature, the planning of aggregation has augmented with a new `StatefulAggregationStrategy`. Unlike batch aggregation, stateful-aggregation uses the `StateStore` (introduced in #11645) to persist the results of partial aggregation across different invocations. The resulting physical plan performs the aggregation using the following progression: - Partial Aggregation - Shuffle - Partial Merge (now there is at most 1 tuple per group) - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) - Partial Merge (now there is at most 1 tuple per group) - StateStoreSave (saves the tuple for the next batch) - Complete (output the current result of the aggregation) The following refactoring was also performed to allow us to plug into existing code: - The get/put implementation is taken from #12013 - The logic for breaking down and de-duping the physical execution of aggregation has been move into a new pattern `PhysicalAggregation` - The `AttributeReference` used to identify the result of an `AggregateFunction` as been moved into the `AggregateExpression` container. This change moves the reference into the same object as the other intermediate references used in aggregation and eliminates the need to pass around a `Map[(AggregateFunction, Boolean), Attribute]`. Further clean up (using a different aggregation container for logical/physical plans) is deferred to a followup. - Some planning logic is moved from the `SessionState` into the `QueryExecution` to make it easier to override in the streaming case. - The ability to write a `StreamTest` that checks only the output of the last batch has been added to simulate the future addition of output modes. Author: Michael Armbrust Closes #12048 from marmbrus/statefulAgg. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 9 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../apache/spark/sql/catalyst/errors/package.scala | 7 +- .../expressions/aggregate/interfaces.scala | 37 ++++- .../catalyst/expressions/namedExpressions.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 14 +- .../spark/sql/catalyst/planning/patterns.scala | 73 ++++++++++ .../apache/spark/sql/catalyst/plans/PlanTest.scala | 3 + .../spark/sql/execution/QueryExecution.scala | 24 +++- .../org/apache/spark/sql/execution/SparkPlan.scala | 7 + .../apache/spark/sql/execution/SparkPlanner.scala | 4 +- .../spark/sql/execution/SparkStrategies.scala | 92 ++++--------- .../org/apache/spark/sql/execution/Window.scala | 2 +- .../aggregate/TungstenAggregationIterator.scala | 4 +- .../spark/sql/execution/aggregate/utils.scala | 121 ++++++++++++---- .../execution/streaming/IncrementalExecution.scala | 72 ++++++++++ .../execution/streaming/StatefulAggregate.scala | 119 ++++++++++++++++ .../sql/execution/streaming/StreamExecution.scala | 12 +- .../spark/sql/execution/streaming/memory.scala | 4 +- .../state/HDFSBackedStateStoreProvider.scala | 36 +++-- .../sql/execution/streaming/state/StateStore.scala | 19 ++- .../execution/streaming/state/StateStoreConf.scala | 4 +- .../execution/streaming/state/StateStoreRDD.scala | 17 +-- .../sql/execution/streaming/state/package.scala | 21 ++- .../org/apache/spark/sql/execution/subquery.scala | 11 +- .../apache/spark/sql/expressions/Aggregator.scala | 4 +- .../apache/spark/sql/internal/SessionState.scala | 16 +-- .../scala/org/apache/spark/sql/StreamTest.scala | 36 +++-- .../apache/spark/sql/execution/SparkPlanTest.scala | 10 +- .../streaming/state/StateStoreRDDSuite.scala | 152 ++++++++++++--------- .../streaming/state/StateStoreSuite.scala | 61 +++++---- .../sql/streaming/StreamingAggregationSuite.scala | 132 ++++++++++++++++++ .../apache/spark/sql/hive/HiveSessionState.scala | 5 +- 33 files changed, 827 insertions(+), 305 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index d82ee3a205..05e2b9a447 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -336,6 +336,11 @@ class Analyzer( Last(ifExpr(expr), Literal(true)) case a: AggregateFunction => a.withNewChildren(a.children.map(ifExpr)) + }.transform { + // We are duplicating aggregates that are now computing a different value for each + // pivot value. + // TODO: Don't construct the physical container until after analysis. + case ae: AggregateExpression => ae.copy(resultId = NamedExpression.newExprId) } if (filteredAggregate.fastEquals(aggregate)) { throw new AnalysisException( @@ -1153,11 +1158,11 @@ class Analyzer( // Extract Windowed AggregateExpression case we @ WindowExpression( - AggregateExpression(function, mode, isDistinct), + ae @ AggregateExpression(function, _, _, _), spec: WindowSpecDefinition) => val newChildren = function.children.map(extractExpr) val newFunction = function.withNewChildren(newChildren).asInstanceOf[AggregateFunction] - val newAgg = AggregateExpression(newFunction, mode, isDistinct) + val newAgg = ae.copy(aggregateFunction = newFunction) seenWindowAggregates += newAgg WindowExpression(newAgg, spec) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 1d1e892e32..4880502398 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -76,7 +76,7 @@ trait CheckAnalysis { case g: GroupingID => failAnalysis(s"grouping_id() can only be used with GroupingSets/Cube/Rollup") - case w @ WindowExpression(AggregateExpression(_, _, true), _) => + case w @ WindowExpression(AggregateExpression(_, _, true, _), _) => failAnalysis(s"Distinct window functions are not supported: $w") case w @ WindowExpression(_: OffsetWindowFunction, WindowSpecDefinition(_, order, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala index 0d44d1dd96..0420b4b538 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/errors/package.scala @@ -25,15 +25,18 @@ import org.apache.spark.sql.catalyst.trees.TreeNode package object errors { class TreeNodeException[TreeType <: TreeNode[_]]( - tree: TreeType, msg: String, cause: Throwable) + @transient val tree: TreeType, + msg: String, + cause: Throwable) extends Exception(msg, cause) { + val treeString = tree.toString + // Yes, this is the same as a default parameter, but... those don't seem to work with SBT // external project dependencies for some reason. def this(tree: TreeType, msg: String) = this(tree, msg, null) override def getMessage: String = { - val treeString = tree.toString s"${super.getMessage}, tree:${if (treeString contains "\n") "\n" else " "}$tree" } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index ff3064ac66..d31ccf9985 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ @@ -66,6 +67,19 @@ private[sql] case object NoOp extends Expression with Unevaluable { override def children: Seq[Expression] = Nil } +object AggregateExpression { + def apply( + aggregateFunction: AggregateFunction, + mode: AggregateMode, + isDistinct: Boolean): AggregateExpression = { + AggregateExpression( + aggregateFunction, + mode, + isDistinct, + NamedExpression.newExprId) + } +} + /** * A container for an [[AggregateFunction]] with its [[AggregateMode]] and a field * (`isDistinct`) indicating if DISTINCT keyword is specified for this function. @@ -73,10 +87,31 @@ private[sql] case object NoOp extends Expression with Unevaluable { private[sql] case class AggregateExpression( aggregateFunction: AggregateFunction, mode: AggregateMode, - isDistinct: Boolean) + isDistinct: Boolean, + resultId: ExprId) extends Expression with Unevaluable { + lazy val resultAttribute: Attribute = if (aggregateFunction.resolved) { + AttributeReference( + aggregateFunction.toString, + aggregateFunction.dataType, + aggregateFunction.nullable)(exprId = resultId) + } else { + // This is a bit of a hack. Really we should not be constructing this container and reasoning + // about datatypes / aggregation mode until after we have finished analysis and made it to + // planning. + UnresolvedAttribute(aggregateFunction.toString) + } + + // We compute the same thing regardless of our final result. + override lazy val canonicalized: Expression = + AggregateExpression( + aggregateFunction.canonicalized.asInstanceOf[AggregateFunction], + mode, + isDistinct, + ExprId(0)) + override def children: Seq[Expression] = aggregateFunction :: Nil override def dataType: DataType = aggregateFunction.dataType override def foldable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 262582ca5d..2307122ea1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -329,7 +329,7 @@ case class PrettyAttribute( override def withName(newName: String): Attribute = throw new UnsupportedOperationException override def qualifier: Option[String] = throw new UnsupportedOperationException override def exprId: ExprId = throw new UnsupportedOperationException - override def nullable: Boolean = throw new UnsupportedOperationException + override def nullable: Boolean = true } object VirtualColumn { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a7a948ef1b..326933ec9e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -534,7 +534,7 @@ object NullPropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { - case e @ AggregateExpression(Count(exprs), _, _) if !exprs.exists(nonNullLiteral) => + case e @ AggregateExpression(Count(exprs), _, _, _) if !exprs.exists(nonNullLiteral) => Cast(Literal(0L), e.dataType) case e @ IsNull(c) if !c.nullable => Literal.create(false, BooleanType) case e @ IsNotNull(c) if !c.nullable => Literal.create(true, BooleanType) @@ -547,9 +547,9 @@ object NullPropagation extends Rule[LogicalPlan] { Literal.create(null, e.dataType) case e @ EqualNullSafe(Literal(null, _), r) => IsNull(r) case e @ EqualNullSafe(l, Literal(null, _)) => IsNull(l) - case e @ AggregateExpression(Count(exprs), mode, false) if !exprs.exists(_.nullable) => + case ae @ AggregateExpression(Count(exprs), _, false, _) if !exprs.exists(_.nullable) => // This rule should be only triggered when isDistinct field is false. - AggregateExpression(Count(Literal(1)), mode, isDistinct = false) + ae.copy(aggregateFunction = Count(Literal(1))) // For Coalesce, remove null literals. case e @ Coalesce(children) => @@ -1225,13 +1225,13 @@ object DecimalAggregates extends Rule[LogicalPlan] { private val MAX_DOUBLE_DIGITS = 15 def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(AggregateExpression(Sum(UnscaledValue(e)), mode, isDistinct), prec + 10, scale) + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - case AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), mode, isDistinct) + case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = AggregateExpression(Average(UnscaledValue(e)), mode, isDistinct) + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) Cast( Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), DecimalType(prec + 4, scale + 4)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 9c927077d0..28d2c445b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -22,6 +22,7 @@ import scala.collection.mutable import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.types.IntegerType @@ -216,3 +217,75 @@ object IntegerIndex { case _ => None } } + +/** + * An extractor used when planning the physical execution of an aggregation. Compared with a logical + * aggregation, the following transformations are performed: + * - Unnamed grouping expressions are named so that they can be referred to across phases of + * aggregation + * - Aggregations that appear multiple times are deduplicated. + * - The compution of the aggregations themselves is separated from the final result. For example, + * the `count` in `count + 1` will be split into an [[AggregateExpression]] and a final + * computation that computes `count.resultAttribute + 1`. + */ +object PhysicalAggregation { + // groupingExpressions, aggregateExpressions, resultExpressions, child + type ReturnType = + (Seq[NamedExpression], Seq[AggregateExpression], Seq[NamedExpression], LogicalPlan) + + def unapply(a: Any): Option[ReturnType] = a match { + case logical.Aggregate(groupingExpressions, resultExpressions, child) => + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. + val aggregateExpressions = resultExpressions.flatMap { expr => + expr.collect { + case agg: AggregateExpression => agg + } + }.distinct + + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case ae: AggregateExpression => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + ae.resultAttribute + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + + Some(( + namedGroupingExpressions.map(_._2), + aggregateExpressions, + rewrittenResultExpressions, + child)) + + case _ => None + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index aa5d4330d3..7191936699 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan, OneRowRelation, Sample} import org.apache.spark.sql.catalyst.util._ @@ -38,6 +39,8 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { AttributeReference(a.name, a.dataType, a.nullable)(exprId = ExprId(0)) case a: Alias => Alias(a.child, a.name)(exprId = ExprId(0)) + case ae: AggregateExpression => + ae.copy(resultId = ExprId(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 912b84abc1..4843553211 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -21,6 +21,8 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{AnalysisException, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.exchange.{EnsureRequirements, ReuseExchange} /** * The primary workflow for executing relational queries using Spark. Designed to allow easy @@ -31,6 +33,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} */ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { + // TODO: Move the planner an optimizer into here from SessionState. + protected def planner = sqlContext.sessionState.planner + def assertAnalyzed(): Unit = try sqlContext.sessionState.analyzer.checkAnalysis(analyzed) catch { case e: AnalysisException => val ae = new AnalysisException(e.message, e.line, e.startPosition, Some(analyzed)) @@ -49,16 +54,31 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { lazy val sparkPlan: SparkPlan = { SQLContext.setActive(sqlContext) - sqlContext.sessionState.planner.plan(ReturnAnswer(optimizedPlan)).next() + planner.plan(ReturnAnswer(optimizedPlan)).next() } // executedPlan should not be used to initialize any SparkPlan. It should be // only used for execution. - lazy val executedPlan: SparkPlan = sqlContext.sessionState.prepareForExecution.execute(sparkPlan) + lazy val executedPlan: SparkPlan = prepareForExecution(sparkPlan) /** Internal version of the RDD. Avoids copies and has no schema */ lazy val toRdd: RDD[InternalRow] = executedPlan.execute() + /** + * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal + * row format conversions as needed. + */ + protected def prepareForExecution(plan: SparkPlan): SparkPlan = { + preparations.foldLeft(plan) { case (sp, rule) => rule.apply(sp) } + } + + /** A sequence of rules that will be applied in order to the physical plan before execution. */ + protected def preparations: Seq[Rule[SparkPlan]] = Seq( + PlanSubqueries(sqlContext), + EnsureRequirements(sqlContext.conf), + CollapseCodegenStages(sqlContext.conf), + ReuseExchange(sqlContext.conf)) + protected def stringOrError[A](f: => A): String = try f.toString catch { case e: Throwable => e.toString } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 010ed7f500..b1b3d4ac81 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -379,6 +379,13 @@ private[sql] trait LeafNode extends SparkPlan { override def producedAttributes: AttributeSet = outputSet } +object UnaryNode { + def unapply(a: Any): Option[(SparkPlan, SparkPlan)] = a match { + case s: SparkPlan if s.children.size == 1 => Some((s, s.children.head)) + case _ => None + } +} + private[sql] trait UnaryNode extends SparkPlan { def child: SparkPlan diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index 9da2c74c62..ac8072f3ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -26,13 +26,13 @@ import org.apache.spark.sql.internal.SQLConf class SparkPlanner( val sparkContext: SparkContext, val conf: SQLConf, - val experimentalMethods: ExperimentalMethods) + val extraStrategies: Seq[Strategy]) extends SparkStrategies { def numPartitions: Int = conf.numShufflePartitions def strategies: Seq[Strategy] = - experimentalMethods.extraStrategies ++ ( + extraStrategies ++ ( FileSourceStrategy :: DataSourceStrategy :: DDLStrategy :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 7a2e2b7382..5bcc172ca7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical.{BroadcastHint, LogicalPlan} @@ -203,29 +202,33 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } + /** + * Used to plan aggregation queries that are computed incrementally as part of a + * [[org.apache.spark.sql.ContinuousQuery]]. Currently this rule is injected into the planner + * on-demand, only when planning in a [[org.apache.spark.sql.execution.streaming.StreamExecution]] + */ + object StatefulAggregationStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case PhysicalAggregation( + namedGroupingExpressions, aggregateExpressions, rewrittenResultExpressions, child) => + + aggregate.Utils.planStreamingAggregation( + namedGroupingExpressions, + aggregateExpressions, + rewrittenResultExpressions, + planLater(child)) + + case _ => Nil + } + } + /** * Used to plan the aggregate operator for expressions based on the AggregateFunction2 interface. */ object Aggregation extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case logical.Aggregate(groupingExpressions, resultExpressions, child) => - // A single aggregate expression might appear multiple times in resultExpressions. - // In order to avoid evaluating an individual aggregate function multiple times, we'll - // build a set of the distinct aggregate expressions and build a function which can - // be used to re-write expressions so that they reference the single copy of the - // aggregate function which actually gets computed. - val aggregateExpressions = resultExpressions.flatMap { expr => - expr.collect { - case agg: AggregateExpression => agg - } - }.distinct - // For those distinct aggregate expressions, we create a map from the - // aggregate function to the corresponding attribute of the function. - val aggregateFunctionToAttribute = aggregateExpressions.map { agg => - val aggregateFunction = agg.aggregateFunction - val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> attribute - }.toMap + case PhysicalAggregation( + groupingExpressions, aggregateExpressions, resultExpressions, child) => val (functionsWithDistinct, functionsWithoutDistinct) = aggregateExpressions.partition(_.isDistinct) @@ -233,41 +236,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // This is a sanity check. We should not reach here when we have multiple distinct // column sets. Our MultipleDistinctRewriter should take care this case. sys.error("You hit a query analyzer bug. Please report your query to " + - "Spark user mailing list.") - } - - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - - // The original `resultExpressions` are a set of expressions which may reference - // aggregate expressions, grouping column values, and constants. When aggregate operator - // emits output rows, we will use `resultExpressions` to generate an output projection - // which takes the grouping columns and final aggregate result buffer as input. - // Thus, we must re-write the result expressions so that their attributes match up with - // the attributes of the final result projection's input row: - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case AggregateExpression(aggregateFunction, _, isDistinct) => - // The final aggregation buffer's attributes will be `finalAggregationAttributes`, - // so replace each aggregate expression by its corresponding attribute in the set: - aggregateFunctionToAttribute(aggregateFunction, isDistinct) - case expression => - // Since we're using `namedGroupingAttributes` to extract the grouping key - // columns, we need to replace grouping key expressions with their corresponding - // attributes. We do not rely on the equality check at here since attributes may - // differ cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + "Spark user mailing list.") } val aggregateOperator = @@ -277,26 +246,23 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, aggregateExpressions, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - namedGroupingExpressions.map(_._2), + groupingExpressions, functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionToAttribute, - rewrittenResultExpressions, + resultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 270c09aff3..7acf020b28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -177,7 +177,7 @@ case class Window( case e @ WindowExpression(function, spec) => val frame = spec.frameSpecification.asInstanceOf[SpecifiedWindowFrame] function match { - case AggregateExpression(f, _, _) => collect("AGGREGATE", frame, e, f) + case AggregateExpression(f, _, _, _) => collect("AGGREGATE", frame, e, f) case f: AggregateWindowFunction => collect("AGGREGATE", frame, e, f) case f: OffsetWindowFunction => collect("OFFSET", frame, e, f) case f => sys.error(s"Unsupported window function: $f") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 213bca907b..ce504e20e6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -242,9 +242,9 @@ class TungstenAggregationIterator( // Basically the value of the KVIterator returned by externalSorter // will be just aggregation buffer, so we rewrite the aggregateExpressions to reflect it. val newExpressions = aggregateExpressions.map { - case agg @ AggregateExpression(_, Partial, _) => + case agg @ AggregateExpression(_, Partial, _, _) => agg.copy(mode = PartialMerge) - case agg @ AggregateExpression(_, Complete, _) => + case agg @ AggregateExpression(_, Complete, _, _) => agg.copy(mode = Final) case other => other } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 1e113ccd4e..4682949fa1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.streaming.{StateStoreRestore, StateStoreSave} /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -29,15 +30,11 @@ object Utils { def planAggregateWithoutPartial( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = completeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } - + val completeAggregateAttributes = completeAggregateExpressions.map(_.resultAttribute) SortBasedAggregate( requiredChildDistributionExpressions = Some(groupingExpressions), groupingExpressions = groupingExpressions, @@ -83,7 +80,6 @@ object Utils { def planAggregateWithoutDistinct( groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. @@ -111,9 +107,7 @@ object Utils { val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val finalAggregate = createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes), @@ -131,7 +125,6 @@ object Utils { groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression], functionsWithoutDistinct: Seq[AggregateExpression], - aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -151,9 +144,7 @@ object Utils { // 1. Create an Aggregate Operator for partial aggregations. val partialAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) // We will group by the original grouping expression, plus an additional expression for the // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping // expressions will be [key, value]. @@ -169,9 +160,7 @@ object Utils { // 2. Create an Aggregate Operator for partial merge aggregations. val partialMergeAggregate: SparkPlan = { val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) - val aggregateAttributes = aggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) createAggregate( requiredChildDistributionExpressions = Some(groupingAttributes ++ distinctAttributes), @@ -190,7 +179,7 @@ object Utils { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children // to AttributeReferences. - case agg @ AggregateExpression(aggregateFunction, mode, true) => + case agg @ AggregateExpression(aggregateFunction, mode, true, _) => aggregateFunction.transformDown(distinctColumnAttributeLookup) .asInstanceOf[AggregateFunction] } @@ -199,9 +188,7 @@ object Utils { val mergeAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val mergeAggregateAttributes = mergeAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val mergeAggregateAttributes = mergeAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => // We rewrite the aggregate function to a non-distinct aggregation because @@ -211,7 +198,7 @@ object Utils { val expr = AggregateExpression(func, Partial, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -232,9 +219,7 @@ object Utils { val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: - val finalAggregateAttributes = finalAggregateExpressions.map { - expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) - } + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) val (distinctAggregateExpressions, distinctAggregateAttributes) = rewrittenDistinctFunctions.zipWithIndex.map { case (func, i) => @@ -245,7 +230,7 @@ object Utils { val expr = AggregateExpression(func, Final, isDistinct = true) // Use original AggregationFunction to lookup attributes, which is used to build // aggregateFunctionToAttribute - val attr = aggregateFunctionToAttribute(functionsWithDistinct(i).aggregateFunction, true) + val attr = functionsWithDistinct(i).resultAttribute (expr, attr) }.unzip @@ -261,4 +246,90 @@ object Utils { finalAndCompleteAggregate :: Nil } + + /** + * Plans a streaming aggregation using the following progression: + * - Partial Aggregation + * - Shuffle + * - Partial Merge (now there is at most 1 tuple per group) + * - StateStoreRestore (now there is 1 tuple from this batch + optionally one from the previous) + * - PartialMerge (now there is at most 1 tuple per group) + * - StateStoreSave (saves the tuple for the next batch) + * - Complete (output the current result of the aggregation) + */ + def planStreamingAggregation( + groupingExpressions: Seq[NamedExpression], + functionsWithoutDistinct: Seq[AggregateExpression], + resultExpressions: Seq[NamedExpression], + child: SparkPlan): Seq[SparkPlan] = { + + val groupingAttributes = groupingExpressions.map(_.toAttribute) + + val partialAggregate: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Partial)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + // We will group by the original grouping expression, plus an additional expression for the + // DISTINCT column. For example, for AVG(DISTINCT value) GROUP BY key, the grouping + // expressions will be [key, value]. + createAggregate( + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = child) + } + + val partialMerged1: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = partialAggregate) + } + + val restored = StateStoreRestore(groupingAttributes, None, partialMerged1) + + val partialMerged2: SparkPlan = { + val aggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = PartialMerge)) + val aggregateAttributes = aggregateExpressions.map(_.resultAttribute) + createAggregate( + requiredChildDistributionExpressions = + Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = groupingAttributes ++ + aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes), + child = restored) + } + + val saved = StateStoreSave(groupingAttributes, None, partialMerged2) + + val finalAndCompleteAggregate: SparkPlan = { + val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map(_.resultAttribute) + + createAggregate( + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, + aggregateExpressions = finalAggregateExpressions, + aggregateAttributes = finalAggregateAttributes, + initialInputBufferOffset = groupingAttributes.length, + resultExpressions = resultExpressions, + child = saved) + } + + finalAndCompleteAggregate :: Nil + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala new file mode 100644 index 0000000000..aaced49dd1 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/IncrementalExecution.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.{QueryExecution, SparkPlan, SparkPlanner, UnaryNode} + +/** + * A variant of [[QueryExecution]] that allows the execution of the given [[LogicalPlan]] + * plan incrementally. Possibly preserving state in between each execution. + */ +class IncrementalExecution( + ctx: SQLContext, + logicalPlan: LogicalPlan, + checkpointLocation: String, + currentBatchId: Long) extends QueryExecution(ctx, logicalPlan) { + + // TODO: make this always part of planning. + val stateStrategy = sqlContext.sessionState.planner.StatefulAggregationStrategy :: Nil + + // Modified planner with stateful operations. + override def planner: SparkPlanner = + new SparkPlanner( + sqlContext.sparkContext, + sqlContext.conf, + stateStrategy) + + /** + * Records the current id for a given stateful operator in the query plan as the `state` + * preperation walks the query plan. + */ + private var operatorId = 0 + + /** Locates save/restore pairs surrounding aggregation. */ + val state = new Rule[SparkPlan] { + override def apply(plan: SparkPlan): SparkPlan = plan transform { + case StateStoreSave(keys, None, + UnaryNode(agg, + StateStoreRestore(keys2, None, child))) => + val stateId = OperatorStateId(checkpointLocation, operatorId, currentBatchId - 1) + operatorId += 1 + + StateStoreSave( + keys, + Some(stateId), + agg.withNewChildren( + StateStoreRestore( + keys, + Some(stateId), + child) :: Nil)) + } + } + + override def preparations: Seq[Rule[SparkPlan]] = state +: super.preparations +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala new file mode 100644 index 0000000000..595774761c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StatefulAggregate.scala @@ -0,0 +1,119 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.errors._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.streaming.state._ +import org.apache.spark.sql.execution.SparkPlan + +/** Used to identify the state store for a given operator. */ +case class OperatorStateId( + checkpointLocation: String, + operatorId: Long, + batchId: Long) + +/** + * An operator that saves or restores state from the [[StateStore]]. The [[OperatorStateId]] should + * be filled in by `prepareForExecution` in [[IncrementalExecution]]. + */ +trait StatefulOperator extends SparkPlan { + def stateId: Option[OperatorStateId] + + protected def getStateId: OperatorStateId = attachTree(this) { + stateId.getOrElse { + throw new IllegalStateException("State location not present for execution") + } + } +} + +/** + * For each input tuple, the key is calculated and the value from the [[StateStore]] is added + * to the stream (in addition to the input tuple) if present. + */ +case class StateStoreRestore( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + iter.flatMap { row => + val key = getKey(row) + val savedState = store.get(key) + row +: savedState.toSeq + } + } + } + override def output: Seq[Attribute] = child.output +} + +/** + * For each input tuple, the key is calculated and the tuple is `put` into the [[StateStore]]. + */ +case class StateStoreSave( + keyExpressions: Seq[Attribute], + stateId: Option[OperatorStateId], + child: SparkPlan) extends execution.UnaryNode with StatefulOperator { + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsWithStateStore( + getStateId.checkpointLocation, + operatorId = getStateId.operatorId, + storeVersion = getStateId.batchId, + keyExpressions.toStructType, + child.output.toStructType, + new StateStoreConf(sqlContext.conf), + Some(sqlContext.streams.stateStoreCoordinator)) { case (store, iter) => + new Iterator[InternalRow] { + private[this] val baseIterator = iter + private[this] val getKey = GenerateUnsafeProjection.generate(keyExpressions, child.output) + + override def hasNext: Boolean = { + if (!baseIterator.hasNext) { + store.commit() + false + } else { + true + } + } + + override def next(): InternalRow = { + val row = baseIterator.next().asInstanceOf[UnsafeRow] + val key = getKey(row) + store.put(key.copy(), row.copy()) + row + } + } + } + } + + override def output: Seq[Attribute] = child.output +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index c4e410d92c..511e30c70c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.internal.Logging import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} import org.apache.spark.sql.catalyst.util._ @@ -272,6 +273,8 @@ class StreamExecution( private def runBatch(): Unit = { val startTime = System.nanoTime() + // TODO: Move this to IncrementalExecution. + // Request unprocessed data from all sources. val newData = availableOffsets.flatMap { case (source, available) if committedOffsets.get(source).map(_ < available).getOrElse(true) => @@ -305,13 +308,14 @@ class StreamExecution( } val optimizerStart = System.nanoTime() - - lastExecution = new QueryExecution(sqlContext, newPlan) - val executedPlan = lastExecution.executedPlan + lastExecution = + new IncrementalExecution(sqlContext, newPlan, checkpointFile("state"), currentBatchId) + lastExecution.executedPlan val optimizerTime = (System.nanoTime() - optimizerStart).toDouble / 1000000 logDebug(s"Optimized batch in ${optimizerTime}ms") - val nextBatch = Dataset.ofRows(sqlContext, newPlan) + val nextBatch = + new Dataset(sqlContext, lastExecution, RowEncoder(lastExecution.analyzed.schema)) sink.addBatch(currentBatchId - 1, nextBatch) awaitBatchLock.synchronized { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 0f91e59e04..7d97f81b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -108,7 +108,7 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink(schema: StructType) extends Sink with Logging { +class MemorySink(val schema: StructType) extends Sink with Logging { /** An order list of batches that have been written to this [[Sink]]. */ private val batches = new ArrayBuffer[Array[Row]]() @@ -117,6 +117,8 @@ class MemorySink(schema: StructType) extends Sink with Logging { batches.flatten } + def lastBatch: Seq[Row] = batches.last + def toDebugString: String = synchronized { batches.zipWithIndex.map { case (b, i) => val dataStr = try b.mkString(" ") catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index ee015baf3f..998eb82de1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -81,7 +81,7 @@ private[state] class HDFSBackedStateStoreProvider( trait STATE case object UPDATING extends STATE case object COMMITTED extends STATE - case object CANCELLED extends STATE + case object ABORTED extends STATE private val newVersion = version + 1 private val tempDeltaFile = new Path(baseDir, s"temp-${Random.nextLong}") @@ -94,15 +94,14 @@ private[state] class HDFSBackedStateStoreProvider( override def id: StateStoreId = HDFSBackedStateStoreProvider.this.id - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - override def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit = { - verify(state == UPDATING, "Cannot update after already committed or cancelled") - val oldValueOption = Option(mapToUpdate.get(key)) - val value = updateFunc(oldValueOption) + override def get(key: UnsafeRow): Option[UnsafeRow] = { + Option(mapToUpdate.get(key)) + } + + override def put(key: UnsafeRow, value: UnsafeRow): Unit = { + verify(state == UPDATING, "Cannot remove after already committed or cancelled") + + val isNewKey = !mapToUpdate.containsKey(key) mapToUpdate.put(key, value) Option(allUpdates.get(key)) match { @@ -115,8 +114,7 @@ private[state] class HDFSBackedStateStoreProvider( case None => // There was no prior update, so mark this as added or updated according to its presence // in previous version. - val update = - if (oldValueOption.nonEmpty) ValueUpdated(key, value) else ValueAdded(key, value) + val update = if (isNewKey) ValueAdded(key, value) else ValueUpdated(key, value) allUpdates.put(key, update) } writeToDeltaFile(tempDeltaFileStream, ValueUpdated(key, value)) @@ -148,7 +146,7 @@ private[state] class HDFSBackedStateStoreProvider( /** Commit all the updates that have been made to the store, and return the new version. */ override def commit(): Long = { - verify(state == UPDATING, "Cannot commit again after already committed or cancelled") + verify(state == UPDATING, "Cannot commit after already committed or cancelled") try { finalizeDeltaFile(tempDeltaFileStream) @@ -164,8 +162,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** Cancel all the updates made on this store. This store will not be usable any more. */ - override def cancel(): Unit = { - state = CANCELLED + override def abort(): Unit = { + state = ABORTED if (tempDeltaFileStream != null) { tempDeltaFileStream.close() } @@ -176,8 +174,8 @@ private[state] class HDFSBackedStateStoreProvider( } /** - * Get an iterator of all the store data. This can be called only after committing the - * updates. + * Get an iterator of all the store data. + * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { verify(state == COMMITTED, "Cannot get iterator of store data before comitting") @@ -186,7 +184,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Get an iterator of all the updates made to the store in the current version. - * This can be called only after committing the updates. + * This can be called only after committing all the updates made in the current thread. */ override def updates(): Iterator[StoreUpdate] = { verify(state == COMMITTED, "Cannot get iterator of updates before committing") @@ -196,7 +194,7 @@ private[state] class HDFSBackedStateStoreProvider( /** * Whether all updates have been committed */ - override def hasCommitted: Boolean = { + override private[state] def hasCommitted: Boolean = { state == COMMITTED } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index ca5c864d9e..d60e6185ac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -47,12 +47,11 @@ trait StateStore { /** Version of the data in this store before committing updates. */ def version: Long - /** - * Update the value of a key using the value generated by the update function. - * @note Do not mutate the retrieved value row as it will unexpectedly affect the previous - * versions of the store data. - */ - def update(key: UnsafeRow, updateFunc: Option[UnsafeRow] => UnsafeRow): Unit + /** Get the current value of a key. */ + def get(key: UnsafeRow): Option[UnsafeRow] + + /** Put a new value for a key. */ + def put(key: UnsafeRow, value: UnsafeRow) /** * Remove keys that match the following condition. @@ -65,24 +64,24 @@ trait StateStore { def commit(): Long /** Cancel all the updates that have been made to the store. */ - def cancel(): Unit + def abort(): Unit /** * Iterator of store data after a set of updates have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def iterator(): Iterator[(UnsafeRow, UnsafeRow)] /** * Iterator of the updates that have been committed. - * This can be called only after commitUpdates() has been called in the current thread. + * This can be called only after committing all the updates made in the current thread. */ def updates(): Iterator[StoreUpdate] /** * Whether all updates have been committed */ - def hasCommitted: Boolean + private[state] def hasCommitted: Boolean } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala index cca22a0af8..f0f1f3a1a8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreConf.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.streaming.state import org.apache.spark.sql.internal.SQLConf /** A class that contains configuration parameters for [[StateStore]]s. */ -private[state] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { +private[streaming] class StateStoreConf(@transient private val conf: SQLConf) extends Serializable { def this() = this(new SQLConf) @@ -31,7 +31,7 @@ private[state] class StateStoreConf(@transient private val conf: SQLConf) extend val minVersionsToRetain = conf.getConf(STATE_STORE_MIN_VERSIONS_TO_RETAIN) } -private[state] object StateStoreConf { +private[streaming] object StateStoreConf { val empty = new StateStoreConf() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala index 3318660895..df3d82c113 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDD.scala @@ -54,17 +54,10 @@ class StateStoreRDD[T: ClassTag, U: ClassTag]( override def compute(partition: Partition, ctxt: TaskContext): Iterator[U] = { var store: StateStore = null - - Utils.tryWithSafeFinally { - val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) - store = StateStore.get( - storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) - val inputIter = dataRDD.iterator(partition, ctxt) - val outputIter = storeUpdateFunction(store, inputIter) - assert(store.hasCommitted) - outputIter - } { - if (store != null) store.cancel() - } + val storeId = StateStoreId(checkpointLocation, operatorId, partition.index) + store = StateStore.get( + storeId, keySchema, valueSchema, storeVersion, storeConf, confBroadcast.value.value) + val inputIter = dataRDD.iterator(partition, ctxt) + storeUpdateFunction(store, inputIter) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala index b249e37921..9b6d0918e2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/package.scala @@ -28,37 +28,36 @@ package object state { implicit class StateStoreOps[T: ClassTag](dataRDD: RDD[T]) { /** Map each partition of a RDD along with data in a [[StateStore]]. */ - def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + def mapPartitionsWithStateStore[U: ClassTag]( + sqlContext: SQLContext, checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, - valueSchema: StructType - )(implicit sqlContext: SQLContext): StateStoreRDD[T, U] = { + valueSchema: StructType)( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { - mapPartitionWithStateStore( - storeUpdateFunction, + mapPartitionsWithStateStore( checkpointLocation, operatorId, storeVersion, keySchema, valueSchema, new StateStoreConf(sqlContext.conf), - Some(sqlContext.streams.stateStoreCoordinator)) + Some(sqlContext.streams.stateStoreCoordinator))( + storeUpdateFunction) } /** Map each partition of a RDD along with data in a [[StateStore]]. */ - private[state] def mapPartitionWithStateStore[U: ClassTag]( - storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U], + private[streaming] def mapPartitionsWithStateStore[U: ClassTag]( checkpointLocation: String, operatorId: Long, storeVersion: Long, keySchema: StructType, valueSchema: StructType, storeConf: StateStoreConf, - storeCoordinator: Option[StateStoreCoordinatorRef] - ): StateStoreRDD[T, U] = { + storeCoordinator: Option[StateStoreCoordinatorRef])( + storeUpdateFunction: (StateStore, Iterator[T]) => Iterator[U]): StateStoreRDD[T, U] = { val cleanedF = dataRDD.sparkContext.clean(storeUpdateFunction) new StateStoreRDD( dataRDD, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 0d580703f5..4b3091ba22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -17,12 +17,12 @@ package org.apache.spark.sql.execution +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.{expressions, InternalRow} import org.apache.spark.sql.catalyst.expressions.{ExprId, Literal, SubqueryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, ReturnAnswer} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.internal.SessionState import org.apache.spark.sql.types.DataType /** @@ -60,14 +60,13 @@ case class ScalarSubquery( } /** - * Convert the subquery from logical plan into executed plan. + * Plans scalar subqueries from that are present in the given [[SparkPlan]]. */ -case class PlanSubqueries(sessionState: SessionState) extends Rule[SparkPlan] { +case class PlanSubqueries(sqlContext: SQLContext) extends Rule[SparkPlan] { def apply(plan: SparkPlan): SparkPlan = { plan.transformAllExpressions { case subquery: expressions.ScalarSubquery => - val sparkPlan = sessionState.planner.plan(ReturnAnswer(subquery.query)).next() - val executedPlan = sessionState.prepareForExecution.execute(sparkPlan) + val executedPlan = new QueryExecution(sqlContext, subquery.plan).executedPlan ScalarSubquery(executedPlan, subquery.exprId) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala index 844f3051fa..9cb356f1ca 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/Aggregator.scala @@ -84,10 +84,10 @@ abstract class Aggregator[-I, B, O] extends Serializable { implicit bEncoder: Encoder[B], cEncoder: Encoder[O]): TypedColumn[I, O] = { val expr = - new AggregateExpression( + AggregateExpression( TypedAggregateExpression(this), Complete, - false) + isDistinct = false) new TypedColumn[I, O](expr, encoderFor[O]) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index f7fdfacd31..cd3d254d1e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -86,20 +86,8 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Planner that converts optimized logical plans to physical plans. */ - lazy val planner: SparkPlanner = new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) - - /** - * Prepares a planned [[SparkPlan]] for execution by inserting shuffle operations and internal - * row format conversions as needed. - */ - lazy val prepareForExecution = new RuleExecutor[SparkPlan] { - override val batches: Seq[Batch] = Seq( - Batch("Subquery", Once, PlanSubqueries(SessionState.this)), - Batch("Add exchange", Once, EnsureRequirements(conf)), - Batch("Whole stage codegen", Once, CollapseCodegenStages(conf)), - Batch("Reuse duplicated exchanges", Once, ReuseExchange(conf)) - ) - } + def planner: SparkPlanner = + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala index b5be7ef47e..550c3c6f9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StreamTest.scala @@ -116,15 +116,30 @@ trait StreamTest extends QueryTest with Timeouts { def apply[A : Encoder](data: A*): CheckAnswerRows = { val encoder = encoderFor[A] val toExternalRow = RowEncoder(encoder.schema) - CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d)))) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), false) } - def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows) + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false) } - case class CheckAnswerRows(expectedAnswer: Seq[Row]) + /** + * Checks to make sure that the current data stored in the sink matches the `expectedAnswer`. + * This operation automatically blocks until all added data has been processed. + */ + object CheckLastBatch { + def apply[A : Encoder](data: A*): CheckAnswerRows = { + val encoder = encoderFor[A] + val toExternalRow = RowEncoder(encoder.schema) + CheckAnswerRows(data.map(d => toExternalRow.fromRow(encoder.toRow(d))), true) + } + + def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true) + } + + case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean) extends StreamAction with StreamMustBeRunning { - override def toString: String = s"CheckAnswer: ${expectedAnswer.mkString(",")}" + override def toString: String = s"$operatorName: ${expectedAnswer.mkString(",")}" + private def operatorName = if (lastOnly) "CheckLastBatch" else "CheckAnswer" } /** Stops the stream. It must currently be running. */ @@ -224,11 +239,8 @@ trait StreamTest extends QueryTest with Timeouts { """.stripMargin def verify(condition: => Boolean, message: String): Unit = { - try { - Assertions.assert(condition) - } catch { - case NonFatal(e) => - failTest(message, e) + if (!condition) { + failTest(message) } } @@ -351,7 +363,7 @@ trait StreamTest extends QueryTest with Timeouts { case a: AddData => awaiting.put(a.source, a.addData()) - case CheckAnswerRows(expectedAnswer) => + case CheckAnswerRows(expectedAnswer, lastOnly) => verify(currentStream != null, "stream not running") // Block until all data added has been processed @@ -361,12 +373,12 @@ trait StreamTest extends QueryTest with Timeouts { } } - val allData = try sink.allData catch { + val sparkAnswer = try if (lastOnly) sink.lastBatch else sink.allData catch { case e: Exception => failTest("Exception while getting data from sink", e) } - QueryTest.sameRows(expectedAnswer, allData).foreach { + QueryTest.sameRows(expectedAnswer, sparkAnswer).foreach { error => failTest(error) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index ed0d3f56e5..38318740a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -231,10 +231,8 @@ object SparkPlanTest { } private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = { - // A very simple resolver to make writing tests easier. In contrast to the real resolver - // this is always case sensitive and does not try to handle scoping or complex type resolution. - val resolvedPlan = sqlContext.sessionState.prepareForExecution.execute( - outputPlan transform { + val execution = new QueryExecution(sqlContext, null) { + override lazy val sparkPlan: SparkPlan = outputPlan transform { case plan: SparkPlan => val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap plan transformExpressions { @@ -243,8 +241,8 @@ object SparkPlanTest { sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap")) } } - ) - resolvedPlan.executeCollectPublic().toSeq + } + execution.executedPlan.executeCollectPublic().toSeq } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala index 85db05157c..6be94eb24f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreRDDSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler.ExecutorCacheTaskLocation import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.util.quietly import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} -import org.apache.spark.util.Utils +import org.apache.spark.util.{CompletionIterator, Utils} class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAndAfterAll { @@ -54,62 +54,93 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn } test("versioning and immutability") { - quietly { - withSpark(new SparkContext(sparkConf)) { sc => - implicit val sqlContet = new SQLContext(sc) - val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } - val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + withSpark(new SparkContext(sparkConf)) { sc => + val sqlContext = new SQLContext(sc) + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 + val rdd1 = + makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)( + increment) + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + + // Generate next version of stores + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) + assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + + // Make sure the previous RDD still has the same data. + assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + } + } - // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) - assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) + test("recovering from files") { + val opId = 0 + val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + + def makeStoreRDD( + sc: SparkContext, + seq: Seq[String], + storeVersion: Int): RDD[(String, Int)] = { + implicit val sqlContext = new SQLContext(sc) + makeRDD(sc, Seq("a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion, keySchema, valueSchema)(increment) + } - // Make sure the previous RDD still has the same data. - assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) + // Generate RDDs and state store data + withSpark(new SparkContext(sparkConf)) { sc => + for (i <- 1 to 20) { + require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) } } + + // With a new context, try using the earlier state store data + withSpark(new SparkContext(sparkConf)) { sc => + assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) + } } - test("recovering from files") { - quietly { - val opId = 0 + test("usage with iterators - only gets and only puts") { + withSpark(new SparkContext(sparkConf)) { sc => + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString + val opId = 0 - def makeStoreRDD( - sc: SparkContext, - seq: Seq[String], - storeVersion: Int): RDD[(String, Int)] = { - implicit val sqlContext = new SQLContext(sc) - makeRDD(sc, Seq("a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion, keySchema, valueSchema) + // Returns an iterator of the incremented value made into the store + def iteratorOfPuts(store: StateStore, iter: Iterator[String]): Iterator[(String, Int)] = { + val resIterator = iter.map { s => + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + val newValue = oldValue + 1 + store.put(key, intToRow(newValue)) + (s, newValue) + } + CompletionIterator[(String, Int), Iterator[(String, Int)]](resIterator, { + store.commit() + }) } - // Generate RDDs and state store data - withSpark(new SparkContext(sparkConf)) { sc => - for (i <- 1 to 20) { - require(makeStoreRDD(sc, Seq("a"), i - 1).collect().toSet === Set("a" -> i)) + def iteratorOfGets( + store: StateStore, + iter: Iterator[String]): Iterator[(String, Option[Int])] = { + iter.map { s => + val key = stringToRow(s) + val value = store.get(key).map(rowToInt) + (s, value) } } - // With a new context, try using the earlier state store data - withSpark(new SparkContext(sparkConf)) { sc => - assert(makeStoreRDD(sc, Seq("a"), 20).collect().toSet === Set("a" -> 21)) - } + val rddOfGets1 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets1.collect().toSet === Set("a" -> None, "b" -> None, "c" -> None)) + + val rddOfPuts = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(iteratorOfPuts) + assert(rddOfPuts.collect().toSet === Set("a" -> 1, "a" -> 2, "b" -> 1)) + + val rddOfGets2 = makeRDD(sc, Seq("a", "b", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(iteratorOfGets) + assert(rddOfGets2.collect().toSet === Set("a" -> Some(2), "b" -> Some(1), "c" -> None)) } } @@ -128,8 +159,8 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn coordinatorRef.getLocation(StateStoreId(path, opId, 0)) === Some(ExecutorCacheTaskLocation("host1", "exec1").toString)) - val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) require(rdd.partitions.length === 2) assert( @@ -148,27 +179,16 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn test("distributed test") { quietly { withSpark(new SparkContext(sparkConf.setMaster("local-cluster[2, 1, 1024]"))) { sc => - implicit val sqlContet = new SQLContext(sc) + implicit val sqlContext = new SQLContext(sc) val path = Utils.createDirectory(tempDir, Random.nextString(10)).toString - val increment = (store: StateStore, iter: Iterator[String]) => { - iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) - } - store.commit() - store.iterator().map(rowsToStringInt) - } val opId = 0 - val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 0, keySchema, valueSchema) + val rdd1 = makeRDD(sc, Seq("a", "b", "a")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 0, keySchema, valueSchema)(increment) assert(rdd1.collect().toSet === Set("a" -> 2, "b" -> 1)) // Generate next version of stores - val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionWithStateStore( - increment, path, opId, storeVersion = 1, keySchema, valueSchema) + val rdd2 = makeRDD(sc, Seq("a", "c")).mapPartitionsWithStateStore( + sqlContext, path, opId, storeVersion = 1, keySchema, valueSchema)(increment) assert(rdd2.collect().toSet === Set("a" -> 3, "b" -> 1, "c" -> 1)) // Make sure the previous RDD still has the same data. @@ -183,11 +203,9 @@ class StateStoreRDDSuite extends SparkFunSuite with BeforeAndAfter with BeforeAn private val increment = (store: StateStore, iter: Iterator[String]) => { iter.foreach { s => - store.update( - stringToRow(s), oldRow => { - val oldValue = oldRow.map(rowToInt).getOrElse(0) - intToRow(oldValue + 1) - }) + val key = stringToRow(s) + val oldValue = store.get(key).map(rowToInt).getOrElse(0) + store.put(key, intToRow(oldValue + 1)) } store.commit() store.iterator().map(rowsToStringInt) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 22b2f4f75d..0e5936d53f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -51,7 +51,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth StateStore.stop() } - test("update, remove, commit, and all data iterator") { + test("get, put, remove, commit, and all data iterator") { val provider = newStoreProvider() // Verify state before starting a new set of updates @@ -67,7 +67,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth } // Verify state after updating - update(store, "a", 1) + put(store, "a", 1) intercept[IllegalStateException] { store.iterator() } @@ -77,8 +77,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth assert(provider.latestIterator().isEmpty) // Make updates, commit and then verify state - update(store, "b", 2) - update(store, "aa", 3) + put(store, "b", 2) + put(store, "aa", 3) remove(store, _.startsWith("a")) assert(store.commit() === 1) @@ -101,7 +101,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val reloadedProvider = new HDFSBackedStateStoreProvider( store.id, keySchema, valueSchema, StateStoreConf.empty, new Configuration) val reloadedStore = reloadedProvider.getStore(1) - update(reloadedStore, "c", 4) + put(reloadedStore, "c", 4) assert(reloadedStore.commit() === 2) assert(rowsToSet(reloadedStore.iterator()) === Set("b" -> 2, "c" -> 4)) assert(getDataFromFiles(provider) === Set("b" -> 2, "c" -> 4)) @@ -112,6 +112,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("updates iterator with all combos of updates and removes") { val provider = newStoreProvider() var currentVersion: Int = 0 + def withStore(body: StateStore => Unit): Unit = { val store = provider.getStore(currentVersion) body(store) @@ -120,9 +121,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // New data should be seen in updates as value added, even if they had multiple updates withStore { store => - update(store, "a", 1) - update(store, "aa", 1) - update(store, "aa", 2) + put(store, "a", 1) + put(store, "aa", 1) + put(store, "aa", 2) store.commit() assert(updatesToSet(store.updates()) === Set(Added("a", 1), Added("aa", 2))) assert(rowsToSet(store.iterator()) === Set("a" -> 1, "aa" -> 2)) @@ -131,8 +132,8 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Multiple updates to same key should be collapsed in the updates as a single value update // Keys that have not been updated should not appear in the updates withStore { store => - update(store, "a", 4) - update(store, "a", 6) + put(store, "a", 4) + put(store, "a", 6) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 6))) assert(rowsToSet(store.iterator()) === Set("a" -> 6, "aa" -> 2)) @@ -140,9 +141,9 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Keys added, updated and finally removed before commit should not appear in updates withStore { store => - update(store, "b", 4) // Added, finally removed - update(store, "bb", 5) // Added, updated, finally removed - update(store, "bb", 6) + put(store, "b", 4) // Added, finally removed + put(store, "bb", 5) // Added, updated, finally removed + put(store, "bb", 6) remove(store, _.startsWith("b")) store.commit() assert(updatesToSet(store.updates()) === Set.empty) @@ -153,7 +154,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Removed, but re-added data should be seen in updates as a value update withStore { store => remove(store, _.startsWith("a")) - update(store, "a", 10) + put(store, "a", 10) store.commit() assert(updatesToSet(store.updates()) === Set(Updated("a", 10), Removed("aa"))) assert(rowsToSet(store.iterator()) === Set("a" -> 10)) @@ -163,14 +164,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth test("cancel") { val provider = newStoreProvider() val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) store.commit() assert(rowsToSet(store.iterator()) === Set("a" -> 1)) // cancelUpdates should not change the data in the files val store1 = provider.getStore(1) - update(store1, "b", 1) - store1.cancel() + put(store1, "b", 1) + store1.abort() assert(getDataFromFiles(provider) === Set("a" -> 1)) } @@ -183,7 +184,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Prepare some data in the stoer val store = provider.getStore(0) - update(store, "a", 1) + put(store, "a", 1) assert(store.commit() === 1) assert(rowsToSet(store.iterator()) === Set("a" -> 1)) @@ -193,14 +194,14 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Update store version with some data val store1 = provider.getStore(1) - update(store1, "b", 1) + put(store1, "b", 1) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 1, "b" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "b" -> 1)) // Overwrite the version with other data val store2 = provider.getStore(1) - update(store2, "c", 1) + put(store2, "c", 1) assert(store2.commit() === 2) assert(rowsToSet(store2.iterator()) === Set("a" -> 1, "c" -> 1)) assert(getDataFromFiles(provider) === Set("a" -> 1, "c" -> 1)) @@ -213,7 +214,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth def updateVersionTo(targetVersion: Int): Unit = { for (i <- currentVersion + 1 to targetVersion) { val store = provider.getStore(currentVersion) - update(store, "a", i) + put(store, "a", i) store.commit() currentVersion += 1 } @@ -264,7 +265,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -284,7 +285,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val provider = newStoreProvider(minDeltasForSnapshot = 5) for (i <- 1 to 6) { val store = provider.getStore(i - 1) - update(store, "a", i) + put(store, "a", i) store.commit() provider.doMaintenance() // do cleanup } @@ -333,7 +334,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth // Increase version of the store val store0 = StateStore.get(storeId, keySchema, valueSchema, 0, storeConf, hadoopConf) assert(store0.version === 0) - update(store0, "a", 1) + put(store0, "a", 1) store0.commit() assert(StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf).version == 1) @@ -345,7 +346,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth val store1 = StateStore.get(storeId, keySchema, valueSchema, 1, storeConf, hadoopConf) assert(StateStore.isLoaded(storeId)) - update(store1, "a", 2) + put(store1, "a", 2) assert(store1.commit() === 2) assert(rowsToSet(store1.iterator()) === Set("a" -> 2)) } @@ -371,7 +372,7 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth for (i <- 1 to 20) { val store = StateStore.get( storeId, keySchema, valueSchema, i - 1, storeConf, hadoopConf) - update(store, "a", i) + put(store, "a", i) store.commit() } eventually(timeout(10 seconds)) { @@ -507,8 +508,12 @@ class StateStoreSuite extends SparkFunSuite with BeforeAndAfter with PrivateMeth store.remove(row => condition(rowToString(row))) } - private def update(store: StateStore, key: String, value: Int): Unit = { - store.update(stringToRow(key), _ => intToRow(value)) + private def put(store: StateStore, key: String, value: Int): Unit = { + store.put(stringToRow(key), intToRow(value)) + } + + private def get(store: StateStore, key: String): Option[Int] = { + store.get(stringToRow(key)).map(rowToInt) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala new file mode 100644 index 0000000000..b63ce89d18 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.streaming + +import org.apache.spark.SparkException +import org.apache.spark.sql.{Encoder, StreamTest, SumOf, TypedColumn} +import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SharedSQLContext + +object FailureSinglton { + var firstTime = true +} + +class StreamingAggregationSuite extends StreamTest with SharedSQLContext { + + import testImplicits._ + + test("simple count") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 3), + CheckLastBatch((3, 1)), + AddData(inputData, 3, 2), + CheckLastBatch((3, 2), (2, 1)), + StopStream, + StartStream, + AddData(inputData, 3, 2, 1), + CheckLastBatch((3, 3), (2, 2), (1, 1)), + // By default we run in new tuple mode. + AddData(inputData, 4, 4, 4, 4), + CheckLastBatch((4, 4)) + ) + } + + test("multiple keys") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value", $"value" + 1) + .agg(count("*")) + .as[(Int, Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 1), (2, 3, 1)), + AddData(inputData, 1, 2), + CheckLastBatch((1, 2, 2), (2, 3, 2)) + ) + } + + test("multiple aggregations") { + val inputData = MemoryStream[Int] + + val aggregated = + inputData.toDF() + .groupBy($"value") + .agg(count("*") as 'count) + .groupBy($"value" % 2) + .agg(sum($"count")) + .as[(Int, Long)] + + testStream(aggregated)( + AddData(inputData, 1, 2, 3, 4), + CheckLastBatch((0, 2), (1, 2)), + AddData(inputData, 1, 3, 5), + CheckLastBatch((1, 5)) + ) + } + + testQuietly("midbatch failure") { + val inputData = MemoryStream[Int] + FailureSinglton.firstTime = true + val aggregated = + inputData.toDS() + .map { i => + if (i == 4 && FailureSinglton.firstTime) { + FailureSinglton.firstTime = false + sys.error("injected failure") + } + + i + } + .groupBy($"value") + .agg(count("*")) + .as[(Int, Long)] + + testStream(aggregated)( + StartStream, + AddData(inputData, 1, 2, 3, 4), + ExpectFailure[SparkException](), + StartStream, + CheckLastBatch((1, 1), (2, 1), (3, 1), (4, 1)) + ) + } + + test("typed aggregators") { + def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = + new SumOf(f).toColumn + + val inputData = MemoryStream[(String, Int)] + val aggregated = inputData.toDS().groupByKey(_._1).agg(sum(_._2)) + + testStream(aggregated)( + AddData(inputData, ("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)), + CheckLastBatch(("a", 30), ("b", 3), ("c", 1)) + ) + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 2bdb428e9d..ff40c366c8 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -77,8 +77,9 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) /** * Planner that takes into account Hive-specific strategies. */ - override lazy val planner: SparkPlanner = { - new SparkPlanner(ctx.sparkContext, conf, experimentalMethods) with HiveStrategies { + override def planner: SparkPlanner = { + new SparkPlanner(ctx.sparkContext, conf, experimentalMethods.extraStrategies) + with HiveStrategies { override val hiveContext = ctx override def strategies: Seq[Strategy] = { -- cgit v1.2.3 From 27e71a2cd930ae28c82c9c3ee6476a12ea165fdf Mon Sep 17 00:00:00 2001 From: Cheng Lian Date: Fri, 1 Apr 2016 22:00:24 -0700 Subject: [SPARK-14244][SQL] Don't use SizeBasedWindowFunction.n created on executor side when evaluating window functions ## What changes were proposed in this pull request? `SizeBasedWindowFunction.n` is a global singleton attribute created for evaluating size based aggregate window functions like `CUME_DIST`. However, this attribute gets different expression IDs when created on both driver side and executor side. This PR adds `withPartitionSize` method to `SizeBasedWindowFunction` so that we can easily rewrite `SizeBasedWindowFunction.n` on executor side. ## How was this patch tested? A test case is added in `HiveSparkSubmitSuite`, which supports launching multi-process clusters. Author: Cheng Lian Closes #12040 from liancheng/spark-14244-fix-sized-window-function. --- .../catalyst/expressions/windowExpressions.scala | 6 ++- .../org/apache/spark/sql/execution/Window.scala | 22 ++++++++--- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 4 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 44 +++++++++++++++++++++- 4 files changed, 67 insertions(+), 9 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala index b8679474cf..c0b453dccf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/windowExpressions.scala @@ -451,7 +451,11 @@ abstract class RowNumberLike extends AggregateWindowFunction { * A [[SizeBasedWindowFunction]] needs the size of the current window for its calculation. */ trait SizeBasedWindowFunction extends AggregateWindowFunction { - protected def n: AttributeReference = SizeBasedWindowFunction.n + // It's made a val so that the attribute created on driver side is serialized to executor side. + // Otherwise, if it's defined as a function, when it's called on executor side, it actually + // returns the singleton value instantiated on executor side, which has different expression ID + // from the one created on driver side. + val n: AttributeReference = SizeBasedWindowFunction.n } object SizeBasedWindowFunction { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 7acf020b28..7d0567842c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -874,7 +874,8 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( * processor class. */ private[execution] object AggregateProcessor { - def apply(functions: Array[Expression], + def apply( + functions: Array[Expression], ordinal: Int, inputAttributes: Seq[Attribute], newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): @@ -885,11 +886,20 @@ private[execution] object AggregateProcessor { val evaluateExpressions = mutable.Buffer.fill[Expression](ordinal)(NoOp) val imperatives = mutable.Buffer.empty[ImperativeAggregate] + // SPARK-14244: `SizeBasedWindowFunction`s are firstly created on driver side and then + // serialized to executor side. These functions all reference a global singleton window + // partition size attribute reference, i.e., `SizeBasedWindowFunction.n`. Here we must collect + // the singleton instance created on driver side instead of using executor side + // `SizeBasedWindowFunction.n` to avoid binding failure caused by mismatching expression ID. + val partitionSize: Option[AttributeReference] = { + val aggs = functions.flatMap(_.collectFirst { case f: SizeBasedWindowFunction => f }) + aggs.headOption.map(_.n) + } + // Check if there are any SizeBasedWindowFunctions. If there are, we add the partition size to // the aggregation buffer. Note that the ordinal of the partition size value will always be 0. - val trackPartitionSize = functions.exists(_.isInstanceOf[SizeBasedWindowFunction]) - if (trackPartitionSize) { - aggBufferAttributes += SizeBasedWindowFunction.n + partitionSize.foreach { n => + aggBufferAttributes += n initialValues += NoOp updateExpressions += NoOp } @@ -920,7 +930,7 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection( initialValues, - Seq(SizeBasedWindowFunction.n))() + partitionSize.toSeq)() val updateProjection = newMutableProjection( updateExpressions, aggBufferAttributes ++ inputAttributes)() @@ -935,7 +945,7 @@ private[execution] object AggregateProcessor { updateProjection, evaluateProjection, imperatives.toArray, - trackPartitionSize) + partitionSize.isDefined) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index c07c428895..5ada3d5598 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -107,7 +107,9 @@ private[hive] class HiveFunctionRegistry( // If there is any other error, we throw an AnalysisException. val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + s"because: ${throwable.getMessage}." - throw new AnalysisException(errorMessage) + val analysisException = new AnalysisException(errorMessage) + analysisException.setStackTrace(throwable.getStackTrace) + throw analysisException } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 16747cab37..53dec6348f 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging -import org.apache.spark.sql.{QueryTest, SQLContext} +import org.apache.spark.sql.{QueryTest, Row, SQLContext} import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -135,6 +135,19 @@ class HiveSparkSubmitSuite runSparkSubmit(args) } + test("SPARK-14244 fix window partition size attribute binding failure") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val args = Seq( + "--class", SPARK_14244.getClass.getName.stripSuffix("$"), + "--name", "SparkSQLConfTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + unusedJar.toString) + runSparkSubmit(args) + } + // NOTE: This is an expensive operation in terms of time (10 seconds+). Use sparingly. // This is copied from org.apache.spark.deploy.SparkSubmitSuite private def runSparkSubmit(args: Seq[String]): Unit = { @@ -378,3 +391,32 @@ object SPARK_11009 extends QueryTest { } } } + +object SPARK_14244 extends QueryTest { + import org.apache.spark.sql.expressions.Window + import org.apache.spark.sql.functions._ + + protected var sqlContext: SQLContext = _ + + def main(args: Array[String]): Unit = { + Utils.configTestLog4j("INFO") + + val sparkContext = new SparkContext( + new SparkConf() + .set("spark.ui.enabled", "false") + .set("spark.sql.shuffle.partitions", "100")) + + val hiveContext = new TestHiveContext(sparkContext) + sqlContext = hiveContext + + import hiveContext.implicits._ + + try { + val window = Window.orderBy('id) + val df = sqlContext.range(2).select(cume_dist().over(window).as('cdist)).orderBy('cdist) + checkAnswer(df, Seq(Row(0.5D), Row(1.0D))) + } finally { + sparkContext.stop() + } + } +} -- cgit v1.2.3 From fa1af0aff7bde9bbf7bfa6a3ac74699734c2fd8a Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Fri, 1 Apr 2016 22:45:52 -0700 Subject: [SPARK-14251][SQL] Add SQL command for printing out generated code for debugging ## What changes were proposed in this pull request? This PR implements `EXPLAIN CODEGEN` SQL command which returns generated codes like `debugCodegen`. In `spark-shell`, we don't need to `import debug` module. In `spark-sql`, we can use this SQL command now. **Before** ``` scala> import org.apache.spark.sql.execution.debug._ scala> sql("select 'a' as a group by 1").debugCodegen() Found 2 WholeStageCodegen subtrees. == Subtree 1 / 2 == ... Generated code: ... == Subtree 2 / 2 == ... Generated code: ... ``` **After** ``` scala> sql("explain extended codegen select 'a' as a group by 1").collect().foreach(println) [Found 2 WholeStageCodegen subtrees.] [== Subtree 1 / 2 ==] ... [] [Generated code:] ... [] [== Subtree 2 / 2 ==] ... [] [Generated code:] ... ``` ## How was this patch tested? Pass the Jenkins tests (including new testcases) Author: Dongjoon Hyun Closes #12099 from dongjoon-hyun/SPARK-14251. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 5 ++- .../spark/sql/execution/SparkSqlParser.scala | 3 +- .../spark/sql/execution/command/commands.scala | 15 ++++++-- .../apache/spark/sql/execution/debug/package.scala | 43 +++++++++++----------- .../spark/sql/execution/debug/DebuggingSuite.scala | 2 +- .../apache/spark/sql/hive/execution/commands.scala | 1 - .../sql/hive/execution/HiveExplainSuite.scala | 29 +++++++++++++++ 7 files changed, 67 insertions(+), 31 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d1747b9915..f34bb061e4 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -584,7 +584,7 @@ frameBound explainOption - : LOGICAL | FORMATTED | EXTENDED + : LOGICAL | FORMATTED | EXTENDED | CODEGEN ; transactionMode @@ -633,7 +633,7 @@ nonReserved | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | TEMPORARY | OPTIONS | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED + | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF | SET | VIEW | REPLACE @@ -724,6 +724,7 @@ DESCRIBE: 'DESCRIBE'; EXPLAIN: 'EXPLAIN'; FORMAT: 'FORMAT'; LOGICAL: 'LOGICAL'; +CODEGEN: 'CODEGEN'; CAST: 'CAST'; SHOW: 'SHOW'; TABLES: 'TABLES'; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 7efe98dd18..ff3ab7746c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -136,7 +136,8 @@ class SparkSqlAstBuilder extends AstBuilder { // Create the explain comment. val statement = plan(ctx.statement) if (isExplainableStatement(statement)) { - ExplainCommand(statement, extended = options.exists(_.EXTENDED != null)) + ExplainCommand(statement, extended = options.exists(_.EXTENDED != null), + codegen = options.exists(_.CODEGEN != null)) } else { ExplainCommand(OneRowRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index f90d8717ca..4bc62cdc4a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -28,10 +28,10 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.debug._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - /** * A logical command that is executed for its side-effects. `RunnableCommand`s are * wrapped in `ExecutedCommand` during execution. @@ -237,15 +237,22 @@ case class ExplainCommand( logicalPlan: LogicalPlan, override val output: Seq[Attribute] = Seq(AttributeReference("plan", StringType, nullable = true)()), - extended: Boolean = false) + extended: Boolean = false, + codegen: Boolean = false) extends RunnableCommand { // Run through the optimizer to generate the physical plan. override def run(sqlContext: SQLContext): Seq[Row] = try { // TODO in Hive, the "extended" ExplainCommand prints the AST as well, and detailed properties. val queryExecution = sqlContext.executePlan(logicalPlan) - val outputString = if (extended) queryExecution.toString else queryExecution.simpleString - + val outputString = + if (codegen) { + codegenString(queryExecution.executedPlan) + } else if (extended) { + queryExecution.toString + } else { + queryExecution.simpleString + } outputString.split("\n").map(Row(_)) } catch { case cause: TreeNodeException[_] => ("Error occurred during query planning: \n" + cause.getMessage).split("\n").map(Row(_)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala index 3a174ed94c..7b0c8ebdfa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala @@ -48,6 +48,25 @@ package object debug { // scalastyle:on println } + def codegenString(plan: SparkPlan): String = { + val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() + plan transform { + case s: WholeStageCodegen => + codegenSubtrees += s + s + case s => s + } + var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" + for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { + output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" + output += s + output += "\nGenerated code:\n" + val (_, source) = s.doCodeGen() + output += s"${CodeFormatter.format(source)}\n" + } + output + } + /** * Augments [[SQLContext]] with debug methods. */ @@ -81,28 +100,7 @@ package object debug { * WholeStageCodegen subtree). */ def debugCodegen(): Unit = { - debugPrint(debugCodegenString()) - } - - /** Visible for testing. */ - def debugCodegenString(): String = { - val plan = query.queryExecution.executedPlan - val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegen]() - plan transform { - case s: WholeStageCodegen => - codegenSubtrees += s - s - case s => s - } - var output = s"Found ${codegenSubtrees.size} WholeStageCodegen subtrees.\n" - for ((s, i) <- codegenSubtrees.toSeq.zipWithIndex) { - output += s"== Subtree ${i + 1} / ${codegenSubtrees.size} ==\n" - output += s - output += "\nGenerated code:\n" - val (_, source) = s.doCodeGen() - output += s"${CodeFormatter.format(source)}\n" - } - output + debugPrint(codegenString(query.queryExecution.executedPlan)) } } @@ -123,6 +121,7 @@ package object debug { /** * A collection of metrics for each column of output. + * * @param elementTypes the actual runtime types for the output. Useful when there are bugs * causing the wrong data to be projected. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala index 979265e274..c0fce4b96a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala @@ -27,7 +27,7 @@ class DebuggingSuite extends SparkFunSuite with SharedSQLContext { } test("debugCodegen") { - val res = sqlContext.range(10).groupBy("id").count().debugCodegenString() + val res = codegenString(sqlContext.range(10).groupBy("id").count().queryExecution.executedPlan) assert(res.contains("Subtree 1 / 2")) assert(res.contains("Subtree 2 / 2")) assert(res.contains("Object[]")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index cd26a68f35..64d1341a47 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -24,7 +24,6 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.command.RunnableCommand import org.apache.spark.sql.execution.datasources.{BucketSpec, DataSource, LogicalRelation} import org.apache.spark.sql.hive.HiveContext diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala index b7ef5d1db7..c45d49d6c0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala @@ -101,4 +101,33 @@ class HiveExplainSuite extends QueryTest with SQLTestUtils with TestHiveSingleto "Physical Plan should not contain Subquery since it's eliminated by optimizer") } } + + test("EXPLAIN CODEGEN command") { + checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), true, + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + "/* 002 */ return new GeneratedIterator(references);", + "/* 003 */ }" + ) + + checkExistence(sql("EXPLAIN CODEGEN SELECT 1"), false, + "== Physical Plan ==" + ) + + checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), true, + "WholeStageCodegen", + "Generated code:", + "/* 001 */ public Object generate(Object[] references) {", + "/* 002 */ return new GeneratedIterator(references);", + "/* 003 */ }" + ) + + checkExistence(sql("EXPLAIN EXTENDED CODEGEN SELECT 1"), false, + "== Parsed Logical Plan ==", + "== Analyzed Logical Plan ==", + "== Optimized Logical Plan ==", + "== Physical Plan ==" + ) + } } -- cgit v1.2.3 From 06694f1c68cb752ea311144f0dbe50e92e1393cf Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Sat, 2 Apr 2016 08:12:04 -0700 Subject: [MINOR] Typo fixes ## What changes were proposed in this pull request? Typo fixes. No functional changes. ## How was this patch tested? Built the sources and ran with samples. Author: Jacek Laskowski Closes #11802 from jaceklaskowski/typo-fixes. --- .../streaming/RecoverableNetworkWordCount.scala | 2 +- .../main/scala/org/apache/spark/ml/Pipeline.scala | 2 +- .../spark/ml/regression/LinearRegression.scala | 2 +- .../sql/catalyst/plans/logical/LogicalPlan.scala | 4 ++-- .../org/apache/spark/sql/ExperimentalMethods.scala | 2 +- .../sql/execution/joins/BroadcastHashJoin.scala | 2 +- .../scala/org/apache/spark/sql/functions.scala | 12 ++++++------ .../apache/spark/streaming/StreamingContext.scala | 13 +++++++------ .../streaming/dstream/ConstantInputDStream.scala | 2 +- .../apache/spark/streaming/dstream/DStream.scala | 8 ++++---- .../streaming/dstream/DStreamCheckpointData.scala | 6 +++--- .../spark/streaming/dstream/InputDStream.scala | 6 +++--- .../streaming/dstream/ReducedWindowedDStream.scala | 2 +- .../spark/streaming/dstream/StateDStream.scala | 12 ++++++------ .../streaming/scheduler/ReceivedBlockTracker.scala | 4 ++-- .../streaming/scheduler/rate/RateEstimator.scala | 22 ++++++++++++---------- 16 files changed, 52 insertions(+), 49 deletions(-) (limited to 'sql/catalyst') diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 05f8e65d65..b6b8bc33f7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("You arguments were " + args.mkString("[", ", ", "]")) + System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 3a99979a88..afefaaa883 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -147,7 +147,7 @@ class Pipeline @Since("1.4.0") ( t case _ => throw new IllegalArgumentException( - s"Do not support stage $stage of type ${stage.getClass}") + s"Does not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { curDataset = transformer.transform(curDataset) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index ba5ad4c072..2633c06f40 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -58,7 +58,7 @@ private[regression] trait LinearRegressionParams extends PredictorParams * The specific squared error loss function used is: * L = 1/2n ||A coefficients - y||^2^ * - * This support multiple types of regularization: + * This supports multiple types of regularization: * - none (a.k.a. ordinary least squares) * - L2 (ridge regression) * - L1 (Lasso) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index ecf4285c46..aceeb8aadc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -79,13 +79,13 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { /** * Computes [[Statistics]] for this plan. The default implementation assumes the output - * cardinality is the product of of all child plan's cardinality, i.e. applies in the case + * cardinality is the product of all child plan's cardinality, i.e. applies in the case * of cartesian joins. * * [[LeafNode]]s must override this. */ def statistics: Statistics = { - if (children.size == 0) { + if (children.isEmpty) { throw new UnsupportedOperationException(s"LeafNode $nodeName must implement statistics.") } Statistics(sizeInBytes = children.map(_.statistics.sizeInBytes).product) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala index d7cd84fd24..c5df028485 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ExperimentalMethods.scala @@ -37,7 +37,7 @@ class ExperimentalMethods private[sql]() { /** * Allows extra strategies to be injected into the query planner at runtime. Note this API - * should be consider experimental and is not intended to be stable across releases. + * should be considered experimental and is not intended to be stable across releases. * * @since 1.3.0 */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index f5b083c216..0ed1ed41b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is * being constructed, a Spark job is asynchronously started to calculate the values for the - * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed + * broadcast relation. This data is then placed in a Spark broadcast variable. The streamed * relation is not shuffled. */ case class BroadcastHashJoin( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 74906050ac..baf947d037 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2232,7 +2232,7 @@ object functions { /** * Splits str around pattern (pattern is a regular expression). - * NOTE: pattern is a string represent the regular expression. + * NOTE: pattern is a string representation of the regular expression. * * @group string_funcs * @since 1.5.0 @@ -2267,9 +2267,9 @@ object functions { /** * Translate any character in the src by a character in replaceString. - * The characters in replaceString is corresponding to the characters in matchingString. - * The translate will happen when any character in the string matching with the character - * in the matchingString. + * The characters in replaceString correspond to the characters in matchingString. + * The translate will happen when any character in the string matches the character + * in the `matchingString`. * * @group string_funcs * @since 1.5.0 @@ -2692,7 +2692,7 @@ object functions { ////////////////////////////////////////////////////////////////////////////////////////////// /** - * Returns true if the array contain the value + * Returns true if the array contains `value` * @group collection_funcs * @since 1.5.0 */ @@ -2920,7 +2920,7 @@ object functions { /** * Defines a user-defined function (UDF) using a Scala closure. For this variant, the caller must - * specifcy the output data type, and there is no automatic input type coercion. + * specify the output data type, and there is no automatic input type coercion. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 3a664c4f5c..c1e151d08b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -132,7 +132,7 @@ class StreamingContext private[streaming] ( "both SparkContext and checkpoint as null") } - private[streaming] val isCheckpointPresent = (_cp != null) + private[streaming] val isCheckpointPresent: Boolean = _cp != null private[streaming] val sc: SparkContext = { if (_sc != null) { @@ -213,8 +213,8 @@ class StreamingContext private[streaming] ( def sparkContext: SparkContext = sc /** - * Set each DStreams in this context to remember RDDs it generated in the last given duration. - * DStreams remember RDDs only for a limited duration of time and releases them for garbage + * Set each DStream in this context to remember RDDs it generated in the last given duration. + * DStreams remember RDDs only for a limited duration of time and release them for garbage * collection. This method allows the developer to specify how long to remember the RDDs ( * if the developer wishes to query old data outside the DStream computation). * @param duration Minimum duration that each DStream should remember its RDDs @@ -282,13 +282,14 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream from TCP source hostname:port. Data is received using + * Creates an input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes is interpreted as UTF8 encoded `\n` delimited * lines. * @param hostname Hostname to connect to for receiving data * @param port Port to connect to for receiving data * @param storageLevel Storage level to use for storing the received objects * (default: StorageLevel.MEMORY_AND_DISK_SER_2) + * @see [[socketStream]] */ def socketTextStream( hostname: String, @@ -299,7 +300,7 @@ class StreamingContext private[streaming] ( } /** - * Create a input stream from TCP source hostname:port. Data is received using + * Creates an input stream from TCP source hostname:port. Data is received using * a TCP socket and the receive bytes it interpreted as object using the given * converter. * @param hostname Hostname to connect to for receiving data @@ -860,7 +861,7 @@ private class StreamingContextPythonHelper { */ def tryRecoverFromCheckpoint(checkpointPath: String): Option[StreamingContext] = { val checkpointOption = CheckpointReader.read( - checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, false) + checkpointPath, new SparkConf(), SparkHadoopUtil.get.conf, ignoreReadError = false) checkpointOption.map(new StreamingContext(null, _, null)) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala index b5f86fe779..995470ec8d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ConstantInputDStream.scala @@ -23,7 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.streaming.{StreamingContext, Time} /** - * An input stream that always returns the same RDD on each timestep. Useful for testing. + * An input stream that always returns the same RDD on each time step. Useful for testing. */ class ConstantInputDStream[T: ClassTag](_ssc: StreamingContext, rdd: RDD[T]) extends InputDStream[T](_ssc) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index eb7b64eaf4..c40beeff97 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -83,7 +83,7 @@ abstract class DStream[T: ClassTag] ( // RDDs generated, marked as private[streaming] so that testsuites can access it @transient - private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]] () + private[streaming] var generatedRDDs = new HashMap[Time, RDD[T]]() // Time zero for the DStream private[streaming] var zeroTime: Time = null @@ -269,7 +269,7 @@ abstract class DStream[T: ClassTag] ( checkpointDuration == null || rememberDuration > checkpointDuration, s"The remember duration for ${this.getClass.getSimpleName} has been set to " + s" $rememberDuration which is not more than the checkpoint interval" + - s" ($checkpointDuration). Please set it to higher than $checkpointDuration." + s" ($checkpointDuration). Please set it to a value higher than $checkpointDuration." ) dependencies.foreach(_.validateAtStart()) @@ -277,7 +277,7 @@ abstract class DStream[T: ClassTag] ( logInfo(s"Slide time = $slideDuration") logInfo(s"Storage level = ${storageLevel.description}") logInfo(s"Checkpoint interval = $checkpointDuration") - logInfo(s"Remember duration = $rememberDuration") + logInfo(s"Remember interval = $rememberDuration") logInfo(s"Initialized and validated $this") } @@ -535,7 +535,7 @@ abstract class DStream[T: ClassTag] ( private def readObject(ois: ObjectInputStream): Unit = Utils.tryOrIOException { logDebug(s"${this.getClass().getSimpleName}.readObject used") ois.defaultReadObject() - generatedRDDs = new HashMap[Time, RDD[T]] () + generatedRDDs = new HashMap[Time, RDD[T]]() } // ======================================================================= diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 365a6bc417..431c9dbe2c 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -29,7 +29,7 @@ import org.apache.spark.streaming.Time import org.apache.spark.util.Utils private[streaming] -class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) +class DStreamCheckpointData[T: ClassTag](dstream: DStream[T]) extends Serializable with Logging { protected val data = new HashMap[Time, AnyRef]() @@ -45,7 +45,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) /** * Updates the checkpoint data of the DStream. This gets called every time * the graph checkpoint is initiated. Default implementation records the - * checkpoint files to which the generate RDDs of the DStream has been saved. + * checkpoint files at which the generated RDDs of the DStream have been saved. */ def update(time: Time) { @@ -103,7 +103,7 @@ class DStreamCheckpointData[T: ClassTag] (dstream: DStream[T]) /** * Restore the checkpoint data. This gets called once when the DStream graph - * (along with its DStreams) are being restored from a graph checkpoint file. + * (along with its output DStreams) is being restored from a graph checkpoint file. * Default implementation restores the RDDs from their checkpoint files. */ def restore() { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index 0b6b191dbe..dc88349db5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -39,7 +39,7 @@ import org.apache.spark.util.Utils * * @param _ssc Streaming context that will execute this input stream */ -abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext) +abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) extends DStream[T](_ssc) { private[streaming] var lastValidTime: Time = null @@ -90,8 +90,8 @@ abstract class InputDStream[T: ClassTag] (_ssc: StreamingContext) } else { // Time is valid, but check it it is more than lastValidTime if (lastValidTime != null && time < lastValidTime) { - logWarning("isTimeValid called with " + time + " where as last valid time is " + - lastValidTime) + logWarning(s"isTimeValid called with $time whereas the last valid time " + + s"is $lastValidTime") } lastValidTime = time true diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala index a9be2f213f..a9e93838b8 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ReducedWindowedDStream.scala @@ -87,7 +87,7 @@ class ReducedWindowedDStream[K: ClassTag, V: ClassTag]( logDebug("Window time = " + windowDuration) logDebug("Slide time = " + slideDuration) - logDebug("ZeroTime = " + zeroTime) + logDebug("Zero time = " + zeroTime) logDebug("Current window = " + currentWindow) logDebug("Previous window = " + previousWindow) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 68eff89030..0379957e58 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -70,7 +70,7 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the parent RDD parent.getOrCompute(validTime) match { case Some(parentRDD) => { // If parent RDD exists, then compute as usual - computeUsingPreviousRDD (parentRDD, prevStateRDD) + computeUsingPreviousRDD(parentRDD, prevStateRDD) } case None => { // If parent RDD does not exist @@ -98,15 +98,15 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, Iterable[V])]) => { - updateFuncLocal (iterator.map (tuple => (tuple._1, tuple._2.toSeq, None))) + updateFuncLocal(iterator.map(tuple => (tuple._1, tuple._2.toSeq, None))) } - val groupedRDD = parentRDD.groupByKey (partitioner) - val sessionRDD = groupedRDD.mapPartitions (finalFunc, preservePartitioning) + val groupedRDD = parentRDD.groupByKey(partitioner) + val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) // logDebug("Generating state RDD for time " + validTime + " (first)") - Some (sessionRDD) + Some(sessionRDD) } - case Some (initialStateRDD) => { + case Some(initialStateRDD) => { computeUsingPreviousRDD(parentRDD, initialStateRDD) } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala index 9c8e68b03d..5d9a8ac0d9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceivedBlockTracker.scala @@ -119,7 +119,7 @@ private[streaming] class ReceivedBlockTracker( timeToAllocatedBlocks.put(batchTime, allocatedBlocks) lastAllocatedBatchTime = batchTime } else { - logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery") } } else { // This situation occurs when: @@ -129,7 +129,7 @@ private[streaming] class ReceivedBlockTracker( // 2. Slow checkpointing makes recovered batch time older than WAL recovered // lastAllocatedBatchTime. // This situation will only occurs in recovery time. - logInfo(s"Possibly processed batch $batchTime need to be processed again in WAL recovery") + logInfo(s"Possibly processed batch $batchTime needs to be processed again in WAL recovery") } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala index d7210f64fc..7b2ef6881d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/rate/RateEstimator.scala @@ -21,18 +21,20 @@ import org.apache.spark.SparkConf import org.apache.spark.streaming.Duration /** - * A component that estimates the rate at wich an InputDStream should ingest - * elements, based on updates at every batch completion. + * A component that estimates the rate at which an `InputDStream` should ingest + * records, based on updates at every batch completion. + * + * @see [[org.apache.spark.streaming.scheduler.RateController]] */ private[streaming] trait RateEstimator extends Serializable { /** - * Computes the number of elements the stream attached to this `RateEstimator` + * Computes the number of records the stream attached to this `RateEstimator` * should ingest per second, given an update on the size and completion * times of the latest batch. * - * @param time The timetamp of the current batch interval that just finished - * @param elements The number of elements that were processed in this batch + * @param time The timestamp of the current batch interval that just finished + * @param elements The number of records that were processed in this batch * @param processingDelay The time in ms that took for the job to complete * @param schedulingDelay The time in ms that the job spent in the scheduling queue */ @@ -46,13 +48,13 @@ private[streaming] trait RateEstimator extends Serializable { object RateEstimator { /** - * Return a new RateEstimator based on the value of `spark.streaming.RateEstimator`. + * Return a new `RateEstimator` based on the value of + * `spark.streaming.backpressure.rateEstimator`. * - * The only known estimator right now is `pid`. + * The only known and acceptable estimator right now is `pid`. * * @return An instance of RateEstimator - * @throws IllegalArgumentException if there is a configured RateEstimator that doesn't match any - * known estimators. + * @throws IllegalArgumentException if the configured RateEstimator is not `pid`. */ def create(conf: SparkConf, batchInterval: Duration): RateEstimator = conf.get("spark.streaming.backpressure.rateEstimator", "pid") match { @@ -64,6 +66,6 @@ object RateEstimator { new PIDRateEstimator(batchInterval.milliseconds, proportional, integral, derived, minRate) case estimator => - throw new IllegalArgumentException(s"Unkown rate estimator: $estimator") + throw new IllegalArgumentException(s"Unknown rate estimator: $estimator") } } -- cgit v1.2.3 From f705037617d55bb479ec60bcb1e55c736224be94 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Apr 2016 17:48:53 -0700 Subject: [SPARK-14338][SQL] Improve `SimplifyConditionals` rule to handle `null` in IF/CASEWHEN ## What changes were proposed in this pull request? Currently, `SimplifyConditionals` handles `true` and `false` to optimize branches. This PR improves `SimplifyConditionals` to take advantage of `null` conditions for `if` and `CaseWhen` expressions, too. **Before** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [if (null) 1 else 0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [CASE WHEN null THEN 1 ELSE 2 END AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#14] : +- INPUT +- Scan OneRowRelation[] ``` **After** ``` scala> sql("SELECT IF(null, 1, 0)").explain() == Physical Plan == WholeStageCodegen : +- Project [0 AS (IF(CAST(NULL AS BOOLEAN), 1, 0))#4] : +- INPUT +- Scan OneRowRelation[] scala> sql("select case when cast(null as boolean) then 1 else 2 end").explain() == Physical Plan == WholeStageCodegen : +- Project [2 AS CASE WHEN CAST(NULL AS BOOLEAN) THEN 1 ELSE 2 END#4] : +- INPUT +- Scan OneRowRelation[] ``` **Hive** ``` hive> select if(null,1,2); OK 2 hive> select case when cast(null as boolean) then 1 else 2 end; OK 2 ``` ## How was this patch tested? Pass the Jenkins tests (including new extended test cases). Author: Dongjoon Hyun Closes #12122 from dongjoon-hyun/SPARK-14338. --- .../apache/spark/sql/catalyst/optimizer/Optimizer.scala | 13 ++++++++++--- .../catalyst/optimizer/SimplifyConditionalSuite.scala | 16 +++++++++++----- 2 files changed, 21 insertions(+), 8 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 326933ec9e..a5ab390c76 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -527,7 +527,7 @@ object LikeSimplification extends Rule[LogicalPlan] { * Null value propagation from bottom to top of the expression tree. */ object NullPropagation extends Rule[LogicalPlan] { - def nonNullLiteral(e: Expression): Boolean = e match { + private def nonNullLiteral(e: Expression): Boolean = e match { case Literal(null, _) => false case _ => true } @@ -773,17 +773,24 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { * Simplifies conditional expressions (if / case). */ object SimplifyConditionals extends Rule[LogicalPlan] with PredicateHelper { + private def falseOrNullLiteral(e: Expression): Boolean = e match { + case FalseLiteral => true + case Literal(null, _) => true + case _ => false + } + def apply(plan: LogicalPlan): LogicalPlan = plan transform { case q: LogicalPlan => q transformExpressionsUp { case If(TrueLiteral, trueValue, _) => trueValue case If(FalseLiteral, _, falseValue) => falseValue + case If(Literal(null, _), _, falseValue) => falseValue - case e @ CaseWhen(branches, elseValue) if branches.exists(_._1 == FalseLiteral) => + case e @ CaseWhen(branches, elseValue) if branches.exists(x => falseOrNullLiteral(x._1)) => // If there are branches that are always false, remove them. // If there are no more branches left, just use the else value. // Note that these two are handled together here in a single case statement because // otherwise we cannot determine the data type for the elseValue if it is None (i.e. null). - val newBranches = branches.filter(_._1 != FalseLiteral) + val newBranches = branches.filter(x => !falseOrNullLiteral(x._1)) if (newBranches.isEmpty) { elseValue.getOrElse(Literal.create(null, e.dataType)) } else { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index d436b627f6..33239c0084 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLite import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules._ -import org.apache.spark.sql.types.IntegerType +import org.apache.spark.sql.types.{IntegerType, NullType} class SimplifyConditionalSuite extends PlanTest with PredicateHelper { @@ -41,6 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val trueBranch = (TrueLiteral, Literal(5)) private val normalBranch = (NonFoldableLiteral(true), Literal(10)) private val unreachableBranch = (FalseLiteral, Literal(20)) + private val nullBranch = (Literal(null, NullType), Literal(30)) test("simplify if") { assertEquivalent( @@ -50,18 +51,22 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { assertEquivalent( If(FalseLiteral, Literal(10), Literal(20)), Literal(20)) + + assertEquivalent( + If(Literal(null, NullType), Literal(10), Literal(20)), + Literal(20)) } test("remove unreachable branches") { // i.e. removing branches whose conditions are always false assertEquivalent( - CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: Nil, None), + CaseWhen(unreachableBranch :: normalBranch :: unreachableBranch :: nullBranch :: Nil, None), CaseWhen(normalBranch :: Nil, None)) } test("remove entire CaseWhen if only the else branch is reachable") { assertEquivalent( - CaseWhen(unreachableBranch :: unreachableBranch :: Nil, Some(Literal(30))), + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: Nil, Some(Literal(30))), Literal(30)) assertEquivalent( @@ -71,12 +76,13 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { test("remove entire CaseWhen if the first branch is always true") { assertEquivalent( - CaseWhen(trueBranch :: normalBranch :: Nil, None), + CaseWhen(trueBranch :: normalBranch :: nullBranch :: Nil, None), Literal(5)) // Test branch elimination and simplification in combination assertEquivalent( - CaseWhen(unreachableBranch :: unreachableBranch:: trueBranch :: normalBranch :: Nil, None), + CaseWhen(unreachableBranch :: unreachableBranch :: nullBranch :: trueBranch :: normalBranch + :: Nil, None), Literal(5)) // Make sure this doesn't trigger if there is a non-foldable branch before the true branch -- cgit v1.2.3 From 4a6e78abd9d5edc4a5092738dff0006bbe202a89 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sat, 2 Apr 2016 17:50:40 -0700 Subject: [MINOR][DOCS] Use multi-line JavaDoc comments in Scala code. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? This PR aims to fix all Scala-Style multiline comments into Java-Style multiline comments in Scala codes. (All comment-only changes over 77 files: +786 lines, −747 lines) ## How was this patch tested? Manual. Author: Dongjoon Hyun Closes #12130 from dongjoon-hyun/use_multiine_javadoc_comments. --- .../main/scala/org/apache/spark/FutureAction.scala | 14 +- .../main/scala/org/apache/spark/SSLOptions.scala | 57 +++--- .../main/scala/org/apache/spark/SparkContext.scala | 42 +++-- .../org/apache/spark/api/java/JavaPairRDD.scala | 8 +- .../apache/spark/api/java/JavaSparkContext.scala | 60 ++++--- .../apache/spark/deploy/worker/CommandUtils.scala | 2 +- .../scala/org/apache/spark/rdd/CoGroupedRDD.scala | 10 +- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../mesos/CoarseMesosSchedulerBackend.scala | 24 +-- .../cluster/mesos/MesosClusterScheduler.scala | 12 +- .../cluster/mesos/MesosSchedulerUtils.scala | 4 +- .../org/apache/spark/shuffle/ShuffleManager.scala | 6 +- .../apache/spark/storage/memory/MemoryStore.scala | 20 +-- .../main/scala/org/apache/spark/util/Utils.scala | 17 +- core/src/test/scala/org/apache/spark/Smuggle.scala | 46 ++--- .../apache/spark/memory/MemoryManagerSuite.scala | 24 +-- .../org/apache/spark/examples/BroadcastTest.scala | 4 +- .../apache/spark/examples/DFSReadWriteTest.scala | 20 +-- .../org/apache/spark/examples/GroupByTest.scala | 4 +- .../apache/spark/examples/MultiBroadcastTest.scala | 4 +- .../spark/examples/SimpleSkewedGroupByTest.scala | 4 +- .../apache/spark/examples/SkewedGroupByTest.scala | 4 +- .../streaming/clickstream/PageViewGenerator.scala | 23 +-- .../streaming/clickstream/PageViewStream.scala | 21 +-- .../spark/streaming/flume/FlumeInputDStream.scala | 15 +- .../spark/streaming/kafka/KafkaRDDPartition.scala | 15 +- .../scala/org/apache/spark/graphx/GraphOps.scala | 10 +- .../spark/graphx/lib/ConnectedComponents.scala | 18 +- .../spark/ml/feature/ElementwiseProduct.scala | 6 +- .../api/python/GaussianMixtureModelWrapper.scala | 8 +- .../mllib/api/python/Word2VecModelWrapper.scala | 4 +- .../org/apache/spark/mllib/linalg/Matrices.scala | 16 +- .../StreamingLinearRegressionWithSGD.scala | 4 +- .../scala/org/apache/spark/repl/SparkILoop.scala | 21 ++- .../scala/org/apache/spark/repl/SparkImports.scala | 5 +- .../main/scala/org/apache/spark/sql/Encoder.scala | 24 +-- .../spark/sql/catalyst/analysis/Analyzer.scala | 20 +-- .../sql/catalyst/expressions/Projection.scala | 6 +- .../expressions/codegen/CodeGenerator.scala | 26 +-- .../spark/sql/catalyst/expressions/grouping.scala | 18 +- .../spark/sql/catalyst/expressions/misc.scala | 4 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 40 ++--- .../spark/sql/catalyst/parser/AstBuilder.scala | 4 +- .../spark/sql/catalyst/planning/patterns.scala | 28 +-- .../spark/sql/catalyst/plans/QueryPlan.scala | 4 +- .../sql/catalyst/plans/physical/partitioning.scala | 6 +- .../optimizer/OptimizerExtendableSuite.scala | 14 +- .../scala/org/apache/spark/sql/SQLContext.scala | 14 +- .../apache/spark/sql/execution/CacheManager.scala | 7 +- .../org/apache/spark/sql/execution/SparkPlan.scala | 4 +- .../spark/sql/execution/WholeStageCodegen.scala | 172 +++++++++---------- .../org/apache/spark/sql/execution/Window.scala | 36 ++-- .../execution/aggregate/AggregationIterator.scala | 22 +-- .../aggregate/SortBasedAggregationIterator.scala | 6 +- .../execution/aggregate/TungstenAggregate.scala | 16 +- .../execution/datasources/SqlNewHadoopRDD.scala | 8 +- .../execution/datasources/csv/CSVInferSchema.scala | 22 +-- .../execution/datasources/csv/DefaultSource.scala | 4 +- .../spark/sql/execution/datasources/ddl.scala | 8 +- .../sql/execution/joins/BroadcastHashJoin.scala | 22 +-- .../sql/execution/joins/CartesianProduct.scala | 8 +- .../spark/sql/execution/joins/HashJoin.scala | 8 +- .../spark/sql/execution/joins/HashedRelation.scala | 76 ++++---- .../spark/sql/execution/joins/SortMergeJoin.scala | 36 ++-- .../state/HDFSBackedStateStoreProvider.scala | 8 +- .../spark/sql/execution/ui/SparkPlanGraph.scala | 4 +- .../scala/org/apache/spark/sql/functions.scala | 191 ++++++++++----------- .../org/apache/spark/sql/sources/interfaces.scala | 26 +-- .../scala/org/apache/spark/sql/QueryTest.scala | 4 +- .../sql/execution/BenchmarkWholeStageCodegen.scala | 8 +- .../execution/datasources/csv/CSVParserSuite.scala | 4 +- .../apache/spark/streaming/StreamingContext.scala | 7 +- .../streaming/api/java/JavaStreamingContext.scala | 7 +- .../spark/streaming/receiver/RateLimiter.scala | 23 +-- .../apache/spark/streaming/scheduler/JobSet.scala | 7 +- .../apache/spark/tools/GenerateMIMAIgnore.scala | 9 +- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 16 +- 77 files changed, 786 insertions(+), 747 deletions(-) (limited to 'sql/catalyst') diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 2a8220ff40..ce11772a6d 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -146,16 +146,16 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: /** - * Handle via which a "run" function passed to a [[ComplexFutureAction]] - * can submit jobs for execution. - */ + * Handle via which a "run" function passed to a [[ComplexFutureAction]] + * can submit jobs for execution. + */ @DeveloperApi trait JobSubmitter { /** - * Submit a job for execution and return a FutureAction holding the result. - * This is a wrapper around the same functionality provided by SparkContext - * to enable cancellation. - */ + * Submit a job for execution and return a FutureAction holding the result. + * This is a wrapper around the same functionality provided by SparkContext + * to enable cancellation. + */ def submitJob[T, U, R]( rdd: RDD[T], processPartition: Iterator[T] => U, diff --git a/core/src/main/scala/org/apache/spark/SSLOptions.scala b/core/src/main/scala/org/apache/spark/SSLOptions.scala index 30db6ccbf4..719905a2c9 100644 --- a/core/src/main/scala/org/apache/spark/SSLOptions.scala +++ b/core/src/main/scala/org/apache/spark/SSLOptions.scala @@ -132,34 +132,35 @@ private[spark] case class SSLOptions( private[spark] object SSLOptions extends Logging { - /** Resolves SSLOptions settings from a given Spark configuration object at a given namespace. - * - * The following settings are allowed: - * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively - * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory - * $ - `[ns].keyStorePassword` - a password to the key-store file - * $ - `[ns].keyPassword` - a password to the private key - * $ - `[ns].keyStoreType` - the type of the key-store - * $ - `[ns].needClientAuth` - whether SSL needs client authentication - * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current - * directory - * $ - `[ns].trustStorePassword` - a password to the trust-store file - * $ - `[ns].trustStoreType` - the type of trust-store - * $ - `[ns].protocol` - a protocol name supported by a particular Java version - * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers - * - * For a list of protocols and ciphers supported by particular Java versions, you may go to - * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle - * blog page]]. - * - * You can optionally specify the default configuration. If you do, for each setting which is - * missing in SparkConf, the corresponding setting is used from the default configuration. - * - * @param conf Spark configuration object where the settings are collected from - * @param ns the namespace name - * @param defaults the default configuration - * @return [[org.apache.spark.SSLOptions]] object - */ + /** + * Resolves SSLOptions settings from a given Spark configuration object at a given namespace. + * + * The following settings are allowed: + * $ - `[ns].enabled` - `true` or `false`, to enable or disable SSL respectively + * $ - `[ns].keyStore` - a path to the key-store file; can be relative to the current directory + * $ - `[ns].keyStorePassword` - a password to the key-store file + * $ - `[ns].keyPassword` - a password to the private key + * $ - `[ns].keyStoreType` - the type of the key-store + * $ - `[ns].needClientAuth` - whether SSL needs client authentication + * $ - `[ns].trustStore` - a path to the trust-store file; can be relative to the current + * directory + * $ - `[ns].trustStorePassword` - a password to the trust-store file + * $ - `[ns].trustStoreType` - the type of trust-store + * $ - `[ns].protocol` - a protocol name supported by a particular Java version + * $ - `[ns].enabledAlgorithms` - a comma separated list of ciphers + * + * For a list of protocols and ciphers supported by particular Java versions, you may go to + * [[https://blogs.oracle.com/java-platform-group/entry/diagnosing_tls_ssl_and_https Oracle + * blog page]]. + * + * You can optionally specify the default configuration. If you do, for each setting which is + * missing in SparkConf, the corresponding setting is used from the default configuration. + * + * @param conf Spark configuration object where the settings are collected from + * @param ns the namespace name + * @param defaults the default configuration + * @return [[org.apache.spark.SSLOptions]] object + */ def parse(conf: SparkConf, ns: String, defaults: Option[SSLOptions] = None): SSLOptions = { val enabled = conf.getBoolean(s"$ns.enabled", defaultValue = defaults.exists(_.enabled)) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d7cb253d69..4b3264cbf5 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -773,9 +773,11 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli parallelize(seq, numSlices) } - /** Distribute a local Scala collection to form an RDD, with one or more - * location preferences (hostnames of Spark nodes) for each object. - * Create a new partition for each collection item. */ + /** + * Distribute a local Scala collection to form an RDD, with one or more + * location preferences (hostnames of Spark nodes) for each object. + * Create a new partition for each collection item. + */ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = withScope { assertNotStopped() val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap @@ -1095,14 +1097,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli new NewHadoopRDD(this, fClass, kClass, vClass, jconf) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle - * operation will create many references to the same object. - * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first - * copy them using a `map` function. - */ + /** + * Get an RDD for a Hadoop SequenceFile with given key and value types. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], @@ -1113,14 +1116,15 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle - * operation will create many references to the same object. - * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first - * copy them using a `map` function. - * */ + /** + * Get an RDD for a Hadoop SequenceFile with given key and value types. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD or directly passing it to an aggregation or shuffle + * operation will create many references to the same object. + * If you plan to directly cache, sort, or aggregate Hadoop writable objects, you should first + * copy them using a `map` function. + */ def sequenceFile[K, V]( path: String, keyClass: Class[K], diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index e080f91f50..2897272a8b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -461,10 +461,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)]) fromRDD(rdd.partitionBy(partitioner)) /** - * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each - * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and - * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. - */ + * Return an RDD containing all pairs of elements with matching keys in `this` and `other`. Each + * pair of elements will be returned as a (k, (v1, v2)) tuple, where (k, v1) is in `this` and + * (k, v2) is in `other`. Uses the given Partitioner to partition the output RDD. + */ def join[W](other: JavaPairRDD[K, W], partitioner: Partitioner): JavaPairRDD[K, (V, W)] = fromRDD(rdd.join(other, partitioner)) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index d362c40b7a..dfd91ae338 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -295,13 +295,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaRDD(sc.binaryRecords(path, recordLength)) } - /** Get an RDD for a Hadoop SequenceFile with given key and value types. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - * */ + /** + * Get an RDD for a Hadoop SequenceFile with given key and value types. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V], @@ -312,13 +313,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minPartitions)) } - /** Get an RDD for a Hadoop SequenceFile. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - */ + /** + * Get an RDD for a Hadoop SequenceFile. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): JavaPairRDD[K, V] = { implicit val ctagK: ClassTag[K] = ClassTag(keyClass) @@ -411,13 +413,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } - /** Get an RDD for a Hadoop file with an arbitrary InputFormat. - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - */ + /** + * Get an RDD for a Hadoop file with an arbitrary InputFormat. + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def hadoopFile[K, V, F <: InputFormat[K, V]]( path: String, inputFormatClass: Class[F], @@ -431,13 +434,14 @@ class JavaSparkContext(val sc: SparkContext) new JavaHadoopRDD(rdd.asInstanceOf[HadoopRDD[K, V]]) } - /** Get an RDD for a Hadoop file with an arbitrary InputFormat - * - * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each - * record, directly caching the returned RDD will create many references to the same object. - * If you plan to directly cache Hadoop writable objects, you should first copy them using - * a `map` function. - */ + /** + * Get an RDD for a Hadoop file with an arbitrary InputFormat + * + * '''Note:''' Because Hadoop's RecordReader class re-uses the same Writable object for each + * record, directly caching the returned RDD will create many references to the same object. + * If you plan to directly cache Hadoop writable objects, you should first copy them using + * a `map` function. + */ def hadoopFile[K, V, F <: InputFormat[K, V]]( path: String, inputFormatClass: Class[F], diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala index a4efafcb27..cba4aaffe2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/CommandUtils.scala @@ -29,7 +29,7 @@ import org.apache.spark.launcher.WorkerCommandBuilder import org.apache.spark.util.Utils /** - ** Utilities for running commands with the spark classpath. + * Utilities for running commands with the spark classpath. */ private[deploy] object CommandUtils extends Logging { diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index e5ebc63082..7bc1eb0436 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -29,10 +29,12 @@ import org.apache.spark.serializer.Serializer import org.apache.spark.util.collection.{CompactBuffer, ExternalAppendOnlyMap} import org.apache.spark.util.Utils -/** The references to rdd and splitIndex are transient because redundant information is stored - * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from - * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the - * task closure. */ +/** + * The references to rdd and splitIndex are transient because redundant information is stored + * in the CoGroupedRDD object. Because CoGroupedRDD is serialized separately from + * CoGroupPartition, if rdd and splitIndex aren't transient, they'll be included twice in the + * task closure. + */ private[spark] case class NarrowCoGroupSplitDep( @transient rdd: RDD[_], @transient splitIndex: Int, diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index f96551c793..4a0a2199ef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -255,8 +255,8 @@ abstract class RDD[T: ClassTag]( } /** - * Returns the number of partitions of this RDD. - */ + * Returns the number of partitions of this RDD. + */ @Since("1.6.0") final def getNumPartitions: Int = partitions.length diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 90b1813750..50b452c72f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -295,12 +295,12 @@ private[spark] class CoarseMesosSchedulerBackend( } /** - * Launches executors on accepted offers, and declines unused offers. Executors are launched - * round-robin on offers. - * - * @param d SchedulerDriver - * @param offers Mesos offers that match attribute constraints - */ + * Launches executors on accepted offers, and declines unused offers. Executors are launched + * round-robin on offers. + * + * @param d SchedulerDriver + * @param offers Mesos offers that match attribute constraints + */ private def handleMatchedOffers(d: SchedulerDriver, offers: Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { @@ -336,12 +336,12 @@ private[spark] class CoarseMesosSchedulerBackend( } /** - * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize - * per-task memory and IO, tasks are round-robin assigned to offers. - * - * @param offers Mesos offers that match attribute constraints - * @return A map from OfferID to a list of Mesos tasks to launch on that offer - */ + * Returns a map from OfferIDs to the tasks to launch on those offers. In order to maximize + * per-task memory and IO, tasks are round-robin assigned to offers. + * + * @param offers Mesos offers that match attribute constraints + * @return A map from OfferID to a list of Mesos tasks to launch on that offer + */ private def buildMesosTasks(offers: Buffer[Offer]): Map[OfferID, List[MesosTaskInfo]] = { // offerID -> tasks val tasks = new HashMap[OfferID, List[MesosTaskInfo]].withDefaultValue(Nil) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index c41fa58607..73bd4c58e1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -453,12 +453,12 @@ private[spark] class MesosClusterScheduler( } /** - * Escape args for Unix-like shells, unless already quoted by the user. - * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html - * and http://www.grymoire.com/Unix/Quote.html - * @param value argument - * @return escaped argument - */ + * Escape args for Unix-like shells, unless already quoted by the user. + * Based on: http://www.gnu.org/software/bash/manual/html_node/Double-Quotes.html + * and http://www.grymoire.com/Unix/Quote.html + * @param value argument + * @return escaped argument + */ private[scheduler] def shellEscape(value: String): String = { val WrappedInQuotes = """^(".+"|'.+')$""".r val ShellSpecialChars = (""".*([ '<>&|\?\*;!#\\(\)"$`]).*""").r diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 9a12a61f2f..35f914355d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -148,8 +148,8 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } /** - * Signal that the scheduler has registered with Mesos. - */ + * Signal that the scheduler has registered with Mesos. + */ protected def markRegistered(): Unit = { registerLatch.countDown() } diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala index 76fd249fbd..364fad664e 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala @@ -54,9 +54,9 @@ private[spark] trait ShuffleManager { context: TaskContext): ShuffleReader[K, C] /** - * Remove a shuffle's metadata from the ShuffleManager. - * @return true if the metadata removed successfully, otherwise false. - */ + * Remove a shuffle's metadata from the ShuffleManager. + * @return true if the metadata removed successfully, otherwise false. + */ def unregisterShuffle(shuffleId: Int): Boolean /** diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala index df38d11e43..99be4de065 100644 --- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala +++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala @@ -455,16 +455,16 @@ private[spark] class MemoryStore( } /** - * Try to evict blocks to free up a given amount of space to store a particular block. - * Can fail if either the block is bigger than our memory or it would require replacing - * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for - * RDDs that don't fit into memory that we want to avoid). - * - * @param blockId the ID of the block we are freeing space for, if any - * @param space the size of this block - * @param memoryMode the type of memory to free (on- or off-heap) - * @return the amount of memory (in bytes) freed by eviction - */ + * Try to evict blocks to free up a given amount of space to store a particular block. + * Can fail if either the block is bigger than our memory or it would require replacing + * another block from the same RDD (which leads to a wasteful cyclic replacement pattern for + * RDDs that don't fit into memory that we want to avoid). + * + * @param blockId the ID of the block we are freeing space for, if any + * @param space the size of this block + * @param memoryMode the type of memory to free (on- or off-heap) + * @return the amount of memory (in bytes) freed by eviction + */ private[spark] def evictBlocksToFreeSpace( blockId: Option[BlockId], space: Long, diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 73768ff4c8..50bcf85805 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -256,10 +256,11 @@ private[spark] object Utils extends Logging { dir } - /** Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream - * copying is disabled by default unless explicitly set transferToEnabled as true, - * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. - */ + /** + * Copy all data from an InputStream to an OutputStream. NIO way of file stream to file stream + * copying is disabled by default unless explicitly set transferToEnabled as true, + * the parameter transferToEnabled should be configured by spark.file.transferTo = [true|false]. + */ def copyStream(in: InputStream, out: OutputStream, closeStreams: Boolean = false, @@ -1564,9 +1565,11 @@ private[spark] object Utils extends Logging { else -1 } - /** Returns the system properties map that is thread-safe to iterator over. It gets the - * properties which have been set explicitly, as well as those for which only a default value - * has been defined. */ + /** + * Returns the system properties map that is thread-safe to iterator over. It gets the + * properties which have been set explicitly, as well as those for which only a default value + * has been defined. + */ def getSystemProperties: Map[String, String] = { System.getProperties.stringPropertyNames().asScala .map(key => (key, System.getProperty(key))).toMap diff --git a/core/src/test/scala/org/apache/spark/Smuggle.scala b/core/src/test/scala/org/apache/spark/Smuggle.scala index 9f0a1b4c25..9d9217ea1b 100644 --- a/core/src/test/scala/org/apache/spark/Smuggle.scala +++ b/core/src/test/scala/org/apache/spark/Smuggle.scala @@ -24,16 +24,16 @@ import scala.collection.mutable import scala.language.implicitConversions /** - * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. - * This is intended for testing purposes, primarily to make locks, semaphores, and - * other constructs that would not survive serialization available from within tasks. - * A Smuggle reference is itself serializable, but after being serialized and - * deserialized, it still refers to the same underlying "smuggled" object, as long - * as it was deserialized within the same JVM. This can be useful for tests that - * depend on the timing of task completion to be deterministic, since one can "smuggle" - * a lock or semaphore into the task, and then the task can block until the test gives - * the go-ahead to proceed via the lock. - */ + * Utility wrapper to "smuggle" objects into tasks while bypassing serialization. + * This is intended for testing purposes, primarily to make locks, semaphores, and + * other constructs that would not survive serialization available from within tasks. + * A Smuggle reference is itself serializable, but after being serialized and + * deserialized, it still refers to the same underlying "smuggled" object, as long + * as it was deserialized within the same JVM. This can be useful for tests that + * depend on the timing of task completion to be deterministic, since one can "smuggle" + * a lock or semaphore into the task, and then the task can block until the test gives + * the go-ahead to proceed via the lock. + */ class Smuggle[T] private(val key: Symbol) extends Serializable { def smuggledObject: T = Smuggle.get(key) } @@ -41,13 +41,13 @@ class Smuggle[T] private(val key: Symbol) extends Serializable { object Smuggle { /** - * Wraps the specified object to be smuggled into a serialized task without - * being serialized itself. - * - * @param smuggledObject - * @tparam T - * @return Smuggle wrapper around smuggledObject. - */ + * Wraps the specified object to be smuggled into a serialized task without + * being serialized itself. + * + * @param smuggledObject + * @tparam T + * @return Smuggle wrapper around smuggledObject. + */ def apply[T](smuggledObject: T): Smuggle[T] = { val key = Symbol(UUID.randomUUID().toString) lock.writeLock().lock() @@ -72,12 +72,12 @@ object Smuggle { } /** - * Implicit conversion of a Smuggle wrapper to the object being smuggled. - * - * @param smuggle the wrapper to unpack. - * @tparam T - * @return the smuggled object represented by the wrapper. - */ + * Implicit conversion of a Smuggle wrapper to the object being smuggled. + * + * @param smuggle the wrapper to unpack. + * @tparam T + * @return the smuggled object represented by the wrapper. + */ implicit def unpackSmuggledObject[T](smuggle : Smuggle[T]): T = smuggle.smuggledObject } diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala index 3d1a0e9795..99d5b496bc 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryManagerSuite.scala @@ -78,18 +78,18 @@ private[memory] trait MemoryManagerSuite extends SparkFunSuite with BeforeAndAft } /** - * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. - * - * This is a significant simplification of the real method, which actually drops existing - * blocks based on the size of each block. Instead, here we simply release as many bytes - * as needed to ensure the requested amount of free space. This allows us to set up the - * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in - * many other dependencies. - * - * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that - * records the number of bytes this is called with. This variable is expected to be cleared - * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. - */ + * Simulate the part of [[MemoryStore.evictBlocksToFreeSpace]] that releases storage memory. + * + * This is a significant simplification of the real method, which actually drops existing + * blocks based on the size of each block. Instead, here we simply release as many bytes + * as needed to ensure the requested amount of free space. This allows us to set up the + * test without relying on the [[org.apache.spark.storage.BlockManager]], which brings in + * many other dependencies. + * + * Every call to this method will set a global variable, [[evictBlocksToFreeSpaceCalled]], that + * records the number of bytes this is called with. This variable is expected to be cleared + * by the test code later through [[assertEvictBlocksToFreeSpaceCalled]]. + */ private def evictBlocksToFreeSpaceAnswer(mm: MemoryManager): Answer[Long] = { new Answer[Long] { override def answer(invocation: InvocationOnMock): Long = { diff --git a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala index 3da5236745..af5a815f6e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/BroadcastTest.scala @@ -21,8 +21,8 @@ package org.apache.spark.examples import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: BroadcastTest [slices] [numElem] [blockSize] - */ + * Usage: BroadcastTest [slices] [numElem] [blockSize] + */ object BroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala index 743fc13db7..7bf023667d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DFSReadWriteTest.scala @@ -25,16 +25,16 @@ import scala.io.Source._ import org.apache.spark.{SparkConf, SparkContext} /** - * Simple test for reading and writing to a distributed - * file system. This example does the following: - * - * 1. Reads local file - * 2. Computes word count on local file - * 3. Writes local file to a DFS - * 4. Reads the file back from the DFS - * 5. Computes word count on the file using Spark - * 6. Compares the word count results - */ + * Simple test for reading and writing to a distributed + * file system. This example does the following: + * + * 1. Reads local file + * 2. Computes word count on local file + * 3. Writes local file to a DFS + * 4. Reads the file back from the DFS + * 5. Computes word count on the file using Spark + * 6. Compares the word count results + */ object DFSReadWriteTest { private var localFilePath: File = new File(".") diff --git a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala index 08b6c717d4..4db229b5de 100644 --- a/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/GroupByTest.scala @@ -23,8 +23,8 @@ import java.util.Random import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] - */ + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object GroupByTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala index 134c3d1d63..3eb0c27723 100644 --- a/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/MultiBroadcastTest.scala @@ -22,8 +22,8 @@ import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.rdd.RDD /** - * Usage: MultiBroadcastTest [slices] [numElem] - */ + * Usage: MultiBroadcastTest [slices] [numElem] + */ object MultiBroadcastTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index 7c09664c2f..ec07e6323e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -23,8 +23,8 @@ import java.util.Random import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] - */ + * Usage: SimpleSkewedGroupByTest [numMappers] [numKVPairs] [valSize] [numReducers] [ratio] + */ object SimpleSkewedGroupByTest { def main(args: Array[String]) { diff --git a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala index d498af9c39..8e4c2b6229 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SkewedGroupByTest.scala @@ -23,8 +23,8 @@ import java.util.Random import org.apache.spark.{SparkConf, SparkContext} /** - * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] - */ + * Usage: GroupByTest [numMappers] [numKVPairs] [KeySize] [numReducers] + */ object SkewedGroupByTest { def main(args: Array[String]) { val sparkConf = new SparkConf().setAppName("GroupBy Test") diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 50216b9bd4..0ddd065f0d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -38,17 +38,18 @@ object PageView extends Serializable { } // scalastyle:off -/** Generates streaming events to simulate page views on a website. - * - * This should be used in tandem with PageViewStream.scala. Example: - * - * To run the generator - * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` - * To process the generated stream - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` - * - */ +/** + * Generates streaming events to simulate page views on a website. + * + * This should be used in tandem with PageViewStream.scala. Example: + * + * To run the generator + * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` + * To process the generated stream + * `$ bin/run-example \ + * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` + * + */ // scalastyle:on object PageViewGenerator { val pages = Map("http://foo.com/" -> .7, diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 773a2e5fc2..1ba093f57b 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -22,16 +22,17 @@ import org.apache.spark.examples.streaming.StreamingExamples import org.apache.spark.streaming.{Seconds, StreamingContext} // scalastyle:off -/** Analyses a streaming dataset of web page views. This class demonstrates several types of - * operators available in Spark streaming. - * - * This should be used in tandem with PageViewStream.scala. Example: - * To run the generator - * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` - * To process the generated stream - * `$ bin/run-example \ - * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` - */ +/** + * Analyses a streaming dataset of web page views. This class demonstrates several types of + * operators available in Spark streaming. + * + * This should be used in tandem with PageViewStream.scala. Example: + * To run the generator + * `$ bin/run-example org.apache.spark.examples.streaming.clickstream.PageViewGenerator 44444 10` + * To process the generated stream + * `$ bin/run-example \ + * org.apache.spark.examples.streaming.clickstream.PageViewStream errorRatePerZipCode localhost 44444` + */ // scalastyle:on object PageViewStream { def main(args: Array[String]) { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 7dc9606913..6e7c3f358e 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -185,13 +185,14 @@ class FlumeReceiver( override def preferredLocation: Option[String] = Option(host) - /** A Netty Pipeline factory that will decompress incoming data from - * and the Netty client and compress data going back to the client. - * - * The compression on the return is required because Flume requires - * a successful response to indicate it can remove the event/batch - * from the configured channel - */ + /** + * A Netty Pipeline factory that will decompress incoming data from + * and the Netty client and compress data going back to the client. + * + * The compression on the return is required because Flume requires + * a successful response to indicate it can remove the event/batch + * from the configured channel + */ private[streaming] class CompressionChannelPipelineFactory extends ChannelPipelineFactory { def getPipeline(): ChannelPipeline = { diff --git a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala index a660d2a00c..02917becf0 100644 --- a/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala +++ b/external/kafka/src/main/scala/org/apache/spark/streaming/kafka/KafkaRDDPartition.scala @@ -19,13 +19,14 @@ package org.apache.spark.streaming.kafka import org.apache.spark.Partition -/** @param topic kafka topic name - * @param partition kafka partition id - * @param fromOffset inclusive starting offset - * @param untilOffset exclusive ending offset - * @param host preferred kafka host, i.e. the leader at the time the rdd was created - * @param port preferred kafka host's port - */ +/** + * @param topic kafka topic name + * @param partition kafka partition id + * @param fromOffset inclusive starting offset + * @param untilOffset exclusive ending offset + * @param host preferred kafka host, i.e. the leader at the time the rdd was created + * @param port preferred kafka host's port + */ private[kafka] class KafkaRDDPartition( val index: Int, diff --git a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala index a783fe305f..868658dfe5 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala @@ -415,11 +415,11 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali } /** - * Compute the connected component membership of each vertex and return a graph with the vertex - * value containing the lowest vertex id in the connected component containing that vertex. - * - * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] - */ + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @see [[org.apache.spark.graphx.lib.ConnectedComponents$#run]] + */ def connectedComponents(): Graph[VertexId, ED] = { ConnectedComponents.run(graph) } diff --git a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala index 137c512c99..4e9b13162e 100644 --- a/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala +++ b/graphx/src/main/scala/org/apache/spark/graphx/lib/ConnectedComponents.scala @@ -60,15 +60,15 @@ object ConnectedComponents { } // end of connectedComponents /** - * Compute the connected component membership of each vertex and return a graph with the vertex - * value containing the lowest vertex id in the connected component containing that vertex. - * - * @tparam VD the vertex attribute type (discarded in the computation) - * @tparam ED the edge attribute type (preserved in the computation) - * @param graph the graph for which to compute the connected components - * @return a graph with vertex attributes containing the smallest vertex in each - * connected component - */ + * Compute the connected component membership of each vertex and return a graph with the vertex + * value containing the lowest vertex id in the connected component containing that vertex. + * + * @tparam VD the vertex attribute type (discarded in the computation) + * @tparam ED the edge attribute type (preserved in the computation) + * @param graph the graph for which to compute the connected components + * @return a graph with vertex attributes containing the smallest vertex in each + * connected component + */ def run[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]): Graph[VertexId, ED] = { run(graph, Int.MaxValue) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala index 2c7ffdb7ba..1b0a9a12e8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala @@ -38,9 +38,9 @@ class ElementwiseProduct(override val uid: String) def this() = this(Identifiable.randomUID("elemProd")) /** - * the vector to multiply with input vectors - * @group param - */ + * the vector to multiply with input vectors + * @group param + */ val scalingVec: Param[Vector] = new Param(this, "scalingVec", "vector for hadamard product") /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala index a689b09341..364d5eea08 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/GaussianMixtureModelWrapper.scala @@ -24,15 +24,15 @@ import org.apache.spark.mllib.clustering.GaussianMixtureModel import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * Wrapper around GaussianMixtureModel to provide helper methods in Python - */ + * Wrapper around GaussianMixtureModel to provide helper methods in Python + */ private[python] class GaussianMixtureModelWrapper(model: GaussianMixtureModel) { val weights: Vector = Vectors.dense(model.weights) val k: Int = weights.size /** - * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian - */ + * Returns gaussians as a List of Vectors and Matrices corresponding each MultivariateGaussian + */ val gaussians: Array[Byte] = { val modelGaussians = model.gaussians.map { gaussian => Array[Any](gaussian.mu, gaussian.sigma) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala index 073f03e16f..05273c3434 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/Word2VecModelWrapper.scala @@ -27,8 +27,8 @@ import org.apache.spark.mllib.feature.Word2VecModel import org.apache.spark.mllib.linalg.{Vector, Vectors} /** - * Wrapper around Word2VecModel to provide helper methods in Python - */ + * Wrapper around Word2VecModel to provide helper methods in Python + */ private[python] class Word2VecModelWrapper(model: Word2VecModel) { def transform(word: String): Vector = { model.transform(word) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index 6e571fe35a..8c09b69b3c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -123,14 +123,18 @@ sealed trait Matrix extends Serializable { @Since("1.4.0") def toString(maxLines: Int, maxLineWidth: Int): String = toBreeze.toString(maxLines, maxLineWidth) - /** Map the values of this matrix using a function. Generates a new matrix. Performs the - * function on only the backing array. For example, an operation such as addition or - * subtraction will only be performed on the non-zero values in a `SparseMatrix`. */ + /** + * Map the values of this matrix using a function. Generates a new matrix. Performs the + * function on only the backing array. For example, an operation such as addition or + * subtraction will only be performed on the non-zero values in a `SparseMatrix`. + */ private[spark] def map(f: Double => Double): Matrix - /** Update all the values of this matrix using the function f. Performed in-place on the - * backing array. For example, an operation such as addition or subtraction will only be - * performed on the non-zero values in a `SparseMatrix`. */ + /** + * Update all the values of this matrix using the function f. Performed in-place on the + * backing array. For example, an operation such as addition or subtraction will only be + * performed on the non-zero values in a `SparseMatrix`. + */ private[mllib] def update(f: Double => Double): Matrix /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala index e8f4422fd4..84764963b5 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/StreamingLinearRegressionWithSGD.scala @@ -81,8 +81,8 @@ class StreamingLinearRegressionWithSGD private[mllib] ( } /** - * Set the number of iterations of gradient descent to run per update. Default: 50. - */ + * Set the number of iterations of gradient descent to run per update. Default: 50. + */ @Since("1.1.0") def setNumIterations(numIterations: Int): this.type = { this.algorithm.optimizer.setNumIterations(numIterations) diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala index 67a616dc15..c5dc6ba221 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -797,9 +797,11 @@ class SparkILoop( // echo("Switched " + (if (old) "off" else "on") + " result printing.") } - /** Run one command submitted by the user. Two values are returned: - * (1) whether to keep running, (2) the line to record for replay, - * if any. */ + /** + * Run one command submitted by the user. Two values are returned: + * (1) whether to keep running, (2) the line to record for replay, + * if any. + */ private[repl] def command(line: String): Result = { if (line startsWith ":") { val cmd = line.tail takeWhile (x => !x.isWhitespace) @@ -841,12 +843,13 @@ class SparkILoop( } import paste.{ ContinueString, PromptString } - /** Interpret expressions starting with the first line. - * Read lines until a complete compilation unit is available - * or until a syntax error has been seen. If a full unit is - * read, go ahead and interpret it. Return the full string - * to be recorded for replay, if any. - */ + /** + * Interpret expressions starting with the first line. + * Read lines until a complete compilation unit is available + * or until a syntax error has been seen. If a full unit is + * read, go ahead and interpret it. Return the full string + * to be recorded for replay, if any. + */ private def interpretStartingWith(code: String): Option[String] = { // signal completion non-completion input has been received in.completion.resetVerbosity() diff --git a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala index 1d0fe10d3d..f22776592c 100644 --- a/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala +++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala @@ -118,8 +118,9 @@ private[repl] trait SparkImports { case class ReqAndHandler(req: Request, handler: MemberHandler) { } def reqsToUse: List[ReqAndHandler] = { - /** Loop through a list of MemberHandlers and select which ones to keep. - * 'wanted' is the set of names that need to be imported. + /** + * Loop through a list of MemberHandlers and select which ones to keep. + * 'wanted' is the set of names that need to be imported. */ def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = { // Single symbol imports might be implicits! See bug #1752. Rather than diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index 1f20e26354..e0bfe3c32f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -140,27 +140,27 @@ object Encoders { def STRING: Encoder[java.lang.String] = ExpressionEncoder() /** - * An encoder for nullable decimal type. - * @since 1.6.0 - */ + * An encoder for nullable decimal type. + * @since 1.6.0 + */ def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() /** - * An encoder for nullable date type. - * @since 1.6.0 - */ + * An encoder for nullable date type. + * @since 1.6.0 + */ def DATE: Encoder[java.sql.Date] = ExpressionEncoder() /** - * An encoder for nullable timestamp type. - * @since 1.6.0 - */ + * An encoder for nullable timestamp type. + * @since 1.6.0 + */ def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() /** - * An encoder for arrays of bytes. - * @since 1.6.1 - */ + * An encoder for arrays of bytes. + * @since 1.6.1 + */ def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 05e2b9a447..a6e317ebf0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -733,9 +733,9 @@ class Analyzer( } /** - * Add the missing attributes into projectList of Project/Window or aggregateExpressions of - * Aggregate. - */ + * Add the missing attributes into projectList of Project/Window or aggregateExpressions of + * Aggregate. + */ private def addMissingAttr(plan: LogicalPlan, missingAttrs: AttributeSet): LogicalPlan = { if (missingAttrs.isEmpty) { return plan @@ -767,9 +767,9 @@ class Analyzer( } /** - * Resolve the expression on a specified logical plan and it's child (recursively), until - * the expression is resolved or meet a non-unary node or Subquery. - */ + * Resolve the expression on a specified logical plan and it's child (recursively), until + * the expression is resolved or meet a non-unary node or Subquery. + */ @tailrec private def resolveExpressionRecursively(expr: Expression, plan: LogicalPlan): Expression = { val resolved = resolveExpression(expr, plan) @@ -1398,8 +1398,8 @@ class Analyzer( } /** - * Check and add order to [[AggregateWindowFunction]]s. - */ + * Check and add order to [[AggregateWindowFunction]]s. + */ object ResolveWindowOrder extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case logical: LogicalPlan => logical transformExpressions { @@ -1489,8 +1489,8 @@ object EliminateSubqueryAliases extends Rule[LogicalPlan] { } /** - * Removes [[Union]] operators from the plan if it just has one child. - */ + * Removes [[Union]] operators from the plan if it just has one child. + */ object EliminateUnions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { case Union(children) if children.size == 1 => children.head diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index 053e612f3e..354311c5e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -136,9 +136,9 @@ object UnsafeProjection { } /** - * Same as other create()'s but allowing enabling/disabling subexpression elimination. - * TODO: refactor the plumbing and clean this up. - */ + * Same as other create()'s but allowing enabling/disabling subexpression elimination. + * TODO: refactor the plumbing and clean this up. + */ def create( exprs: Seq[Expression], inputSchema: Seq[Attribute], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index cd490dd676..b64d3eea49 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -58,10 +58,10 @@ class CodegenContext { val references: mutable.ArrayBuffer[Any] = new mutable.ArrayBuffer[Any]() /** - * Add an object to `references`, create a class member to access it. - * - * Returns the name of class member. - */ + * Add an object to `references`, create a class member to access it. + * + * Returns the name of class member. + */ def addReferenceObj(name: String, obj: Any, className: String = null): String = { val term = freshName(name) val idx = references.length @@ -72,9 +72,9 @@ class CodegenContext { } /** - * Holding a list of generated columns as input of current operator, will be used by - * BoundReference to generate code. - */ + * Holding a list of generated columns as input of current operator, will be used by + * BoundReference to generate code. + */ var currentVars: Seq[ExprCode] = null /** @@ -169,14 +169,14 @@ class CodegenContext { final var INPUT_ROW = "i" /** - * The map from a variable name to it's next ID. - */ + * The map from a variable name to it's next ID. + */ private val freshNameIds = new mutable.HashMap[String, Int] freshNameIds += INPUT_ROW -> 1 /** - * A prefix used to generate fresh name. - */ + * A prefix used to generate fresh name. + */ var freshNamePrefix = "" /** @@ -234,8 +234,8 @@ class CodegenContext { } /** - * Update a column in MutableRow from ExprCode. - */ + * Update a column in MutableRow from ExprCode. + */ def updateColumn( row: String, dataType: DataType, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala index 437e417266..3be761c867 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala @@ -22,8 +22,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.types._ /** - * A placeholder expression for cube/rollup, which will be replaced by analyzer - */ + * A placeholder expression for cube/rollup, which will be replaced by analyzer + */ trait GroupingSet extends Expression with CodegenFallback { def groupByExprs: Seq[Expression] @@ -43,9 +43,9 @@ case class Cube(groupByExprs: Seq[Expression]) extends GroupingSet {} case class Rollup(groupByExprs: Seq[Expression]) extends GroupingSet {} /** - * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. - * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. - */ + * Indicates whether a specified column expression in a GROUP BY list is aggregated or not. + * GROUPING returns 1 for aggregated or 0 for not aggregated in the result set. + */ case class Grouping(child: Expression) extends Expression with Unevaluable { override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) override def children: Seq[Expression] = child :: Nil @@ -54,10 +54,10 @@ case class Grouping(child: Expression) extends Expression with Unevaluable { } /** - * GroupingID is a function that computes the level of grouping. - * - * If groupByExprs is empty, it means all grouping expressions in GroupingSets. - */ + * GroupingID is a function that computes the level of grouping. + * + * If groupByExprs is empty, it means all grouping expressions in GroupingSets. + */ case class GroupingID(groupByExprs: Seq[Expression]) extends Expression with Unevaluable { override def references: AttributeSet = AttributeSet(VirtualColumn.groupingIdAttribute :: Nil) override def children: Seq[Expression] = groupByExprs diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index e8a3e129b4..eb8dc1423a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -467,8 +467,8 @@ object Murmur3HashFunction extends InterpretedHashFunction { } /** - * Print the result of an expression to stderr (used for debugging codegen). - */ + * Print the result of an expression to stderr (used for debugging codegen). + */ case class PrintToStderr(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index a5ab390c76..69b09bcb35 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -31,9 +31,9 @@ import org.apache.spark.sql.catalyst.rules._ import org.apache.spark.sql.types._ /** - * Abstract class all optimizers should inherit of, contains the standard batches (extending - * Optimizers can override this. - */ + * Abstract class all optimizers should inherit of, contains the standard batches (extending + * Optimizers can override this. + */ abstract class Optimizer extends RuleExecutor[LogicalPlan] { def batches: Seq[Batch] = { // Technically some of the rules in Finish Analysis are not optimizer rules and belong more @@ -111,11 +111,11 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { } /** - * Non-abstract representation of the standard Spark optimizing strategies - * - * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while - * specific rules go to the subclasses - */ + * Non-abstract representation of the standard Spark optimizing strategies + * + * To ensure extendability, we leave the standard rules in the abstract optimizer rules, while + * specific rules go to the subclasses + */ object DefaultOptimizer extends Optimizer /** @@ -962,21 +962,21 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } /** - * Reorder the joins and push all the conditions into join, so that the bottom ones have at least - * one condition. - * - * The order of joins will not be changed if all of them already have at least one condition. - */ + * Reorder the joins and push all the conditions into join, so that the bottom ones have at least + * one condition. + * + * The order of joins will not be changed if all of them already have at least one condition. + */ object ReorderJoin extends Rule[LogicalPlan] with PredicateHelper { /** - * Join a list of plans together and push down the conditions into them. - * - * The joined plan are picked from left to right, prefer those has at least one join condition. - * - * @param input a list of LogicalPlans to join. - * @param conditions a list of condition for join. - */ + * Join a list of plans together and push down the conditions into them. + * + * The joined plan are picked from left to right, prefer those has at least one join condition. + * + * @param input a list of LogicalPlans to join. + * @param conditions a list of condition for join. + */ @tailrec def createOrderedJoin(input: Seq[LogicalPlan], conditions: Seq[Expression]): LogicalPlan = { assert(input.size >= 2) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index c350f3049f..8541b1f7c6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1430,8 +1430,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a [[StructType]] from a sequence of [[StructField]]s. - */ + * Create a [[StructType]] from a sequence of [[StructField]]s. + */ protected def createStructType(ctx: ColTypeListContext): StructType = { StructType(Option(ctx).toSeq.flatMap(visitColTypeList)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 28d2c445b1..6f35d87ebb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -140,20 +140,20 @@ object ExtractEquiJoinKeys extends Logging with PredicateHelper { } /** - * A pattern that collects the filter and inner joins. - * - * Filter - * | - * inner Join - * / \ ----> (Seq(plan0, plan1, plan2), conditions) - * Filter plan2 - * | - * inner join - * / \ - * plan0 plan1 - * - * Note: This pattern currently only works for left-deep trees. - */ + * A pattern that collects the filter and inner joins. + * + * Filter + * | + * inner Join + * / \ ----> (Seq(plan0, plan1, plan2), conditions) + * Filter plan2 + * | + * inner join + * / \ + * plan0 plan1 + * + * Note: This pattern currently only works for left-deep trees. + */ object ExtractFiltersAndInnerJoins extends PredicateHelper { // flatten all inner joins, which are next to each other diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 22a4461e66..609a33e2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -122,8 +122,8 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output)) /** - * The set of all attributes that are produced by this node. - */ + * The set of all attributes that are produced by this node. + */ def producedAttributes: AttributeSet = AttributeSet.empty /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala index be9f1ffa22..d449088498 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala @@ -76,9 +76,9 @@ case class OrderedDistribution(ordering: Seq[SortOrder]) extends Distribution { } /** - * Represents data where tuples are broadcasted to every node. It is quite common that the - * entire set of tuples is transformed into different data structure. - */ + * Represents data where tuples are broadcasted to every node. It is quite common that the + * entire set of tuples is transformed into different data structure. + */ case class BroadcastDistribution(mode: BroadcastMode) extends Distribution /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala index 7e3da6bea7..6e5672ddc3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/OptimizerExtendableSuite.scala @@ -23,21 +23,21 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule /** - * This is a test for SPARK-7727 if the Optimizer is kept being extendable - */ + * This is a test for SPARK-7727 if the Optimizer is kept being extendable + */ class OptimizerExtendableSuite extends SparkFunSuite { /** - * Dummy rule for test batches - */ + * Dummy rule for test batches + */ object DummyRule extends Rule[LogicalPlan] { def apply(p: LogicalPlan): LogicalPlan = p } /** - * This class represents a dummy extended optimizer that takes the batches of the - * Optimizer and adds custom ones. - */ + * This class represents a dummy extended optimizer that takes the batches of the + * Optimizer and adds custom ones. + */ class ExtendedOptimizer extends Optimizer { // rules set to DummyRule, would not be executed anyways diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 221782ee8f..d4290fee0a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -712,13 +712,13 @@ class SQLContext private[sql]( } /** - * :: Experimental :: - * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements - * in an range from `start` to `end` (exclusive) with an step value. - * - * @since 2.0.0 - * @group dataset - */ + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataset + */ @Experimental def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { range(start, end, step, numPartitions = sparkContext.defaultParallelism) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index f3478a873a..124ec09efd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -109,9 +109,10 @@ private[sql] class CacheManager extends Logging { cachedData.remove(dataIndex) } - /** Tries to remove the data for the given [[Dataset]] from the cache - * if it's cached - */ + /** + * Tries to remove the data for the given [[Dataset]] from the cache + * if it's cached + */ private[sql] def tryUncacheQuery( query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index b1b3d4ac81..ff19d1be1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -84,8 +84,8 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ private[sql] def metrics: Map[String, SQLMetric[_, _]] = Map.empty /** - * Reset all the metrics. - */ + * Reset all the metrics. + */ private[sql] def resetMetrics(): Unit = { metrics.valuesIterator.foreach(_.reset()) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 9bdf611f6e..9f539c4929 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -31,8 +31,8 @@ import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.internal.SQLConf /** - * An interface for those physical operators that support codegen. - */ + * An interface for those physical operators that support codegen. + */ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ @@ -46,10 +46,10 @@ trait CodegenSupport extends SparkPlan { } /** - * Creates a metric using the specified name. - * - * @return name of the variable representing the metric - */ + * Creates a metric using the specified name. + * + * @return name of the variable representing the metric + */ def metricTerm(ctx: CodegenContext, name: String): String = { val metric = ctx.addReferenceObj(name, longMetric(name)) val value = ctx.freshName("metricValue") @@ -59,25 +59,25 @@ trait CodegenSupport extends SparkPlan { } /** - * Whether this SparkPlan support whole stage codegen or not. - */ + * Whether this SparkPlan support whole stage codegen or not. + */ def supportCodegen: Boolean = true /** - * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. - */ + * Which SparkPlan is calling produce() of this one. It's itself for the first SparkPlan. + */ protected var parent: CodegenSupport = null /** - * Returns all the RDDs of InternalRow which generates the input rows. - * - * Note: right now we support up to two RDDs. - */ + * Returns all the RDDs of InternalRow which generates the input rows. + * + * Note: right now we support up to two RDDs. + */ def upstreams(): Seq[RDD[InternalRow]] /** - * Returns Java source code to process the rows from upstream. - */ + * Returns Java source code to process the rows from upstream. + */ final def produce(ctx: CodegenContext, parent: CodegenSupport): String = { this.parent = parent ctx.freshNamePrefix = variablePrefix @@ -89,28 +89,28 @@ trait CodegenSupport extends SparkPlan { } /** - * Generate the Java source code to process, should be overridden by subclass to support codegen. - * - * doProduce() usually generate the framework, for example, aggregation could generate this: - * - * if (!initialized) { - * # create a hash map, then build the aggregation hash map - * # call child.produce() - * initialized = true; - * } - * while (hashmap.hasNext()) { - * row = hashmap.next(); - * # build the aggregation results - * # create variables for results - * # call consume(), which will call parent.doConsume() + * Generate the Java source code to process, should be overridden by subclass to support codegen. + * + * doProduce() usually generate the framework, for example, aggregation could generate this: + * + * if (!initialized) { + * # create a hash map, then build the aggregation hash map + * # call child.produce() + * initialized = true; + * } + * while (hashmap.hasNext()) { + * row = hashmap.next(); + * # build the aggregation results + * # create variables for results + * # call consume(), which will call parent.doConsume() * if (shouldStop()) return; - * } - */ + * } + */ protected def doProduce(ctx: CodegenContext): String /** - * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). - */ + * Consume the generated columns or row from current SparkPlan, call it's parent's doConsume(). + */ final def consume(ctx: CodegenContext, outputVars: Seq[ExprCode], row: String = null): String = { val inputVars = if (row != null) { @@ -158,9 +158,9 @@ trait CodegenSupport extends SparkPlan { } /** - * Returns source code to evaluate all the variables, and clear the code of them, to prevent - * them to be evaluated twice. - */ + * Returns source code to evaluate all the variables, and clear the code of them, to prevent + * them to be evaluated twice. + */ protected def evaluateVariables(variables: Seq[ExprCode]): String = { val evaluate = variables.filter(_.code != "").map(_.code.trim).mkString("\n") variables.foreach(_.code = "") @@ -168,9 +168,9 @@ trait CodegenSupport extends SparkPlan { } /** - * Returns source code to evaluate the variables for required attributes, and clear the code - * of evaluated variables, to prevent them to be evaluated twice.. - */ + * Returns source code to evaluate the variables for required attributes, and clear the code + * of evaluated variables, to prevent them to be evaluated twice.. + */ protected def evaluateRequiredVariables( attributes: Seq[Attribute], variables: Seq[ExprCode], @@ -194,18 +194,18 @@ trait CodegenSupport extends SparkPlan { def usedInputs: AttributeSet = references /** - * Generate the Java source code to process the rows from child SparkPlan. - * - * This should be override by subclass to support codegen. - * - * For example, Filter will generate the code like this: - * - * # code to evaluate the predicate expression, result is isNull1 and value2 - * if (isNull1 || !value2) continue; - * # call consume(), which will call parent.doConsume() - * - * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). - */ + * Generate the Java source code to process the rows from child SparkPlan. + * + * This should be override by subclass to support codegen. + * + * For example, Filter will generate the code like this: + * + * # code to evaluate the predicate expression, result is isNull1 and value2 + * if (isNull1 || !value2) continue; + * # call consume(), which will call parent.doConsume() + * + * Note: A plan can either consume the rows as UnsafeRow (row), or a list of variables (input). + */ def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { throw new UnsupportedOperationException } @@ -213,11 +213,11 @@ trait CodegenSupport extends SparkPlan { /** - * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. - * - * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes - * an RDD iterator of InternalRow. - */ + * InputAdapter is used to hide a SparkPlan from a subtree that support codegen. + * + * This is the leaf node of a tree with WholeStageCodegen, is used to generate code that consumes + * an RDD iterator of InternalRow. + */ case class InputAdapter(child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -260,33 +260,33 @@ object WholeStageCodegen { } /** - * WholeStageCodegen compile a subtree of plans that support codegen together into single Java - * function. - * - * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): - * - * WholeStageCodegen Plan A FakeInput Plan B - * ========================================================================= - * - * -> execute() - * | - * doExecute() ---------> upstreams() -------> upstreams() ------> execute() - * | - * +-----------------> produce() - * | - * doProduce() -------> produce() - * | - * doProduce() - * | - * doConsume() <--------- consume() - * | - * doConsume() <-------- consume() - * - * SparkPlan A should override doProduce() and doConsume(). - * - * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, - * used to generated code for BoundReference. - */ + * WholeStageCodegen compile a subtree of plans that support codegen together into single Java + * function. + * + * Here is the call graph of to generate Java source (plan A support codegen, but plan B does not): + * + * WholeStageCodegen Plan A FakeInput Plan B + * ========================================================================= + * + * -> execute() + * | + * doExecute() ---------> upstreams() -------> upstreams() ------> execute() + * | + * +-----------------> produce() + * | + * doProduce() -------> produce() + * | + * doProduce() + * | + * doConsume() <--------- consume() + * | + * doConsume() <-------- consume() + * + * SparkPlan A should override doProduce() and doConsume(). + * + * doCodeGen() will create a CodeGenContext, which will hold a list of variables for input, + * used to generated code for BoundReference. + */ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -422,8 +422,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup /** - * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. - */ + * Find the chained plans that support codegen, collapse them together as WholeStageCodegen. + */ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { private def supportCodegen(e: Expression): Boolean = e match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 7d0567842c..806089196c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -444,8 +444,8 @@ private[execution] final case class RangeBoundOrdering( } /** - * The interface of row buffer for a partition - */ + * The interface of row buffer for a partition + */ private[execution] abstract class RowBuffer { /** Number of rows. */ @@ -462,8 +462,8 @@ private[execution] abstract class RowBuffer { } /** - * A row buffer based on ArrayBuffer (the number of rows is limited) - */ + * A row buffer based on ArrayBuffer (the number of rows is limited) + */ private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends RowBuffer { private[this] var cursor: Int = -1 @@ -493,8 +493,8 @@ private[execution] class ArrayRowBuffer(buffer: ArrayBuffer[UnsafeRow]) extends } /** - * An external buffer of rows based on UnsafeExternalSorter - */ + * An external buffer of rows based on UnsafeExternalSorter + */ private[execution] class ExternalRowBuffer(sorter: UnsafeExternalSorter, numFields: Int) extends RowBuffer { @@ -654,12 +654,16 @@ private[execution] final class SlidingWindowFunctionFrame( /** The rows within current sliding window. */ private[this] val buffer = new util.ArrayDeque[InternalRow]() - /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ private[this] var inputHighIndex = 0 - /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ private[this] var inputLowIndex = 0 /** Prepare the frame for calculating a new partition. Reset all variables. */ @@ -763,8 +767,10 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( /** The next row from `input`. */ private[this] var nextRow: InternalRow = null - /** Index of the first input row with a value greater than the upper bound of the current - * output row. */ + /** + * Index of the first input row with a value greater than the upper bound of the current + * output row. + */ private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ @@ -819,8 +825,10 @@ private[execution] final class UnboundedFollowingWindowFunctionFrame( /** Rows of the partition currently being processed. */ private[this] var input: RowBuffer = null - /** Index of the first input row with a value equal to or greater than the lower bound of the - * current output row. */ + /** + * Index of the first input row with a value equal to or greater than the lower bound of the + * current output row. + */ private[this] var inputIndex = 0 /** Prepare the frame for calculating a new partition. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 15627a7004..042c731901 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -47,17 +47,17 @@ abstract class AggregationIterator( /////////////////////////////////////////////////////////////////////////// /** - * The following combinations of AggregationMode are supported: - * - Partial - * - PartialMerge (for single distinct) - * - Partial and PartialMerge (for single distinct) - * - Final - * - Complete (for SortBasedAggregate with functions that does not support Partial) - * - Final and Complete (currently not used) - * - * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression - * could have a flag to tell it's final or not. - */ + * The following combinations of AggregationMode are supported: + * - Partial + * - PartialMerge (for single distinct) + * - Partial and PartialMerge (for single distinct) + * - Final + * - Complete (for SortBasedAggregate with functions that does not support Partial) + * - Final and Complete (currently not used) + * + * TODO: AggregateMode should have only two modes: Update and Merge, AggregateExpression + * could have a flag to tell it's final or not. + */ { val modes = aggregateExpressions.map(_.mode).distinct.toSet require(modes.size <= 2, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index 8f974980bb..de1491d357 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -46,9 +46,9 @@ class SortBasedAggregationIterator( newMutableProjection) { /** - * Creates a new aggregation buffer and initializes buffer values - * for all aggregate functions. - */ + * Creates a new aggregation buffer and initializes buffer values + * for all aggregate functions. + */ private def newBuffer: MutableRow = { val bufferSchema = aggregateFunctions.flatMap(_.aggBufferAttributes) val bufferRowSize: Int = bufferSchema.length diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 7c215d1b96..60027edc7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -266,8 +266,8 @@ case class TungstenAggregate( private var sorterTerm: String = _ /** - * This is called by generated Java class, should be public. - */ + * This is called by generated Java class, should be public. + */ def createHashMap(): UnsafeFixedWidthAggregationMap = { // create initialized aggregate buffer val initExpr = declFunctions.flatMap(f => f.initialValues) @@ -286,15 +286,15 @@ case class TungstenAggregate( } /** - * This is called by generated Java class, should be public. - */ + * This is called by generated Java class, should be public. + */ def createUnsafeJoiner(): UnsafeRowJoiner = { GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) } /** - * Called by generated Java class to finish the aggregate and return a KVIterator. - */ + * Called by generated Java class to finish the aggregate and return a KVIterator. + */ def finishAggregate( hashMap: UnsafeFixedWidthAggregationMap, sorter: UnsafeKVExternalSorter): KVIterator[UnsafeRow, UnsafeRow] = { @@ -372,8 +372,8 @@ case class TungstenAggregate( } /** - * Generate the code for output. - */ + * Generate the code for output. + */ private def generateResultCode( ctx: CodegenContext, keyTerm: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index f3514cd14c..159fdc99dd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -168,10 +168,10 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( private[this] var reader: RecordReader[Void, V] = null /** - * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this - * fails (for example, unsupported schema), try with the normal reader. - * TODO: plumb this through a different way? - */ + * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this + * fails (for example, unsupported schema), try with the normal reader. + * TODO: plumb this through a different way? + */ if (enableVectorizedParquetReader && format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { val parquetReader: VectorizedParquetRecordReader = new VectorizedParquetRecordReader() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala index 797f740dc5..ea843a1013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchema.scala @@ -33,11 +33,11 @@ import org.apache.spark.unsafe.types.UTF8String private[csv] object CSVInferSchema { /** - * Similar to the JSON schema inference - * 1. Infer type of each row - * 2. Merge row types to find common type - * 3. Replace any null types with string type - */ + * Similar to the JSON schema inference + * 1. Infer type of each row + * 2. Merge row types to find common type + * 3. Replace any null types with string type + */ def infer( tokenRdd: RDD[Array[String]], header: Array[String], @@ -75,9 +75,9 @@ private[csv] object CSVInferSchema { } /** - * Infer type of string field. Given known type Double, and a string "1", there is no - * point checking if it is an Int, as the final type must be Double or higher. - */ + * Infer type of string field. Given known type Double, and a string "1", there is no + * point checking if it is an Int, as the final type must be Double or higher. + */ def inferField(typeSoFar: DataType, field: String, nullValue: String = ""): DataType = { if (field == null || field.isEmpty || field == nullValue) { typeSoFar @@ -142,9 +142,9 @@ private[csv] object CSVInferSchema { private val numericPrecedence: IndexedSeq[DataType] = HiveTypeCoercion.numericPrecedence /** - * Copied from internal Spark api - * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] - */ + * Copied from internal Spark api + * [[org.apache.spark.sql.catalyst.analysis.HiveTypeCoercion]] + */ val findTightestCommonType: (DataType, DataType) => Option[DataType] = { case (t1, t2) if t1 == t2 => Some(t1) case (NullType, t1) => Some(t1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala index c0d6f6fbf7..34fcbdf871 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/DefaultSource.scala @@ -38,8 +38,8 @@ import org.apache.spark.util.SerializableConfiguration import org.apache.spark.util.collection.BitSet /** - * Provides access to CSV data from pure SQL statements. - */ + * Provides access to CSV data from pure SQL statements. + */ class DefaultSource extends FileFormat with DataSourceRegister { override def shortName(): String = "csv" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala index 877e159fbd..2e88d588be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ddl.scala @@ -51,11 +51,11 @@ case class DescribeCommand( } /** - * Used to represent the operation of create table using a data source. + * Used to represent the operation of create table using a data source. * - * @param allowExisting If it is true, we will do nothing when the table already exists. - * If it is false, an exception will be thrown - */ + * @param allowExisting If it is true, we will do nothing when the table already exists. + * If it is false, an exception will be thrown + */ case class CreateTableUsing( tableIdent: TableIdentifier, userSpecifiedSchema: Option[StructType], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 0ed1ed41b0..41e566c27b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -122,8 +122,8 @@ case class BroadcastHashJoin( } /** - * Returns a tuple of Broadcast of HashedRelation and the variable name for it. - */ + * Returns a tuple of Broadcast of HashedRelation and the variable name for it. + */ private def prepareBroadcast(ctx: CodegenContext): (Broadcast[HashedRelation], String) = { // create a name for HashedRelation val broadcastRelation = buildPlan.executeBroadcast[HashedRelation]() @@ -139,9 +139,9 @@ case class BroadcastHashJoin( } /** - * Returns the code for generating join key for stream side, and expression of whether the key - * has any null in it or not. - */ + * Returns the code for generating join key for stream side, and expression of whether the key + * has any null in it or not. + */ private def genStreamSideJoinKey( ctx: CodegenContext, input: Seq[ExprCode]): (ExprCode, String) = { @@ -160,8 +160,8 @@ case class BroadcastHashJoin( } /** - * Generates the code for variable of build side. - */ + * Generates the code for variable of build side. + */ private def genBuildSideVars(ctx: CodegenContext, matched: String): Seq[ExprCode] = { ctx.currentVars = null ctx.INPUT_ROW = matched @@ -188,8 +188,8 @@ case class BroadcastHashJoin( } /** - * Generates the code for Inner join. - */ + * Generates the code for Inner join. + */ private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) @@ -254,8 +254,8 @@ case class BroadcastHashJoin( /** - * Generates the code for left or right outer join. - */ + * Generates the code for left or right outer join. + */ private def codegenOuter(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala index fb65b50da8..edb4c5a16f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/CartesianProduct.scala @@ -28,10 +28,10 @@ import org.apache.spark.util.CompletionIterator import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter /** - * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, - * will be much faster than building the right partition for every row in left RDD, it also - * materialize the right RDD (in case of the right RDD is nondeterministic). - */ + * An optimized CartesianRDD for UnsafeRow, which will cache the rows from second child RDD, + * will be much faster than building the right partition for every row in left RDD, it also + * materialize the right RDD (in case of the right RDD is nondeterministic). + */ private[spark] class UnsafeCartesianRDD(left : RDD[UnsafeRow], right : RDD[UnsafeRow], numFieldsOfRight: Int) extends CartesianRDD[UnsafeRow, UnsafeRow](left.sparkContext, left, right) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 5f42d07273..c298b7dee0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -64,10 +64,10 @@ trait HashJoin { } /** - * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. - * - * If not, returns the original expressions. - */ + * Try to rewrite the key as LongType so we can use getLong(), if they key can fit with a long. + * + * If not, returns the original expressions. + */ def rewriteKeyExpr(keys: Seq[Expression]): Seq[Expression] = { var keyExpr: Expression = null var width = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index dc4793e85a..91c470d187 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -38,20 +38,20 @@ import org.apache.spark.util.collection.CompactBuffer */ private[execution] sealed trait HashedRelation { /** - * Returns matched rows. - */ + * Returns matched rows. + */ def get(key: InternalRow): Seq[InternalRow] /** - * Returns matched rows for a key that has only one column with LongType. - */ + * Returns matched rows for a key that has only one column with LongType. + */ def get(key: Long): Seq[InternalRow] = { throw new UnsupportedOperationException } /** - * Returns the size of used memory. - */ + * Returns the size of used memory. + */ def getMemorySize: Long = 1L // to make the test happy /** @@ -77,20 +77,20 @@ private[execution] sealed trait HashedRelation { } /** - * Interface for a hashed relation that have only one row per key. - * - * We should call getValue() for better performance. - */ + * Interface for a hashed relation that have only one row per key. + * + * We should call getValue() for better performance. + */ private[execution] trait UniqueHashedRelation extends HashedRelation { /** - * Returns the matched single row. - */ + * Returns the matched single row. + */ def getValue(key: InternalRow): InternalRow /** - * Returns the matched single row with key that have only one column of LongType. - */ + * Returns the matched single row with key that have only one column of LongType. + */ def getValue(key: Long): InternalRow = { throw new UnsupportedOperationException } @@ -345,8 +345,8 @@ private[joins] object UnsafeHashedRelation { } /** - * An interface for a hashed relation that the key is a Long. - */ + * An interface for a hashed relation that the key is a Long. + */ private[joins] trait LongHashedRelation extends HashedRelation { override def get(key: InternalRow): Seq[InternalRow] = { get(key.getLong(0)) @@ -396,26 +396,26 @@ private[joins] final class UniqueLongHashedRelation( } /** - * A relation that pack all the rows into a byte array, together with offsets and sizes. - * - * All the bytes of UnsafeRow are packed together as `bytes`: - * - * [ Row0 ][ Row1 ][] ... [ RowN ] - * - * With keys: - * - * start start+1 ... start+N - * - * `offsets` are offsets of UnsafeRows in the `bytes` - * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. - * - * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: - * - * start = 3 - * offsets = [0, 0, 24] - * sizes = [24, 0, 32] - * bytes = [0 - 24][][24 - 56] - */ + * A relation that pack all the rows into a byte array, together with offsets and sizes. + * + * All the bytes of UnsafeRow are packed together as `bytes`: + * + * [ Row0 ][ Row1 ][] ... [ RowN ] + * + * With keys: + * + * start start+1 ... start+N + * + * `offsets` are offsets of UnsafeRows in the `bytes` + * `sizes` are the numbers of bytes of UnsafeRows, 0 means no row for this key. + * + * For example, two UnsafeRows (24 bytes and 32 bytes), with keys as 3 and 5 will stored as: + * + * start = 3 + * offsets = [0, 0, 24] + * sizes = [24, 0, 32] + * bytes = [0 - 24][][24 - 56] + */ private[joins] final class LongArrayRelation( private var numFields: Int, private var start: Long, @@ -483,8 +483,8 @@ private[joins] final class LongArrayRelation( } /** - * Create hashed relation with key that is long. - */ + * Create hashed relation with key that is long. + */ private[joins] object LongHashedRelation { val DENSE_FACTOR = 0.2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala index 60bd8ea39a..0e7b2f2f31 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/SortMergeJoin.scala @@ -256,9 +256,9 @@ case class SortMergeJoin( } /** - * Generate a function to scan both left and right to find a match, returns the term for - * matched one row from left side and buffered rows from right side. - */ + * Generate a function to scan both left and right to find a match, returns the term for + * matched one row from left side and buffered rows from right side. + */ private def genScanner(ctx: CodegenContext): (String, String) = { // Create class member for next row from both sides. val leftRow = ctx.freshName("leftRow") @@ -341,12 +341,12 @@ case class SortMergeJoin( } /** - * Creates variables for left part of result row. - * - * In order to defer the access after condition and also only access once in the loop, - * the variables should be declared separately from accessing the columns, we can't use the - * codegen of BoundReference here. - */ + * Creates variables for left part of result row. + * + * In order to defer the access after condition and also only access once in the loop, + * the variables should be declared separately from accessing the columns, we can't use the + * codegen of BoundReference here. + */ private def createLeftVars(ctx: CodegenContext, leftRow: String): Seq[ExprCode] = { ctx.INPUT_ROW = leftRow left.output.zipWithIndex.map { case (a, i) => @@ -370,9 +370,9 @@ case class SortMergeJoin( } /** - * Creates the variables for right part of result row, using BoundReference, since the right - * part are accessed inside the loop. - */ + * Creates the variables for right part of result row, using BoundReference, since the right + * part are accessed inside the loop. + */ private def createRightVar(ctx: CodegenContext, rightRow: String): Seq[ExprCode] = { ctx.INPUT_ROW = rightRow right.output.zipWithIndex.map { case (a, i) => @@ -381,12 +381,12 @@ case class SortMergeJoin( } /** - * Splits variables based on whether it's used by condition or not, returns the code to create - * these variables before the condition and after the condition. - * - * Only a few columns are used by condition, then we can skip the accessing of those columns - * that are not used by condition also filtered out by condition. - */ + * Splits variables based on whether it's used by condition or not, returns the code to create + * these variables before the condition and after the condition. + * + * Only a few columns are used by condition, then we can skip the accessing of those columns + * that are not used by condition also filtered out by condition. + */ private def splitVarsByCondition( attributes: Seq[Attribute], variables: Seq[ExprCode]): (String, String) = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 998eb82de1..8ece3c971a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -468,10 +468,10 @@ private[state] class HDFSBackedStateStoreProvider( } /** - * Clean up old snapshots and delta files that are not needed any more. It ensures that last - * few versions of the store can be recovered from the files, so re-executed RDD operations - * can re-apply updates on the past versions of the store. - */ + * Clean up old snapshots and delta files that are not needed any more. It ensures that last + * few versions of the store can be recovered from the files, so re-executed RDD operations + * can re-apply updates on the past versions of the store. + */ private[state] def cleanup(): Unit = { try { val files = fetchFiles() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 24a01f5be1..012b125d6b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -45,8 +45,8 @@ private[ui] case class SparkPlanGraph( } /** - * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. - */ + * All the SparkPlanGraphNodes, including those inside of WholeStageCodegen. + */ val allNodes: Seq[SparkPlanGraphNode] = { nodes.flatMap { case cluster: SparkPlanGraphCluster => cluster.nodes :+ cluster diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index baf947d037..da58ba2add 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -332,95 +332,94 @@ object functions { } /** - * Aggregate function: returns the first value in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def first(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { new First(e.expr, Literal(ignoreNulls)) } /** - * Aggregate function: returns the first value of a column in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def first(columnName: String, ignoreNulls: Boolean): Column = { first(Column(columnName), ignoreNulls) } /** - * Aggregate function: returns the first value in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(e: Column): Column = first(e, ignoreNulls = false) /** - * Aggregate function: returns the first value of a column in a group. - * - * The function by default returns the first values it sees. It will return the first non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the first value of a column in a group. + * + * The function by default returns the first values it sees. It will return the first non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def first(columnName: String): Column = first(Column(columnName)) - /** - * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated - * or not, returns 1 for aggregated or 0 for not aggregated in the result set. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping(e: Column): Column = Column(Grouping(e.expr)) /** - * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated - * or not, returns 1 for aggregated or 0 for not aggregated in the result set. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: indicates whether a specified column in a GROUP BY list is aggregated + * or not, returns 1 for aggregated or 0 for not aggregated in the result set. + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping(columnName: String): Column = grouping(Column(columnName)) /** - * Aggregate function: returns the level of grouping, equals to - * - * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) - * - * Note: the list of columns should match with grouping columns exactly, or empty (means all the - * grouping columns). - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly, or empty (means all the + * grouping columns). + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping_id(cols: Column*): Column = Column(GroupingID(cols.map(_.expr))) /** - * Aggregate function: returns the level of grouping, equals to - * - * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) - * - * Note: the list of columns should match with grouping columns exactly. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the level of grouping, equals to + * + * (grouping(c1) << (n-1)) + (grouping(c2) << (n-2)) + ... + grouping(cn) + * + * Note: the list of columns should match with grouping columns exactly. + * + * @group agg_funcs + * @since 2.0.0 + */ def grouping_id(colName: String, colNames: String*): Column = { grouping_id((Seq(colName) ++ colNames).map(n => Column(n)) : _*) } @@ -442,51 +441,51 @@ object functions { def kurtosis(columnName: String): Column = kurtosis(Column(columnName)) /** - * Aggregate function: returns the last value in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def last(e: Column, ignoreNulls: Boolean): Column = withAggregateFunction { new Last(e.expr, Literal(ignoreNulls)) } /** - * Aggregate function: returns the last value of the column in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 2.0.0 - */ + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 2.0.0 + */ def last(columnName: String, ignoreNulls: Boolean): Column = { last(Column(columnName), ignoreNulls) } /** - * Aggregate function: returns the last value in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the last value in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def last(e: Column): Column = last(e, ignoreNulls = false) /** - * Aggregate function: returns the last value of the column in a group. - * - * The function by default returns the last values it sees. It will return the last non-null - * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. - * - * @group agg_funcs - * @since 1.3.0 - */ + * Aggregate function: returns the last value of the column in a group. + * + * The function by default returns the last values it sees. It will return the last non-null + * value it sees when ignoreNulls is set to true. If all values are null, then null is returned. + * + * @group agg_funcs + * @since 1.3.0 + */ def last(columnName: String): Column = last(Column(columnName), ignoreNulls = false) /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index e8834d052c..14e14710f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -152,19 +152,19 @@ trait StreamSinkProvider { @DeveloperApi trait CreatableRelationProvider { /** - * Creates a relation with the given parameters based on the contents of the given - * DataFrame. The mode specifies the expected behavior of createRelation when - * data already exists. - * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. - * Append mode means that when saving a DataFrame to a data source, if data already exists, - * contents of the DataFrame are expected to be appended to existing data. - * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, - * existing data is expected to be overwritten by the contents of the DataFrame. - * ErrorIfExists mode means that when saving a DataFrame to a data source, - * if data already exists, an exception is expected to be thrown. - * - * @since 1.3.0 - */ + * Creates a relation with the given parameters based on the contents of the given + * DataFrame. The mode specifies the expected behavior of createRelation when + * data already exists. + * Right now, there are three modes, Append, Overwrite, and ErrorIfExists. + * Append mode means that when saving a DataFrame to a data source, if data already exists, + * contents of the DataFrame are expected to be appended to existing data. + * Overwrite mode means that when saving a DataFrame to a data source, if data already exists, + * existing data is expected to be overwritten by the contents of the DataFrame. + * ErrorIfExists mode means that when saving a DataFrame to a data source, + * if data already exists, an exception is expected to be thrown. + * + * @since 1.3.0 + */ def createRelation( sqlContext: SQLContext, mode: SaveMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 854a662cc4..d160f8ab8c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -286,8 +286,8 @@ abstract class QueryTest extends PlanTest { } /** - * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. - */ + * Asserts that a given [[Dataset]] does not have missing inputs in all the analyzed plans. + */ def assertEmptyMissingInput(query: Dataset[_]): Unit = { assert(query.queryExecution.analyzed.missingInput.isEmpty, s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 55906793c0..289e1b6db9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -32,10 +32,10 @@ import org.apache.spark.unsafe.map.BytesToBytesMap import org.apache.spark.util.Benchmark /** - * Benchmark to measure whole stage codegen performance. - * To run this: - * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" - */ + * Benchmark to measure whole stage codegen performance. + * To run this: + * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" + */ class BenchmarkWholeStageCodegen extends SparkFunSuite { lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") .set("spark.sql.shuffle.partitions", "1") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala index dc54883277..aaeecef5f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVParserSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.SparkFunSuite /** - * test cases for StringIteratorReader - */ + * test cases for StringIteratorReader + */ class CSVParserSuite extends SparkFunSuite { private def readAll(iter: Iterator[String]) = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index c1e151d08b..ac37e8e022 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -497,9 +497,10 @@ class StreamingContext private[streaming] ( new TransformedDStream[T](dstreams, sparkContext.clean(transformFunc)) } - /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for - * receiving system events related to streaming. - */ + /** + * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ def addStreamingListener(streamingListener: StreamingListener) { scheduler.listenerBus.addListener(streamingListener) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index 05f4da6fac..922e4a5e4d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -517,9 +517,10 @@ class JavaStreamingContext(val ssc: StreamingContext) extends Closeable { ssc.remember(duration) } - /** Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for - * receiving system events related to streaming. - */ + /** + * Add a [[org.apache.spark.streaming.scheduler.StreamingListener]] object for + * receiving system events related to streaming. + */ def addStreamingListener(streamingListener: StreamingListener) { ssc.addStreamingListener(streamingListener) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala index 0a861f22b1..fbac4880bd 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/RateLimiter.scala @@ -22,17 +22,18 @@ import com.google.common.util.concurrent.{RateLimiter => GuavaRateLimiter} import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -/** Provides waitToPush() method to limit the rate at which receivers consume data. - * - * waitToPush method will block the thread if too many messages have been pushed too quickly, - * and only return when a new message has been pushed. It assumes that only one message is - * pushed at a time. - * - * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages - * per second that each receiver will accept. - * - * @param conf spark configuration - */ +/** + * Provides waitToPush() method to limit the rate at which receivers consume data. + * + * waitToPush method will block the thread if too many messages have been pushed too quickly, + * and only return when a new message has been pushed. It assumes that only one message is + * pushed at a time. + * + * The spark configuration spark.streaming.receiver.maxRate gives the maximum number of messages + * per second that each receiver will accept. + * + * @param conf spark configuration + */ private[receiver] abstract class RateLimiter(conf: SparkConf) extends Logging { // treated as an upper limit diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala index 66d5ffb797..0baedaf275 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobSet.scala @@ -21,9 +21,10 @@ import scala.collection.mutable.HashSet import org.apache.spark.streaming.Time -/** Class representing a set of Jobs - * belong to the same batch. - */ +/** + * Class representing a set of Jobs + * belong to the same batch. + */ private[streaming] case class JobSet( time: Time, diff --git a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala index 0df3c501de..c9058ff409 100644 --- a/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala +++ b/tools/src/main/scala/org/apache/spark/tools/GenerateMIMAIgnore.scala @@ -91,10 +91,11 @@ object GenerateMIMAIgnore { (ignoredClasses.flatMap(c => Seq(c, c.replace("$", "#"))).toSet, ignoredMembers.toSet) } - /** Scala reflection does not let us see inner function even if they are upgraded - * to public for some reason. So had to resort to java reflection to get all inner - * functions with $$ in there name. - */ + /** + * Scala reflection does not let us see inner function even if they are upgraded + * to public for some reason. So had to resort to java reflection to get all inner + * functions with $$ in there name. + */ def getInnerFunctions(classSymbol: unv.ClassSymbol): Seq[String] = { try { Class.forName(classSymbol.fullName, false, classLoader).getMethods.map(_.getName) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index 5af2c29808..4b36da309d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -135,8 +135,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } /** - * Obtains token for the Hive metastore and adds them to the credentials. - */ + * Obtains token for the Hive metastore and adds them to the credentials. + */ def obtainTokenForHiveMetastore( sparkConf: SparkConf, conf: Configuration, @@ -149,8 +149,8 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } /** - * Obtain a security token for HBase. - */ + * Obtain a security token for HBase. + */ def obtainTokenForHBase( sparkConf: SparkConf, conf: Configuration, @@ -164,10 +164,10 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { } /** - * Return whether delegation tokens should be retrieved for the given service when security is - * enabled. By default, tokens are retrieved, but that behavior can be changed by setting - * a service-specific configuration. - */ + * Return whether delegation tokens should be retrieved for the given service when security is + * enabled. By default, tokens are retrieved, but that behavior can be changed by setting + * a service-specific configuration. + */ private def shouldGetTokens(conf: SparkConf, service: String): Boolean = { conf.getBoolean(s"spark.yarn.security.tokens.${service}.enabled", true) } -- cgit v1.2.3 From 7be46205083fc688249ee619ac7758904f7aa55d Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 2 Apr 2016 23:05:23 -0700 Subject: [HOTFIX] Fix Scala 2.10 compilation --- .../spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala index 33239c0084..c02fec3085 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/SimplifyConditionalSuite.scala @@ -41,7 +41,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { private val trueBranch = (TrueLiteral, Literal(5)) private val normalBranch = (NonFoldableLiteral(true), Literal(10)) private val unreachableBranch = (FalseLiteral, Literal(20)) - private val nullBranch = (Literal(null, NullType), Literal(30)) + private val nullBranch = (Literal.create(null, NullType), Literal(30)) test("simplify if") { assertEquivalent( @@ -53,7 +53,7 @@ class SimplifyConditionalSuite extends PlanTest with PredicateHelper { Literal(20)) assertEquivalent( - If(Literal(null, NullType), Literal(10), Literal(20)), + If(Literal.create(null, NullType), Literal(10), Literal(20)), Literal(20)) } -- cgit v1.2.3 From c238cd07448f94bbceb661daad90b6a6d597a846 Mon Sep 17 00:00:00 2001 From: bomeng Date: Sun, 3 Apr 2016 17:15:02 +0200 Subject: [SPARK-14341][SQL] Throw exception on unsupported create / drop macro ddl ## What changes were proposed in this pull request? We throw an AnalysisException that looks like this: ``` scala> sqlContext.sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") org.apache.spark.sql.catalyst.parser.ParseException: Unsupported SQL statement == SQL == CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x)) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.nativeCommand(ParseDriver.scala:66) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser$$anonfun$parsePlan$1.apply(ParseDriver.scala:56) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser$$anonfun$parsePlan$1.apply(ParseDriver.scala:53) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parse(ParseDriver.scala:86) at org.apache.spark.sql.catalyst.parser.AbstractSqlParser.parsePlan(ParseDriver.scala:53) at org.apache.spark.sql.SQLContext.parseSql(SQLContext.scala:198) at org.apache.spark.sql.SQLContext.sql(SQLContext.scala:749) ... 48 elided ``` ## How was this patch tested? Add test cases in HiveQuerySuite.scala Author: bomeng Closes #12125 from bomeng/SPARK-14341. --- .../main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 3 +++ .../apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala | 6 ++++-- .../scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala | 7 +++++++ 3 files changed, 14 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f34bb061e4..6cf47b5c30 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -173,6 +173,8 @@ unsupportedHiveNativeCommands | kw1=LOCK kw2=DATABASE | kw1=UNLOCK kw2=TABLE | kw1=UNLOCK kw2=DATABASE + | kw1=CREATE kw2=TEMPORARY kw3=MACRO + | kw1=DROP kw2=TEMPORARY kw3=MACRO ; createTableHeader @@ -759,6 +761,7 @@ SNAPSHOT: 'SNAPSHOT'; READ: 'READ'; WRITE: 'WRITE'; ONLY: 'ONLY'; +MACRO: 'MACRO'; IF: 'IF'; diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index d8695bc5db..4b4f88ece0 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -363,7 +363,10 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", - "alter_index" + "alter_index", + + // Macro commands are not supported + "macro" ) /** @@ -733,7 +736,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_file_with_space_in_the_name", "loadpart1", "louter_join_ppr", - "macro", "mapjoin_distinct", "mapjoin_filter_on_outerjoin", "mapjoin_mapjoin", diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 5450368b88..b951948fda 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -1295,6 +1295,13 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { assertUnsupportedFeature { sql("ALTER INDEX my_index ON my_table set IDXPROPERTIES (\"prop1\"=\"val1_new\")")} } + + test("create/drop macro commands are not supported") { + assertUnsupportedFeature { + sql("CREATE TEMPORARY MACRO SIGMOID (x DOUBLE) 1.0 / (1.0 + EXP(-x))") + } + assertUnsupportedFeature { sql("DROP TEMPORARY MACRO SIGMOID") } + } } // for SPARK-2180 test -- cgit v1.2.3 From 3f749f7ed443899d667c9e2b2a11bc595d6fc7f6 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 3 Apr 2016 18:14:16 -0700 Subject: [SPARK-14355][BUILD] Fix typos in Exception/Testcase/Comments and static analysis results ## What changes were proposed in this pull request? This PR contains the following 5 types of maintenance fix over 59 files (+94 lines, -93 lines). - Fix typos(exception/log strings, testcase name, comments) in 44 lines. - Fix lint-java errors (MaxLineLength) in 6 lines. (New codes after SPARK-14011) - Use diamond operators in 40 lines. (New codes after SPARK-13702) - Fix redundant semicolon in 5 lines. - Rename class `InferSchemaSuite` to `CSVInferSchemaSuite` in CSVInferSchemaSuite.scala. ## How was this patch tested? Manual and pass the Jenkins tests. Author: Dongjoon Hyun Closes #12139 from dongjoon-hyun/SPARK-14355. --- .../spark/network/client/TransportClientFactory.java | 2 +- .../spark/network/client/TransportResponseHandler.java | 6 +++--- .../spark/network/server/OneForOneStreamManager.java | 2 +- .../apache/spark/network/sasl/ShuffleSecretManager.java | 2 +- .../collection/unsafe/sort/UnsafeSorterSpillMerger.java | 2 +- core/src/main/scala/org/apache/spark/api/r/RRunner.scala | 2 +- .../scala/org/apache/spark/util/random/RandomSampler.scala | 2 +- .../spark/shuffle/sort/UnsafeShuffleWriterSuite.java | 4 ++-- .../main/java/org/apache/spark/examples/JavaLogQuery.java | 4 ++-- .../mllib/JavaMultiLabelClassificationMetricsExample.java | 14 +++++++------- .../mllib/JavaPowerIterationClusteringExample.java | 10 +++++----- .../examples/mllib/JavaStratifiedSamplingExample.java | 2 +- .../spark/examples/streaming/JavaFlumeEventCount.java | 4 ++-- .../apache/spark/streaming/flume/JavaFlumeStreamSuite.java | 11 ++++++----- .../org/apache/spark/launcher/CommandBuilderUtils.java | 2 +- .../main/java/org/apache/spark/launcher/SparkLauncher.java | 2 +- .../org/apache/spark/launcher/LauncherServerSuite.java | 2 +- .../spark/launcher/SparkSubmitCommandBuilderSuite.java | 2 +- .../spark/ml/classification/DecisionTreeClassifier.scala | 2 +- .../apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 2 +- .../org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 2 +- .../main/scala/org/apache/spark/ml/tree/treeModels.scala | 2 +- .../spark/mllib/classification/LogisticRegression.scala | 2 +- .../java/org/apache/spark/ml/param/JavaTestParams.java | 2 +- .../JavaStreamingLogisticRegressionSuite.java | 4 ++-- .../spark/mllib/clustering/JavaStreamingKMeansSuite.java | 4 ++-- .../org/apache/spark/mllib/linalg/JavaVectorsSuite.java | 4 ++-- .../regression/JavaStreamingLinearRegressionSuite.java | 4 ++-- .../apache/spark/ml/regression/LinearRegressionSuite.scala | 2 +- .../spark/sql/execution/UnsafeExternalRowSorter.java | 2 +- .../scala/org/apache/spark/sql/catalyst/CatalystConf.scala | 2 +- .../apache/spark/sql/catalyst/expressions/Expression.scala | 4 ++-- .../sql/catalyst/expressions/codegen/CodeGenerator.scala | 2 +- .../sql/catalyst/expressions/conditionalExpressions.scala | 2 +- .../org/apache/spark/sql/catalyst/parser/AstBuilder.scala | 4 ++-- .../src/test/scala/org/apache/spark/sql/RowTest.scala | 2 +- .../apache/spark/sql/execution/UnsafeKVExternalSorter.java | 2 +- .../spark/sql/execution/vectorized/ColumnarBatch.java | 4 ++-- .../spark/sql/execution/vectorized/OnHeapColumnVector.java | 2 +- .../main/scala/org/apache/spark/sql/ContinuousQuery.scala | 2 +- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 2 +- .../org/apache/spark/sql/execution/SparkStrategies.scala | 2 +- .../main/scala/org/apache/spark/sql/execution/Window.scala | 2 +- .../org/apache/spark/sql/execution/basicOperators.scala | 2 +- .../spark/sql/execution/columnar/ColumnBuilder.scala | 2 +- .../apache/spark/sql/execution/stat/StatFunctions.scala | 2 +- .../spark/sql/execution/streaming/FileStreamSink.scala | 2 +- .../streaming/state/HDFSBackedStateStoreProvider.scala | 2 +- .../spark/sql/execution/streaming/state/StateStore.scala | 2 +- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 2 +- .../java/test/org/apache/spark/sql/JavaDatasetSuite.java | 10 +++++----- .../src/test/scala/org/apache/spark/sql/QueryTest.scala | 4 ++-- .../execution/datasources/csv/CSVInferSchemaSuite.scala | 2 +- .../sql/execution/datasources/parquet/ParquetIOSuite.scala | 2 +- .../org/apache/spark/sql/streaming/FileStressSuite.scala | 2 +- .../scala/org/apache/spark/sql/hive/test/TestHive.scala | 4 ++-- .../spark/sql/hive/execution/HiveComparisonTest.scala | 2 +- .../scala/org/apache/spark/sql/hive/parquetSuites.scala | 4 ++-- 59 files changed, 94 insertions(+), 93 deletions(-) (limited to 'sql/catalyst') diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java index 5a36e18b09..b5a9d6671f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java @@ -94,7 +94,7 @@ public class TransportClientFactory implements Closeable { this.context = Preconditions.checkNotNull(context); this.conf = context.getConf(); this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps)); - this.connectionPool = new ConcurrentHashMap(); + this.connectionPool = new ConcurrentHashMap<>(); this.numConnectionsPerPeer = conf.numConnectionsPerPeer(); this.rand = new Random(); diff --git a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java index f0e2004d2d..8a69223c88 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java +++ b/common/network-common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java @@ -64,9 +64,9 @@ public class TransportResponseHandler extends MessageHandler { public TransportResponseHandler(Channel channel) { this.channel = channel; - this.outstandingFetches = new ConcurrentHashMap(); - this.outstandingRpcs = new ConcurrentHashMap(); - this.streamCallbacks = new ConcurrentLinkedQueue(); + this.outstandingFetches = new ConcurrentHashMap<>(); + this.outstandingRpcs = new ConcurrentHashMap<>(); + this.streamCallbacks = new ConcurrentLinkedQueue<>(); this.timeOfLastRequestNs = new AtomicLong(0); } diff --git a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java index e2222ae085..ae7e520b2f 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java +++ b/common/network-common/src/main/java/org/apache/spark/network/server/OneForOneStreamManager.java @@ -63,7 +63,7 @@ public class OneForOneStreamManager extends StreamManager { // For debugging purposes, start with a random stream id to help identifying different streams. // This does not need to be globally unique, only unique to this class. nextStreamId = new AtomicLong((long) new Random().nextInt(Integer.MAX_VALUE) * 1000); - streams = new ConcurrentHashMap(); + streams = new ConcurrentHashMap<>(); } @Override diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java index 268cb40121..56a025c4d9 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/sasl/ShuffleSecretManager.java @@ -37,7 +37,7 @@ public class ShuffleSecretManager implements SecretKeyHolder { private static final String SPARK_SASL_USER = "sparkSaslUser"; public ShuffleSecretManager() { - shuffleSecretMap = new ConcurrentHashMap(); + shuffleSecretMap = new ConcurrentHashMap<>(); } /** diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java index 2b1c860e55..01aed95878 100644 --- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java +++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeSorterSpillMerger.java @@ -45,7 +45,7 @@ final class UnsafeSorterSpillMerger { } } }; - priorityQueue = new PriorityQueue(numSpills, comparator); + priorityQueue = new PriorityQueue<>(numSpills, comparator); } /** diff --git a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala index ff279ec270..07d1fa2c4a 100644 --- a/core/src/main/scala/org/apache/spark/api/r/RRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/r/RRunner.scala @@ -182,7 +182,7 @@ private[spark] class RRunner[U]( } stream.flush() } catch { - // TODO: We should propogate this error to the task thread + // TODO: We should propagate this error to the task thread case e: Exception => logError("R Writer thread got an exception", e) } finally { diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala index d397cca4b4..8c67364ef1 100644 --- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala +++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala @@ -326,7 +326,7 @@ class GapSamplingReplacement( /** * Skip elements with replication factor zero (i.e. elements that won't be sampled). * Samples 'k' from geometric distribution P(k) = (1-q)(q)^k, where q = e^(-f), that is - * q is the probabililty of Poisson(0; f) + * q is the probability of Poisson(0; f) */ private def advance(): Unit = { val u = math.max(rng.nextDouble(), epsilon) diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java index 44733dcdaf..30750b1bf1 100644 --- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java +++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java @@ -170,11 +170,11 @@ public class UnsafeShuffleWriterSuite { private UnsafeShuffleWriter createWriter( boolean transferToEnabled) throws IOException { conf.set("spark.file.transferTo", String.valueOf(transferToEnabled)); - return new UnsafeShuffleWriter( + return new UnsafeShuffleWriter<>( blockManager, shuffleBlockResolver, taskMemoryManager, - new SerializedShuffleHandle(0, 1, shuffleDep), + new SerializedShuffleHandle<>(0, 1, shuffleDep), 0, // map id taskContext, conf diff --git a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java index 8abc03e73d..ebb0687b14 100644 --- a/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java +++ b/examples/src/main/java/org/apache/spark/examples/JavaLogQuery.java @@ -82,10 +82,10 @@ public final class JavaLogQuery { String user = m.group(3); String query = m.group(5); if (!user.equalsIgnoreCase("-")) { - return new Tuple3(ip, user, query); + return new Tuple3<>(ip, user, query); } } - return new Tuple3(null, null, null); + return new Tuple3<>(null, null, null); } public static Stats extractStats(String line) { diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java index 5904260e2d..bc99dc023f 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaMultiLabelClassificationMetricsExample.java @@ -34,13 +34,13 @@ public class JavaMultiLabelClassificationMetricsExample { JavaSparkContext sc = new JavaSparkContext(conf); // $example on$ List> data = Arrays.asList( - new Tuple2(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), - new Tuple2(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{}, new double[]{0.0}), - new Tuple2(new double[]{2.0}, new double[]{2.0}), - new Tuple2(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), - new Tuple2(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), - new Tuple2(new double[]{1.0}, new double[]{1.0, 2.0}) + new Tuple2<>(new double[]{0.0, 1.0}, new double[]{0.0, 2.0}), + new Tuple2<>(new double[]{0.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2<>(new double[]{}, new double[]{0.0}), + new Tuple2<>(new double[]{2.0}, new double[]{2.0}), + new Tuple2<>(new double[]{2.0, 0.0}, new double[]{2.0, 0.0}), + new Tuple2<>(new double[]{0.0, 1.0, 2.0}, new double[]{0.0, 1.0}), + new Tuple2<>(new double[]{1.0}, new double[]{1.0, 2.0}) ); JavaRDD> scoreAndLabels = sc.parallelize(data); diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java index b62fa90c34..91c3bd72da 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaPowerIterationClusteringExample.java @@ -40,11 +40,11 @@ public class JavaPowerIterationClusteringExample { @SuppressWarnings("unchecked") // $example on$ JavaRDD> similarities = sc.parallelize(Lists.newArrayList( - new Tuple3(0L, 1L, 0.9), - new Tuple3(1L, 2L, 0.9), - new Tuple3(2L, 3L, 0.9), - new Tuple3(3L, 4L, 0.1), - new Tuple3(4L, 5L, 0.9))); + new Tuple3<>(0L, 1L, 0.9), + new Tuple3<>(1L, 2L, 0.9), + new Tuple3<>(2L, 3L, 0.9), + new Tuple3<>(3L, 4L, 0.1), + new Tuple3<>(4L, 5L, 0.9))); PowerIterationClustering pic = new PowerIterationClustering() .setK(2) diff --git a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java index c27fba2783..86c389e11c 100644 --- a/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java +++ b/examples/src/main/java/org/apache/spark/examples/mllib/JavaStratifiedSamplingExample.java @@ -36,7 +36,7 @@ public class JavaStratifiedSamplingExample { JavaSparkContext jsc = new JavaSparkContext(conf); // $example on$ - List> list = new ArrayList>( + List> list = new ArrayList<>( Arrays.>asList( new Tuple2(1, 'a'), new Tuple2(1, 'b'), diff --git a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java index da56637fe8..bae4b78ac2 100644 --- a/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java +++ b/examples/src/main/java/org/apache/spark/examples/streaming/JavaFlumeEventCount.java @@ -19,7 +19,6 @@ package org.apache.spark.examples.streaming; import org.apache.spark.SparkConf; import org.apache.spark.api.java.function.Function; -import org.apache.spark.examples.streaming.StreamingExamples; import org.apache.spark.streaming.*; import org.apache.spark.streaming.api.java.*; import org.apache.spark.streaming.flume.FlumeUtils; @@ -58,7 +57,8 @@ public final class JavaFlumeEventCount { Duration batchInterval = new Duration(2000); SparkConf sparkConf = new SparkConf().setAppName("JavaFlumeEventCount"); JavaStreamingContext ssc = new JavaStreamingContext(sparkConf, batchInterval); - JavaReceiverInputDStream flumeStream = FlumeUtils.createStream(ssc, host, port); + JavaReceiverInputDStream flumeStream = + FlumeUtils.createStream(ssc, host, port); flumeStream.count(); diff --git a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java index 3b5e0c7746..ada05f203b 100644 --- a/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java +++ b/external/flume/src/test/java/org/apache/spark/streaming/flume/JavaFlumeStreamSuite.java @@ -27,10 +27,11 @@ public class JavaFlumeStreamSuite extends LocalJavaStreamingContext { @Test public void testFlumeStream() { // tests the API, does not actually test data receiving - JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", 12345); - JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2()); - JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", 12345, - StorageLevel.MEMORY_AND_DISK_SER_2(), false); + JavaReceiverInputDStream test1 = FlumeUtils.createStream(ssc, "localhost", + 12345); + JavaReceiverInputDStream test2 = FlumeUtils.createStream(ssc, "localhost", + 12345, StorageLevel.MEMORY_AND_DISK_SER_2()); + JavaReceiverInputDStream test3 = FlumeUtils.createStream(ssc, "localhost", + 12345, StorageLevel.MEMORY_AND_DISK_SER_2(), false); } } diff --git a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java index 1e55aad5c9..a08c8dcba4 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java +++ b/launcher/src/main/java/org/apache/spark/launcher/CommandBuilderUtils.java @@ -34,7 +34,7 @@ class CommandBuilderUtils { /** The set of known JVM vendors. */ enum JavaVendor { Oracle, IBM, OpenJDK, Unknown - }; + } /** Returns whether the given string is null or empty. */ static boolean isEmpty(String s) { diff --git a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java index a542159901..a083f05a2a 100644 --- a/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java +++ b/launcher/src/main/java/org/apache/spark/launcher/SparkLauncher.java @@ -477,6 +477,6 @@ public class SparkLauncher { // No op. } - }; + } } diff --git a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java index 5bf2babdd1..a9039b3ec9 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/LauncherServerSuite.java @@ -175,7 +175,7 @@ public class LauncherServerSuite extends BaseSuite { TestClient(Socket s) throws IOException { super(s); - this.inbound = new LinkedBlockingQueue(); + this.inbound = new LinkedBlockingQueue<>(); this.clientThread = new Thread(this); clientThread.setName("TestClient"); clientThread.setDaemon(true); diff --git a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java index b7f4f2efc5..29cbbe825b 100644 --- a/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java +++ b/launcher/src/test/java/org/apache/spark/launcher/SparkSubmitCommandBuilderSuite.java @@ -160,7 +160,7 @@ public class SparkSubmitCommandBuilderSuite extends BaseSuite { "SparkPi", "42"); - Map env = new HashMap(); + Map env = new HashMap<>(); List cmd = buildCommand(sparkSubmitArgs, env); assertEquals("foo", findArgValue(cmd, parser.MASTER)); assertEquals("bar", findArgValue(cmd, parser.DEPLOY_MODE)); diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 23c4af17f9..4525bf71f6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -205,7 +205,7 @@ final class DecisionTreeClassificationModel private[ml] ( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Classification) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 3ce129b12c..1d03a5b4f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -62,7 +62,7 @@ private[shared] object SharedParamsCodeGen { "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + - "will filter out rows with bad values), or error (which will throw an errror). More " + + "will filter out rows with bad values), or error (which will throw an error). More " + "options may be added later", isValid = "ParamValidators.inArray(Array(\"skip\", \"error\"))"), ParamDesc[Boolean]("standardization", "whether to standardize the training features" + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index 96263c5baf..64d6af2766 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -270,10 +270,10 @@ private[ml] trait HasFitIntercept extends Params { private[ml] trait HasHandleInvalid extends Params { /** - * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. + * Param for how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later. * @group param */ - final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later", ParamValidators.inArray(Array("skip", "error"))) + final val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an error). More options may be added later", ParamValidators.inArray(Array("skip", "error"))) /** @group getParam */ final def getHandleInvalid: String = $(handleInvalid) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 0a3d00e470..1289a317ee 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -205,7 +205,7 @@ final class DecisionTreeRegressionModel private[ml] ( @Since("2.0.0") lazy val featureImportances: Vector = TreeEnsembleModel.featureImportances(this, numFeatures) - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ override private[spark] def toOld: OldDecisionTreeModel = { new OldDecisionTreeModel(rootNode.toOld(1), OldAlgo.Regression) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index 1fad9d6d8c..8ea767b2b3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -71,7 +71,7 @@ private[spark] trait DecisionTreeModel { */ private[ml] def maxSplitFeatureIndex(): Int = rootNode.maxSplitFeatureIndex() - /** Convert to spark.mllib DecisionTreeModel (losing some infomation) */ + /** Convert to spark.mllib DecisionTreeModel (losing some information) */ private[spark] def toOld: OldDecisionTreeModel } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala index c0404be019..f10570e662 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala @@ -418,7 +418,7 @@ class LogisticRegressionWithLBFGS private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean): LogisticRegressionModel = { - // ml's Logisitic regression only supports binary classifcation currently. + // ml's Logistic regression only supports binary classification currently. if (numOfLinearPredictor == 1) { def runWithMlLogisitcRegression(elasticNetParam: Double) = { // Prepare the ml LogisticRegression based on our settings diff --git a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java index 65841182df..06f7fbb86e 100644 --- a/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java +++ b/mllib/src/test/java/org/apache/spark/ml/param/JavaTestParams.java @@ -89,7 +89,7 @@ public class JavaTestParams extends JavaParams { myDoubleParam_ = new DoubleParam(this, "myDoubleParam", "this is a double param", ParamValidators.inRange(0.0, 1.0)); List validStrings = Arrays.asList("a", "b"); - myStringParam_ = new Param(this, "myStringParam", "this is a string param", + myStringParam_ = new Param<>(this, "myStringParam", "this is a string param", ParamValidators.inArray(validStrings)); myDoubleArrayParam_ = new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param"); diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java index c9e5ee22f3..62c6d9b7e3 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaStreamingLogisticRegressionSuite.java @@ -66,8 +66,8 @@ public class JavaStreamingLogisticRegressionSuite implements Serializable { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLogisticRegressionWithSGD slr = new StreamingLogisticRegressionWithSGD() diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java index d644766d1e..62edbd3a29 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaStreamingKMeansSuite.java @@ -66,8 +66,8 @@ public class JavaStreamingKMeansSuite implements Serializable { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingKMeans skmeans = new StreamingKMeans() diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java index 77c8c6274f..4ba8e543a9 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java @@ -37,8 +37,8 @@ public class JavaVectorsSuite implements Serializable { public void sparseArrayConstruction() { @SuppressWarnings("unchecked") Vector v = Vectors.sparse(3, Arrays.asList( - new Tuple2(0, 2.0), - new Tuple2(2, 3.0))); + new Tuple2<>(0, 2.0), + new Tuple2<>(2, 3.0))); assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0); } } diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java index dbf6488d41..ea0ccd7448 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaStreamingLinearRegressionSuite.java @@ -65,8 +65,8 @@ public class JavaStreamingLinearRegressionSuite implements Serializable { JavaDStream training = attachTestInputStream(ssc, Arrays.asList(trainingBatch, trainingBatch), 2); List> testBatch = Arrays.asList( - new Tuple2(10, Vectors.dense(1.0)), - new Tuple2(11, Vectors.dense(0.0))); + new Tuple2<>(10, Vectors.dense(1.0)), + new Tuple2<>(11, Vectors.dense(0.0))); JavaPairDStream test = JavaPairDStream.fromJavaDStream( attachTestInputStream(ssc, Arrays.asList(testBatch, testBatch), 2)); StreamingLinearRegressionWithSGD slr = new StreamingLinearRegressionWithSGD() diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index cccb7f8d1b..eb19d13093 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -759,7 +759,7 @@ class LinearRegressionSuite .sliding(2) .forall(x => x(0) >= x(1))) } else { - // To clalify that the normal solver is used here. + // To clarify that the normal solver is used here. assert(model.summary.objectiveHistory.length == 1) assert(model.summary.objectiveHistory(0) == 0.0) val devianceResidualsR = Array(-0.47082, 0.34635) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index aa7fc2121e..7784345a7a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -151,7 +151,7 @@ public final class UnsafeExternalRowSorter { Platform.throwException(e); } throw new RuntimeException("Exception should have been re-thrown in next()"); - }; + } }; } catch (IOException e) { cleanupResources(); diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index d5ac01500b..2b98aacdd7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -26,7 +26,7 @@ private[spark] trait CatalystConf { def groupByOrdinal: Boolean /** - * Returns the [[Resolver]] for the current configuration, which can be used to determin if two + * Returns the [[Resolver]] for the current configuration, which can be used to determine if two * identifiers are equal. */ def resolver: Resolver = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 5f8899d599..a24a5db8d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -153,8 +153,8 @@ abstract class Expression extends TreeNode[Expression] { * evaluate to the same result. */ lazy val canonicalized: Expression = { - val canonicalizedChildred = children.map(_.canonicalized) - Canonicalize.execute(withNewChildren(canonicalizedChildred)) + val canonicalizedChildren = children.map(_.canonicalized) + Canonicalize.execute(withNewChildren(canonicalizedChildren)) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index b64d3eea49..1bebd4e904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -509,7 +509,7 @@ class CodegenContext { /** * Checks and sets up the state and codegen for subexpression elimination. This finds the - * common subexpresses, generates the functions that evaluate those expressions and populates + * common subexpressions, generates the functions that evaluate those expressions and populates * the mapping of common subexpressions to the generated functions. */ private def subexpressionElimination(expressions: Seq[Expression]) = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 103ab365e3..35a7b46020 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -222,7 +222,7 @@ object CaseWhen { } /** - * A factory method to faciliate the creation of this expression when used in parsers. + * A factory method to facilitate the creation of this expression when used in parsers. * @param branches Expressions at even position are the branch conditions, and expressions at odd * position are branch values. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 8541b1f7c6..61ea3e4010 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -965,7 +965,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { /** * Create a binary arithmetic expression. The following arithmetic operators are supported: - * - Mulitplication: '*' + * - Multiplication: '*' * - Division: '/' * - Hive Long Division: 'DIV' * - Modulo: '%' @@ -1270,7 +1270,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a double literal for a number denoted in scientifc notation. + * Create a double literal for a number denoted in scientific notation. */ override def visitScientificDecimalLiteral( ctx: ScientificDecimalLiteralContext): Literal = withOrigin(ctx) { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala index d9577dea1b..c9c9599e7f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RowTest.scala @@ -121,7 +121,7 @@ class RowTest extends FunSpec with Matchers { externalRow should be theSameInstanceAs externalRow.copy() } - it("copy should return same ref for interal rows") { + it("copy should return same ref for internal rows") { internalRow should be theSameInstanceAs internalRow.copy() } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java index d3bfb00b3f..8132bba04c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/UnsafeKVExternalSorter.java @@ -272,5 +272,5 @@ public final class UnsafeKVExternalSorter { public void close() { cleanupResources(); } - }; + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java index 792e17911f..d1cc4e6d03 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java @@ -79,7 +79,7 @@ public final class ColumnarBatch { /** * Called to close all the columns in this batch. It is not valid to access the data after - * calling this. This must be called at the end to clean up memory allcoations. + * calling this. This must be called at the end to clean up memory allocations. */ public void close() { for (ColumnVector c: columns) { @@ -315,7 +315,7 @@ public final class ColumnarBatch { public int numRows() { return numRows; } /** - * Returns the number of valid rowss. + * Returns the number of valid rows. */ public int numValidRows() { assert(numRowsFiltered <= numRows); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index b1429fe7cb..708a00953a 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -212,7 +212,7 @@ public final class OnHeapColumnVector extends ColumnVector { public void putIntsLittleEndian(int rowId, int count, byte[] src, int srcIndex) { int srcOffset = srcIndex + Platform.BYTE_ARRAY_OFFSET; for (int i = 0; i < count; ++i) { - intData[i + rowId] = Platform.getInt(src, srcOffset);; + intData[i + rowId] = Platform.getInt(src, srcOffset); srcIndex += 4; srcOffset += 4; } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala index 1dc9a6893e..d9973b092d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/ContinuousQuery.scala @@ -94,7 +94,7 @@ trait ContinuousQuery { /** * Blocks until all available data in the source has been processed an committed to the sink. * This method is intended for testing. Note that in the case of continually arriving data, this - * method may block forever. Additionally, this method is only guranteed to block until data that + * method may block forever. Additionally, this method is only guaranteed to block until data that * has been synchronously appended data to a [[org.apache.spark.sql.execution.streaming.Source]] * prior to invocation. (i.e. `getOffset` must immediately reflect the addition). */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 41cb799b97..a39a2113e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -2077,7 +2077,7 @@ class Dataset[T] private[sql]( /** * Returns a new [[Dataset]] partitioned by the given partitioning expressions into - * `numPartitions`. The resulting Datasetis hash partitioned. + * `numPartitions`. The resulting Dataset is hash partitioned. * * This is the same operation as "DISTRIBUTE BY" in SQL (Hive QL). * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 5bcc172ca7..e1fabf519a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -108,7 +108,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { /** * Matches a plan whose single partition should be small enough to build a hash table. * - * Note: this assume that the number of partition is fixed, requires addtional work if it's + * Note: this assume that the number of partition is fixed, requires additional work if it's * dynamic. */ def canBuildHashMap(plan: LogicalPlan): Boolean = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 806089196c..8e9214fa25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -811,7 +811,7 @@ private[execution] final class UnboundedPrecedingWindowFunctionFrame( * * This is a very expensive operator to use, O(n * (n - 1) /2), because we need to maintain a * buffer and must do full recalculation after each row. Reverse iteration would be possible, if - * the communitativity of the used window functions can be guaranteed. + * the commutativity of the used window functions can be guaranteed. * * @param target to write results to. * @param processor to calculate the row values with. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fb1c6182cf..aba500ad8d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -146,7 +146,7 @@ case class Filter(condition: Expression, child: SparkPlan) // This has the property of not doing redundant IsNotNull checks and taking better advantage of // short-circuiting, not loading attributes until they are needed. // This is very perf sensitive. - // TODO: revisit this. We can consider reodering predicates as well. + // TODO: revisit this. We can consider reordering predicates as well. val generatedIsNotNullChecks = new Array[Boolean](notNullPreds.length) val generated = otherPreds.map { c => val nullChecks = c.references.map { r => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 7e26f19bb7..9a173367f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -185,7 +185,7 @@ private[columnar] object ColumnBuilder { case udt: UserDefinedType[_] => return apply(udt.sqlType, initialSize, columnName, useCompression) case other => - throw new Exception(s"not suppported type: $other") + throw new Exception(s"not supported type: $other") } builder.initialize(initialSize, columnName, useCompression) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala index e0b6709c51..d603f63a08 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/stat/StatFunctions.scala @@ -296,7 +296,7 @@ private[sql] object StatFunctions extends Logging { val defaultRelativeError: Double = 0.01 /** - * Statisttics from the Greenwald-Khanna paper. + * Statistics from the Greenwald-Khanna paper. * @param value the sampled value * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index e819e95d61..6921ae584d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -32,7 +32,7 @@ object FileStreamSink { /** * A sink that writes out results to parquet files. Each batch is written out to a unique - * directory. After all of the files in a batch have been succesfully written, the list of + * directory. After all of the files in a batch have been successfully written, the list of * file paths is appended to the log atomically. In the case of partial failures, some duplicate * data may be present in the target directory, but only one copy of each file will be present * in the log. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala index 8ece3c971a..1e0a4a5d4f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/HDFSBackedStateStoreProvider.scala @@ -178,7 +178,7 @@ private[state] class HDFSBackedStateStoreProvider( * This can be called only after committing all the updates made in the current thread. */ override def iterator(): Iterator[(UnsafeRow, UnsafeRow)] = { - verify(state == COMMITTED, "Cannot get iterator of store data before comitting") + verify(state == COMMITTED, "Cannot get iterator of store data before committing") HDFSBackedStateStoreProvider.this.iterator(newVersion) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index d60e6185ac..07f63f928b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -220,7 +220,7 @@ private[state] object StateStore extends Logging { val executorId = SparkEnv.get.blockManager.blockManagerId.executorId val verified = coordinatorRef.map(_.verifyIfInstanceActive(storeId, executorId)).getOrElse(false) - logDebug(s"Verifyied whether the loaded instance $storeId is active: $verified" ) + logDebug(s"Verified whether the loaded instance $storeId is active: $verified" ) verified } catch { case NonFatal(e) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index ca2d909e2c..cfe4911cb7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -126,7 +126,7 @@ object JdbcDialects { /** * Register a dialect for use on all new matching jdbc [[org.apache.spark.sql.DataFrame]]. - * Readding an existing dialect will cause a move-to-front. + * Reading an existing dialect will cause a move-to-front. * * @param dialect The new dialect. */ diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index a5ab446e08..873f681bdf 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -318,14 +318,14 @@ public class JavaDatasetSuite implements Serializable { Encoder> encoder3 = Encoders.tuple(Encoders.INT(), Encoders.LONG(), Encoders.STRING()); List> data3 = - Arrays.asList(new Tuple3(1, 2L, "a")); + Arrays.asList(new Tuple3<>(1, 2L, "a")); Dataset> ds3 = context.createDataset(data3, encoder3); Assert.assertEquals(data3, ds3.collectAsList()); Encoder> encoder4 = Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING()); List> data4 = - Arrays.asList(new Tuple4(1, "b", 2L, "a")); + Arrays.asList(new Tuple4<>(1, "b", 2L, "a")); Dataset> ds4 = context.createDataset(data4, encoder4); Assert.assertEquals(data4, ds4.collectAsList()); @@ -333,7 +333,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.INT(), Encoders.STRING(), Encoders.LONG(), Encoders.STRING(), Encoders.BOOLEAN()); List> data5 = - Arrays.asList(new Tuple5(1, "b", 2L, "a", true)); + Arrays.asList(new Tuple5<>(1, "b", 2L, "a", true)); Dataset> ds5 = context.createDataset(data5, encoder5); Assert.assertEquals(data5, ds5.collectAsList()); @@ -354,7 +354,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.INT(), Encoders.tuple(Encoders.STRING(), Encoders.STRING(), Encoders.LONG())); List>> data2 = - Arrays.asList(tuple2(1, new Tuple3("a", "b", 3L))); + Arrays.asList(tuple2(1, new Tuple3<>("a", "b", 3L))); Dataset>> ds2 = context.createDataset(data2, encoder2); Assert.assertEquals(data2, ds2.collectAsList()); @@ -376,7 +376,7 @@ public class JavaDatasetSuite implements Serializable { Encoders.tuple(Encoders.DOUBLE(), Encoders.DECIMAL(), Encoders.DATE(), Encoders.TIMESTAMP(), Encoders.FLOAT()); List> data = - Arrays.asList(new Tuple5( + Arrays.asList(new Tuple5<>( 1.7976931348623157E308, new BigDecimal("0.922337203685477589"), Date.valueOf("1970-01-01"), new Timestamp(System.currentTimeMillis()), Float.MAX_VALUE)); Dataset> ds = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index d160f8ab8c..f7f3bd78e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -105,10 +105,10 @@ abstract class QueryTest extends PlanTest { val expected = expectedAnswer.toSet.toSeq.map((a: Any) => a.toString).sorted val actual = decoded.toSet.toSeq.map((a: Any) => a.toString).sorted - val comparision = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") + val comparison = sideBySide("expected" +: expected, "spark" +: actual).mkString("\n") fail( s"""Decoded objects do not match expected objects: - |$comparision + |$comparison |${ds.resolvedTEncoder.deserializer.treeString} """.stripMargin) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala index 3a7cb25b4f..23d422635b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVInferSchemaSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.csv import org.apache.spark.SparkFunSuite import org.apache.spark.sql.types._ -class InferSchemaSuite extends SparkFunSuite { +class CSVInferSchemaSuite extends SparkFunSuite { test("String fields types are inferred correctly from null types") { assert(CSVInferSchema.inferField(NullType, "") == NullType) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 9746187d22..a3017258d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -469,7 +469,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { } } - testQuietly("SPARK-9849 DirectParquetOutputCommitter qualified name backwards compatiblity") { + testQuietly("SPARK-9849 DirectParquetOutputCommitter qualified name backwards compatibility") { val clonedConf = new Configuration(hadoopConfiguration) // Write to a parquet file and let it fail. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala index 3916430cdf..5b49a0a86a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStressSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils /** - * A stress test for streamign queries that read and write files. This test constists of + * A stress test for streaming queries that read and write files. This test consists of * two threads: * - one that writes out `numRecords` distinct integers to files of random sizes (the total * number of records is fixed but each files size / creation time is random). diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 4afc8d18a6..9393302355 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -380,8 +380,8 @@ class TestHiveContext private[hive]( """.stripMargin.cmd, s"LOAD DATA LOCAL INPATH '${getHiveFile("data/files/episodes.avro")}' INTO TABLE episodes".cmd ), - // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC PARITIONING - // IS NOT YET SUPPORTED + // THIS TABLE IS NOT THE SAME AS THE HIVE TEST TABLE episodes_partitioned AS DYNAMIC + // PARTITIONING IS NOT YET SUPPORTED TestTable("episodes_part", s"""CREATE TABLE episodes_part (title STRING, air_date STRING, doctor INT) |PARTITIONED BY (doctor_pt INT) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 4c1b425b16..e67fcbedc3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -482,7 +482,7 @@ abstract class HiveComparisonTest val tablesGenerated = queryList.zip(executions).flatMap { // We should take executedPlan instead of sparkPlan, because in following codes we // will run the collected plans. As we will do extra processing for sparkPlan such - // as adding exchage, collapsing codegen stages, etc., collecing sparkPlan here + // as adding exchange, collapsing codegen stages, etc., collecting sparkPlan here // will cause some errors when running these plans later. case (q, e) => e.executedPlan.collect { case i: InsertIntoHiveTable if tablesRead contains i.table.tableName => diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index b6fc61d453..eac65d5720 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -311,7 +311,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan. " + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan. " + s"However, found a ${o.toString} ") } @@ -341,7 +341,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { case ExecutedCommand(_: InsertIntoHadoopFsRelation) => // OK case o => fail("test_insert_parquet should be converted to a " + s"${classOf[HadoopFsRelation ].getCanonicalName} and " + - s"${classOf[InsertIntoDataSource].getCanonicalName} is expcted as the SparkPlan." + + s"${classOf[InsertIntoDataSource].getCanonicalName} is expected as the SparkPlan." + s"However, found a ${o.toString} ") } -- cgit v1.2.3 From 2715bc68bd1661d207b1af5f44ae8d02aec9d4ec Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Tue, 5 Apr 2016 08:41:59 +0200 Subject: [SPARK-14348][SQL] Support native execution of SHOW TBLPROPERTIES command ## What changes were proposed in this pull request? This PR adds Native execution of SHOW TBLPROPERTIES command. Command Syntax: ``` SQL SHOW TBLPROPERTIES table_name[(property_key_literal)] ``` ## How was this patch tested? Tests added in HiveComandSuiie and DDLCommandSuite Author: Dilip Biswal Closes #12133 from dilipbiswal/dkb_show_tblproperties. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../sql/catalyst/catalog/SessionCatalog.scala | 11 ++ .../spark/sql/execution/SparkSqlParser.scala | 37 ++++-- .../spark/sql/execution/command/commands.scala | 44 +++++++- .../sql/execution/command/DDLCommandSuite.scala | 8 ++ .../sql/hive/execution/HiveCommandSuite.scala | 125 +++++++++++++++++++++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 22 ---- 7 files changed, 219 insertions(+), 30 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6cf47b5c30..27b01e0bed 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -116,6 +116,8 @@ statement | SHOW TABLES ((FROM | IN) db=identifier)? (LIKE? pattern=STRING)? #showTables | SHOW DATABASES (LIKE pattern=STRING)? #showDatabases + | SHOW TBLPROPERTIES table=tableIdentifier + ('(' key=tablePropertyKey ')')? #showTblProperties | SHOW FUNCTIONS (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | (DESC | DESCRIBE) FUNCTION EXTENDED? qualifiedName #describeFunction | (DESC | DESCRIBE) option=(EXTENDED | FORMATTED)? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 569b99e414..3b8ce6373d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -163,6 +163,7 @@ class SessionCatalog( /** * Retrieve the metadata of an existing metastore table. * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then an [[AnalysisException]] is thrown. */ def getTable(name: TableIdentifier): CatalogTable = { val db = name.database.getOrElse(currentDb) @@ -271,6 +272,16 @@ class SessionCatalog( } } + /** + * Return whether a table with the specified name is a temporary table. + * + * Note: The temporary table cache is checked only when database is not + * explicitly specified. + */ + def isTemporaryTable(name: TableIdentifier): Boolean = { + !name.database.isDefined && tempTables.contains(formatTableName(name.table)) + } + /** * List all tables in the specified database, including temporary tables. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index ff3ab7746c..fb106d1aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -92,6 +92,22 @@ class SparkSqlAstBuilder extends AstBuilder { ShowDatabasesCommand(Option(ctx.pattern).map(string)) } + /** + * A command for users to list the properties for a table. If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ + override def visitShowTblProperties( + ctx: ShowTblPropertiesContext): LogicalPlan = withOrigin(ctx) { + ShowTablePropertiesCommand( + visitTableIdentifier(ctx.tableIdentifier), + Option(ctx.key).map(visitTablePropertyKey)) + } + /** * Create a [[RefreshTable]] logical plan. */ @@ -220,18 +236,25 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitTablePropertyList( ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { ctx.tableProperty.asScala.map { property => - // A key can either be a String or a collection of dot separated elements. We need to treat - // these differently. - val key = if (property.key.STRING != null) { - string(property.key.STRING) - } else { - property.key.getText - } + val key = visitTablePropertyKey(property.key) val value = Option(property.value).map(string).orNull key -> value }.toMap } + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + /** * Create a [[CreateDatabase]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 4eb8d7ff0d..a4be3bc333 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -21,7 +21,7 @@ import java.util.NoSuchElementException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{Dataset, Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Dataset, Row, SQLContext} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.errors.TreeNodeException import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} @@ -380,6 +380,48 @@ case class ShowDatabasesCommand(databasePattern: Option[String]) extends Runnabl } } +/** + * A command for users to list the properties for a table If propertyKey is specified, the value + * for the propertyKey is returned. If propertyKey is not specified, all the keys and their + * corresponding values are returned. + * The syntax of using this command in SQL is: + * {{{ + * SHOW TBLPROPERTIES table_name[('propertyKey')]; + * }}} + */ +case class ShowTablePropertiesCommand( + table: TableIdentifier, + propertyKey: Option[String]) extends RunnableCommand { + + override val output: Seq[Attribute] = { + val schema = AttributeReference("value", StringType, nullable = false)() :: Nil + propertyKey match { + case None => AttributeReference("key", StringType, nullable = false)() :: schema + case _ => schema + } + } + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + + if (catalog.isTemporaryTable(table)) { + Seq.empty[Row] + } else { + val catalogTable = sqlContext.sessionState.catalog.getTable(table) + + propertyKey match { + case Some(p) => + val propValue = catalogTable + .properties + .getOrElse(p, s"Table ${catalogTable.qualifiedName} does not have property: $p") + Seq(Row(propValue)) + case None => + catalogTable.properties.map(p => Row(p._1, p._2)).toSeq + } + } + } +} + /** * A command for users to list all of the registered functions. * The syntax of using this command in SQL is: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 458f36e832..8b2a5979e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -773,4 +773,12 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } + test("show tblproperties") { + val parsed1 = parser.parsePlan("SHOW TBLPROPERTIES tab1") + val expected1 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), None) + val parsed2 = parser.parsePlan("SHOW TBLPROPERTIES tab1('propKey1')") + val expected2 = ShowTablePropertiesCommand(TableIdentifier("tab1", None), Some("propKey1")) + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala new file mode 100644 index 0000000000..4c3f450522 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils + +class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + protected override def beforeAll(): Unit = { + super.beforeAll() + sql( + """ + |CREATE EXTERNAL TABLE parquet_tab1 (c1 INT, c2 STRING) + |USING org.apache.spark.sql.parquet.DefaultSource + """.stripMargin) + + sql( + """ + |CREATE EXTERNAL TABLE parquet_tab2 (c1 INT, c2 STRING) + |STORED AS PARQUET + |TBLPROPERTIES('prop1Key'="prop1Val", '`prop2Key`'="prop2Val") + """.stripMargin) + } + + override protected def afterAll(): Unit = { + try { + sql("DROP TABLE IF EXISTS parquet_tab1") + sql("DROP TABLE IF EXISTS parquet_tab2") + } finally { + super.afterAll() + } + } + + test("show tables") { + withTable("show1a", "show2b") { + sql("CREATE TABLE show1a(c1 int)") + sql("CREATE TABLE show2b(c2 int)") + checkAnswer( + sql("SHOW TABLES IN default 'show1*'"), + Row("show1a", false) :: Nil) + checkAnswer( + sql("SHOW TABLES IN default 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + checkAnswer( + sql("SHOW TABLES 'show1*|show2*'"), + Row("show1a", false) :: + Row("show2b", false) :: Nil) + assert( + sql("SHOW TABLES").count() >= 2) + assert( + sql("SHOW TABLES IN default").count() >= 2) + } + } + + test("show tblproperties of data source tables - basic") { + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1") + .filter(s"key = 'spark.sql.sources.provider'"), + Row("spark.sql.sources.provider", "org.apache.spark.sql.parquet.DefaultSource") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1(spark.sql.sources.provider)"), + Row("org.apache.spark.sql.parquet.DefaultSource") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1") + .filter(s"key = 'spark.sql.sources.schema.numParts'"), + Row("spark.sql.sources.schema.numParts", "1") :: Nil + ) + + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1('spark.sql.sources.schema.numParts')"), + Row("1")) + } + + test("show tblproperties for datasource table - errors") { + val message1 = intercept[AnalysisException] { + sql("SHOW TBLPROPERTIES badtable") + }.getMessage + assert(message1.contains("Table badtable not found in database default")) + + // When key is not found, a row containing the error is returned. + checkAnswer( + sql("SHOW TBLPROPERTIES parquet_tab1('invalid.prop.key')"), + Row("Table default.parquet_tab1 does not have property: invalid.prop.key") :: Nil + ) + } + + test("show tblproperties for hive table") { + checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('prop1Key')"), Row("prop1Val")) + checkAnswer(sql("SHOW TBLPROPERTIES parquet_tab2('`prop2Key`')"), Row("prop2Val")) + } + + test("show tblproperties for spark temporary table - empty row") { + withTempTable("parquet_temp") { + sql( + """ + |CREATE TEMPORARY TABLE parquet_temp (c1 INT, c2 STRING) + |USING org.apache.spark.sql.parquet.DefaultSource + """.stripMargin) + + // An empty sequence of row is returned for session temporary table. + checkAnswer(sql("SHOW TBLPROPERTIES parquet_temp"), Nil) + } + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index c203518fdd..6199253d34 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1811,26 +1811,4 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } } } - - test("show tables") { - withTable("show1a", "show2b") { - sql("CREATE TABLE show1a(c1 int)") - sql("CREATE TABLE show2b(c2 int)") - checkAnswer( - sql("SHOW TABLES IN default 'show1*'"), - Row("show1a", false) :: Nil) - checkAnswer( - sql("SHOW TABLES IN default 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) - checkAnswer( - sql("SHOW TABLES 'show1*|show2*'"), - Row("show1a", false) :: - Row("show2b", false) :: Nil) - assert( - sql("SHOW TABLES").count() >= 2) - assert( - sql("SHOW TABLES IN default").count() >= 2) - } - } } -- cgit v1.2.3 From 78071736799b6c86b5c01b27395f4ab87075342b Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 5 Apr 2016 11:19:46 +0200 Subject: [SPARK-14349][SQL] Issue Error Messages for Unsupported Operators/DML/DDL in SQL Context. #### What changes were proposed in this pull request? Currently, the weird error messages are issued if we use Hive Context-only operations in SQL Context. For example, - When calling `Drop Table` in SQL Context, we got the following message: ``` Expected exception org.apache.spark.sql.catalyst.parser.ParseException to be thrown, but java.lang.ClassCastException was thrown. ``` - When calling `Script Transform` in SQL Context, we got the message: ``` assertion failed: No plan for ScriptTransformation [key#9,value#10], cat, [tKey#155,tValue#156], null +- LogicalRDD [key#9,value#10], MapPartitionsRDD[3] at beforeAll at BeforeAndAfterAll.scala:187 ``` Updates: Based on the investigation from hvanhovell , the root cause is `visitChildren`, which is the default implementation. It always returns the result of the last defined context child. After merging the code changes from hvanhovell , it works! Thank you hvanhovell ! #### How was this patch tested? A few test cases are added. Not sure if the same issue exist for the other operators/DDL/DML. hvanhovell Author: gatorsmile Author: xiaoli Author: Herman van Hovell Author: Xiao Li Closes #12134 from gatorsmile/hiveParserCommand. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 14 +-- .../spark/sql/catalyst/parser/AstBuilder.scala | 118 +++++++++++---------- .../sql/catalyst/parser/PlanParserSuite.scala | 10 -- .../sql/execution/command/DDLCommandSuite.scala | 23 ++++ .../spark/sql/hive/execution/HiveSqlParser.scala | 16 ++- .../org/apache/spark/sql/hive/HiveQlSuite.scala | 31 +++++- 6 files changed, 133 insertions(+), 79 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 27b01e0bed..96c170be3d 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -468,15 +468,15 @@ booleanExpression // https://github.com/antlr/antlr4/issues/780 // https://github.com/antlr/antlr4/issues/781 predicated - : valueExpression predicate[$valueExpression.ctx]? + : valueExpression predicate? ; -predicate[ParserRuleContext value] - : NOT? BETWEEN lower=valueExpression AND upper=valueExpression #between - | NOT? IN '(' expression (',' expression)* ')' #inList - | NOT? IN '(' query ')' #inSubquery - | NOT? like=(RLIKE | LIKE) pattern=valueExpression #like - | IS NOT? NULL #nullPredicate +predicate + : NOT? kind=BETWEEN lower=valueExpression AND upper=valueExpression + | NOT? kind=IN '(' expression (',' expression)* ')' + | NOT? kind=IN '(' query ')' + | NOT? kind=(RLIKE | LIKE) pattern=valueExpression + | IS NOT? kind=NULL ; valueExpression diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 61ea3e4010..14c90918e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -22,7 +22,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.antlr.v4.runtime.{ParserRuleContext, Token} -import org.antlr.v4.runtime.tree.{ParseTree, TerminalNode} +import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} @@ -46,6 +46,19 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { ctx.accept(this).asInstanceOf[T] } + /** + * Override the default behavior for all visit methods. This will only return a non-null result + * when the context has only one child. This is done because there is no generic method to + * combine the results of the context children. In all other cases null is returned. + */ + override def visitChildren(node: RuleNode): AnyRef = { + if (node.getChildCount == 1) { + node.getChild(0).accept(this) + } else { + null + } + } + override def visitSingleStatement(ctx: SingleStatementContext): LogicalPlan = withOrigin(ctx) { visit(ctx.statement).asInstanceOf[LogicalPlan] } @@ -351,7 +364,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { string(script), attributes, withFilter, - withScriptIOSchema(inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) + withScriptIOSchema( + ctx, inRowFormat, recordWriter, outRowFormat, recordReader, schemaLess)) case SqlBaseParser.SELECT => // Regular select @@ -398,11 +412,14 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create a (Hive based) [[ScriptInputOutputSchema]]. */ protected def withScriptIOSchema( + ctx: QuerySpecificationContext, inRowFormat: RowFormatContext, recordWriter: Token, outRowFormat: RowFormatContext, recordReader: Token, - schemaLess: Boolean): ScriptInputOutputSchema = null + schemaLess: Boolean): ScriptInputOutputSchema = { + throw new ParseException("Script Transform is not supported", ctx) + } /** * Create a logical plan for a given 'FROM' clause. Note that we support multiple (comma @@ -778,17 +795,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { trees.asScala.map(expression) } - /** - * Invert a boolean expression if it has a valid NOT clause. - */ - private def invertIfNotDefined(expression: Expression, not: TerminalNode): Expression = { - if (not != null) { - Not(expression) - } else { - expression - } - } - /** * Create a star (i.e. all) expression; this selects all elements (in the specified object). * Both un-targeted (global) and targeted aliases are supported. @@ -909,57 +915,55 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { } /** - * Create a BETWEEN expression. This tests if an expression lies with in the bounds set by two - * other expressions. The inverse can also be created. - */ - override def visitBetween(ctx: BetweenContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - val between = And( - GreaterThanOrEqual(value, expression(ctx.lower)), - LessThanOrEqual(value, expression(ctx.upper))) - invertIfNotDefined(between, ctx.NOT) - } - - /** - * Create an IN expression. This tests if the value of the left hand side expression is - * contained by the sequence of expressions on the right hand side. + * Create a predicated expression. A predicated expression is a normal expression with a + * predicate attached to it, for example: + * {{{ + * a + 1 IS NULL + * }}} */ - override def visitInList(ctx: InListContext): Expression = withOrigin(ctx) { - val in = In(expression(ctx.value), ctx.expression().asScala.map(expression)) - invertIfNotDefined(in, ctx.NOT) + override def visitPredicated(ctx: PredicatedContext): Expression = withOrigin(ctx) { + val e = expression(ctx.valueExpression) + if (ctx.predicate != null) { + withPredicate(e, ctx.predicate) + } else { + e + } } /** - * Create an IN expression, where the the right hand side is a query. This is unsupported. + * Add a predicate to the given expression. Supported expressions are: + * - (NOT) BETWEEN + * - (NOT) IN + * - (NOT) LIKE + * - (NOT) RLIKE + * - IS (NOT) NULL. */ - override def visitInSubquery(ctx: InSubqueryContext): Expression = { - throw new ParseException("IN with a Sub-query is currently not supported.", ctx) - } + private def withPredicate(e: Expression, ctx: PredicateContext): Expression = withOrigin(ctx) { + // Invert a predicate if it has a valid NOT clause. + def invertIfNotDefined(e: Expression): Expression = ctx.NOT match { + case null => e + case not => Not(e) + } - /** - * Create a (R)LIKE/REGEXP expression. - */ - override def visitLike(ctx: LikeContext): Expression = { - val left = expression(ctx.value) - val right = expression(ctx.pattern) - val like = ctx.like.getType match { + // Create the predicate. + ctx.kind.getType match { + case SqlBaseParser.BETWEEN => + // BETWEEN is translated to lower <= e && e <= upper + invertIfNotDefined(And( + GreaterThanOrEqual(e, expression(ctx.lower)), + LessThanOrEqual(e, expression(ctx.upper)))) + case SqlBaseParser.IN if ctx.query != null => + throw new ParseException("IN with a Sub-query is currently not supported.", ctx) + case SqlBaseParser.IN => + invertIfNotDefined(In(e, ctx.expression.asScala.map(expression))) case SqlBaseParser.LIKE => - Like(left, right) + invertIfNotDefined(Like(e, expression(ctx.pattern))) case SqlBaseParser.RLIKE => - RLike(left, right) - } - invertIfNotDefined(like, ctx.NOT) - } - - /** - * Create an IS (NOT) NULL expression. - */ - override def visitNullPredicate(ctx: NullPredicateContext): Expression = withOrigin(ctx) { - val value = expression(ctx.value) - if (ctx.NOT != null) { - IsNotNull(value) - } else { - IsNull(value) + invertIfNotDefined(RLike(e, expression(ctx.pattern))) + case SqlBaseParser.NULL if ctx.NOT != null => + IsNotNull(e) + case SqlBaseParser.NULL => + IsNull(e) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 23f05ce846..9e1660df06 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -122,16 +122,6 @@ class PlanParserSuite extends PlanTest { table("a").union(table("b")).as("c").select(star())) } - test("transform query spec") { - val p = ScriptTransformation(Seq('a, 'b), "func", Seq.empty, table("e"), null) - assertEqual("select transform(a, b) using 'func' from e where f < 10", - p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) - assertEqual("map a, b using 'func' as c, d from e", - p.copy(output = Seq('c.string, 'd.string))) - assertEqual("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e", - p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) - } - test("multi select query") { assertEqual( "from a select * select * where s < 10", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 8b2a5979e2..47e295a7e7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.BucketSpec @@ -781,4 +782,26 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } + + test("commands only available in HiveContext") { + intercept[ParseException] { + parser.parsePlan("DROP TABLE D1.T1") + } + intercept[ParseException] { + parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan("ALTER VIEW testView AS SELECT id FROM tab") + } + intercept[ParseException] { + parser.parsePlan( + """ + |CREATE EXTERNAL TABLE parquet_tab2(c1 INT, c2 STRING) + |TBLPROPERTIES('prop1Key '= "prop1Val", ' `prop2Key` '= "prop2Val") + """.stripMargin) + } + intercept[ParseException] { + parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 12e4f49756..55e69f99a4 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -133,6 +133,18 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } + /** + * Create a [[CatalogStorageFormat]]. This is part of the [[CreateTableAsSelect]] command. + */ + override def visitCreateFileFormat( + ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { + if (ctx.storageHandler == null) { + typedVisit[CatalogStorageFormat](ctx.fileFormat) + } else { + visitStorageHandler(ctx.storageHandler) + } + } + /** * Create a [[CreateTableAsSelect]] command. */ @@ -282,6 +294,7 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { * Create a [[HiveScriptIOSchema]]. */ override protected def withScriptIOSchema( + ctx: QuerySpecificationContext, inRowFormat: RowFormatContext, recordWriter: Token, outRowFormat: RowFormatContext, @@ -391,7 +404,8 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Storage Handlers are currently not supported in the statements we support (CTAS). */ - override def visitStorageHandler(ctx: StorageHandlerContext): AnyRef = withOrigin(ctx) { + override def visitStorageHandler( + ctx: StorageHandlerContext): CatalogStorageFormat = withOrigin(ctx) { throw new ParseException("Storage Handlers are currently unsupported.", ctx) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala index 75108c6d47..a8a0d6b8de 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveQlSuite.scala @@ -18,16 +18,19 @@ package org.apache.spark.sql.hive import org.apache.hadoop.hive.serde.serdeConstants -import org.scalatest.BeforeAndAfterAll -import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans +import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple -import org.apache.spark.sql.catalyst.plans.logical.Generate +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} import org.apache.spark.sql.hive.execution.HiveSqlParser -class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { +class HiveQlSuite extends PlanTest { val parser = HiveSqlParser private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -201,6 +204,26 @@ class HiveQlSuite extends SparkFunSuite with BeforeAndAfterAll { assert(plan.children.head.asInstanceOf[Generate].generator.isInstanceOf[JsonTuple]) } + test("transform query spec") { + val plan1 = parser.parsePlan("select transform(a, b) using 'func' from e where f < 10") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan2 = parser.parsePlan("map a, b using 'func' as c, d from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + val plan3 = parser.parsePlan("reduce a, b using 'func' as (c: int, d decimal(10, 0)) from e") + .asInstanceOf[ScriptTransformation].copy(ioschema = null) + + val p = ScriptTransformation( + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b")), + "func", Seq.empty, plans.table("e"), null) + + comparePlans(plan1, + p.copy(child = p.child.where('f < 10), output = Seq('key.string, 'value.string))) + comparePlans(plan2, + p.copy(output = Seq('c.string, 'd.string))) + comparePlans(plan3, + p.copy(output = Seq('c.int, 'd.decimal(10, 0)))) + } + test("use backticks in output of Script Transform") { val plan = parser.parsePlan( """SELECT `t`.`thing1` -- cgit v1.2.3 From f77f11c67125fdac2e6849a4d45d9286fc872ed9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 5 Apr 2016 10:53:54 -0700 Subject: [SPARK-14345][SQL] Decouple deserializer expression resolution from ObjectOperator ## What changes were proposed in this pull request? This PR decouples deserializer expression resolution from `ObjectOperator`, so that we can use deserializer expression in normal operators. This is needed by #12061 and #12067 , I abstracted the logic out and put them in this PR to reduce code change in the future. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12131 from cloud-fan/separate. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 183 +++++++++++---------- .../spark/sql/catalyst/analysis/unresolved.scala | 22 +++ .../sql/catalyst/encoders/ExpressionEncoder.scala | 8 +- .../spark/sql/catalyst/expressions/objects.scala | 14 +- .../spark/sql/catalyst/plans/logical/object.scala | 52 ++---- 5 files changed, 153 insertions(+), 126 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a6e317ebf0..3e0a6d29b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.catalyst.analysis -import java.lang.reflect.Modifier - import scala.annotation.tailrec import scala.collection.mutable.ArrayBuffer @@ -87,9 +85,11 @@ class Analyzer( Batch("Resolution", fixedPoint, ResolveRelations :: ResolveReferences :: + ResolveDeserializer :: + ResolveNewInstance :: + ResolveUpCast :: ResolveGroupingAnalytics :: ResolvePivot :: - ResolveUpCast :: ResolveOrdinalInOrderByAndGroupBy :: ResolveSortReferences :: ResolveGenerate :: @@ -499,18 +499,9 @@ class Analyzer( Generate(newG.asInstanceOf[Generator], join, outer, qualifier, output, child) } - // A special case for ObjectOperator, because the deserializer expressions in ObjectOperator - // should be resolved by their corresponding attributes instead of children's output. - case o: ObjectOperator if containsUnresolvedDeserializer(o.deserializers.map(_._1)) => - val deserializerToAttributes = o.deserializers.map { - case (deserializer, attributes) => new TreeNodeRef(deserializer) -> attributes - }.toMap - - o.transformExpressions { - case expr => deserializerToAttributes.get(new TreeNodeRef(expr)).map { attributes => - resolveDeserializer(expr, attributes) - }.getOrElse(expr) - } + // Skips plan which contains deserializer expressions, as they should be resolved by another + // rule: ResolveDeserializer. + case plan if containsDeserializer(plan.expressions) => plan case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString}") @@ -526,38 +517,6 @@ class Analyzer( } } - private def containsUnresolvedDeserializer(exprs: Seq[Expression]): Boolean = { - exprs.exists { expr => - !expr.resolved || expr.find(_.isInstanceOf[BoundReference]).isDefined - } - } - - def resolveDeserializer( - deserializer: Expression, - attributes: Seq[Attribute]): Expression = { - val unbound = deserializer transform { - case b: BoundReference => attributes(b.ordinal) - } - - resolveExpression(unbound, LocalRelation(attributes), throws = true) transform { - case n: NewInstance - // If this is an inner class of another class, register the outer object in `OuterScopes`. - // Note that static inner classes (e.g., inner classes within Scala objects) don't need - // outer pointer registration. - if n.outerPointer.isEmpty && - n.cls.isMemberClass && - !Modifier.isStatic(n.cls.getModifiers) => - val outer = OuterScopes.getOuterScope(n.cls) - if (outer == null) { - throw new AnalysisException( - s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + - "access to the scope that this class was defined in.\n" + - "Try moving this class out of its parent class.") - } - n.copy(outerPointer = Some(outer)) - } - } - def newAliases(expressions: Seq[NamedExpression]): Seq[NamedExpression] = { expressions.map { case a: Alias => Alias(a.child, a.name)(isGenerated = a.isGenerated) @@ -623,6 +582,10 @@ class Analyzer( } } + private def containsDeserializer(exprs: Seq[Expression]): Boolean = { + exprs.exists(_.find(_.isInstanceOf[UnresolvedDeserializer]).isDefined) + } + protected[sql] def resolveExpression( expr: Expression, plan: LogicalPlan, @@ -1475,7 +1438,94 @@ class Analyzer( Project(projectList, Join(left, right, joinType, newCondition)) } + /** + * Replaces [[UnresolvedDeserializer]] with the deserialization expression that has been resolved + * to the given input attributes. + */ + object ResolveDeserializer extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + case p => p transformExpressions { + case UnresolvedDeserializer(deserializer, inputAttributes) => + val inputs = if (inputAttributes.isEmpty) { + p.children.flatMap(_.output) + } else { + inputAttributes + } + val unbound = deserializer transform { + case b: BoundReference => inputs(b.ordinal) + } + resolveExpression(unbound, LocalRelation(inputs), throws = true) + } + } + } + + /** + * Resolves [[NewInstance]] by finding and adding the outer scope to it if the object being + * constructed is an inner class. + */ + object ResolveNewInstance extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case n: NewInstance if n.childrenResolved && !n.resolved => + val outer = OuterScopes.getOuterScope(n.cls) + if (outer == null) { + throw new AnalysisException( + s"Unable to generate an encoder for inner class `${n.cls.getName}` without " + + "access to the scope that this class was defined in.\n" + + "Try moving this class out of its parent class.") + } + n.copy(outerPointer = Some(outer)) + } + } + } + + /** + * Replace the [[UpCast]] expression by [[Cast]], and throw exceptions if the cast may truncate. + */ + object ResolveUpCast extends Rule[LogicalPlan] { + private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { + throw new AnalysisException(s"Cannot up cast ${from.sql} from " + + s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + + "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + + "You can either add an explicit cast to the input data or choose a higher precision " + + "type of the field in the target object") + } + + private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { + val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) + val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) + toPrecedence > 0 && fromPrecedence > toPrecedence + } + + def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case p if !p.childrenResolved => p + case p if p.resolved => p + + case p => p transformExpressions { + case u @ UpCast(child, _, _) if !child.resolved => u + + case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { + case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => + fail(child, to, walkedTypePath) + case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => + fail(child, to, walkedTypePath) + case (from, to) if illegalNumericPrecedence(from, to) => + fail(child, to, walkedTypePath) + case (TimestampType, DateType) => + fail(child, DateType, walkedTypePath) + case (StringType, to: NumericType) => + fail(child, to, walkedTypePath) + case _ => Cast(child, dataType.asNullable) + } + } + } + } } /** @@ -1559,45 +1609,6 @@ object CleanupAliases extends Rule[LogicalPlan] { } } -/** - * Replace the `UpCast` expression by `Cast`, and throw exceptions if the cast may truncate. - */ -object ResolveUpCast extends Rule[LogicalPlan] { - private def fail(from: Expression, to: DataType, walkedTypePath: Seq[String]) = { - throw new AnalysisException(s"Cannot up cast ${from.sql} from " + - s"${from.dataType.simpleString} to ${to.simpleString} as it may truncate\n" + - "The type path of the target object is:\n" + walkedTypePath.mkString("", "\n", "\n") + - "You can either add an explicit cast to the input data or choose a higher precision " + - "type of the field in the target object") - } - - private def illegalNumericPrecedence(from: DataType, to: DataType): Boolean = { - val fromPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(from) - val toPrecedence = HiveTypeCoercion.numericPrecedence.indexOf(to) - toPrecedence > 0 && fromPrecedence > toPrecedence - } - - def apply(plan: LogicalPlan): LogicalPlan = { - plan transformAllExpressions { - case u @ UpCast(child, _, _) if !child.resolved => u - - case UpCast(child, dataType, walkedTypePath) => (child.dataType, dataType) match { - case (from: NumericType, to: DecimalType) if !to.isWiderThan(from) => - fail(child, to, walkedTypePath) - case (from: DecimalType, to: NumericType) if !from.isTighterThan(to) => - fail(child, to, walkedTypePath) - case (from, to) if illegalNumericPrecedence(from, to) => - fail(child, to, walkedTypePath) - case (TimestampType, DateType) => - fail(child, DateType, walkedTypePath) - case (StringType, to: NumericType) => - fail(child, to, walkedTypePath) - case _ => Cast(child, dataType.asNullable) - } - } - } -} - /** * Maps a time column to multiple time windows using the Expand operator. Since it's non-trivial to * figure out how many windows a time column can map to, we over-estimate the number of windows and diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index e73d367a73..fbbf6302e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -307,3 +307,25 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) override lazy val resolved = false } + +/** + * Holds the deserializer expression and the attributes that are available during the resolution + * for it. Deserializer expression is a special kind of expression that is not always resolved by + * children output, but by given attributes, e.g. the `keyDeserializer` in `MapGroups` should be + * resolved by `groupingAttributes` instead of children output. + * + * @param deserializer The unresolved deserializer expression + * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty + * if we want to resolve deserializer by children output. + */ +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) + extends UnaryExpression with Unevaluable with NonSQLExpression { + // The input attributes used to resolve deserializer expression must be all resolved. + require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") + + override def child: Expression = deserializer + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 1c712fde26..56d29cfbe1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -24,7 +24,7 @@ import scala.reflect.runtime.universe.{typeTag, TypeTag} import org.apache.spark.sql.{AnalysisException, Encoder} import org.apache.spark.sql.catalyst.{InternalRow, JavaTypeInference, ScalaReflection} -import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedExtractValue} +import org.apache.spark.sql.catalyst.analysis.{SimpleAnalyzer, UnresolvedAttribute, UnresolvedDeserializer, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts @@ -317,11 +317,11 @@ case class ExpressionEncoder[T]( def resolve( schema: Seq[Attribute], outerScopes: ConcurrentMap[String, AnyRef]): ExpressionEncoder[T] = { - val resolved = SimpleAnalyzer.ResolveReferences.resolveDeserializer(deserializer, schema) - // Make a fake plan to wrap the deserializer, so that we can go though the whole analyzer, check // analysis, go through optimizer, etc. - val plan = Project(Alias(resolved, "")() :: Nil, LocalRelation(schema)) + val plan = Project( + Alias(UnresolvedDeserializer(deserializer, schema), "")() :: Nil, + LocalRelation(schema)) val analyzedPlan = SimpleAnalyzer.execute(plan) SimpleAnalyzer.checkAnalysis(analyzedPlan) copy(deserializer = SimplifyCasts(analyzedPlan).expressions.head.children.head) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 07b67a0240..eebd43dae9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.expressions +import java.lang.reflect.Modifier + import scala.annotation.tailrec import scala.language.existentials import scala.reflect.ClassTag @@ -112,7 +114,7 @@ case class Invoke( arguments: Seq[Expression] = Nil) extends Expression with NonSQLExpression { override def nullable: Boolean = true - override def children: Seq[Expression] = arguments.+:(targetObject) + override def children: Seq[Expression] = targetObject +: arguments override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") @@ -214,6 +216,16 @@ case class NewInstance( override def children: Seq[Expression] = arguments + override lazy val resolved: Boolean = { + // If the class to construct is an inner class, we need to get its outer pointer, or this + // expression should be regarded as unresolved. + // Note that static inner classes (e.g., inner classes within Scala objects) don't need + // outer pointer registration. + val needOuterPointer = + outerPointer.isEmpty && cls.isMemberClass && !Modifier.isStatic(cls.getModifiers) + childrenResolved && !needOuterPointer + } + override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 058fb6bff1..58313c7b72 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.catalyst.plans.logical import org.apache.spark.sql.Encoder +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{ObjectType, StructType} @@ -32,13 +33,6 @@ trait ObjectOperator extends LogicalPlan { override def output: Seq[Attribute] = serializer.map(_.toAttribute) - /** - * An [[ObjectOperator]] may have one or more deserializers to convert internal rows to objects. - * It must also provide the attributes that are available during the resolution of each - * deserializer. - */ - def deserializers: Seq[(Expression, Seq[Attribute])] - /** * The object type that is produced by the user defined function. Note that the return type here * is the same whether or not the operator is output serialized data. @@ -71,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer, Nil), encoderFor[U].namedExpressions, child) } @@ -87,9 +81,7 @@ case class MapPartitions( func: Iterator[Any] => Iterator[Any], deserializer: Expression, serializer: Seq[NamedExpression], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { @@ -98,7 +90,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[T].deserializer, Nil), encoderFor[U].namedExpressions, child) } @@ -120,8 +112,6 @@ case class AppendColumns( override def output: Seq[Attribute] = child.output ++ newColumns def newColumns: Seq[Attribute] = serializer.map(_.toAttribute) - - override def deserializers: Seq[(Expression, Seq[Attribute])] = Seq(deserializer -> child.output) } /** Factory for constructing new `MapGroups` nodes. */ @@ -133,8 +123,8 @@ object MapGroups { child: LogicalPlan): MapGroups = { new MapGroups( func.asInstanceOf[(Any, Iterator[Any]) => TraversableOnce[Any]], - encoderFor[K].deserializer, - encoderFor[T].deserializer, + UnresolvedDeserializer(encoderFor[K].deserializer, groupingAttributes), + UnresolvedDeserializer(encoderFor[T].deserializer, dataAttributes), encoderFor[U].namedExpressions, groupingAttributes, dataAttributes, @@ -158,11 +148,7 @@ case class MapGroups( serializer: Seq[NamedExpression], groupingAttributes: Seq[Attribute], dataAttributes: Seq[Attribute], - child: LogicalPlan) extends UnaryNode with ObjectOperator { - - override def deserializers: Seq[(Expression, Seq[Attribute])] = - Seq(keyDeserializer -> groupingAttributes, valueDeserializer -> dataAttributes) -} + child: LogicalPlan) extends UnaryNode with ObjectOperator /** Factory for constructing new `CoGroup` nodes. */ object CoGroup { @@ -170,22 +156,24 @@ object CoGroup { func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result], leftGroup: Seq[Attribute], rightGroup: Seq[Attribute], - leftData: Seq[Attribute], - rightData: Seq[Attribute], + leftAttr: Seq[Attribute], + rightAttr: Seq[Attribute], left: LogicalPlan, right: LogicalPlan): CoGroup = { require(StructType.fromAttributes(leftGroup) == StructType.fromAttributes(rightGroup)) CoGroup( func.asInstanceOf[(Any, Iterator[Any], Iterator[Any]) => TraversableOnce[Any]], - encoderFor[Key].deserializer, - encoderFor[Left].deserializer, - encoderFor[Right].deserializer, + // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to + // resolve the `keyDeserializer` based on either of them, here we pick the left one. + UnresolvedDeserializer(encoderFor[Key].deserializer, leftGroup), + UnresolvedDeserializer(encoderFor[Left].deserializer, leftAttr), + UnresolvedDeserializer(encoderFor[Right].deserializer, rightAttr), encoderFor[Result].namedExpressions, leftGroup, rightGroup, - leftData, - rightData, + leftAttr, + rightAttr, left, right) } @@ -206,10 +194,4 @@ case class CoGroup( leftAttr: Seq[Attribute], rightAttr: Seq[Attribute], left: LogicalPlan, - right: LogicalPlan) extends BinaryNode with ObjectOperator { - - override def deserializers: Seq[(Expression, Seq[Attribute])] = - // The `leftGroup` and `rightGroup` are guaranteed te be of same schema, so it's safe to resolve - // the `keyDeserializer` based on either of them, here we pick the left one. - Seq(keyDeserializer -> leftGroup, leftDeserializer -> leftAttr, rightDeserializer -> rightAttr) -} + right: LogicalPlan) extends BinaryNode with ObjectOperator -- cgit v1.2.3 From 72544d6f2a72b9e44e0a32de1fb379e3342be5c3 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Tue, 5 Apr 2016 12:27:06 -0700 Subject: [SPARK-14123][SPARK-14384][SQL] Handle CreateFunction/DropFunction ## What changes were proposed in this pull request? This PR implements CreateFunction and DropFunction commands. Besides implementing these two commands, we also change how to manage functions. Here are the main changes. * `FunctionRegistry` will be a container to store all functions builders and it will not actively load any functions. Because of this change, we do not need to maintain a separate registry for HiveContext. So, `HiveFunctionRegistry` is deleted. * SessionCatalog takes care the job of loading a function if this function is not in the `FunctionRegistry` but its metadata is stored in the external catalog. For this case, SessionCatalog will (1) load the metadata from the external catalog, (2) load all needed resources (i.e. jars and files), (3) create a function builder based on the function definition, (4) register the function builder in the `FunctionRegistry`. * A `UnresolvedGenerator` is created. So, the parser will not need to call `FunctionRegistry` directly during parsing, which is not a good time to create a Hive UDTF. In the analysis phase, we will resolve `UnresolvedGenerator`. This PR is based on viirya's https://github.com/apache/spark/pull/12036/ ## How was this patch tested? Existing tests and new tests. ## TODOs [x] Self-review [x] Cleanup [x] More tests for create/drop functions (we need to more tests for permanent functions). [ ] File JIRAs for all TODOs [x] Standardize the error message when a function does not exist. Author: Yin Huai Author: Liang-Chi Hsieh Closes #12117 from yhuai/function. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 17 +- .../sql/catalyst/analysis/FunctionRegistry.scala | 2 + .../spark/sql/catalyst/analysis/unresolved.scala | 31 +++- .../sql/catalyst/catalog/InMemoryCatalog.scala | 5 - .../sql/catalyst/catalog/SessionCatalog.scala | 171 +++++++++++++-------- .../sql/catalyst/catalog/functionResources.scala | 61 ++++++++ .../spark/sql/catalyst/catalog/interface.scala | 16 +- .../apache/spark/sql/catalyst/identifiers.scala | 16 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 14 +- .../spark/sql/catalyst/analysis/AnalysisTest.scala | 2 +- .../catalyst/analysis/DecimalPrecisionSuite.scala | 2 +- .../sql/catalyst/catalog/CatalogTestCases.scala | 20 +-- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 132 ++++------------ .../optimizer/BooleanSimplificationSuite.scala | 1 - .../catalyst/optimizer/EliminateSortsSuite.scala | 2 +- .../sql/catalyst/parser/PlanParserSuite.scala | 15 +- .../scala/org/apache/spark/sql/SQLContext.scala | 18 ++- .../spark/sql/execution/SparkSqlParser.scala | 7 +- .../spark/sql/execution/command/commands.scala | 29 ++-- .../apache/spark/sql/execution/command/ddl.scala | 20 --- .../spark/sql/execution/command/functions.scala | 114 ++++++++++++++ .../apache/spark/sql/internal/SessionState.scala | 11 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 +- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 3 +- .../sql/execution/command/DDLCommandSuite.scala | 12 +- .../org/apache/spark/sql/test/SQLTestUtils.scala | 22 +++ .../thriftserver/HiveThriftServer2Suites.scala | 94 +++++------ .../spark/sql/hive/HiveExternalCatalog.scala | 11 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 143 ++++++++++++++++- .../apache/spark/sql/hive/HiveSessionState.scala | 18 +-- .../spark/sql/hive/client/HiveClientImpl.scala | 20 ++- .../spark/sql/hive/execution/HiveSqlParser.scala | 17 +- .../scala/org/apache/spark/sql/hive/hiveUDFs.scala | 119 +------------- .../org/apache/spark/sql/hive/test/TestHive.scala | 16 +- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 165 ++++++++++++++++++++ .../scala/org/apache/spark/sql/hive/UDFSuite.scala | 167 +++++++++++++++++++- .../spark/sql/hive/execution/HiveQuerySuite.scala | 8 +- .../spark/sql/hive/execution/HiveUDFSuite.scala | 10 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 74 ++++++--- 39 files changed, 1100 insertions(+), 513 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3e0a6d29b4..473c91e69e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -45,10 +45,7 @@ object SimpleAnalyzer new SimpleCatalystConf(caseSensitiveAnalysis = true)) class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) - extends Analyzer( - new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), - functionRegistry, - conf) + extends Analyzer(new SessionCatalog(new InMemoryCatalog, functionRegistry, conf), conf) /** * Provides a logical query plan analyzer, which translates [[UnresolvedAttribute]]s and @@ -57,7 +54,6 @@ class SimpleAnalyzer(functionRegistry: FunctionRegistry, conf: CatalystConf) */ class Analyzer( catalog: SessionCatalog, - registry: FunctionRegistry, conf: CatalystConf, maxIterations: Int = 100) extends RuleExecutor[LogicalPlan] with CheckAnalysis { @@ -756,9 +752,18 @@ class Analyzer( case q: LogicalPlan => q transformExpressions { case u if !u.childrenResolved => u // Skip until children are resolved. + case u @ UnresolvedGenerator(name, children) => + withPosition(u) { + catalog.lookupFunction(name, children) match { + case generator: Generator => generator + case other => + failAnalysis(s"$name is expected to be a generator. However, " + + s"its class is ${other.getClass.getCanonicalName}, which is not a generator.") + } + } case u @ UnresolvedFunction(name, children, isDistinct) => withPosition(u) { - registry.lookupFunction(name, children) match { + catalog.lookupFunction(name, children) match { // DISTINCT is not meaningful for a Max or a Min. case max: Max if isDistinct => AggregateExpression(max, Complete, isDistinct = false) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ca8db3cbc5..7af5ffbe47 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -52,6 +52,8 @@ trait FunctionRegistry { /** Drop a function and return whether the function existed. */ def dropFunction(name: String): Boolean + /** Checks if a function with a given name exists. */ + def functionExists(name: String): Boolean = lookupFunction(name).isDefined } class SimpleFunctionRegistry extends FunctionRegistry { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index fbbf6302e9..b2f362b6b8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{errors, TableIdentifier} +import org.apache.spark.sql.catalyst.{errors, InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan} import org.apache.spark.sql.catalyst.trees.TreeNode import org.apache.spark.sql.catalyst.util.quoteIdentifier @@ -133,6 +133,33 @@ object UnresolvedAttribute { } } +/** + * Represents an unresolved generator, which will be created by the parser for + * the [[org.apache.spark.sql.catalyst.plans.logical.Generate]] operator. + * The analyzer will resolve this generator. + */ +case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends Generator { + + override def elementTypes: Seq[(DataType, Boolean, String)] = + throw new UnresolvedException(this, "elementTypes") + override def dataType: DataType = throw new UnresolvedException(this, "dataType") + override def foldable: Boolean = throw new UnresolvedException(this, "foldable") + override def nullable: Boolean = throw new UnresolvedException(this, "nullable") + override lazy val resolved = false + + override def prettyName: String = name + override def toString: String = s"'$name(${children.mkString(", ")})" + + override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override protected def genCode(ctx: CodegenContext, ev: ExprCode): String = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") + + override def terminate(): TraversableOnce[InternalRow] = + throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") +} + case class UnresolvedFunction( name: String, children: Seq[Expression], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 2bbb970ec9..2af0107fa3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -315,11 +315,6 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).functions.put(newName, newFunc) } - override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = synchronized { - requireFunctionExists(db, funcDefinition.identifier.funcName) - catalog(db).functions.put(funcDefinition.identifier.funcName, funcDefinition) - } - override def getFunction(db: String, funcName: String): CatalogFunction = synchronized { requireFunctionExists(db, funcName) catalog(db).functions(funcName) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 3b8ce6373d..c08ffbb235 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -24,7 +24,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, SimpleFunctionRegistry} +import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionException, SimpleFunctionRegistry} import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} @@ -39,17 +39,21 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} */ class SessionCatalog( externalCatalog: ExternalCatalog, + functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: CatalystConf) { import ExternalCatalog._ - def this(externalCatalog: ExternalCatalog, functionRegistry: FunctionRegistry) { - this(externalCatalog, functionRegistry, new SimpleCatalystConf(true)) + def this( + externalCatalog: ExternalCatalog, + functionRegistry: FunctionRegistry, + conf: CatalystConf) { + this(externalCatalog, DummyFunctionResourceLoader, functionRegistry, conf) } // For testing only. def this(externalCatalog: ExternalCatalog) { - this(externalCatalog, new SimpleFunctionRegistry) + this(externalCatalog, new SimpleFunctionRegistry, new SimpleCatalystConf(true)) } protected[this] val tempTables = new mutable.HashMap[String, LogicalPlan] @@ -439,53 +443,88 @@ class SessionCatalog( */ def dropFunction(name: FunctionIdentifier): Unit = { val db = name.database.getOrElse(currentDb) + val qualified = name.copy(database = Some(db)).unquotedString + if (functionRegistry.functionExists(qualified)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(qualified) + } externalCatalog.dropFunction(db, name.funcName) } - /** - * Alter a metastore function whose name that matches the one specified in `funcDefinition`. - * - * If no database is specified in `funcDefinition`, assume the function is in the - * current database. - * - * Note: If the underlying implementation does not support altering a certain field, - * this becomes a no-op. - */ - def alterFunction(funcDefinition: CatalogFunction): Unit = { - val db = funcDefinition.identifier.database.getOrElse(currentDb) - val newFuncDefinition = funcDefinition.copy( - identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))) - externalCatalog.alterFunction(db, newFuncDefinition) - } - /** * Retrieve the metadata of a metastore function. * * If a database is specified in `name`, this will return the function in that database. * If no database is specified, this will return the function in the current database. */ + // TODO: have a better name. This method is actually for fetching the metadata of a function. def getFunction(name: FunctionIdentifier): CatalogFunction = { val db = name.database.getOrElse(currentDb) externalCatalog.getFunction(db, name.funcName) } + /** + * Check if the specified function exists. + */ + def functionExists(name: FunctionIdentifier): Boolean = { + if (functionRegistry.functionExists(name.unquotedString)) { + // This function exists in the FunctionRegistry. + true + } else { + // Need to check if this function exists in the metastore. + try { + // TODO: It's better to ask external catalog if this function exists. + // So, we can avoid of having this hacky try/catch block. + getFunction(name) != null + } catch { + case _: NoSuchFunctionException => false + case _: AnalysisException => false // HiveExternalCatalog wraps all exceptions with it. + } + } + } // ---------------------------------------------------------------- // | Methods that interact with temporary and metastore functions | // ---------------------------------------------------------------- + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + * + * This performs reflection to decide what type of [[Expression]] to return in the builder. + */ + private[sql] def makeFunctionBuilder(name: String, functionClassName: String): FunctionBuilder = { + // TODO: at least support UDAFs here + throw new UnsupportedOperationException("Use sqlContext.udf.register(...) instead.") + } + + /** + * Loads resources such as JARs and Files for a function. Every resource is represented + * by a tuple (resource type, resource uri). + */ + def loadFunctionResources(resources: Seq[(String, String)]): Unit = { + resources.foreach { case (resourceType, uri) => + val functionResource = + FunctionResource(FunctionResourceType.fromString(resourceType.toLowerCase), uri) + functionResourceLoader.loadResource(functionResource) + } + } + /** * Create a temporary function. * This assumes no database is specified in `funcDefinition`. */ def createTempFunction( name: String, + info: ExpressionInfo, funcDefinition: FunctionBuilder, ignoreIfExists: Boolean): Unit = { if (functionRegistry.lookupFunctionBuilder(name).isDefined && !ignoreIfExists) { throw new AnalysisException(s"Temporary function '$name' already exists.") } - functionRegistry.registerFunction(name, funcDefinition) + functionRegistry.registerFunction(name, info, funcDefinition) } /** @@ -501,41 +540,59 @@ class SessionCatalog( } } - /** - * Rename a function. - * - * If a database is specified in `oldName`, this will rename the function in that database. - * If no database is specified, this will first attempt to rename a temporary function with - * the same name, then, if that does not exist, rename the function in the current database. - * - * This assumes the database specified in `oldName` matches the one specified in `newName`. - */ - def renameFunction(oldName: FunctionIdentifier, newName: FunctionIdentifier): Unit = { - if (oldName.database != newName.database) { - throw new AnalysisException("rename does not support moving functions across databases") - } - val db = oldName.database.getOrElse(currentDb) - val oldBuilder = functionRegistry.lookupFunctionBuilder(oldName.funcName) - if (oldName.database.isDefined || oldBuilder.isEmpty) { - externalCatalog.renameFunction(db, oldName.funcName, newName.funcName) - } else { - val oldExpressionInfo = functionRegistry.lookupFunction(oldName.funcName).get - val newExpressionInfo = new ExpressionInfo( - oldExpressionInfo.getClassName, - newName.funcName, - oldExpressionInfo.getUsage, - oldExpressionInfo.getExtended) - functionRegistry.dropFunction(oldName.funcName) - functionRegistry.registerFunction(newName.funcName, newExpressionInfo, oldBuilder.get) - } + protected def failFunctionLookup(name: String): Nothing = { + throw new AnalysisException(s"Undefined function: $name. This function is " + + s"neither a registered temporary function nor " + + s"a permanent function registered in the database $currentDb.") } /** * Return an [[Expression]] that represents the specified function, assuming it exists. - * Note: This is currently only used for temporary functions. + * + * For a temporary function or a permanent function that has been loaded, + * this method will simply lookup the function through the + * FunctionRegistry and create an expression based on the builder. + * + * For a permanent function that has not been loaded, we will first fetch its metadata + * from the underlying external catalog. Then, we will load all resources associated + * with this function (i.e. jars and files). Finally, we create a function builder + * based on the function class and put the builder into the FunctionRegistry. + * The name of this function in the FunctionRegistry will be `databaseName.functionName`. */ def lookupFunction(name: String, children: Seq[Expression]): Expression = { - functionRegistry.lookupFunction(name, children) + // TODO: Right now, the name can be qualified or not qualified. + // It will be better to get a FunctionIdentifier. + // TODO: Right now, we assume that name is not qualified! + val qualifiedName = FunctionIdentifier(name, Some(currentDb)).unquotedString + if (functionRegistry.functionExists(name)) { + // This function has been already loaded into the function registry. + functionRegistry.lookupFunction(name, children) + } else if (functionRegistry.functionExists(qualifiedName)) { + // This function has been already loaded into the function registry. + // Unlike the above block, we find this function by using the qualified name. + functionRegistry.lookupFunction(qualifiedName, children) + } else { + // The function has not been loaded to the function registry, which means + // that the function is a permanent function (if it actually has been registered + // in the metastore). We need to first put the function in the FunctionRegistry. + val catalogFunction = try { + externalCatalog.getFunction(currentDb, name) + } catch { + case e: AnalysisException => failFunctionLookup(name) + case e: NoSuchFunctionException => failFunctionLookup(name) + } + loadFunctionResources(catalogFunction.resources) + // Please note that qualifiedName is provided by the user. However, + // catalogFunction.identifier.unquotedString is returned by the underlying + // catalog. So, it is possible that qualifiedName is not exactly the same as + // catalogFunction.identifier.unquotedString (difference is on case-sensitivity). + // At here, we preserve the input from the user. + val info = new ExpressionInfo(catalogFunction.className, qualifiedName) + val builder = makeFunctionBuilder(qualifiedName, catalogFunction.className) + createTempFunction(qualifiedName, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + functionRegistry.lookupFunction(qualifiedName, children) + } } /** @@ -545,17 +602,11 @@ class SessionCatalog( val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } val regex = pattern.replaceAll("\\*", ".*").r - val _tempFunctions = functionRegistry.listFunction() + val loadedFunctions = functionRegistry.listFunction() .filter { f => regex.pattern.matcher(f).matches() } .map { f => FunctionIdentifier(f) } - dbFunctions ++ _tempFunctions - } - - /** - * Return a temporary function. For testing only. - */ - private[catalog] def getTempFunction(name: String): Option[FunctionBuilder] = { - functionRegistry.lookupFunctionBuilder(name) + // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. + // So, the returned list may have two entries for the same function. + dbFunctions ++ loadedFunctions } - } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala new file mode 100644 index 0000000000..5adcc892cf --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/functionResources.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.catalog + +import org.apache.spark.sql.AnalysisException + +/** An trait that represents the type of a resourced needed by a function. */ +sealed trait FunctionResourceType + +object JarResource extends FunctionResourceType + +object FileResource extends FunctionResourceType + +// We do not allow users to specify a archive because it is YARN specific. +// When loading resources, we will throw an exception and ask users to +// use --archive with spark submit. +object ArchiveResource extends FunctionResourceType + +object FunctionResourceType { + def fromString(resourceType: String): FunctionResourceType = { + resourceType.toLowerCase match { + case "jar" => JarResource + case "file" => FileResource + case "archive" => ArchiveResource + case other => + throw new AnalysisException(s"Resource Type '$resourceType' is not supported.") + } + } +} + +case class FunctionResource(resourceType: FunctionResourceType, uri: String) + +/** + * A simple trait representing a class that can be used to load resources used by + * a function. Because only a SQLContext can load resources, we create this trait + * to avoid of explicitly passing SQLContext around. + */ +trait FunctionResourceLoader { + def loadResource(resource: FunctionResource): Unit +} + +object DummyFunctionResourceLoader extends FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + throw new UnsupportedOperationException + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 303846d313..97b9946140 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -150,15 +150,6 @@ abstract class ExternalCatalog { def renameFunction(db: String, oldName: String, newName: String): Unit - /** - * Alter a function whose name that matches the one specified in `funcDefinition`, - * assuming the function exists. - * - * Note: If the underlying implementation does not support altering a certain field, - * this becomes a no-op. - */ - def alterFunction(db: String, funcDefinition: CatalogFunction): Unit - def getFunction(db: String, funcName: String): CatalogFunction def listFunctions(db: String, pattern: String): Seq[String] @@ -171,8 +162,13 @@ abstract class ExternalCatalog { * * @param identifier name of the function * @param className fully qualified class name, e.g. "org.apache.spark.util.MyFunc" + * @param resources resource types and Uris used by the function */ -case class CatalogFunction(identifier: FunctionIdentifier, className: String) +// TODO: Use FunctionResource instead of (String, String) as the element type of resources. +case class CatalogFunction( + identifier: FunctionIdentifier, + className: String, + resources: Seq[(String, String)]) /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala index 87f4d1b007..aae75956ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/identifiers.scala @@ -25,10 +25,10 @@ package org.apache.spark.sql.catalyst * Format (quoted): "`name`" or "`db`.`name`" */ sealed trait IdentifierWithDatabase { - val name: String + val identifier: String def database: Option[String] - def quotedString: String = database.map(db => s"`$db`.`$name`").getOrElse(s"`$name`") - def unquotedString: String = database.map(db => s"$db.$name").getOrElse(name) + def quotedString: String = database.map(db => s"`$db`.`$identifier`").getOrElse(s"`$identifier`") + def unquotedString: String = database.map(db => s"$db.$identifier").getOrElse(identifier) override def toString: String = quotedString } @@ -36,13 +36,15 @@ sealed trait IdentifierWithDatabase { /** * Identifies a table in a database. * If `database` is not defined, the current database is used. + * When we register a permenent function in the FunctionRegistry, we use + * unquotedString as the function name. */ case class TableIdentifier(table: String, database: Option[String]) extends IdentifierWithDatabase { - override val name: String = table + override val identifier: String = table - def this(name: String) = this(name, None) + def this(table: String) = this(table, None) } @@ -58,9 +60,9 @@ object TableIdentifier { case class FunctionIdentifier(funcName: String, database: Option[String]) extends IdentifierWithDatabase { - override val name: String = funcName + override val identifier: String = funcName - def this(name: String) = this(name, None) + def this(funcName: String) = this(funcName, None) } object FunctionIdentifier { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 14c90918e6..5a3aebff09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -549,8 +549,8 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { Explode(expressions.head) case "json_tuple" => JsonTuple(expressions) - case other => - withGenerator(other, expressions, ctx) + case name => + UnresolvedGenerator(name, expressions) } Generate( @@ -562,16 +562,6 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { query) } - /** - * Create a [[Generator]]. Override this method in order to support custom Generators. - */ - protected def withGenerator( - name: String, - expressions: Seq[Expression], - ctx: LateralViewContext): Generator = { - throw new ParseException(s"Generator function '$name' is not supported", ctx) - } - /** * Create a joins between two or more logical plans. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala index 3ec95ef5b5..b1fcf011f4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisTest.scala @@ -32,7 +32,7 @@ trait AnalysisTest extends PlanTest { val conf = new SimpleCatalystConf(caseSensitive) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) catalog.createTempTable("TaBlE", TestRelations.testRelation, overrideIfExists = true) - new Analyzer(catalog, EmptyFunctionRegistry, conf) { + new Analyzer(catalog, conf) { override val extendedResolutionRules = EliminateSubqueryAliases :: Nil } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala index 1a350bf847..b3b1f5b920 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DecimalPrecisionSuite.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.types._ class DecimalPrecisionSuite extends PlanTest with BeforeAndAfter { private val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true) private val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) - private val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) + private val analyzer = new Analyzer(catalog, conf) private val relation = LocalRelation( AttributeReference("i", IntegerType)(), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 959bd564d9..fbcac09ce2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -433,7 +433,8 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { test("get function") { val catalog = newBasicCatalog() assert(catalog.getFunction("db2", "func1") == - CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass)) + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[(String, String)])) intercept[AnalysisException] { catalog.getFunction("db2", "does_not_exist") } @@ -464,21 +465,6 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { } } - test("alter function") { - val catalog = newBasicCatalog() - assert(catalog.getFunction("db2", "func1").className == funcClass) - catalog.alterFunction("db2", newFunc("func1").copy(className = "muhaha")) - assert(catalog.getFunction("db2", "func1").className == "muhaha") - intercept[AnalysisException] { catalog.alterFunction("db2", newFunc("funcky")) } - } - - test("alter function when database does not exist") { - val catalog = newBasicCatalog() - intercept[AnalysisException] { - catalog.alterFunction("does_not_exist", newFunc()) - } - } - test("list functions") { val catalog = newBasicCatalog() catalog.createFunction("db2", newFunc("func2")) @@ -557,7 +543,7 @@ abstract class CatalogTestUtils { } def newFunc(name: String, database: Option[String] = None): CatalogFunction = { - CatalogFunction(FunctionIdentifier(name, database), funcClass) + CatalogFunction(FunctionIdentifier(name, database), funcClass, Seq.empty[(String, String)]) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index acd97592b6..4d56d001b3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.catalog import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.{Expression, Literal} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo, Literal} import org.apache.spark.sql.catalyst.plans.logical.{Range, SubqueryAlias} @@ -685,19 +685,26 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createTempFunction("temp1", tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("temp2", tempFunc2, ignoreIfExists = false) - assert(catalog.getTempFunction("temp1") == Some(tempFunc1)) - assert(catalog.getTempFunction("temp2") == Some(tempFunc2)) - assert(catalog.getTempFunction("temp3") == None) + val info1 = new ExpressionInfo("tempFunc1", "temp1") + val info2 = new ExpressionInfo("tempFunc2", "temp2") + catalog.createTempFunction("temp1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("temp2", info2, tempFunc2, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction("temp1", arguments) === Literal(1)) + assert(catalog.lookupFunction("temp2", arguments) === Literal(3)) + // Temporary function does not exist. + intercept[AnalysisException] { + catalog.lookupFunction("temp3", arguments) + } val tempFunc3 = (e: Seq[Expression]) => Literal(e.size) + val info3 = new ExpressionInfo("tempFunc3", "temp1") // Temporary function already exists intercept[AnalysisException] { - catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = false) + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = false) } // Temporary function is overridden - catalog.createTempFunction("temp1", tempFunc3, ignoreIfExists = true) - assert(catalog.getTempFunction("temp1") == Some(tempFunc3)) + catalog.createTempFunction("temp1", info3, tempFunc3, ignoreIfExists = true) + assert(catalog.lookupFunction("temp1", arguments) === Literal(arguments.length)) } test("drop function") { @@ -726,11 +733,15 @@ class SessionCatalogSuite extends SparkFunSuite { test("drop temp function") { val catalog = new SessionCatalog(newBasicCatalog()) + val info = new ExpressionInfo("tempFunc", "func1") val tempFunc = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", tempFunc, ignoreIfExists = false) - assert(catalog.getTempFunction("func1") == Some(tempFunc)) + catalog.createTempFunction("func1", info, tempFunc, ignoreIfExists = false) + val arguments = Seq(Literal(1), Literal(2), Literal(3)) + assert(catalog.lookupFunction("func1", arguments) === Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) - assert(catalog.getTempFunction("func1") == None) + intercept[AnalysisException] { + catalog.lookupFunction("func1", arguments) + } intercept[AnalysisException] { catalog.dropTempFunction("func1", ignoreIfNotExists = false) } @@ -739,7 +750,9 @@ class SessionCatalogSuite extends SparkFunSuite { test("get function") { val catalog = new SessionCatalog(newBasicCatalog()) - val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass) + val expected = + CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, + Seq.empty[(String, String)]) assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == expected) // Get function without explicitly specifying database catalog.setCurrentDatabase("db2") @@ -758,8 +771,9 @@ class SessionCatalogSuite extends SparkFunSuite { test("lookup temp function") { val catalog = new SessionCatalog(newBasicCatalog()) + val info1 = new ExpressionInfo("tempFunc1", "func1") val tempFunc1 = (e: Seq[Expression]) => e.head - catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) assert(catalog.lookupFunction("func1", Seq(Literal(1), Literal(2), Literal(3))) == Literal(1)) catalog.dropTempFunction("func1", ignoreIfNotExists = false) intercept[AnalysisException] { @@ -767,98 +781,16 @@ class SessionCatalogSuite extends SparkFunSuite { } } - test("rename function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val newName = "funcky" - assert(sessionCatalog.getFunction( - FunctionIdentifier("func1", Some("db2"))) == newFunc("func1", Some("db2"))) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.renameFunction( - FunctionIdentifier("func1", Some("db2")), FunctionIdentifier(newName, Some("db2"))) - assert(sessionCatalog.getFunction( - FunctionIdentifier(newName, Some("db2"))) == newFunc(newName, Some("db2"))) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set(newName)) - // Rename function without explicitly specifying database - sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.renameFunction(FunctionIdentifier(newName), FunctionIdentifier("func1")) - assert(sessionCatalog.getFunction( - FunctionIdentifier("func1")) == newFunc("func1", Some("db2"))) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - // Renaming "db2.func1" to "db1.func2" should fail because databases don't match - intercept[AnalysisException] { - sessionCatalog.renameFunction( - FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func2", Some("db1"))) - } - } - - test("rename function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.renameFunction( - FunctionIdentifier("func1", Some("does_not_exist")), - FunctionIdentifier("func5", Some("does_not_exist"))) - } - intercept[AnalysisException] { - catalog.renameFunction( - FunctionIdentifier("does_not_exist", Some("db2")), - FunctionIdentifier("x", Some("db2"))) - } - } - - test("rename temp function") { - val externalCatalog = newBasicCatalog() - val sessionCatalog = new SessionCatalog(externalCatalog) - val tempFunc = (e: Seq[Expression]) => e.head - sessionCatalog.createTempFunction("func1", tempFunc, ignoreIfExists = false) - sessionCatalog.setCurrentDatabase("db2") - // If a database is specified, we'll always rename the function in that database - sessionCatalog.renameFunction( - FunctionIdentifier("func1", Some("db2")), FunctionIdentifier("func3", Some("db2"))) - assert(sessionCatalog.getTempFunction("func1") == Some(tempFunc)) - assert(sessionCatalog.getTempFunction("func3") == None) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3")) - // If no database is specified, we'll first rename temporary functions - sessionCatalog.createFunction(newFunc("func1", Some("db2"))) - sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func4")) - assert(sessionCatalog.getTempFunction("func4") == Some(tempFunc)) - assert(sessionCatalog.getTempFunction("func1") == None) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1", "func3")) - // Then, if no such temporary function exist, rename the function in the current database - sessionCatalog.renameFunction(FunctionIdentifier("func1"), FunctionIdentifier("func5")) - assert(sessionCatalog.getTempFunction("func5") == None) - assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func3", "func5")) - } - - test("alter function") { - val catalog = new SessionCatalog(newBasicCatalog()) - assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == funcClass) - catalog.alterFunction(newFunc("func1", Some("db2")).copy(className = "muhaha")) - assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))).className == "muhaha") - // Alter function without explicitly specifying database - catalog.setCurrentDatabase("db2") - catalog.alterFunction(newFunc("func1").copy(className = "derpy")) - assert(catalog.getFunction(FunctionIdentifier("func1")).className == "derpy") - } - - test("alter function when database/function does not exist") { - val catalog = new SessionCatalog(newBasicCatalog()) - intercept[AnalysisException] { - catalog.alterFunction(newFunc("func5", Some("does_not_exist"))) - } - intercept[AnalysisException] { - catalog.alterFunction(newFunc("funcky", Some("db2"))) - } - } - test("list functions") { val catalog = new SessionCatalog(newBasicCatalog()) + val info1 = new ExpressionInfo("tempFunc1", "func1") + val info2 = new ExpressionInfo("tempFunc2", "yes_me") val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last catalog.createFunction(newFunc("func2", Some("db2"))) catalog.createFunction(newFunc("not_me", Some("db2"))) - catalog.createTempFunction("func1", tempFunc1, ignoreIfExists = false) - catalog.createTempFunction("yes_me", tempFunc2, ignoreIfExists = false) + catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) + catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) assert(catalog.listFunctions("db1", "*").toSet == Set(FunctionIdentifier("func1"), FunctionIdentifier("yes_me"))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala index dd6b5cac28..8147d06969 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BooleanSimplificationSuite.scala @@ -141,7 +141,6 @@ class BooleanSimplificationSuite extends PlanTest with PredicateHelper { private val caseInsensitiveConf = new SimpleCatalystConf(false) private val caseInsensitiveAnalyzer = new Analyzer( new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, caseInsensitiveConf), - EmptyFunctionRegistry, caseInsensitiveConf) test("(a && b) || (a && c) => a && (b || c) when case insensitive") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala index 009889d5a1..8c92ad82ac 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSortsSuite.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.catalyst.rules._ class EliminateSortsSuite extends PlanTest { val conf = new SimpleCatalystConf(caseSensitiveAnalysis = true, orderByOrdinal = false) val catalog = new SessionCatalog(new InMemoryCatalog, EmptyFunctionRegistry, conf) - val analyzer = new Analyzer(catalog, EmptyFunctionRegistry, conf) + val analyzer = new Analyzer(catalog, conf) object Optimize extends RuleExecutor[LogicalPlan] { val batches = diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 9e1660df06..262537d9c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -296,10 +297,18 @@ class PlanParserSuite extends PlanTest { .insertInto("t2"), from.where('s < 10).select(star()).insertInto("t3"))) - // Unsupported generator. - intercept( + // Unresolved generator. + val expected = table("t") + .generate( + UnresolvedGenerator("posexplode", Seq('x)), + join = true, + outer = false, + Some("posexpl"), + Seq("x", "y")) + .select(star()) + assertEqual( "select * from t lateral view posexplode(x) posexpl as x, y", - "Generator function 'posexplode' is not supported") + expected) } test("joins") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d4290fee0a..587ba1ea05 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -32,7 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ -import org.apache.spark.sql.catalyst.catalog.{ExternalCatalog, InMemoryCatalog} +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} @@ -208,6 +208,22 @@ class SQLContext private[sql]( sparkContext.addJar(path) } + /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */ + @transient protected[sql] lazy val functionResourceLoader: FunctionResourceLoader = { + new FunctionResourceLoader { + override def loadResource(resource: FunctionResource): Unit = { + resource.resourceType match { + case JarResource => addJar(resource.uri) + case FileResource => sparkContext.addFile(resource.uri) + case ArchiveResource => + throw new AnalysisException( + "Archive is not allowed to be loaded. If YARN mode is used, " + + "please use --archives options while calling spark-submit.") + } + } + } + } + /** * :: Experimental :: * A collection of methods that are considered experimental, but can be used to hook into diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index fb106d1aef..382cc61fac 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -337,10 +337,9 @@ class SparkSqlAstBuilder extends AstBuilder { CreateFunction( database, function, - string(ctx.className), // TODO this is not an alias. + string(ctx.className), resources, - ctx.TEMPORARY != null)( - command(ctx)) + ctx.TEMPORARY != null) } /** @@ -353,7 +352,7 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitDropFunction(ctx: DropFunctionContext): LogicalPlan = withOrigin(ctx) { val (database, function) = visitFunctionName(ctx.qualifiedName) - DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null)(command(ctx)) + DropFunction(database, function, ctx.EXISTS != null, ctx.TEMPORARY != null) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index a4be3bc333..faa7a2cdb4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -426,8 +426,12 @@ case class ShowTablePropertiesCommand( * A command for users to list all of the registered functions. * The syntax of using this command in SQL is: * {{{ - * SHOW FUNCTIONS + * SHOW FUNCTIONS [LIKE pattern] * }}} + * For the pattern, '*' matches any sequence of characters (including no characters) and + * '|' is for alternation. + * For example, "show functions like 'yea*|windo*'" will return "window" and "year". + * * TODO currently we are simply ignore the db */ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends RunnableCommand { @@ -438,18 +442,17 @@ case class ShowFunctions(db: Option[String], pattern: Option[String]) extends Ru schema.toAttributes } - override def run(sqlContext: SQLContext): Seq[Row] = pattern match { - case Some(p) => - try { - val regex = java.util.regex.Pattern.compile(p) - sqlContext.sessionState.functionRegistry.listFunction() - .filter(regex.matcher(_).matches()).map(Row(_)) - } catch { - // probably will failed in the regex that user provided, then returns empty row. - case _: Throwable => Seq.empty[Row] - } - case None => - sqlContext.sessionState.functionRegistry.listFunction().map(Row(_)) + override def run(sqlContext: SQLContext): Seq[Row] = { + val dbName = db.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) + // If pattern is not specified, we use '*', which is used to + // match any sequence of characters (including no characters). + val functionNames = + sqlContext.sessionState.catalog + .listFunctions(dbName, pattern.getOrElse("*")) + .map(_.unquotedString) + // The session catalog caches some persistent functions in the FunctionRegistry + // so there can be duplicates. + functionNames.distinct.sorted.map(Row(_)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index cd7e0eed65..6896881910 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -175,26 +175,6 @@ case class DescribeDatabase( } } -case class CreateFunction( - databaseName: Option[String], - functionName: String, - alias: String, - resources: Seq[(String, String)], - isTemp: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging - -/** - * The DDL command that drops a function. - * ifExists: returns an error if the function doesn't exist, unless this is true. - * isTemp: indicates if it is a temporary function. - */ -case class DropFunction( - databaseName: Option[String], - functionName: String, - ifExists: Boolean, - isTemp: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging - /** Rename in ALTER TABLE/VIEW: change the name of a table/view to a different name. */ case class AlterTableRename( oldName: TableIdentifier, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala new file mode 100644 index 0000000000..66d17e322e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -0,0 +1,114 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} +import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.expressions.ExpressionInfo + + +/** + * The DDL command that creates a function. + * To create a temporary function, the syntax of using this command in SQL is: + * {{{ + * CREATE TEMPORARY FUNCTION functionName + * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] + * }}} + * + * To create a permanent function, the syntax in SQL is: + * {{{ + * CREATE FUNCTION [databaseName.]functionName + * AS className [USING JAR\FILE 'uri' [, JAR|FILE 'uri']] + * }}} + */ +// TODO: Use Seq[FunctionResource] instead of Seq[(String, String)] for resources. +case class CreateFunction( + databaseName: Option[String], + functionName: String, + className: String, + resources: Seq[(String, String)], + isTemp: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + if (isTemp) { + if (databaseName.isDefined) { + throw new AnalysisException( + s"It is not allowed to provide database name when defining a temporary function. " + + s"However, database name ${databaseName.get} is provided.") + } + // We first load resources and then put the builder in the function registry. + // Please note that it is allowed to overwrite an existing temp function. + sqlContext.sessionState.catalog.loadFunctionResources(resources) + val info = new ExpressionInfo(className, functionName) + val builder = + sqlContext.sessionState.catalog.makeFunctionBuilder(functionName, className) + sqlContext.sessionState.catalog.createTempFunction( + functionName, info, builder, ignoreIfExists = false) + } else { + // For a permanent, we will store the metadata into underlying external catalog. + // This function will be loaded into the FunctionRegistry when a query uses it. + // We do not load it into FunctionRegistry right now. + val dbName = databaseName.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) + val func = FunctionIdentifier(functionName, Some(dbName)) + val catalogFunc = CatalogFunction(func, className, resources) + if (sqlContext.sessionState.catalog.functionExists(func)) { + throw new AnalysisException( + s"Function '$functionName' already exists in database '$dbName'.") + } + sqlContext.sessionState.catalog.createFunction(catalogFunc) + } + Seq.empty[Row] + } +} + +/** + * The DDL command that drops a function. + * ifExists: returns an error if the function doesn't exist, unless this is true. + * isTemp: indicates if it is a temporary function. + */ +case class DropFunction( + databaseName: Option[String], + functionName: String, + ifExists: Boolean, + isTemp: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (isTemp) { + if (databaseName.isDefined) { + throw new AnalysisException( + s"It is not allowed to provide database name when dropping a temporary function. " + + s"However, database name ${databaseName.get} is provided.") + } + catalog.dropTempFunction(functionName, ifExists) + } else { + // We are dropping a permanent function. + val dbName = databaseName.getOrElse(catalog.getCurrentDatabase) + val func = FunctionIdentifier(functionName, Some(dbName)) + if (!ifExists && !catalog.functionExists(func)) { + throw new AnalysisException( + s"Function '$functionName' does not exist in database '$dbName'.") + } + catalog.dropFunction(func) + } + Seq.empty[Row] + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index cd29def3be..69e3358d4e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -51,7 +51,12 @@ private[sql] class SessionState(ctx: SQLContext) { /** * Internal catalog for managing table and database states. */ - lazy val catalog = new SessionCatalog(ctx.externalCatalog, functionRegistry, conf) + lazy val catalog = + new SessionCatalog( + ctx.externalCatalog, + ctx.functionResourceLoader, + functionRegistry, + conf) /** * Interface exposed to the user for registering user-defined functions. @@ -62,7 +67,7 @@ private[sql] class SessionState(ctx: SQLContext) { * Logical query plan analyzer for resolving unresolved attributes and relations. */ lazy val analyzer: Analyzer = { - new Analyzer(catalog, functionRegistry, conf) { + new Analyzer(catalog, conf) { override val extendedResolutionRules = PreInsertCastAndRename :: DataSourceAnalysis :: @@ -98,5 +103,5 @@ private[sql] class SessionState(ctx: SQLContext) { * Interface to start and stop [[org.apache.spark.sql.ContinuousQuery]]s. */ lazy val continuousQueryManager: ContinuousQueryManager = new ContinuousQueryManager(ctx) - } + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index b727e88668..5a851b47ca 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -61,8 +61,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { .filter(regex.matcher(_).matches()).map(Row(_)) } checkAnswer(sql("SHOW functions"), getFunctions(".*")) - Seq("^c.*", ".*e$", "log.*", ".*date.*").foreach { pattern => - checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) + Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => + // For the pattern part, only '*' and '|' are allowed as wildcards. + // For '*', we need to replace it to '.*'. + checkAnswer( + sql(s"SHOW FUNCTIONS '$pattern'"), + getFunctions(pattern.replaceAll("\\*", ".*"))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index fd736718af..ec950332c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -83,7 +83,8 @@ class UDFSuite extends QueryTest with SharedSQLContext { val e = intercept[AnalysisException] { df.selectExpr("a_function_that_does_not_exist()") } - assert(e.getMessage.contains("undefined function")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("a_function_that_does_not_exist")) } test("Simple UDF") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 47e295a7e7..c42e8e7233 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -147,13 +147,13 @@ class DDLCommandSuite extends PlanTest { "helloworld", "com.matthewrathbone.example.SimpleUDFExample", Seq(("jar", "/path/to/jar1"), ("jar", "/path/to/jar2")), - isTemp = true)(sql1) + isTemp = true) val expected2 = CreateFunction( Some("hello"), "world", "com.matthewrathbone.example.SimpleUDFExample", Seq(("archive", "/path/to/archive"), ("file", "/path/to/file")), - isTemp = false)(sql2) + isTemp = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } @@ -173,22 +173,22 @@ class DDLCommandSuite extends PlanTest { None, "helloworld", ifExists = false, - isTemp = true)(sql1) + isTemp = true) val expected2 = DropFunction( None, "helloworld", ifExists = true, - isTemp = true)(sql2) + isTemp = true) val expected3 = DropFunction( Some("hello"), "world", ifExists = false, - isTemp = false)(sql3) + isTemp = false) val expected4 = DropFunction( Some("hello"), "world", ifExists = true, - isTemp = false)(sql4) + isTemp = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index 80a85a6615..7844d1b296 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -29,6 +29,7 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.Filter @@ -131,6 +132,27 @@ private[sql] trait SQLTestUtils try f(dir) finally Utils.deleteRecursively(dir) } + /** + * Drops functions after calling `f`. A function is represented by (functionName, isTemporary). + */ + protected def withUserDefinedFunction(functions: (String, Boolean)*)(f: => Unit): Unit = { + try { + f + } catch { + case cause: Throwable => throw cause + } finally { + // If the test failed part way, we don't want to mask the failure by failing to remove + // temp tables that never got created. + try functions.foreach { case (functionName, isTemporary) => + val withTemporary = if (isTemporary) "TEMPORARY" else "" + sqlContext.sql(s"DROP $withTemporary FUNCTION IF EXISTS $functionName") + assert( + !sqlContext.sessionState.catalog.functionExists(FunctionIdentifier(functionName)), + s"Function $functionName should have been dropped. But, it still exists.") + } + } + } + /** * Drops temporary table `tableName` after calling `f`. */ diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala index 2c7358e59a..a1268b8e94 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/HiveThriftServer2Suites.scala @@ -491,46 +491,50 @@ class HiveThriftBinaryServerSuite extends HiveThriftJdbcTest { test("SPARK-11595 ADD JAR with input path having URL scheme") { withJdbcStatement { statement => - val jarPath = "../hive/src/test/resources/TestUDTF.jar" - val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" + try { + val jarPath = "../hive/src/test/resources/TestUDTF.jar" + val jarURL = s"file://${System.getProperty("user.dir")}/$jarPath" - Seq( - s"ADD JAR $jarURL", - s"""CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin - ).foreach(statement.execute) + Seq( + s"ADD JAR $jarURL", + s"""CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin + ).foreach(statement.execute) - val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + val rs1 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") - assert(rs1.next()) - assert(rs1.getString(1) === "Function: udtf_count2") + assert(rs1.next()) + assert(rs1.getString(1) === "Function: udtf_count2") - assert(rs1.next()) - assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { - rs1.getString(1) - } + assert(rs1.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs1.getString(1) + } - assert(rs1.next()) - assert(rs1.getString(1) === "Usage: To be added.") + assert(rs1.next()) + assert(rs1.getString(1) === "Usage: To be added.") - val dataPath = "../hive/src/test/resources/data/files/kv1.txt" + val dataPath = "../hive/src/test/resources/data/files/kv1.txt" - Seq( - s"CREATE TABLE test_udtf(key INT, value STRING)", - s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" - ).foreach(statement.execute) + Seq( + s"CREATE TABLE test_udtf(key INT, value STRING)", + s"LOAD DATA LOCAL INPATH '$dataPath' OVERWRITE INTO TABLE test_udtf" + ).foreach(statement.execute) - val rs2 = statement.executeQuery( - "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") + val rs2 = statement.executeQuery( + "SELECT key, cc FROM test_udtf LATERAL VIEW udtf_count2(value) dd AS cc") - assert(rs2.next()) - assert(rs2.getInt(1) === 97) - assert(rs2.getInt(2) === 500) + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) - assert(rs2.next()) - assert(rs2.getInt(1) === 97) - assert(rs2.getInt(2) === 500) + assert(rs2.next()) + assert(rs2.getInt(1) === 97) + assert(rs2.getInt(2) === 500) + } finally { + statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") + } } } @@ -565,24 +569,28 @@ class SingleSessionSuite extends HiveThriftJdbcTest { }, { statement => - val rs1 = statement.executeQuery("SET foo") + try { + val rs1 = statement.executeQuery("SET foo") - assert(rs1.next()) - assert(rs1.getString(1) === "foo") - assert(rs1.getString(2) === "bar") + assert(rs1.next()) + assert(rs1.getString(1) === "foo") + assert(rs1.getString(2) === "bar") - val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") + val rs2 = statement.executeQuery("DESCRIBE FUNCTION udtf_count2") - assert(rs2.next()) - assert(rs2.getString(1) === "Function: udtf_count2") + assert(rs2.next()) + assert(rs2.getString(1) === "Function: udtf_count2") - assert(rs2.next()) - assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { - rs2.getString(1) - } + assert(rs2.next()) + assertResult("Class: org.apache.spark.sql.hive.execution.GenericUDTFCount2") { + rs2.getString(1) + } - assert(rs2.next()) - assert(rs2.getString(1) === "Usage: To be added.") + assert(rs2.next()) + assert(rs2.getString(1) === "Usage: To be added.") + } finally { + statement.executeQuery("DROP TEMPORARY FUNCTION udtf_count2") + } } ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 11205ae67c..98a5998d03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -272,7 +272,12 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat override def createFunction( db: String, funcDefinition: CatalogFunction): Unit = withClient { - client.createFunction(db, funcDefinition) + // Hive's metastore is case insensitive. However, Hive's createFunction does + // not normalize the function name (unlike the getFunction part). So, + // we are normalizing the function name. + val functionName = funcDefinition.identifier.funcName.toLowerCase + val functionIdentifier = funcDefinition.identifier.copy(funcName = functionName) + client.createFunction(db, funcDefinition.copy(identifier = functionIdentifier)) } override def dropFunction(db: String, name: String): Unit = withClient { @@ -283,10 +288,6 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.renameFunction(db, oldName, newName) } - override def alterFunction(db: String, funcDefinition: CatalogFunction): Unit = withClient { - client.alterFunction(db, funcDefinition) - } - override def getFunction(db: String, funcName: String): CatalogFunction = withClient { client.getFunction(db, funcName) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index dfbf22cc47..d315f39a91 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -17,27 +17,39 @@ package org.apache.spark.sql.hive +import scala.util.{Failure, Success, Try} +import scala.util.control.NonFatal + import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.conf.HiveConf +import org.apache.hadoop.hive.ql.exec.{UDAF, UDF} +import org.apache.hadoop.hive.ql.exec.{FunctionRegistry => HiveFunctionRegistry} +import org.apache.hadoop.hive.ql.udf.generic.{AbstractGenericUDAFResolver, GenericUDF, GenericUDTF} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.FunctionRegistry -import org.apache.spark.sql.catalyst.catalog.SessionCatalog +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.catalog.{FunctionResourceLoader, SessionCatalog} +import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.execution.datasources.BucketSpec +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils -class HiveSessionCatalog( +private[sql] class HiveSessionCatalog( externalCatalog: HiveExternalCatalog, client: HiveClient, context: HiveContext, + functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, conf: SQLConf) - extends SessionCatalog(externalCatalog, functionRegistry, conf) { + extends SessionCatalog(externalCatalog, functionResourceLoader, functionRegistry, conf) { override def setCurrentDatabase(db: String): Unit = { super.setCurrentDatabase(db) @@ -112,4 +124,129 @@ class HiveSessionCatalog( metastoreCatalog.cachedDataSourceTables.getIfPresent(key) } + override def makeFunctionBuilder(funcName: String, className: String): FunctionBuilder = { + makeFunctionBuilder(funcName, Utils.classForName(className)) + } + + /** + * Construct a [[FunctionBuilder]] based on the provided class that represents a function. + */ + private def makeFunctionBuilder(name: String, clazz: Class[_]): FunctionBuilder = { + // When we instantiate hive UDF wrapper class, we may throw exception if the input + // expressions don't satisfy the hive UDF, such as type mismatch, input number + // mismatch, etc. Here we catch the exception and throw AnalysisException instead. + (children: Seq[Expression]) => { + try { + if (classOf[UDF].isAssignableFrom(clazz)) { + val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[GenericUDF].isAssignableFrom(clazz)) { + val udf = HiveGenericUDF(name, new HiveFunctionWrapper(clazz.getName), children) + udf.dataType // Force it to check input data types. + udf + } else if (classOf[AbstractGenericUDAFResolver].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(clazz.getName), children) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[UDAF].isAssignableFrom(clazz)) { + val udaf = HiveUDAFFunction( + name, + new HiveFunctionWrapper(clazz.getName), + children, + isUDAFBridgeRequired = true) + udaf.dataType // Force it to check input data types. + udaf + } else if (classOf[GenericUDTF].isAssignableFrom(clazz)) { + val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(clazz.getName), children) + udtf.elementTypes // Force it to check input data types. + udtf + } else { + throw new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}'") + } + } catch { + case ae: AnalysisException => + throw ae + case NonFatal(e) => + val analysisException = + new AnalysisException(s"No handler for Hive UDF '${clazz.getCanonicalName}': $e") + analysisException.setStackTrace(e.getStackTrace) + throw analysisException + } + } + } + + // We have a list of Hive built-in functions that we do not support. So, we will check + // Hive's function registry and lazily load needed functions into our own function registry. + // Those Hive built-in functions are + // assert_true, collect_list, collect_set, compute_stats, context_ngrams, create_union, + // current_user ,elt, ewah_bitmap, ewah_bitmap_and, ewah_bitmap_empty, ewah_bitmap_or, field, + // histogram_numeric, in_file, index, inline, java_method, map_keys, map_values, + // matchpath, ngrams, noop, noopstreaming, noopwithmap, noopwithmapstreaming, + // parse_url, parse_url_tuple, percentile, percentile_approx, posexplode, reflect, reflect2, + // regexp, sentences, stack, std, str_to_map, windowingtablefunction, xpath, xpath_boolean, + // xpath_double, xpath_float, xpath_int, xpath_long, xpath_number, + // xpath_short, and xpath_string. + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + // TODO: Once lookupFunction accepts a FunctionIdentifier, we should refactor this method to + // if (super.functionExists(name)) { + // super.lookupFunction(name, children) + // } else { + // // This function is a Hive builtin function. + // ... + // } + Try(super.lookupFunction(name, children)) match { + case Success(expr) => expr + case Failure(error) => + if (functionRegistry.functionExists(name)) { + // If the function actually exists in functionRegistry, it means that there is an + // error when we create the Expression using the given children. + // We need to throw the original exception. + throw error + } else { + // This function is not in functionRegistry, let's try to load it as a Hive's + // built-in function. + // Hive is case insensitive. + val functionName = name.toLowerCase + // TODO: This may not really work for current_user because current_user is not evaluated + // with session info. + // We do not need to use executionHive at here because we only load + // Hive's builtin functions, which do not need current db. + val functionInfo = { + try { + Option(HiveFunctionRegistry.getFunctionInfo(functionName)).getOrElse( + failFunctionLookup(name)) + } catch { + // If HiveFunctionRegistry.getFunctionInfo throws an exception, + // we are failing to load a Hive builtin function, which means that + // the given function is not a Hive builtin function. + case NonFatal(e) => failFunctionLookup(name) + } + } + val className = functionInfo.getFunctionClass.getName + val builder = makeFunctionBuilder(functionName, className) + // Put this Hive built-in function to our function registry. + val info = new ExpressionInfo(className, functionName) + createTempFunction(functionName, info, builder, ignoreIfExists = false) + // Now, we need to create the Expression. + functionRegistry.lookupFunction(functionName, children) + } + } + } + + // Pre-load a few commonly used Hive built-in functions. + HiveSessionCatalog.preloadedHiveBuiltinFunctions.foreach { + case (functionName, clazz) => + val builder = makeFunctionBuilder(functionName, clazz) + val info = new ExpressionInfo(clazz.getCanonicalName, functionName) + createTempFunction(functionName, info, builder, ignoreIfExists = false) + } +} + +private[sql] object HiveSessionCatalog { + // This is the list of Hive's built-in functions that are commonly used and we want to + // pre-load when we create the FunctionRegistry. + val preloadedHiveBuiltinFunctions = + ("collect_set", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectSet]) :: + ("collect_list", classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDAFCollectList]) :: Nil } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index 829afa8432..cff24e28fd 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -35,26 +35,24 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) override def caseSensitiveAnalysis: Boolean = getConf(SQLConf.CASE_SENSITIVE, false) } - /** - * Internal catalog for managing functions registered by the user. - * Note that HiveUDFs will be overridden by functions registered in this context. - */ - override lazy val functionRegistry: FunctionRegistry = { - new HiveFunctionRegistry(FunctionRegistry.builtin.copy(), ctx.executionHive) - } - /** * Internal catalog for managing table and database states. */ override lazy val catalog = { - new HiveSessionCatalog(ctx.hiveCatalog, ctx.metadataHive, ctx, functionRegistry, conf) + new HiveSessionCatalog( + ctx.hiveCatalog, + ctx.metadataHive, + ctx, + ctx.functionResourceLoader, + functionRegistry, + conf) } /** * An analyzer that uses the Hive metastore. */ override lazy val analyzer: Analyzer = { - new Analyzer(catalog, functionRegistry, conf) { + new Analyzer(catalog, conf) { override val extendedResolutionRules = catalog.ParquetConversions :: catalog.OrcConversions :: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a31178e347..1f66fbfd85 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -21,13 +21,14 @@ import java.io.{File, PrintStream} import scala.collection.JavaConverters._ import scala.language.reflectiveCalls +import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} -import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceUri} +import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc @@ -37,6 +38,7 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.internal.Logging +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis.{NoSuchDatabaseException, NoSuchPartitionException} import org.apache.spark.sql.catalyst.catalog._ @@ -611,6 +613,9 @@ private[hive] class HiveClientImpl( .asInstanceOf[Class[_ <: org.apache.hadoop.hive.ql.io.HiveOutputFormat[_, _]]] private def toHiveFunction(f: CatalogFunction, db: String): HiveFunction = { + val resourceUris = f.resources.map { case (resourceType, resourcePath) => + new ResourceUri(ResourceType.valueOf(resourceType.toUpperCase), resourcePath) + } new HiveFunction( f.identifier.funcName, db, @@ -619,12 +624,21 @@ private[hive] class HiveClientImpl( PrincipalType.USER, (System.currentTimeMillis / 1000).toInt, FunctionType.JAVA, - List.empty[ResourceUri].asJava) + resourceUris.asJava) } private def fromHiveFunction(hf: HiveFunction): CatalogFunction = { val name = FunctionIdentifier(hf.getFunctionName, Option(hf.getDbName)) - new CatalogFunction(name, hf.getClassName) + val resources = hf.getResourceUris.asScala.map { uri => + val resourceType = uri.getResourceType() match { + case ResourceType.ARCHIVE => "archive" + case ResourceType.FILE => "file" + case ResourceType.JAR => "jar" + case r => throw new AnalysisException(s"Unknown resource type: $r") + } + (resourceType, uri.getUri()) + } + new CatalogFunction(name, hf.getClassName, resources) } private def toHiveColumn(c: CatalogColumn): FieldSchema = { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 55e69f99a4..c6c0b2ca59 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -21,13 +21,13 @@ import scala.collection.JavaConverters._ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.conf.HiveConf.ConfVars -import org.apache.hadoop.hive.ql.exec.FunctionRegistry import org.apache.hadoop.hive.ql.parse.EximUtil import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe -import org.apache.spark.sql.catalyst.catalog.{CatalogColumn, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator +import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -277,19 +277,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { CreateView(tableDesc, plan(query), allowExist, replace, command(ctx)) } - /** - * Create a [[Generator]]. Override this method in order to support custom Generators. - */ - override protected def withGenerator( - name: String, - expressions: Seq[Expression], - ctx: LateralViewContext): Generator = { - val info = Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse { - throw new ParseException(s"Couldn't find Generator function '$name'", ctx) - } - HiveGenericUDTF(name, new HiveFunctionWrapper(info.getFunctionClass.getName), expressions) - } - /** * Create a [[HiveScriptIOSchema]]. */ diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 5ada3d5598..784b018353 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.hive import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer -import scala.util.Try import org.apache.hadoop.hive.ql.exec._ import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} @@ -31,130 +30,14 @@ import org.apache.hadoop.hive.serde2.objectinspector.{ConstantObjectInspector, O import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions import org.apache.spark.internal.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.{analysis, InternalRow} -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.hive.client.HiveClientImpl import org.apache.spark.sql.types._ -private[hive] class HiveFunctionRegistry( - underlying: analysis.FunctionRegistry, - executionHive: HiveClientImpl) - extends analysis.FunctionRegistry with HiveInspectors { - - def getFunctionInfo(name: String): FunctionInfo = { - // Hive Registry need current database to lookup function - // TODO: the current database of executionHive should be consistent with metadataHive - executionHive.withHiveState { - FunctionRegistry.getFunctionInfo(name) - } - } - - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - Try(underlying.lookupFunction(name, children)).getOrElse { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(getFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"undefined function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - // When we instantiate hive UDF wrapper class, we may throw exception if the input expressions - // don't satisfy the hive UDF, such as type mismatch, input number mismatch, etc. Here we - // catch the exception and throw AnalysisException instead. - try { - if (classOf[GenericUDFMacro].isAssignableFrom(functionInfo.getFunctionClass)) { - val udf = HiveGenericUDF( - name, new HiveFunctionWrapper(functionClassName, functionInfo.getGenericUDF), children) - udf.dataType // Force it to check input data types. - udf - } else if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udf = HiveSimpleUDF(name, new HiveFunctionWrapper(functionClassName), children) - udf.dataType // Force it to check input data types. - udf - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udf = HiveGenericUDF(name, new HiveFunctionWrapper(functionClassName), children) - udf.dataType // Force it to check input data types. - udf - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - val udaf = HiveUDAFFunction(name, new HiveFunctionWrapper(functionClassName), children) - udaf.dataType // Force it to check input data types. - udaf - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udaf = HiveUDAFFunction( - name, new HiveFunctionWrapper(functionClassName), children, isUDAFBridgeRequired = true) - udaf.dataType // Force it to check input data types. - udaf - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - val udtf = HiveGenericUDTF(name, new HiveFunctionWrapper(functionClassName), children) - udtf.elementTypes // Force it to check input data types. - udtf - } else { - throw new AnalysisException(s"No handler for udf ${functionInfo.getFunctionClass}") - } - } catch { - case analysisException: AnalysisException => - // If the exception is an AnalysisException, just throw it. - throw analysisException - case throwable: Throwable => - // If there is any other error, we throw an AnalysisException. - val errorMessage = s"No handler for Hive udf ${functionInfo.getFunctionClass} " + - s"because: ${throwable.getMessage}." - val analysisException = new AnalysisException(errorMessage) - analysisException.setStackTrace(throwable.getStackTrace) - throw analysisException - } - } - } - - override def registerFunction(name: String, info: ExpressionInfo, builder: FunctionBuilder) - : Unit = underlying.registerFunction(name, info, builder) - - /* List all of the registered function names. */ - override def listFunction(): Seq[String] = { - (FunctionRegistry.getFunctionNames.asScala ++ underlying.listFunction()).toList.sorted - } - - /* Get the class of the registered function by specified name. */ - override def lookupFunction(name: String): Option[ExpressionInfo] = { - underlying.lookupFunction(name).orElse( - Try { - val info = getFunctionInfo(name) - val annotation = info.getFunctionClass.getAnnotation(classOf[Description]) - if (annotation != null) { - Some(new ExpressionInfo( - info.getFunctionClass.getCanonicalName, - annotation.name(), - annotation.value(), - annotation.extended())) - } else { - Some(new ExpressionInfo( - info.getFunctionClass.getCanonicalName, - name, - null, - null)) - } - }.getOrElse(None)) - } - - override def lookupFunctionBuilder(name: String): Option[FunctionBuilder] = { - underlying.lookupFunctionBuilder(name) - } - - // Note: This does not drop functions stored in the metastore - override def dropFunction(name: String): Boolean = { - underlying.dropFunction(name) - } - -} - private[hive] case class HiveSimpleUDF( name: String, funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) extends Expression with HiveInspectors with CodegenFallback with Logging { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index 9393302355..7f6ca21782 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -201,8 +201,13 @@ class TestHiveContext private[hive]( } override lazy val functionRegistry = { - new TestHiveFunctionRegistry( - org.apache.spark.sql.catalyst.analysis.FunctionRegistry.builtin.copy(), self.executionHive) + // We use TestHiveFunctionRegistry at here to track functions that have been explicitly + // unregistered (through TestHiveFunctionRegistry.unregisterFunction method). + val fr = new TestHiveFunctionRegistry + org.apache.spark.sql.catalyst.analysis.FunctionRegistry.expressions.foreach { + case (name, (info, builder)) => fr.registerFunction(name, info, builder) + } + fr } } @@ -528,19 +533,18 @@ class TestHiveContext private[hive]( } -private[hive] class TestHiveFunctionRegistry(fr: SimpleFunctionRegistry, client: HiveClientImpl) - extends HiveFunctionRegistry(fr, client) { +private[hive] class TestHiveFunctionRegistry extends SimpleFunctionRegistry { private val removedFunctions = collection.mutable.ArrayBuffer.empty[(String, (ExpressionInfo, FunctionBuilder))] def unregisterFunction(name: String): Unit = { - fr.functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) + functionBuilders.remove(name).foreach(f => removedFunctions += name -> f) } def restore(): Unit = { removedFunctions.foreach { - case (name, (info, builder)) => fr.registerFunction(name, info, builder) + case (name, (info, builder)) => registerFunction(name, info, builder) } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index 53dec6348f..dd2129375d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -32,6 +32,8 @@ import org.scalatest.time.SpanSugar._ import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.sql.{QueryTest, Row, SQLContext} +import org.apache.spark.sql.catalyst.catalog.CatalogFunction +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.hive.test.{TestHive, TestHiveContext} import org.apache.spark.sql.test.ProcessTestUtils.ProcessOutputCapturer @@ -55,6 +57,57 @@ class HiveSparkSubmitSuite System.setProperty("spark.testing", "true") } + test("temporary Hive UDF: define a UDF and use it") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", TemporaryHiveUDFTest.getClass.getName.stripSuffix("$"), + "--name", "TemporaryHiveUDFTest", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("permanent Hive UDF: define a UDF and use it") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", PermanentHiveUDFTest1.getClass.getName.stripSuffix("$"), + "--name", "PermanentHiveUDFTest1", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + + test("permanent Hive UDF: use a already defined permanent function") { + val unusedJar = TestUtils.createJarWithClasses(Seq.empty) + val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) + val jar2 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassB")) + val jarsString = Seq(jar1, jar2).map(j => j.toString).mkString(",") + val args = Seq( + "--class", PermanentHiveUDFTest2.getClass.getName.stripSuffix("$"), + "--name", "PermanentHiveUDFTest2", + "--master", "local-cluster[2,1,1024]", + "--conf", "spark.ui.enabled=false", + "--conf", "spark.master.rest.enabled=false", + "--driver-java-options", "-Dderby.system.durability=test", + "--jars", jarsString, + unusedJar.toString, "SparkSubmitClassA", "SparkSubmitClassB") + runSparkSubmit(args) + } + test("SPARK-8368: includes jars passed in through --jars") { val unusedJar = TestUtils.createJarWithClasses(Seq.empty) val jar1 = TestUtils.createJarWithClasses(Seq("SparkSubmitClassA")) @@ -208,6 +261,118 @@ class HiveSparkSubmitSuite } } +// This application is used to test defining a new Hive UDF (with an associated jar) +// and use this UDF. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object TemporaryHiveUDFTest extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + + // Load a Hive UDF from the jar. + logInfo("Registering a temporary Hive UDF provided in a jar.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + hiveContext.sql( + s""" + |CREATE TEMPORARY FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + |USING JAR '$jar' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP temporary FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + +// This application is used to test defining a new Hive UDF (with an associated jar) +// and use this UDF. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object PermanentHiveUDFTest1 extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + + // Load a Hive UDF from the jar. + logInfo("Registering a permanent Hive UDF provided in a jar.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + hiveContext.sql( + s""" + |CREATE FUNCTION example_max + |AS 'org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax' + |USING JAR '$jar' + """.stripMargin) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + +// This application is used to test that a pre-defined permanent function with a jar +// resources can be used. We need to run this test in separate JVM to make sure we +// can load the jar defined with the function. +object PermanentHiveUDFTest2 extends Logging { + def main(args: Array[String]) { + Utils.configTestLog4j("INFO") + val conf = new SparkConf() + conf.set("spark.ui.enabled", "false") + val sc = new SparkContext(conf) + val hiveContext = new TestHiveContext(sc) + // Load a Hive UDF from the jar. + logInfo("Write the metadata of a permanent Hive UDF into metastore.") + val jar = hiveContext.getHiveFile("hive-contrib-0.13.1.jar").getCanonicalPath + val function = CatalogFunction( + FunctionIdentifier("example_max"), + "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax", + ("JAR" -> jar) :: Nil) + hiveContext.sessionState.catalog.createFunction(function) + val source = + hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") + source.registerTempTable("sourceTable") + // Actually use the loaded UDF. + logInfo("Using the UDF.") + val result = hiveContext.sql( + "SELECT example_max(key) as key, val FROM sourceTable GROUP BY val") + logInfo("Running a simple query on the table.") + val count = result.orderBy("key", "val").count() + if (count != 10) { + throw new Exception(s"Result table should have 10 rows instead of $count rows") + } + hiveContext.sql("DROP FUNCTION example_max") + logInfo("Test finishes.") + sc.stop() + } +} + // This object is used for testing SPARK-8368: https://issues.apache.org/jira/browse/SPARK-8368. // We test if we can load user jars in both driver and executors when HiveContext is used. object SparkSubmitClassLoaderTest extends Logging { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala index 3ab4576811..d1aa5aa931 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala @@ -17,12 +17,51 @@ package org.apache.spark.sql.hive -import org.apache.spark.sql.QueryTest +import org.scalatest.BeforeAndAfterEach + +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.test.SQLTestUtils case class FunctionResult(f1: String, f2: String) -class UDFSuite extends QueryTest with TestHiveSingleton { +/** + * A test suite for UDF related functionalities. Because Hive metastore is + * case insensitive, database names and function names have both upper case + * letters and lower case letters. + */ +class UDFSuite + extends QueryTest + with SQLTestUtils + with TestHiveSingleton + with BeforeAndAfterEach { + + import hiveContext.implicits._ + + private[this] val functionName = "myUPper" + private[this] val functionNameUpper = "MYUPPER" + private[this] val functionNameLower = "myupper" + + private[this] val functionClass = + classOf[org.apache.hadoop.hive.ql.udf.generic.GenericUDFUpper].getCanonicalName + + private var testDF: DataFrame = null + private[this] val testTableName = "testDF_UDFSuite" + private var expectedDF: DataFrame = null + + override def beforeAll(): Unit = { + sql("USE default") + + testDF = (1 to 10).map(i => s"sTr$i").toDF("value") + testDF.registerTempTable(testTableName) + expectedDF = (1 to 10).map(i => s"STR$i").toDF("value") + super.beforeAll() + } + + override def afterEach(): Unit = { + sql("USE default") + super.afterEach() + } test("UDF case insensitive") { hiveContext.udf.register("random0", () => { Math.random() }) @@ -32,4 +71,128 @@ class UDFSuite extends QueryTest with TestHiveSingleton { assert(hiveContext.sql("SELECT RANDOm1() FROM src LIMIT 1").head().getDouble(0) >= 0.0) assert(hiveContext.sql("SELECT strlenscala('test', 1) FROM src LIMIT 1").head().getInt(0) === 5) } + + test("temporary function: create and drop") { + withUserDefinedFunction(functionName -> true) { + intercept[AnalysisException] { + sql(s"CREATE TEMPORARY FUNCTION default.$functionName AS '$functionClass'") + } + sql(s"CREATE TEMPORARY FUNCTION $functionName AS '$functionClass'") + checkAnswer( + sql(s"SELECT $functionNameLower(value) from $testTableName"), + expectedDF + ) + intercept[AnalysisException] { + sql(s"DROP TEMPORARY FUNCTION default.$functionName") + } + } + } + + test("permanent function: create and drop without specifying db name") { + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $functionName AS '$functionClass'") + checkAnswer( + sql("SHOW functions like '.*upper'"), + Row(s"default.$functionNameLower") + ) + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + assert( + sql("SHOW functions").collect() + .map(_.getString(0)) + .contains(s"default.$functionNameLower")) + } + } + + test("permanent function: create and drop with a db name") { + // For this block, drop function command uses functionName as the function name. + withUserDefinedFunction(functionNameUpper -> false) { + sql(s"CREATE FUNCTION default.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // in SessionCatalog.lookupFunction. + // checkAnswer( + // sql(s"SELECT default.myuPPer(value) from $testTableName"), + // expectedDF + // ) + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + checkAnswer( + sql(s"SELECT default.$functionName(value) from $testTableName"), + expectedDF + ) + } + + // For this block, drop function command uses default.functionName as the function name. + withUserDefinedFunction(s"DEfault.$functionNameLower" -> false) { + sql(s"CREATE FUNCTION dEFault.$functionName AS '$functionClass'") + checkAnswer( + sql(s"SELECT $functionNameUpper(value) from $testTableName"), + expectedDF + ) + } + } + + test("permanent function: create and drop a function in another db") { + // For this block, drop function command uses functionName as the function name. + withTempDatabase { dbName => + withUserDefinedFunction(functionName -> false) { + sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // checkAnswer( + // sql(s"SELECT $dbName.myuPPer(value) from $testTableName"), + // expectedDF + // ) + + checkAnswer( + sql(s"SHOW FUNCTIONS like $dbName.$functionNameUpper"), + Row(s"$dbName.$functionNameLower") + ) + + sql(s"USE $dbName") + + checkAnswer( + sql(s"SELECT $functionName(value) from $testTableName"), + expectedDF + ) + + sql(s"USE default") + + checkAnswer( + sql(s"SELECT $dbName.$functionName(value) from $testTableName"), + expectedDF + ) + + sql(s"USE $dbName") + } + + sql(s"USE default") + + // For this block, drop function command uses default.functionName as the function name. + withUserDefinedFunction(s"$dbName.$functionNameUpper" -> false) { + sql(s"CREATE FUNCTION $dbName.$functionName AS '$functionClass'") + // TODO: Re-enable it after can distinguish qualified and unqualified function name + // checkAnswer( + // sql(s"SELECT $dbName.myupper(value) from $testTableName"), + // expectedDF + // ) + + sql(s"USE $dbName") + + assert( + sql("SHOW functions").collect() + .map(_.getString(0)) + .contains(s"$dbName.$functionNameLower")) + checkAnswer( + sql(s"SELECT $functionNameLower(value) from $testTableName"), + expectedDF + ) + + sql(s"USE default") + } + } + } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index b951948fda..0c57ede9ed 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -62,7 +62,7 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { TestHive.cacheTables = false TimeZone.setDefault(originalTimeZone) Locale.setDefault(originalLocale) - sql("DROP TEMPORARY FUNCTION udtf_count2") + sql("DROP TEMPORARY FUNCTION IF EXISTS udtf_count2") } finally { super.afterAll() } @@ -1230,14 +1230,16 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") } - assert(e.getMessage.contains("undefined function not_a_udf")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("not_a_udf")) var success = false val t = new Thread("test") { override def run(): Unit = { val e = intercept[AnalysisException] { range(1).selectExpr("not_a_udf()") } - assert(e.getMessage.contains("undefined function not_a_udf")) + assert(e.getMessage.contains("Undefined function")) + assert(e.getMessage.contains("not_a_udf")) success = true } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index b0e263dff9..d07ac56586 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -303,7 +303,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDFTwoListList() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") } @@ -313,7 +313,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDFAnd() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFAnd") } @@ -323,7 +323,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDAFPercentile(a) FROM testUDF GROUP BY b") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFPercentile") } @@ -333,7 +333,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDAFAverage() FROM testUDF GROUP BY b") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDAFAverage") } @@ -343,7 +343,7 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { val message = intercept[AnalysisException] { sql("SELECT testUDTFExplode() FROM testUDF") }.getMessage - assert(message.contains("No handler for Hive udf")) + assert(message.contains("No handler for Hive UDF")) sql("DROP TEMPORARY FUNCTION IF EXISTS testUDTFExplode") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 6199253d34..14a1d4cd30 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, FunctionRegistry} +import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.{HiveContext, MetastoreRelation} @@ -67,22 +68,43 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import hiveContext.implicits._ test("UDTF") { - sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") - // The function source code can be found at: - // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF - sql( - """ - |CREATE TEMPORARY FUNCTION udtf_count2 - |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' - """.stripMargin) + withUserDefinedFunction("udtf_count2" -> true) { + sql(s"ADD JAR ${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}") + // The function source code can be found at: + // https://cwiki.apache.org/confluence/display/Hive/DeveloperGuide+UDTF + sql( + """ + |CREATE TEMPORARY FUNCTION udtf_count2 + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + """.stripMargin) - checkAnswer( - sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), - Row(97, 500) :: Row(97, 500) :: Nil) + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count2(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) - checkAnswer( - sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), - Row(3) :: Row(3) :: Nil) + checkAnswer( + sql("SELECT udtf_count2(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } + } + + test("permanent UDTF") { + withUserDefinedFunction("udtf_count_temp" -> false) { + sql( + s""" + |CREATE FUNCTION udtf_count_temp + |AS 'org.apache.spark.sql.hive.execution.GenericUDTFCount2' + |USING JAR '${hiveContext.getHiveFile("TestUDTF.jar").getCanonicalPath()}' + """.stripMargin) + + checkAnswer( + sql("SELECT key, cc FROM src LATERAL VIEW udtf_count_temp(value) dd AS cc"), + Row(97, 500) :: Row(97, 500) :: Nil) + + checkAnswer( + sql("SELECT udtf_count_temp(a) FROM (SELECT 1 AS a FROM src LIMIT 3) t"), + Row(3) :: Row(3) :: Nil) + } } test("SPARK-6835: udtf in lateral view") { @@ -169,9 +191,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("show functions") { - val allBuiltinFunctions = - (FunctionRegistry.builtin.listFunction().toSet[String] ++ - org.apache.hadoop.hive.ql.exec.FunctionRegistry.getFunctionNames.asScala).toList.sorted + val allBuiltinFunctions = FunctionRegistry.builtin.listFunction().toSet[String].toList.sorted // The TestContext is shared by all the test cases, some functions may be registered before // this, so we check that all the builtin functions are returned. val allFunctions = sql("SHOW functions").collect().map(r => r(0)) @@ -183,11 +203,16 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - checkAnswer(sql("SHOW functions `~`"), Row("~")) + // TODO: Re-enable this test after we fix SPARK-14335. + // checkAnswer(sql("SHOW functions `~`"), Row("~")) checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) - checkAnswer(sql("SHOW functions `weekofyea.*`"), Row("weekofyear")) + checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. - checkAnswer(sql("SHOW functions `sha.*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + checkAnswer(sql("SHOW functions `sha*`"), Row("sha") :: Row("sha1") :: Row("sha2") :: Nil) + // Test '|' for alternation. + checkAnswer( + sql("SHOW functions 'sha*|weekofyea*'"), + Row("sha") :: Row("sha1") :: Row("sha2") :: Row("weekofyear") :: Nil) } test("describe functions") { @@ -211,10 +236,11 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkExistence(sql("describe functioN abcadf"), true, "Function: abcadf not found.") - checkExistence(sql("describe functioN `~`"), true, - "Function: ~", - "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", - "Usage: ~ n - Bitwise not") + // TODO: Re-enable this test after we fix SPARK-14335. + // checkExistence(sql("describe functioN `~`"), true, + // "Function: ~", + // "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", + // "Usage: ~ n - Bitwise not") } test("SPARK-5371: union with null and sum") { -- cgit v1.2.3 From 9ee5c257176d5c7989031d260e74e3eca530c120 Mon Sep 17 00:00:00 2001 From: Burak Yavuz Date: Tue, 5 Apr 2016 13:18:39 -0700 Subject: [SPARK-14353] Dataset Time Window `window` API for Python, and SQL ## What changes were proposed in this pull request? The `window` function was added to Dataset with [this PR](https://github.com/apache/spark/pull/12008). This PR adds the Python, and SQL, API for this function. With this PR, SQL, Java, and Scala will share the same APIs as in users can use: - `window(timeColumn, windowDuration)` - `window(timeColumn, windowDuration, slideDuration)` - `window(timeColumn, windowDuration, slideDuration, startTime)` In Python, users can access all APIs above, but in addition they can do - In Python: `window(timeColumn, windowDuration, startTime=...)` that is, they can provide the startTime without providing the `slideDuration`. In this case, we will generate tumbling windows. ## How was this patch tested? Unit tests + manual tests Author: Burak Yavuz Closes #12136 from brkyvz/python-windows. --- python/pyspark/sql/functions.py | 49 +++++++++++++++++++ .../sql/catalyst/analysis/FunctionRegistry.scala | 5 +- .../sql/catalyst/expressions/TimeWindow.scala | 35 +++++++++++++ .../apache/spark/sql/catalyst/trees/TreeNode.scala | 27 +++++++--- .../sql/catalyst/expressions/TimeWindowSuite.scala | 37 +++++++++++++- .../scala/org/apache/spark/sql/functions.scala | 9 ++-- .../spark/sql/DataFrameTimeWindowingSuite.scala | 57 ++++++++++++++++++++++ 7 files changed, 204 insertions(+), 15 deletions(-) (limited to 'sql/catalyst') diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 3b20ba5177..5017ab5b36 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -1053,6 +1053,55 @@ def to_utc_timestamp(timestamp, tz): return Column(sc._jvm.functions.to_utc_timestamp(_to_java_column(timestamp), tz)) +@since(2.0) +@ignore_unicode_prefix +def window(timeColumn, windowDuration, slideDuration=None, startTime=None): + """Bucketize rows into one or more time windows given a timestamp specifying column. Window + starts are inclusive but the window ends are exclusive, e.g. 12:05 will be in the window + [12:05,12:10) but not in [12:00,12:05). Windows can support microsecond precision. Windows in + the order of months are not supported. + + The time column must be of TimestampType. + + Durations are provided as strings, e.g. '1 second', '1 day 12 hours', '2 minutes'. Valid + interval strings are 'week', 'day', 'hour', 'minute', 'second', 'millisecond', 'microsecond'. + If the `slideDuration` is not provided, the windows will be tumbling windows. + + The startTime is the offset with respect to 1970-01-01 00:00:00 UTC with which to start + window intervals. For example, in order to have hourly tumbling windows that start 15 minutes + past the hour, e.g. 12:15-13:15, 13:15-14:15... provide `startTime` as `15 minutes`. + + The output column will be a struct called 'window' by default with the nested columns 'start' + and 'end', where 'start' and 'end' will be of `TimestampType`. + + >>> df = sqlContext.createDataFrame([("2016-03-11 09:00:07", 1)]).toDF("date", "val") + >>> w = df.groupBy(window("date", "5 seconds")).agg(sum("val").alias("sum")) + >>> w.select(w.window.start.cast("string").alias("start"), + ... w.window.end.cast("string").alias("end"), "sum").collect() + [Row(start=u'2016-03-11 09:00:05', end=u'2016-03-11 09:00:10', sum=1)] + """ + def check_string_field(field, fieldName): + if not field or type(field) is not str: + raise TypeError("%s should be provided as a string" % fieldName) + + sc = SparkContext._active_spark_context + time_col = _to_java_column(timeColumn) + check_string_field(windowDuration, "windowDuration") + if slideDuration and startTime: + check_string_field(slideDuration, "slideDuration") + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration, startTime) + elif slideDuration: + check_string_field(slideDuration, "slideDuration") + res = sc._jvm.functions.window(time_col, windowDuration, slideDuration) + elif startTime: + check_string_field(startTime, "startTime") + res = sc._jvm.functions.window(time_col, windowDuration, windowDuration, startTime) + else: + res = sc._jvm.functions.window(time_col, windowDuration) + return Column(res) + + # ---------------------------- misc functions ---------------------------------- @since(1.5) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 7af5ffbe47..1ebdf49348 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -364,7 +364,10 @@ object FunctionRegistry { } Try(f.newInstance(expressions : _*).asInstanceOf[Expression]) match { case Success(e) => e - case Failure(e) => throw new AnalysisException(e.getMessage) + case Failure(e) => + // the exception is an invocation exception. To get a meaningful message, we need the + // cause. + throw new AnalysisException(e.getCause.getMessage) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 8e13833486..daf3de95dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.lang.StringUtils +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.TypeCheckFailure import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} @@ -34,6 +35,28 @@ case class TimeWindow( with Unevaluable with NonSQLExpression { + ////////////////////////// + // SQL Constructors + ////////////////////////// + + def this( + timeColumn: Expression, + windowDuration: Expression, + slideDuration: Expression, + startTime: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), TimeWindow.parseExpression(startTime)) + } + + def this(timeColumn: Expression, windowDuration: Expression, slideDuration: Expression) = { + this(timeColumn, TimeWindow.parseExpression(windowDuration), + TimeWindow.parseExpression(windowDuration), 0) + } + + def this(timeColumn: Expression, windowDuration: Expression) = { + this(timeColumn, windowDuration, windowDuration) + } + override def child: Expression = timeColumn override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = new StructType() @@ -104,6 +127,18 @@ object TimeWindow { cal.microseconds } + /** + * Parses the duration expression to generate the long value for the original constructor so + * that we can use `window` in SQL. + */ + private def parseExpression(expr: Expression): Long = expr match { + case NonNullLiteral(s, StringType) => getIntervalInMicroSeconds(s.toString) + case IntegerLiteral(i) => i.toLong + case NonNullLiteral(l, LongType) => l.toString.toLong + case _ => throw new AnalysisException("The duration and time inputs to window must be " + + "an integer, long or string literal.") + } + def apply( timeColumn: Expression, windowDuration: String, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala index 6b7997e903..232ca43588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala @@ -22,6 +22,7 @@ import java.util.UUID import scala.collection.Map import scala.collection.mutable.Stack +import org.apache.commons.lang.ClassUtils import org.json4s.JsonAST._ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ @@ -365,20 +366,32 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product { * @param newArgs the new product arguments. */ def makeCopy(newArgs: Array[AnyRef]): BaseType = attachTree(this, "makeCopy") { + // Skip no-arg constructors that are just there for kryo. val ctors = getClass.getConstructors.filter(_.getParameterTypes.size != 0) if (ctors.isEmpty) { sys.error(s"No valid constructor for $nodeName") } - val defaultCtor = ctors.maxBy(_.getParameterTypes.size) + val allArgs: Array[AnyRef] = if (otherCopyArgs.isEmpty) { + newArgs + } else { + newArgs ++ otherCopyArgs + } + val defaultCtor = ctors.find { ctor => + if (ctor.getParameterTypes.length != allArgs.length) { + false + } else if (allArgs.contains(null)) { + // if there is a `null`, we can't figure out the class, therefore we should just fallback + // to older heuristic + false + } else { + val argsArray: Array[Class[_]] = allArgs.map(_.getClass) + ClassUtils.isAssignable(argsArray, ctor.getParameterTypes, true /* autoboxing */) + } + }.getOrElse(ctors.maxBy(_.getParameterTypes.length)) // fall back to older heuristic try { CurrentOrigin.withOrigin(origin) { - // Skip no-arg constructors that are just there for kryo. - if (otherCopyArgs.isEmpty) { - defaultCtor.newInstance(newArgs: _*).asInstanceOf[BaseType] - } else { - defaultCtor.newInstance((newArgs ++ otherCopyArgs).toArray: _*).asInstanceOf[BaseType] - } + defaultCtor.newInstance(allArgs.toArray: _*).asInstanceOf[BaseType] } } catch { case e: java.lang.IllegalArgumentException => diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala index 71f969aee2..b82cf8d169 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/TimeWindowSuite.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.catalyst.expressions +import org.scalatest.PrivateMethodTester + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.types.LongType -class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { +class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper with PrivateMethodTester { test("time window is unevaluable") { intercept[UnsupportedOperationException] { @@ -73,4 +76,36 @@ class TimeWindowSuite extends SparkFunSuite with ExpressionEvalHelper { === seconds) } } + + private val parseExpression = PrivateMethod[Long]('parseExpression) + + test("parse sql expression for duration in microseconds - string") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal("5 seconds"))) + assert(dur.isInstanceOf[Long]) + assert(dur === 5000000) + } + + test("parse sql expression for duration in microseconds - integer") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal(100))) + assert(dur.isInstanceOf[Long]) + assert(dur === 100) + } + + test("parse sql expression for duration in microseconds - long") { + val dur = TimeWindow.invokePrivate(parseExpression(Literal.create(2 << 52, LongType))) + assert(dur.isInstanceOf[Long]) + assert(dur === (2 << 52)) + } + + test("parse sql expression for duration in microseconds - invalid interval") { + intercept[IllegalArgumentException] { + TimeWindow.invokePrivate(parseExpression(Literal("2 apples"))) + } + } + + test("parse sql expression for duration in microseconds - invalid expression") { + intercept[AnalysisException] { + TimeWindow.invokePrivate(parseExpression(Rand(123))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index da58ba2add..5bc0034cb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -2574,8 +2574,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. @@ -2629,8 +2628,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. @@ -2672,8 +2670,7 @@ object functions { * processing time. * * @param timeColumn The column or the expression to use as the timestamp for windowing by time. - * The time can be as TimestampType or LongType, however when using LongType, - * the time must be given in seconds. + * The time column must be of TimestampType. * @param windowDuration A string specifying the width of the window, e.g. `10 minutes`, * `1 second`. Check [[org.apache.spark.unsafe.types.CalendarInterval]] for * valid duration identifiers. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala index e8103a31d5..06584ec21e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTimeWindowingSuite.scala @@ -239,4 +239,61 @@ class DataFrameTimeWindowingSuite extends QueryTest with SharedSQLContext with B Row("2016-03-27 09:00:00.68", "2016-03-27 09:00:00.88", 1)) ) } + + private def withTempTable(f: String => Unit): Unit = { + val tableName = "temp" + Seq( + ("2016-03-27 19:39:34", 1), + ("2016-03-27 19:39:56", 2), + ("2016-03-27 19:39:27", 4)).toDF("time", "value").registerTempTable(tableName) + try { + f(tableName) + } finally { + sqlContext.dropTempTable(tableName) + } + } + + test("time window in SQL with single string expression") { + withTempTable { table => + checkAnswer( + sqlContext.sql(s"""select window(time, "10 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with two expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:20", "2016-03-27 19:39:30", 4), + Row("2016-03-27 19:39:30", "2016-03-27 19:39:40", 1), + Row("2016-03-27 19:39:50", "2016-03-27 19:40:00", 2) + ) + ) + } + } + + test("time window in SQL with with three expressions") { + withTempTable { table => + checkAnswer( + sqlContext.sql( + s"""select window(time, "10 seconds", 10000000, "5 seconds"), value from $table""") + .select($"window.start".cast(StringType), $"window.end".cast(StringType), $"value"), + Seq( + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 1), + Row("2016-03-27 19:39:25", "2016-03-27 19:39:35", 4), + Row("2016-03-27 19:39:55", "2016-03-27 19:40:05", 2) + ) + ) + } + } } -- cgit v1.2.3 From c59abad052b7beec4ef550049413e95578e545be Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 5 Apr 2016 13:31:00 -0700 Subject: [SPARK-14402][SQL] initcap UDF doesn't match Hive/Oracle behavior in lowercasing rest of string ## What changes were proposed in this pull request? Current, SparkSQL `initCap` is using `toTitleCase` function. However, `UTF8String.toTitleCase` implementation changes only the first letter and just copy the other letters: e.g. sParK --> SParK. This is the correct implementation `toTitleCase`. ``` hive> select initcap('sParK'); Spark ``` ``` scala> sql("select initcap('sParK')").head res0: org.apache.spark.sql.Row = [SParK] ``` This PR updates the implementation of `initcap` using `toLowerCase` and `toTitleCase`. ## How was this patch tested? Pass the Jenkins tests (including new testcase). Author: Dongjoon Hyun Closes #12175 from dongjoon-hyun/SPARK-14402. --- .../spark/sql/catalyst/expressions/stringExpressions.scala | 11 ++++++++--- .../sql/catalyst/expressions/StringExpressionsSuite.scala | 1 + .../scala/org/apache/spark/sql/StringFunctionsSuite.scala | 6 +++--- 3 files changed, 12 insertions(+), 6 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 3ee19cc4ad..b6ea03cd5c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -618,19 +618,24 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } /** - * Returns string, with the first letter of each word in uppercase. + * Returns string, with the first letter of each word in uppercase, all other letters in lowercase. * Words are delimited by whitespace. */ +@ExpressionDescription( + usage = "_FUNC_(str) - " + + "Returns str, with the first letter of each word in uppercase, all other letters in " + + "lowercase. Words are delimited by white space.", + extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'") case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(StringType) override def dataType: DataType = StringType override def nullSafeEval(string: Any): Any = { - string.asInstanceOf[UTF8String].toTitleCase + string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } override def genCode(ctx: CodegenContext, ev: ExprCode): String = { - defineCodeGen(ctx, ev, str => s"$str.toTitleCase()") + defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 99e3b13ce8..2cf8ca7000 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -382,6 +382,7 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(InitCap(Literal("a b")), "A B") checkEvaluation(InitCap(Literal(" a")), " A") checkEvaluation(InitCap(Literal("the test")), "The Test") + checkEvaluation(InitCap(Literal("sParK")), "Spark") // scalastyle:off // non ascii characters are not allowed in the code, so we disable the scalastyle here. checkEvaluation(InitCap(Literal("世界")), "世界") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index e2090b0a83..6809f26968 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -272,12 +272,12 @@ class StringFunctionsSuite extends QueryTest with SharedSQLContext { } test("initcap function") { - val df = Seq(("ab", "a B")).toDF("l", "r") + val df = Seq(("ab", "a B", "sParK")).toDF("x", "y", "z") checkAnswer( - df.select(initcap($"l"), initcap($"r")), Row("Ab", "A B")) + df.select(initcap($"x"), initcap($"y"), initcap($"z")), Row("Ab", "A B", "Spark")) checkAnswer( - df.selectExpr("InitCap(l)", "InitCap(r)"), Row("Ab", "A B")) + df.selectExpr("InitCap(x)", "InitCap(y)", "InitCap(z)"), Row("Ab", "A B", "Spark")) } test("number format function") { -- cgit v1.2.3 From 45d8cdee3945bf94d0f1bd93a12e4cb0d416468e Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Tue, 5 Apr 2016 14:54:07 -0700 Subject: [SPARK-14129][SPARK-14128][SQL] Alter table DDL commands ## What changes were proposed in this pull request? In Spark 2.0, we want to handle the most common `ALTER TABLE` commands ourselves instead of passing the entire query text to Hive. This is done using the new `SessionCatalog` API introduced recently. The commands supported in this patch include: ``` ALTER TABLE ... RENAME TO ... ALTER TABLE ... SET TBLPROPERTIES ... ALTER TABLE ... UNSET TBLPROPERTIES ... ALTER TABLE ... SET LOCATION ... ALTER TABLE ... SET SERDE ... ``` The commands we explicitly do not support are: ``` ALTER TABLE ... CLUSTERED BY ... ALTER TABLE ... SKEWED BY ... ALTER TABLE ... NOT CLUSTERED ALTER TABLE ... NOT SORTED ALTER TABLE ... NOT SKEWED ALTER TABLE ... NOT STORED AS DIRECTORIES ``` For these we throw exceptions complaining that they are not supported. ## How was this patch tested? `DDLSuite` Author: Andrew Or Closes #12121 from andrewor14/alter-table-ddl. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 12 + .../sql/catalyst/catalog/SessionCatalog.scala | 42 +++ .../spark/sql/catalyst/util/StringKeyHashMap.scala | 2 + .../spark/sql/execution/SparkSqlParser.scala | 121 ++------ .../apache/spark/sql/execution/command/ddl.scala | 216 +++++++++++--- .../sql/execution/command/DDLCommandSuite.scala | 127 +------- .../spark/sql/execution/command/DDLSuite.scala | 330 ++++++++++++++++++--- .../hive/execution/HiveCompatibilitySuite.scala | 10 +- .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 +- 9 files changed, 562 insertions(+), 300 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1ebdf49348..f239b33e44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -54,6 +54,10 @@ trait FunctionRegistry { /** Checks if a function with a given name exists. */ def functionExists(name: String): Boolean = lookupFunction(name).isDefined + + /** Clear all registered functions. */ + def clear(): Unit + } class SimpleFunctionRegistry extends FunctionRegistry { @@ -93,6 +97,10 @@ class SimpleFunctionRegistry extends FunctionRegistry { functionBuilders.remove(name).isDefined } + override def clear(): Unit = { + functionBuilders.clear() + } + def copy(): SimpleFunctionRegistry = synchronized { val registry = new SimpleFunctionRegistry functionBuilders.iterator.foreach { case (name, (info, builder)) => @@ -132,6 +140,10 @@ object EmptyFunctionRegistry extends FunctionRegistry { throw new UnsupportedOperationException } + override def clear(): Unit = { + throw new UnsupportedOperationException + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c08ffbb235..62a3b1c105 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -304,11 +304,18 @@ class SessionCatalog( dbTables ++ _tempTables } + // TODO: It's strange that we have both refresh and invalidate here. + /** * Refresh the cache entry for a metastore table, if any. */ def refreshTable(name: TableIdentifier): Unit = { /* no-op */ } + /** + * Invalidate the cache entry for a metastore table, if any. + */ + def invalidateTable(name: TableIdentifier): Unit = { /* no-op */ } + /** * Drop all existing temporary tables. * For testing only. @@ -595,6 +602,11 @@ class SessionCatalog( } } + /** + * List all functions in the specified database, including temporary functions. + */ + def listFunctions(db: String): Seq[FunctionIdentifier] = listFunctions(db, "*") + /** * List all matching functions in the specified database, including temporary functions. */ @@ -609,4 +621,34 @@ class SessionCatalog( // So, the returned list may have two entries for the same function. dbFunctions ++ loadedFunctions } + + + // ----------------- + // | Other methods | + // ----------------- + + /** + * Drop all existing databases (except "default") along with all associated tables, + * partitions and functions, and set the current database to "default". + * + * This is mainly used for tests. + */ + private[sql] def reset(): Unit = { + val default = "default" + listDatabases().filter(_ != default).foreach { db => + dropDatabase(db, ignoreIfNotExists = false, cascade = true) + } + tempTables.clear() + functionRegistry.clear() + // restore built-in functions + FunctionRegistry.builtin.listFunction().foreach { f => + val expressionInfo = FunctionRegistry.builtin.lookupFunction(f) + val functionBuilder = FunctionRegistry.builtin.lookupFunctionBuilder(f) + require(expressionInfo.isDefined, s"built-in function '$f' is missing expression info") + require(functionBuilder.isDefined, s"built-in function '$f' is missing function builder") + functionRegistry.registerFunction(f, expressionInfo.get, functionBuilder.get) + } + setCurrentDatabase(default) + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala index 191d5e6399..d5d151a580 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringKeyHashMap.scala @@ -41,4 +41,6 @@ class StringKeyHashMap[T](normalizer: (String) => String) { def remove(key: String): Option[T] = base.remove(normalizer(key)) def iterator: Iterator[(String, T)] = base.toIterator + + def clear(): Unit = base.clear() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 382cc61fac..d3086fc91e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution import scala.collection.JavaConverters._ -import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.{AbstractSqlParser, AstBuilder, ParseException} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ @@ -378,8 +378,7 @@ class SparkSqlAstBuilder extends AstBuilder { override def visitRenameTable(ctx: RenameTableContext): LogicalPlan = withOrigin(ctx) { AlterTableRename( visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to))( - command(ctx)) + visitTableIdentifier(ctx.to)) } /** @@ -395,8 +394,7 @@ class SparkSqlAstBuilder extends AstBuilder { ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableSetProperties( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList))( - command(ctx)) + visitTablePropertyList(ctx.tablePropertyList)) } /** @@ -404,17 +402,16 @@ class SparkSqlAstBuilder extends AstBuilder { * * For example: * {{{ - * ALTER TABLE table UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); - * ALTER VIEW view UNSET TBLPROPERTIES IF EXISTS ('comment', 'key'); + * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); * }}} */ override def visitUnsetTableProperties( ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { AlterTableUnsetProperties( visitTableIdentifier(ctx.tableIdentifier), - visitTablePropertyList(ctx.tablePropertyList), - ctx.EXISTS != null)( - command(ctx)) + visitTablePropertyList(ctx.tablePropertyList).keys.toSeq, + ctx.EXISTS != null) } /** @@ -432,116 +429,41 @@ class SparkSqlAstBuilder extends AstBuilder { Option(ctx.STRING).map(string), Option(ctx.tablePropertyList).map(visitTablePropertyList), // TODO a partition spec is allowed to have optional values. This is currently violated. - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec)) } - /** - * Create an [[AlterTableStorageProperties]] command. - * - * For example: - * {{{ - * ALTER TABLE table CLUSTERED BY (col, ...) [SORTED BY (col, ...)] INTO n BUCKETS; - * }}} - */ + // TODO: don't even bother parsing alter table commands related to bucketing and skewing + override def visitBucketTable(ctx: BucketTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableStorageProperties( - visitTableIdentifier(ctx.tableIdentifier), - visitBucketSpec(ctx.bucketSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... CLUSTERED BY ... INTO N BUCKETS") } - /** - * Create an [[AlterTableNotClustered]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT CLUSTERED; - * }}} - */ override def visitUnclusterTable(ctx: UnclusterTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotClustered(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT CLUSTERED") } - /** - * Create an [[AlterTableNotSorted]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT SORTED; - * }}} - */ override def visitUnsortTable(ctx: UnsortTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotSorted(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SORTED") } - /** - * Create an [[AlterTableSkewed]] command. - * - * For example: - * {{{ - * ALTER TABLE table SKEWED BY (col1, col2) - * ON ((col1_value, col2_value) [, (col1_value, col2_value), ...]) - * [STORED AS DIRECTORIES]; - * }}} - */ override def visitSkewTable(ctx: SkewTableContext): LogicalPlan = withOrigin(ctx) { - val table = visitTableIdentifier(ctx.tableIdentifier) - val (cols, values, storedAsDirs) = visitSkewSpec(ctx.skewSpec) - AlterTableSkewed(table, cols, values, storedAsDirs)(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... SKEWED BY ...") } - /** - * Create an [[AlterTableNotSorted]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT SKEWED; - * }}} - */ override def visitUnskewTable(ctx: UnskewTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotSkewed(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... NOT SKEWED") } - /** - * Create an [[AlterTableNotStoredAsDirs]] command. - * - * For example: - * {{{ - * ALTER TABLE table NOT STORED AS DIRECTORIES - * }}} - */ override def visitUnstoreTable(ctx: UnstoreTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableNotStoredAsDirs(visitTableIdentifier(ctx.tableIdentifier))(command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... NOT STORED AS DIRECTORIES") } - /** - * Create an [[AlterTableSkewedLocation]] command. - * - * For example: - * {{{ - * ALTER TABLE table SET SKEWED LOCATION (col1="loc1" [, (col2, col3)="loc2", ...] ); - * }}} - */ override def visitSetTableSkewLocations( ctx: SetTableSkewLocationsContext): LogicalPlan = withOrigin(ctx) { - val skewedMap = ctx.skewedLocationList.skewedLocation.asScala.flatMap { - slCtx => - val location = string(slCtx.STRING) - if (slCtx.constant != null) { - Seq(visitStringConstant(slCtx.constant) -> location) - } else { - // TODO this is similar to what was in the original implementation. However this does not - // make to much sense to me since we should be storing a tuple of values (not column - // names) for which we want a dedicated storage location. - visitConstantList(slCtx.constantList).map(_ -> location) - } - }.toMap - - AlterTableSkewedLocation( - visitTableIdentifier(ctx.tableIdentifier), - skewedMap)( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... SET SKEWED LOCATION ...") } /** @@ -703,8 +625,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableSetLocation( visitTableIdentifier(ctx.tableIdentifier), Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - visitLocationSpec(ctx.locationSpec))( - command(ctx)) + visitLocationSpec(ctx.locationSpec)) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6896881910..0d38c41a3f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -18,12 +18,11 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ @@ -175,67 +174,133 @@ case class DescribeDatabase( } } -/** Rename in ALTER TABLE/VIEW: change the name of a table/view to a different name. */ +/** + * A command that renames a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ case class AlterTableRename( oldName: TableIdentifier, - newName: TableIdentifier)(sql: String) - extends NativeDDLCommand(sql) with Logging + newName: TableIdentifier) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.invalidateTable(oldName) + catalog.renameTable(oldName, newName) + Seq.empty[Row] + } -/** Set Properties in ALTER TABLE/VIEW: add metadata to a table/view. */ +} + +/** + * A command that sets table/view properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); + * ALTER VIEW view1 SET TBLPROPERTIES ('key1' = 'val1', 'key2' = 'val2', ...); + * }}} + */ case class AlterTableSetProperties( tableName: TableIdentifier, - properties: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging + properties: Map[String, String]) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + val newProperties = table.properties ++ properties + if (DDLUtils.isDatasourceTable(newProperties)) { + throw new AnalysisException( + "alter table properties is not supported for tables defined using the datasource API") + } + val newTable = table.copy(properties = newProperties) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} -/** Unset Properties in ALTER TABLE/VIEW: remove metadata from a table/view. */ +/** + * A command that unsets table/view properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); + * ALTER VIEW view1 UNSET TBLPROPERTIES [IF EXISTS] ('key1', 'key2', ...); + * }}} + */ case class AlterTableUnsetProperties( tableName: TableIdentifier, - properties: Map[String, String], - ifExists: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + propKeys: Seq[String], + ifExists: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table properties is not supported for datasource tables") + } + if (!ifExists) { + propKeys.foreach { k => + if (!table.properties.contains(k)) { + throw new AnalysisException( + s"attempted to unset non-existent property '$k' in table '$tableName'") + } + } + } + val newProperties = table.properties.filter { case (k, _) => !propKeys.contains(k) } + val newTable = table.copy(properties = newProperties) + catalog.alterTable(newTable) + Seq.empty[Row] + } + +} +/** + * A command that sets the serde class and/or serde properties of a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table [PARTITION spec] SET SERDE serde_name [WITH SERDEPROPERTIES props]; + * ALTER TABLE table [PARTITION spec] SET SERDEPROPERTIES serde_properties; + * }}} + */ case class AlterTableSerDeProperties( tableName: TableIdentifier, serdeClassName: Option[String], serdeProperties: Option[Map[String, String]], - partition: Option[Map[String, String]])(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableStorageProperties( - tableName: TableIdentifier, - buckets: BucketSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + partition: Option[Map[String, String]]) + extends RunnableCommand { -case class AlterTableNotClustered( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging + // should never happen if we parsed things correctly + require(serdeClassName.isDefined || serdeProperties.isDefined, + "alter table attempted to set neither serde class name nor serde properties") -case class AlterTableNotSorted( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + // Do not support setting serde for datasource tables + if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table serde is not supported for datasource tables") + } + val newTable = table.withNewStorage( + serde = serdeClassName.orElse(table.storage.serde), + serdeProperties = table.storage.serdeProperties ++ serdeProperties.getOrElse(Map())) + catalog.alterTable(newTable) + Seq.empty[Row] + } -case class AlterTableSkewed( - tableName: TableIdentifier, - // e.g. (dt, country) - skewedCols: Seq[String], - // e.g. ('2008-08-08', 'us), ('2009-09-09', 'uk') - skewedValues: Seq[Seq[String]], - storedAsDirs: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging { - - require(skewedValues.forall(_.size == skewedCols.size), - "number of columns in skewed values do not match number of skewed columns provided") } -case class AlterTableNotSkewed( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging - -case class AlterTableNotStoredAsDirs( - tableName: TableIdentifier)(sql: String) extends NativeDDLCommand(sql) with Logging - -case class AlterTableSkewedLocation( - tableName: TableIdentifier, - skewedMap: Map[String, String])(sql: String) - extends NativeDDLCommand(sql) with Logging - /** * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, @@ -292,11 +357,53 @@ case class AlterTableSetFileFormat( genericFormat: Option[String])(sql: String) extends NativeDDLCommand(sql) with Logging +/** + * A command that sets the location of a table or a partition. + * + * For normal tables, this just sets the location URI in the table/partition's storage format. + * For datasource tables, this sets a "path" parameter in the table/partition's serde properties. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table_name [PARTITION partition_spec] SET LOCATION "loc"; + * }}} + */ case class AlterTableSetLocation( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], - location: String)(sql: String) - extends NativeDDLCommand(sql) with Logging + location: String) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTable(tableName) + partitionSpec match { + case Some(spec) => + // Partition spec is specified, so we set the location only for this partition + val part = catalog.getPartition(tableName, spec) + val newPart = + if (DDLUtils.isDatasourceTable(table)) { + part.copy(storage = part.storage.copy( + serdeProperties = part.storage.serdeProperties ++ Map("path" -> location))) + } else { + part.copy(storage = part.storage.copy(locationUri = Some(location))) + } + catalog.alterPartitions(tableName, Seq(newPart)) + case None => + // No partition spec is specified, so we set the location for the table itself + val newTable = + if (DDLUtils.isDatasourceTable(table)) { + table.withNewStorage( + serdeProperties = table.storage.serdeProperties ++ Map("path" -> location)) + } else { + table.withNewStorage(locationUri = Some(location)) + } + catalog.alterTable(newTable) + } + Seq.empty[Row] + } + +} case class AlterTableTouch( tableName: TableIdentifier, @@ -341,3 +448,16 @@ case class AlterTableReplaceCol( restrict: Boolean, cascade: Boolean)(sql: String) extends NativeDDLCommand(sql) with Logging + + +private object DDLUtils { + + def isDatasourceTable(props: Map[String, String]): Boolean = { + props.contains("spark.sql.sources.provider") + } + + def isDatasourceTable(table: CatalogTable): Boolean = { + isDatasourceTable(table.properties) + } +} + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index c42e8e7233..618c9a58a6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -205,10 +205,10 @@ class DDLCommandSuite extends PlanTest { val parsed_view = parser.parsePlan(sql_view) val expected_table = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None))(sql_table) + TableIdentifier("new_table_name", None)) val expected_view = AlterTableRename( TableIdentifier("table_name", None), - TableIdentifier("new_table_name", None))(sql_view) + TableIdentifier("new_table_name", None)) comparePlans(parsed_table, expected_table) comparePlans(parsed_view, expected_view) } @@ -235,14 +235,14 @@ class DDLCommandSuite extends PlanTest { val tableIdent = TableIdentifier("table_name", None) val expected1_table = AlterTableSetProperties( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"))(sql1_table) + tableIdent, Map("test" -> "test", "comment" -> "new_comment")) val expected2_table = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = false)(sql2_table) + tableIdent, Seq("comment", "test"), ifExists = false) val expected3_table = AlterTableUnsetProperties( - tableIdent, Map("comment" -> null, "test" -> null), ifExists = true)(sql3_table) - val expected1_view = expected1_table.copy()(sql = sql1_view) - val expected2_view = expected2_table.copy()(sql = sql2_view) - val expected3_view = expected3_table.copy()(sql = sql3_view) + tableIdent, Seq("comment", "test"), ifExists = true) + val expected1_view = expected1_table + val expected2_view = expected2_table + val expected3_view = expected3_table comparePlans(parsed1_table, expected1_table) comparePlans(parsed2_table, expected2_table) @@ -282,97 +282,24 @@ class DDLCommandSuite extends PlanTest { val parsed5 = parser.parsePlan(sql5) val tableIdent = TableIdentifier("table_name", None) val expected1 = AlterTableSerDeProperties( - tableIdent, Some("org.apache.class"), None, None)(sql1) + tableIdent, Some("org.apache.class"), None, None) val expected2 = AlterTableSerDeProperties( tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - None)(sql2) + None) val expected3 = AlterTableSerDeProperties( - tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None)(sql3) + tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), None) val expected4 = AlterTableSerDeProperties( tableIdent, Some("org.apache.class"), Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))(sql4) + Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) val expected5 = AlterTableSerDeProperties( tableIdent, None, Some(Map("columns" -> "foo,bar", "field.delim" -> ",")), - Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us")))(sql5) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - comparePlans(parsed5, expected5) - } - - test("alter table: storage properties") { - val sql1 = "ALTER TABLE table_name CLUSTERED BY (dt, country) INTO 10 BUCKETS" - val sql2 = "ALTER TABLE table_name CLUSTERED BY (dt, country) SORTED BY " + - "(dt, country DESC) INTO 10 BUCKETS" - val sql3 = "ALTER TABLE table_name NOT CLUSTERED" - val sql4 = "ALTER TABLE table_name NOT SORTED" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) - val parsed4 = parser.parsePlan(sql4) - val tableIdent = TableIdentifier("table_name", None) - val cols = List("dt", "country") - // TODO: also test the sort directions once we keep track of that - val expected1 = AlterTableStorageProperties( - tableIdent, BucketSpec(10, cols, Nil))(sql1) - val expected2 = AlterTableStorageProperties( - tableIdent, BucketSpec(10, cols, cols))(sql2) - val expected3 = AlterTableNotClustered(tableIdent)(sql3) - val expected4 = AlterTableNotSorted(tableIdent)(sql4) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - } - - test("alter table: skewed") { - val sql1 = - """ - |ALTER TABLE table_name SKEWED BY (dt, country) ON - |(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn')) STORED AS DIRECTORIES - """.stripMargin - val sql2 = - """ - |ALTER TABLE table_name SKEWED BY (dt, country) ON - |('2008-08-08', 'us') STORED AS DIRECTORIES - """.stripMargin - val sql3 = - """ - |ALTER TABLE table_name SKEWED BY (dt, country) ON - |(('2008-08-08', 'us'), ('2009-09-09', 'uk')) - """.stripMargin - val sql4 = "ALTER TABLE table_name NOT SKEWED" - val sql5 = "ALTER TABLE table_name NOT STORED AS DIRECTORIES" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) - val parsed4 = parser.parsePlan(sql4) - val parsed5 = parser.parsePlan(sql5) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSkewed( - tableIdent, - Seq("dt", "country"), - Seq(List("2008-08-08", "us"), List("2009-09-09", "uk"), List("2010-10-10", "cn")), - storedAsDirs = true)(sql1) - val expected2 = AlterTableSkewed( - tableIdent, - Seq("dt", "country"), - Seq(List("2008-08-08", "us")), - storedAsDirs = true)(sql2) - val expected3 = AlterTableSkewed( - tableIdent, - Seq("dt", "country"), - Seq(List("2008-08-08", "us"), List("2009-09-09", "uk")), - storedAsDirs = false)(sql3) - val expected4 = AlterTableNotSkewed(tableIdent)(sql4) - val expected5 = AlterTableNotStoredAsDirs(tableIdent)(sql5) + Some(Map("test" -> null, "dt" -> "2008-08-08", "country" -> "us"))) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) comparePlans(parsed3, expected3) @@ -380,30 +307,6 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed5, expected5) } - test("alter table: skewed location") { - val sql1 = - """ - |ALTER TABLE table_name SET SKEWED LOCATION - |('123'='location1', 'test'='location2') - """.stripMargin - val sql2 = - """ - |ALTER TABLE table_name SET SKEWED LOCATION - |(('2008-08-08', 'us')='location1', 'test'='location2') - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSkewedLocation( - tableIdent, - Map("123" -> "location1", "test" -> "location2"))(sql1) - val expected2 = AlterTableSkewedLocation( - tableIdent, - Map("2008-08-08" -> "location1", "us" -> "location1", "test" -> "location2"))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - } - // ALTER TABLE table_name ADD [IF NOT EXISTS] PARTITION partition_spec // [LOCATION 'location1'] partition_spec [LOCATION 'location2'] ...; test("alter table: add partition") { @@ -615,11 +518,11 @@ class DDLCommandSuite extends PlanTest { val expected1 = AlterTableSetLocation( tableIdent, None, - "new location")(sql1) + "new location") val expected2 = AlterTableSetLocation( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), - "new location")(sql2) + "new location") comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 885a04af59..d8e2c94a8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -19,34 +19,63 @@ package org.apache.spark.sql.execution.command import java.io.File +import org.scalatest.BeforeAndAfterEach + import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.apache.spark.sql.catalyst.catalog.CatalogDatabase +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogStorageFormat} +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, SessionCatalog} +import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.test.SharedSQLContext -class DDLSuite extends QueryTest with SharedSQLContext { - +class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private val escapedIdentifier = "`(.+)`".r + override def afterEach(): Unit = { + try { + // drop all databases, tables and functions after each test + sqlContext.sessionState.catalog.reset() + } finally { + super.afterEach() + } + } + /** * Strip backticks, if any, from the string. */ - def cleanIdentifier(ident: String): String = { + private def cleanIdentifier(ident: String): String = { ident match { case escapedIdentifier(i) => i case plainIdent => plainIdent } } - /** - * Drops database `databaseName` after calling `f`. - */ - private def withDatabase(dbNames: String*)(f: => Unit): Unit = { - try f finally { - dbNames.foreach { name => - sqlContext.sql(s"DROP DATABASE IF EXISTS $name CASCADE") - } - sqlContext.sessionState.catalog.setCurrentDatabase("default") + private def assertUnsupported(query: String): Unit = { + val e = intercept[AnalysisException] { + sql(query) } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + + private def createDatabase(catalog: SessionCatalog, name: String): Unit = { + catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false) + } + + private def createTable(catalog: SessionCatalog, name: TableIdentifier): Unit = { + catalog.createTable(CatalogTable( + identifier = name, + tableType = CatalogTableType.EXTERNAL_TABLE, + storage = CatalogStorageFormat(None, None, None, None, Map()), + schema = Seq()), ignoreIfExists = false) + } + + private def createTablePartition( + catalog: SessionCatalog, + spec: TablePartitionSpec, + tableName: TableIdentifier): Unit = { + val part = CatalogTablePartition(spec, CatalogStorageFormat(None, None, None, None, Map())) + catalog.createPartitions(tableName, Seq(part), ignoreIfExists = false) } test("Create/Drop Database") { @@ -55,7 +84,7 @@ class DDLSuite extends QueryTest with SharedSQLContext { val databaseNames = Seq("db1", "`database`") databaseNames.foreach { dbName => - withDatabase(dbName) { + try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") @@ -67,6 +96,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { Map.empty)) sql(s"DROP DATABASE $dbName CASCADE") assert(!catalog.databaseExists(dbNameWithoutBackTicks)) + } finally { + catalog.reset() } } } @@ -76,8 +107,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { val databaseNames = Seq("db1", "`database`") databaseNames.foreach { dbName => - val dbNameWithoutBackTicks = cleanIdentifier(dbName) - withDatabase(dbName) { + try { + val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") val db1 = catalog.getDatabase(dbNameWithoutBackTicks) assert(db1 == CatalogDatabase( @@ -90,6 +121,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { sql(s"CREATE DATABASE $dbName") }.getMessage assert(message.contains(s"Database '$dbNameWithoutBackTicks' already exists.")) + } finally { + catalog.reset() } } } @@ -99,7 +132,7 @@ class DDLSuite extends QueryTest with SharedSQLContext { val databaseNames = Seq("db1", "`database`") databaseNames.foreach { dbName => - withDatabase(dbName) { + try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) val location = System.getProperty("java.io.tmpdir") + File.separator + s"$dbNameWithoutBackTicks.db" @@ -129,6 +162,8 @@ class DDLSuite extends QueryTest with SharedSQLContext { Row("Description", "") :: Row("Location", location) :: Row("Properties", "((a,a), (b,b), (c,c), (d,d))") :: Nil) + } finally { + catalog.reset() } } } @@ -159,6 +194,131 @@ class DDLSuite extends QueryTest with SharedSQLContext { } } + // TODO: test drop database in restrict mode + + test("alter table: rename") { + val catalog = sqlContext.sessionState.catalog + val tableIdent1 = TableIdentifier("tab1", Some("dbx")) + val tableIdent2 = TableIdentifier("tab2", Some("dbx")) + createDatabase(catalog, "dbx") + createDatabase(catalog, "dby") + createTable(catalog, tableIdent1) + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + sql("ALTER TABLE dbx.tab1 RENAME TO dbx.tab2") + assert(catalog.listTables("dbx") == Seq(tableIdent2)) + catalog.setCurrentDatabase("dbx") + // rename without explicitly specifying database + sql("ALTER TABLE tab2 RENAME TO tab1") + assert(catalog.listTables("dbx") == Seq(tableIdent1)) + // table to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.does_not_exist RENAME TO dbx.tab2") + } + // destination database is different + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 RENAME TO dby.tab2") + } + } + + test("alter table: set location") { + testSetLocation(isDatasourceTable = false) + } + + test("alter table: set location (datasource table)") { + testSetLocation(isDatasourceTable = true) + } + + test("alter table: set properties") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.getTable(tableIdent).properties.isEmpty) + // set table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") + assert(catalog.getTable(tableIdent).properties == Map("andrew" -> "or14", "kor" -> "bel")) + // set table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") + assert(catalog.getTable(tableIdent).properties == + Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET TBLPROPERTIES ('winner' = 'loser')") + } + // throw exception for datasource tables + convertToDatasourceTable(catalog, tableIdent) + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 SET TBLPROPERTIES ('sora' = 'bol')") + } + assert(e.getMessage.contains("datasource")) + } + + test("alter table: unset properties") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + // unset table properties + sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')") + sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") + assert(catalog.getTable(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) + // unset table properties without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") + assert(catalog.getTable(tableIdent).properties == Map("c" -> "lan")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") + } + // property to unset does not exist + val e = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('c', 'xyz')") + } + assert(e.getMessage.contains("xyz")) + // property to unset does not exist, but "IF EXISTS" is specified + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") + assert(catalog.getTable(tableIdent).properties.isEmpty) + // throw exception for datasource tables + convertToDatasourceTable(catalog, tableIdent) + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('sora')") + } + assert(e1.getMessage.contains("datasource")) + } + + test("alter table: set serde") { + testSetSerde(isDatasourceTable = false) + } + + test("alter table: set serde (datasource table)") { + testSetSerde(isDatasourceTable = true) + } + + test("alter table: bucketing is not supported") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (blood, lemon, grape) INTO 11 BUCKETS") + assertUnsupported("ALTER TABLE dbx.tab1 CLUSTERED BY (fuji) SORTED BY (grape) INTO 5 BUCKETS") + assertUnsupported("ALTER TABLE dbx.tab1 NOT CLUSTERED") + assertUnsupported("ALTER TABLE dbx.tab1 NOT SORTED") + } + + test("alter table: skew is not supported") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " + + "(('2008-08-08', 'us'), ('2009-09-09', 'uk'), ('2010-10-10', 'cn'))") + assertUnsupported("ALTER TABLE dbx.tab1 SKEWED BY (dt, country) ON " + + "(('2008-08-08', 'us'), ('2009-09-09', 'uk')) STORED AS DIRECTORIES") + assertUnsupported("ALTER TABLE dbx.tab1 NOT SKEWED") + assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") + } + // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext test("show tables") { @@ -206,29 +366,129 @@ class DDLSuite extends QueryTest with SharedSQLContext { } test("show databases") { - withDatabase("showdb1A", "showdb2B") { - sql("CREATE DATABASE showdb1A") - sql("CREATE DATABASE showdb2B") + sql("CREATE DATABASE showdb1A") + sql("CREATE DATABASE showdb2B") - assert( - sql("SHOW DATABASES").count() >= 2) + assert( + sql("SHOW DATABASES").count() >= 2) - checkAnswer( - sql("SHOW DATABASES LIKE '*db1A'"), - Row("showdb1A") :: Nil) + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A'"), + Row("showdb1A") :: Nil) - checkAnswer( - sql("SHOW DATABASES LIKE 'showdb1A'"), - Row("showdb1A") :: Nil) + checkAnswer( + sql("SHOW DATABASES LIKE 'showdb1A'"), + Row("showdb1A") :: Nil) - checkAnswer( - sql("SHOW DATABASES LIKE '*db1A|*db2B'"), - Row("showdb1A") :: - Row("showdb2B") :: Nil) + checkAnswer( + sql("SHOW DATABASES LIKE '*db1A|*db2B'"), + Row("showdb1A") :: + Row("showdb2B") :: Nil) - checkAnswer( - sql("SHOW DATABASES LIKE 'non-existentdb'"), - Nil) + checkAnswer( + sql("SHOW DATABASES LIKE 'non-existentdb'"), + Nil) + } + + private def convertToDatasourceTable( + catalog: SessionCatalog, + tableIdent: TableIdentifier): Unit = { + catalog.alterTable(catalog.getTable(tableIdent).copy( + properties = Map("spark.sql.sources.provider" -> "csv"))) + } + + private def testSetLocation(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val partSpec = Map("a" -> "1") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, partSpec, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getTable(tableIdent).storage.locationUri.isEmpty) + assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, partSpec).storage.serdeProperties.isEmpty) + // Verify that the location is set to the expected string + def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { + val storageFormat = spec + .map { s => catalog.getPartition(tableIdent, s).storage } + .getOrElse { catalog.getTable(tableIdent).storage } + if (isDatasourceTable) { + assert(storageFormat.serdeProperties.get("path") === Some(expected)) + } else { + assert(storageFormat.locationUri === Some(expected)) + } + } + // set table location + sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") + verifyLocation("/path/to/your/lovely/heart") + // set table partition location + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') SET LOCATION '/path/to/part/ways'") + verifyLocation("/path/to/part/ways", Some(partSpec)) + // set table location without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET LOCATION '/swanky/steak/place'") + verifyLocation("/swanky/steak/place") + // set table partition location without explicitly specifying database + sql("ALTER TABLE tab1 PARTITION (a='1') SET LOCATION 'vienna'") + verifyLocation("vienna", Some(partSpec)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.does_not_exist SET LOCATION '/mister/spark'") + } + // partition to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') SET LOCATION '/mister/spark'") } } + + private def testSetSerde(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.getTable(tableIdent).storage.serde.isEmpty) + assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + // set table serde and/or properties (should fail on datasource tables) + if (isDatasourceTable) { + val e1 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 SET SERDE 'whatever'") + } + val e2 = intercept[AnalysisException] { + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + } + assert(e1.getMessage.contains("datasource")) + assert(e2.getMessage.contains("datasource")) + } else { + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") + assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.jadoop")) + assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") + assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.madoop")) + assert(catalog.getTable(tableIdent).storage.serdeProperties == + Map("k" -> "v", "kay" -> "vee")) + } + // set serde properties only + sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") + assert(catalog.getTable(tableIdent).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "vee")) + // set things without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") + assert(catalog.getTable(tableIdent).storage.serdeProperties == + Map("k" -> "vvv", "kay" -> "veee")) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist SET SERDEPROPERTIES ('x' = 'y')") + } + } + } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 4b4f88ece0..b01f556f0a 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -360,6 +360,12 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_create_table_serde", "show_create_table_view", + // These tests try to change how a table is bucketed, which we don't support + "alter4", + "sort_merge_join_desc_5", + "sort_merge_join_desc_6", + "sort_merge_join_desc_7", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", @@ -381,7 +387,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alias_casted_column", "alter2", "alter3", - "alter4", "alter5", "alter_merge_2", "alter_partition_format_loc", @@ -880,9 +885,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "sort_merge_join_desc_2", "sort_merge_join_desc_3", "sort_merge_join_desc_4", - "sort_merge_join_desc_5", - "sort_merge_join_desc_6", - "sort_merge_join_desc_7", "stats0", "stats_aggregator_error_1", "stats_empty_partition", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index d315f39a91..0cccc22e5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -94,7 +94,7 @@ private[sql] class HiveSessionCatalog( metastoreCatalog.refreshTable(name) } - def invalidateTable(name: TableIdentifier): Unit = { + override def invalidateTable(name: TableIdentifier): Unit = { metastoreCatalog.invalidateTable(name) } -- cgit v1.2.3 From f6456fa80ba442bfd7ce069fc23b7dbd993e6cb9 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 6 Apr 2016 12:09:10 +0800 Subject: [SPARK-14296][SQL] whole stage codegen support for Dataset.map ## What changes were proposed in this pull request? This PR adds a new operator `MapElements` for `Dataset.map`, it's a 1-1 mapping and is easier to adapt to whole stage codegen framework. ## How was this patch tested? new test in `WholeStageCodegenSuite` Author: Wenchen Fan Closes #12087 from cloud-fan/map. --- .../spark/sql/catalyst/analysis/unresolved.scala | 2 +- .../spark/sql/catalyst/expressions/objects.scala | 40 ++++++---- .../spark/sql/catalyst/optimizer/Optimizer.scala | 9 +++ .../spark/sql/catalyst/plans/logical/object.scala | 28 ++++++- .../main/scala/org/apache/spark/sql/Dataset.scala | 22 +++--- .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/WholeStageCodegen.scala | 11 ++- .../org/apache/spark/sql/execution/objects.scala | 69 ++++++++++++++++- .../org/apache/spark/sql/DatasetBenchmark.scala | 86 ++++++++++++++++++++++ .../scala/org/apache/spark/sql/QueryTest.scala | 5 +- .../sql/execution/WholeStageCodegenSuite.scala | 14 +++- 11 files changed, 247 insertions(+), 41 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index b2f362b6b8..4ec43aba02 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -345,7 +345,7 @@ case class UnresolvedAlias(child: Expression, aliasName: Option[String] = None) * @param inputAttributes The input attributes used to resolve deserializer expression, can be empty * if we want to resolve deserializer by children output. */ -case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute]) +case class UnresolvedDeserializer(deserializer: Expression, inputAttributes: Seq[Attribute] = Nil) extends UnaryExpression with Unevaluable with NonSQLExpression { // The input attributes used to resolve deserializer expression must be all resolved. require(inputAttributes.forall(_.resolved), "Input attributes must all be resolved.") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index eebd43dae9..a0490e1351 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -119,18 +119,18 @@ case class Invoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - lazy val method = targetObject.dataType match { + @transient lazy val method = targetObject.dataType match { case ObjectType(cls) => - cls - .getMethods - .find(_.getName == functionName) - .getOrElse(sys.error(s"Couldn't find $functionName on $cls")) - .getReturnType - .getName - case _ => "" + val m = cls.getMethods.find(_.getName == functionName) + if (m.isEmpty) { + sys.error(s"Couldn't find $functionName on $cls") + } else { + m + } + case _ => None } - lazy val unboxer = (dataType, method) match { + lazy val unboxer = (dataType, method.map(_.getReturnType.getName).getOrElse("")) match { case (IntegerType, "java.lang.Object") => (s: String) => s"((java.lang.Integer)$s).intValue()" case (LongType, "java.lang.Object") => (s: String) => @@ -157,21 +157,31 @@ case class Invoke( // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val objNullCheck = if (ctx.defaultValue(dataType) == "null") { - s"${ev.isNull} = ${ev.value} == null;" + s"boolean ${ev.isNull} = ${ev.value} == null;" } else { + ev.isNull = obj.isNull "" } val value = unboxer(s"${obj.value}.$functionName($argString)") + val evaluate = if (method.forall(_.getExceptionTypes.isEmpty)) { + s"$javaType ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType) $value;" + } else { + s""" + $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; + try { + ${ev.value} = ${obj.isNull} ? ${ctx.defaultValue(javaType)} : ($javaType) $value; + } catch (Exception e) { + org.apache.spark.unsafe.Platform.throwException(e); + } + """ + } + s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} - - boolean ${ev.isNull} = ${obj.isNull}; - $javaType ${ev.value} = - ${ev.isNull} ? - ${ctx.defaultValue(dataType)} : ($javaType) $value; + $evaluate $objNullCheck """ } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 69b09bcb35..c085a377ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -136,6 +136,7 @@ object SamplePushDown extends Rule[LogicalPlan] { * representation of data item. For example back to back map operations. */ object EliminateSerialization extends Rule[LogicalPlan] { + // TODO: find a more general way to do this optimization. def apply(plan: LogicalPlan): LogicalPlan = plan transform { case m @ MapPartitions(_, deserializer, _, child: ObjectOperator) if !deserializer.isInstanceOf[Attribute] && @@ -144,6 +145,14 @@ object EliminateSerialization extends Rule[LogicalPlan] { m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case m @ MapElements(_, deserializer, _, child: ObjectOperator) + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => + val childWithoutSerialization = child.withObjectOutput + m.copy( + deserializer = childWithoutSerialization.output.head, + child = childWithoutSerialization) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index 58313c7b72..ec33a538a9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -65,7 +65,7 @@ object MapPartitions { child: LogicalPlan): MapPartitions = { MapPartitions( func.asInstanceOf[Iterator[Any] => Iterator[Any]], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } @@ -83,6 +83,30 @@ case class MapPartitions( serializer: Seq[NamedExpression], child: LogicalPlan) extends UnaryNode with ObjectOperator +object MapElements { + def apply[T : Encoder, U : Encoder]( + func: AnyRef, + child: LogicalPlan): MapElements = { + MapElements( + func, + UnresolvedDeserializer(encoderFor[T].deserializer), + encoderFor[U].namedExpressions, + child) + } +} + +/** + * A relation produced by applying `func` to each element of the `child`. + * + * @param deserializer used to extract the input to `func` from an input row. + * @param serializer use to serialize the output of `func`. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode with ObjectOperator + /** Factory for constructing new `AppendColumn` nodes. */ object AppendColumns { def apply[T : Encoder, U : Encoder]( @@ -90,7 +114,7 @@ object AppendColumns { child: LogicalPlan): AppendColumns = { new AppendColumns( func.asInstanceOf[Any => Any], - UnresolvedDeserializer(encoderFor[T].deserializer, Nil), + UnresolvedDeserializer(encoderFor[T].deserializer), encoderFor[U].namedExpressions, child) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index f472a5068e..2854d5f9da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -766,7 +766,8 @@ class Dataset[T] private[sql]( implicit val tuple2Encoder: Encoder[(T, U)] = ExpressionEncoder.tuple(this.unresolvedTEncoder, other.unresolvedTEncoder) - withTypedPlan[(T, U)](other, encoderFor[(T, U)]) { (left, right) => + + withTypedPlan { Project( leftData :: rightData :: Nil, joined.analyzed) @@ -1900,7 +1901,9 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U : Encoder](func: T => U): Dataset[U] = mapPartitions(_.map(func)) + def map[U : Encoder](func: T => U): Dataset[U] = withTypedPlan { + MapElements[T, U](func, logicalPlan) + } /** * :: Experimental :: @@ -1911,8 +1914,10 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = - map(t => func.call(t))(encoder) + def map[U](func: MapFunction[T, U], encoder: Encoder[U]): Dataset[U] = { + implicit val uEnc = encoder + withTypedPlan(MapElements[T, U](func, logicalPlan)) + } /** * :: Experimental :: @@ -2412,12 +2417,7 @@ class Dataset[T] private[sql]( } /** A convenient function to wrap a logical plan and produce a Dataset. */ - @inline private def withTypedPlan(logicalPlan: => LogicalPlan): Dataset[T] = { - new Dataset[T](sqlContext, logicalPlan, encoder) + @inline private def withTypedPlan[U : Encoder](logicalPlan: => LogicalPlan): Dataset[U] = { + Dataset(sqlContext, logicalPlan) } - - private[sql] def withTypedPlan[R]( - other: Dataset[_], encoder: Encoder[R])( - f: (LogicalPlan, LogicalPlan) => LogicalPlan): Dataset[R] = - new Dataset[R](sqlContext, f(logicalPlan, other.logicalPlan), encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e52f05a5f4..5f3128d8e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -341,6 +341,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil + case logical.MapElements(f, in, out, child) => + execution.MapElements(f, in, out, planLater(child)) :: Nil case logical.AppendColumns(f, in, out, child) => execution.AppendColumns(f, in, out, planLater(child)) :: Nil case logical.MapGroups(f, key, in, out, grouping, data, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 9f539c4929..4e75a3a794 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -152,7 +152,7 @@ trait CodegenSupport extends SparkPlan { s""" | |/*** CONSUME: ${toCommentSafeString(parent.simpleString)} */ - |${evaluated} + |$evaluated |${parent.doConsume(ctx, inputVars, rowVar)} """.stripMargin } @@ -169,20 +169,20 @@ trait CodegenSupport extends SparkPlan { /** * Returns source code to evaluate the variables for required attributes, and clear the code - * of evaluated variables, to prevent them to be evaluated twice.. + * of evaluated variables, to prevent them to be evaluated twice. */ protected def evaluateRequiredVariables( attributes: Seq[Attribute], variables: Seq[ExprCode], required: AttributeSet): String = { - var evaluateVars = "" + val evaluateVars = new StringBuilder variables.zipWithIndex.foreach { case (ev, i) => if (ev.code != "" && required.contains(attributes(i))) { - evaluateVars += ev.code.trim + "\n" + evaluateVars.append(ev.code.trim + "\n") ev.code = "" } } - evaluateVars + evaluateVars.toString() } /** @@ -305,7 +305,6 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup def doCodeGen(): (CodegenContext, String) = { val ctx = new CodegenContext val code = child.asInstanceOf[CodegenSupport].produce(ctx, this) - val references = ctx.references.toArray val source = s""" public Object generate(Object[] references) { return new GeneratedIterator(references); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index 582dda8603..f48f3f09c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.execution +import scala.language.existentials + +import org.apache.spark.api.java.function.MapFunction import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection, GenerateUnsafeRowJoiner} +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType @@ -67,6 +70,70 @@ case class MapPartitions( } } +/** + * Applies the given function to each input row and encodes the result. + * + * Note that, each serializer expression needs the result object which is returned by the given + * function, as input. This operator uses some tricks to make sure we only calculate the result + * object once. We don't use [[Project]] directly as subexpression elimination doesn't work with + * whole stage codegen and it's confusing to show the un-common-subexpression-eliminated version of + * a project while explain. + */ +case class MapElements( + func: AnyRef, + deserializer: Expression, + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with ObjectOperator with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val (funcClass, methodName) = func match { + case m: MapFunction[_, _] => classOf[MapFunction[_, _]] -> "call" + case _ => classOf[Any => Any] -> "apply" + } + val funcObj = Literal.create(func, ObjectType(funcClass)) + val resultObjType = serializer.head.collect { case b: BoundReference => b }.head.dataType + val callFunc = Invoke(funcObj, methodName, resultObjType, Seq(deserializer)) + + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(callFunc, child.output)) + ctx.currentVars = input + val evaluated = bound.gen(ctx) + + val resultObj = LambdaVariable(evaluated.value, evaluated.isNull, resultObjType) + val outputFields = serializer.map(_ transform { + case _: BoundReference => resultObj + }) + val resultVars = outputFields.map(_.gen(ctx)) + s""" + ${evaluated.code} + ${consume(ctx, resultVars)} + """ + } + + override protected def doExecute(): RDD[InternalRow] = { + val callFunc: Any => Any = func match { + case m: MapFunction[_, _] => i => m.asInstanceOf[MapFunction[Any, Any]].call(i) + case _ => func.asInstanceOf[Any => Any] + } + child.execute().mapPartitionsInternal { iter => + val getObject = generateToObject(deserializer, child.output) + val outputObject = generateToRow(serializer) + iter.map(row => outputObject(callFunc(getObject(row)))) + } + } + + override def outputOrdering: Seq[SortOrder] = child.outputOrdering +} + /** * Applies the given function to each input row, appending the encoded result at the end of the row. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala new file mode 100644 index 0000000000..6eb952445f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import org.apache.spark.SparkContext +import org.apache.spark.sql.types.StringType +import org.apache.spark.util.Benchmark + +/** + * Benchmark for Dataset typed operations comparing with DataFrame and RDD versions. + */ +object DatasetBenchmark { + + case class Data(l: Long, s: String) + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + import sqlContext.implicits._ + + val numRows = 10000000 + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val numChains = 10 + + val benchmark = new Benchmark("back-to-back map", numRows) + + val func = (d: Data) => Data(d.l + 1, d.s) + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.map(func) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.select($"l" + 1 as "l") + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.map(func) + i += 1 + } + res.foreach(_ => Unit) + } + + /* + Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 + Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz + back-to-back map: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Dataset 902 / 995 11.1 90.2 1.0X + DataFrame 132 / 167 75.5 13.2 6.8X + RDD 216 / 237 46.3 21.6 4.2X + */ + benchmark.run() + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index f7f3bd78e9..4e62fac919 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -198,10 +198,7 @@ abstract class QueryTest extends PlanTest { val logicalPlan = df.queryExecution.analyzed // bypass some cases that we can't handle currently. logicalPlan.transform { - case _: MapPartitions => return - case _: MapGroups => return - case _: AppendColumns => return - case _: CoGroup => return + case _: ObjectOperator => return case _: LogicalRelation => return }.transformAllExpressions { case a: ImperativeAggregate => return diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 6d5be0b5dd..f73ca887f1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,7 +17,8 @@ package org.apache.spark.sql.execution -import org.apache.spark.sql.Row +import org.apache.spark.api.java.function.MapFunction +import org.apache.spark.sql.{Encoders, Row} import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} @@ -70,4 +71,15 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Sort]).isDefined) assert(df.collect() === Array(Row(1), Row(2), Row(3))) } + + test("MapElements should be included in WholeStageCodegen") { + import testImplicits._ + + val ds = sqlContext.range(10).map(_.toString) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) + assert(ds.collect() === 0.until(10).map(_.toString).toArray) + } } -- cgit v1.2.3 From 10494feae0c2c1aca545c73ba61af6d8f743c5bb Mon Sep 17 00:00:00 2001 From: Kousuke Saruta Date: Wed, 6 Apr 2016 10:57:46 -0700 Subject: [SPARK-14426][SQL] Merge PerserUtils and ParseUtils ## What changes were proposed in this pull request? We have ParserUtils and ParseUtils which are both utility collections for use during the parsing process. Those names and what they are used for is very similar so I think we can merge them. Also, the original unescapeSQLString method may have a fault. When "\u0061" style character literals are passed to the method, it's not unescaped successfully. This patch fix the bug. ## How was this patch tested? Added a new test case. Author: Kousuke Saruta Closes #12199 from sarutak/merge-ParseUtils-and-ParserUtils. --- scalastyle-config.xml | 2 +- .../spark/sql/catalyst/parser/ParseUtils.java | 135 --------------------- .../spark/sql/catalyst/parser/ParserUtils.scala | 78 +++++++++++- .../catalyst/parser/ExpressionParserSuite.scala | 2 +- .../sql/catalyst/parser/ParserUtilsSuite.scala | 65 ++++++++++ 5 files changed, 144 insertions(+), 138 deletions(-) delete mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala (limited to 'sql/catalyst') diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 37d2ecf48e..33c2cbd293 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -116,7 +116,7 @@ This file is divided into 3 sections: - + diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java deleted file mode 100644 index 01f89112a7..0000000000 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/parser/ParseUtils.java +++ /dev/null @@ -1,135 +0,0 @@ -/** - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.parser; - -import java.nio.charset.StandardCharsets; - -/** - * A couple of utility methods that help with parsing ASTs. - * - * The 'unescapeSQLString' method in this class was take from the SemanticAnalyzer in Hive: - * ql/src/java/org/apache/hadoop/hive/ql/parse/BaseSemanticAnalyzer.java - */ -public final class ParseUtils { - private ParseUtils() { - super(); - } - - private static final int[] multiplier = new int[] {1000, 100, 10, 1}; - - @SuppressWarnings("nls") - public static String unescapeSQLString(String b) { - Character enclosure = null; - - // Some of the strings can be passed in as unicode. For example, the - // delimiter can be passed in as \002 - So, we first check if the - // string is a unicode number, else go back to the old behavior - StringBuilder sb = new StringBuilder(b.length()); - for (int i = 0; i < b.length(); i++) { - - char currentChar = b.charAt(i); - if (enclosure == null) { - if (currentChar == '\'' || b.charAt(i) == '\"') { - enclosure = currentChar; - } - // ignore all other chars outside the enclosure - continue; - } - - if (enclosure.equals(currentChar)) { - enclosure = null; - continue; - } - - if (currentChar == '\\' && (i + 6 < b.length()) && b.charAt(i + 1) == 'u') { - int code = 0; - int base = i + 2; - for (int j = 0; j < 4; j++) { - int digit = Character.digit(b.charAt(j + base), 16); - code += digit * multiplier[j]; - } - sb.append((char)code); - i += 5; - continue; - } - - if (currentChar == '\\' && (i + 4 < b.length())) { - char i1 = b.charAt(i + 1); - char i2 = b.charAt(i + 2); - char i3 = b.charAt(i + 3); - if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') - && (i3 >= '0' && i3 <= '7')) { - byte bVal = (byte) ((i3 - '0') + ((i2 - '0') * 8) + ((i1 - '0') * 8 * 8)); - byte[] bValArr = new byte[1]; - bValArr[0] = bVal; - String tmp = new String(bValArr, StandardCharsets.UTF_8); - sb.append(tmp); - i += 3; - continue; - } - } - - if (currentChar == '\\' && (i + 2 < b.length())) { - char n = b.charAt(i + 1); - switch (n) { - case '0': - sb.append("\0"); - break; - case '\'': - sb.append("'"); - break; - case '"': - sb.append("\""); - break; - case 'b': - sb.append("\b"); - break; - case 'n': - sb.append("\n"); - break; - case 'r': - sb.append("\r"); - break; - case 't': - sb.append("\t"); - break; - case 'Z': - sb.append("\u001A"); - break; - case '\\': - sb.append("\\"); - break; - // The following 2 lines are exactly what MySQL does TODO: why do we do this? - case '%': - sb.append("\\%"); - break; - case '_': - sb.append("\\_"); - break; - default: - sb.append(n); - } - i++; - } else { - sb.append(currentChar); - } - } - return sb.toString(); - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 90b76dc314..cb9fefec8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,11 +16,12 @@ */ package org.apache.spark.sql.catalyst.parser +import scala.collection.mutable.StringBuilder + import org.antlr.v4.runtime.{CharStream, ParserRuleContext, Token} import org.antlr.v4.runtime.misc.Interval import org.antlr.v4.runtime.tree.TerminalNode -import org.apache.spark.sql.catalyst.parser.ParseUtils.unescapeSQLString import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.{CurrentOrigin, Origin} @@ -87,6 +88,81 @@ object ParserUtils { } } + /** Unescape baskslash-escaped string enclosed by quotes. */ + def unescapeSQLString(b: String): String = { + var enclosure: Character = null + val sb = new StringBuilder(b.length()) + + def appendEscapedChar(n: Char) { + n match { + case '0' => sb.append('\u0000') + case '\'' => sb.append('\'') + case '"' => sb.append('\"') + case 'b' => sb.append('\b') + case 'n' => sb.append('\n') + case 'r' => sb.append('\r') + case 't' => sb.append('\t') + case 'Z' => sb.append('\u001A') + case '\\' => sb.append('\\') + // The following 2 lines are exactly what MySQL does TODO: why do we do this? + case '%' => sb.append("\\%") + case '_' => sb.append("\\_") + case _ => sb.append(n) + } + } + + var i = 0 + val strLength = b.length + while (i < strLength) { + val currentChar = b.charAt(i) + if (enclosure == null) { + if (currentChar == '\'' || currentChar == '\"') { + enclosure = currentChar + } + } else if (enclosure == currentChar) { + enclosure = null + } else if (currentChar == '\\') { + + if ((i + 6 < strLength) && b.charAt(i + 1) == 'u') { + // \u0000 style character literals. + + val base = i + 2 + val code = (0 until 4).foldLeft(0) { (mid, j) => + val digit = Character.digit(b.charAt(j + base), 16) + (mid << 4) + digit + } + sb.append(code.asInstanceOf[Char]) + i += 5 + } else if (i + 4 < strLength) { + // \000 style character literals. + + val i1 = b.charAt(i + 1) + val i2 = b.charAt(i + 2) + val i3 = b.charAt(i + 3) + + if ((i1 >= '0' && i1 <= '1') && (i2 >= '0' && i2 <= '7') && (i3 >= '0' && i3 <= '7')) { + val tmp = ((i3 - '0') + ((i2 - '0') << 3) + ((i1 - '0') << 6)).asInstanceOf[Char] + sb.append(tmp) + i += 3 + } else { + appendEscapedChar(i1) + i += 1 + } + } else if (i + 2 < strLength) { + // escaped character literals. + val n = b.charAt(i + 1) + appendEscapedChar(n) + i += 1 + } + } else { + // non-escaped character literals. + sb.append(currentChar) + } + i += 1 + } + sb.toString() + } + /** Some syntactic sugar which makes it easier to work with optional clauses for LogicalPlans. */ implicit class EnhancedLogicalPlan(val plan: LogicalPlan) extends AnyVal { /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala index a80d29ce5d..6f40ec67ec 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ExpressionParserSuite.scala @@ -415,7 +415,7 @@ class ExpressionParserSuite extends PlanTest { assertEqual("'\\110\\145\\154\\154\\157\\041'", "Hello!") // Unicode - assertEqual("'\\u0087\\u0111\\u0114\\u0108\\u0100\\u0032\\u0058\\u0041'", "World :)") + assertEqual("'\\u0057\\u006F\\u0072\\u006C\\u0064\\u0020\\u003A\\u0029'", "World :)") } test("intervals") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala new file mode 100644 index 0000000000..d090daf7b4 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/ParserUtilsSuite.scala @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.SparkFunSuite + +class ParserUtilsSuite extends SparkFunSuite { + + import ParserUtils._ + + test("unescapeSQLString") { + // scalastyle:off nonascii + + // String not including escaped characters and enclosed by double quotes. + assert(unescapeSQLString(""""abcdefg"""") == "abcdefg") + + // String enclosed by single quotes. + assert(unescapeSQLString("""'C0FFEE'""") == "C0FFEE") + + // Strings including single escaped characters. + assert(unescapeSQLString("""'\0'""") == "\u0000") + assert(unescapeSQLString(""""\'"""") == "\'") + assert(unescapeSQLString("""'\"'""") == "\"") + assert(unescapeSQLString(""""\b"""") == "\b") + assert(unescapeSQLString("""'\n'""") == "\n") + assert(unescapeSQLString(""""\r"""") == "\r") + assert(unescapeSQLString("""'\t'""") == "\t") + assert(unescapeSQLString(""""\Z"""") == "\u001A") + assert(unescapeSQLString("""'\\'""") == "\\") + assert(unescapeSQLString(""""\%"""") == "\\%") + assert(unescapeSQLString("""'\_'""") == "\\_") + + // String including '\000' style literal characters. + assert(unescapeSQLString("""'3 + 5 = \070'""") == "3 + 5 = \u0038") + assert(unescapeSQLString(""""\000"""") == "\u0000") + + // String including invalid '\000' style literal characters. + assert(unescapeSQLString(""""\256"""") == "256") + + // String including a '\u0000' style literal characters (\u732B is a cat in Kanji). + assert(unescapeSQLString(""""How cute \u732B are"""") == "How cute \u732B are") + + // String including a surrogate pair character + // (\uD867\uDE3D is Okhotsk atka mackerel in Kanji). + assert(unescapeSQLString(""""\uD867\uDE3D is a fish"""") == "\uD867\uDE3D is a fish") + + // scalastyle:on nonascii + } + + // TODO: Add test cases for other methods in ParserUtils +} -- cgit v1.2.3 From 5abd02c02b3fa3505defdc8ab0c5c5e23a16aa80 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 6 Apr 2016 11:05:52 -0700 Subject: [SPARK-14429][SQL] Improve LIKE pattern in "SHOW TABLES / FUNCTIONS LIKE " DDL LIKE is commonly used in SHOW TABLES / FUNCTIONS etc DDL. In the pattern, user can use `|` or `*` as wildcards. 1. Currently, we used `replaceAll()` to replace `*` with `.*`, but the replacement was scattered in several places; I have created an utility method and use it in all the places; 2. Consistency with Hive: the pattern is case insensitive in Hive and white spaces will be trimmed, but current pattern matching does not do that. For example, suppose we have tables (t1, t2, t3), `SHOW TABLES LIKE ' T* ' ` will list all the t-tables. Please use Hive to verify it. 3. Combined with `|`, the result will be sorted. For pattern like `' B*|a* '`, it will list the result in a-b order. I've made some changes to the utility method to make sure we will get the same result as Hive does. A new method was created in StringUtil and test cases were added. andrewor14 Author: bomeng Closes #12206 from bomeng/SPARK-14429. --- .../sql/catalyst/catalog/InMemoryCatalog.scala | 13 ++++-------- .../sql/catalyst/catalog/SessionCatalog.scala | 10 +++------- .../spark/sql/catalyst/util/StringUtils.scala | 23 +++++++++++++++++++++- .../spark/sql/catalyst/util/StringUtilsSuite.scala | 12 +++++++++++ .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 12 +++++------ 5 files changed, 46 insertions(+), 24 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 2af0107fa3..5d136b663f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -21,7 +21,7 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An in-memory (ephemeral) implementation of the system catalog. @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog { // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] - private def filterPattern(names: Seq[String], pattern: String): Seq[String] = { - val regex = pattern.replaceAll("\\*", ".*").r - names.filter { funcName => regex.pattern.matcher(funcName).matches() } - } - private def functionExists(db: String, funcName: String): Boolean = { requireDbExists(db) catalog(db).functions.contains(funcName) @@ -141,7 +136,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listDatabases(pattern: String): Seq[String] = synchronized { - filterPattern(listDatabases(), pattern) + StringUtils.filterPattern(listDatabases(), pattern) } override def setCurrentDatabase(db: String): Unit = { /* no-op */ } @@ -208,7 +203,7 @@ class InMemoryCatalog extends ExternalCatalog { } override def listTables(db: String, pattern: String): Seq[String] = synchronized { - filterPattern(listTables(db), pattern) + StringUtils.filterPattern(listTables(db), pattern) } // -------------------------------------------------------------------------- @@ -322,7 +317,7 @@ class InMemoryCatalog extends ExternalCatalog { override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { requireDbExists(db) - filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) + StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 62a3b1c105..2acf584e8f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, NoSuchFunctionE import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionInfo} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias} - +import org.apache.spark.sql.catalyst.util.StringUtils /** * An internal catalog that is used by a Spark Session. This internal catalog serves as a @@ -297,9 +297,7 @@ class SessionCatalog( def listTables(db: String, pattern: String): Seq[TableIdentifier] = { val dbTables = externalCatalog.listTables(db, pattern).map { t => TableIdentifier(t, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val _tempTables = tempTables.keys.toSeq - .filter { t => regex.pattern.matcher(t).matches() } + val _tempTables = StringUtils.filterPattern(tempTables.keys.toSeq, pattern) .map { t => TableIdentifier(t) } dbTables ++ _tempTables } @@ -613,9 +611,7 @@ class SessionCatalog( def listFunctions(db: String, pattern: String): Seq[FunctionIdentifier] = { val dbFunctions = externalCatalog.listFunctions(db, pattern).map { f => FunctionIdentifier(f, Some(db)) } - val regex = pattern.replaceAll("\\*", ".*").r - val loadedFunctions = functionRegistry.listFunction() - .filter { f => regex.pattern.matcher(f).matches() } + val loadedFunctions = StringUtils.filterPattern(functionRegistry.listFunction(), pattern) .map { f => FunctionIdentifier(f) } // TODO: Actually, there will be dbFunctions that have been loaded into the FunctionRegistry. // So, the returned list may have two entries for the same function. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index c2eeb3c565..0f65028261 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.util -import java.util.regex.Pattern +import java.util.regex.{Pattern, PatternSyntaxException} import org.apache.spark.unsafe.types.UTF8String @@ -52,4 +52,25 @@ object StringUtils { def isTrueString(s: UTF8String): Boolean = trueStrings.contains(s.toLowerCase) def isFalseString(s: UTF8String): Boolean = falseStrings.contains(s.toLowerCase) + + /** + * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL + * @param names the names list to be filtered + * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will + * follows regular expression convention, case insensitive match and white spaces + * on both ends will be ignored + * @return the filtered names list in order + */ + def filterPattern(names: Seq[String], pattern: String): Seq[String] = { + val funcNames = scala.collection.mutable.SortedSet.empty[String] + pattern.trim().split("\\|").foreach { subPattern => + try { + val regex = ("(?i)" + subPattern.replaceAll("\\*", ".*")).r + funcNames ++= names.filter{ name => regex.pattern.matcher(name).matches() } + } catch { + case _: PatternSyntaxException => + } + } + funcNames.toSeq + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala index d6f273f9e5..2ffc18a8d1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/StringUtilsSuite.scala @@ -31,4 +31,16 @@ class StringUtilsSuite extends SparkFunSuite { assert(escapeLikeRegex("**") === "(?s)\\Q*\\E\\Q*\\E") assert(escapeLikeRegex("a_b") === "(?s)\\Qa\\E.\\Qb\\E") } + + test("filter pattern") { + val names = Seq("a1", "a2", "b2", "c3") + assert(filterPattern(names, " * ") === Seq("a1", "a2", "b2", "c3")) + assert(filterPattern(names, "*a*") === Seq("a1", "a2")) + assert(filterPattern(names, " *a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a* ") === Seq("a1", "a2")) + assert(filterPattern(names, " a.* ") === Seq("a1", "a2")) + assert(filterPattern(names, " B.*|a* ") === Seq("a1", "a2", "b2")) + assert(filterPattern(names, " a. ") === Seq("a1", "a2")) + assert(filterPattern(names, " d* ") === Nil) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5a851b47ca..2ab7c1581c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,6 +24,7 @@ import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.catalyst.analysis.UnresolvedException import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.plans.logical.Aggregate +import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, CartesianProduct, SortMergeJoin} import org.apache.spark.sql.functions._ @@ -56,17 +57,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { test("show functions") { def getFunctions(pattern: String): Seq[Row] = { - val regex = java.util.regex.Pattern.compile(pattern) - sqlContext.sessionState.functionRegistry.listFunction() - .filter(regex.matcher(_).matches()).map(Row(_)) + StringUtils.filterPattern(sqlContext.sessionState.functionRegistry.listFunction(), pattern) + .map(Row(_)) } - checkAnswer(sql("SHOW functions"), getFunctions(".*")) + checkAnswer(sql("SHOW functions"), getFunctions("*")) Seq("^c*", "*e$", "log*", "*date*").foreach { pattern => // For the pattern part, only '*' and '|' are allowed as wildcards. // For '*', we need to replace it to '.*'. - checkAnswer( - sql(s"SHOW FUNCTIONS '$pattern'"), - getFunctions(pattern.replaceAll("\\*", ".*"))) + checkAnswer(sql(s"SHOW FUNCTIONS '$pattern'"), getFunctions(pattern)) } } -- cgit v1.2.3 From 3c8d8821654e3d82ef927c55272348e1bcc34a79 Mon Sep 17 00:00:00 2001 From: bomeng Date: Wed, 6 Apr 2016 11:12:48 -0700 Subject: [SPARK-14383][SQL] missing "|" in the g4 file ## What changes were proposed in this pull request? A very trivial one. It missed "|" between DISTRIBUTE and UNSET. ## How was this patch tested? I do not think it is really needed. Author: bomeng Closes #12156 from bomeng/SPARK-14383. --- .../main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../org/apache/spark/sql/execution/command/DDLCommandSuite.scala | 7 +++++++ 2 files changed, 8 insertions(+), 1 deletion(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 96c170be3d..8a45b4f2e1 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -645,7 +645,7 @@ nonReserved | NO | DATA | START | TRANSACTION | COMMIT | ROLLBACK | WORK | ISOLATION | LEVEL | SNAPSHOT | READ | WRITE | ONLY - | SORT | CLUSTER | DISTRIBUTE UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION + | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE | FIRST | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 46dcadd690..8e63b69876 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ @@ -685,4 +686,10 @@ class DDLCommandSuite extends PlanTest { parser.parsePlan("SELECT TRANSFORM (key, value) USING 'cat' AS (tKey, tValue) FROM testData") } } + + test("SPARK-14383: DISTRIBUTE and UNSET as non-keywords") { + val sql = "SELECT distribute, unset FROM x" + val parsed = parser.parsePlan(sql) + assert(parsed.isInstanceOf[Project]) + } } -- cgit v1.2.3 From 5a4b11a901703464b9261dea0642d80cf8d4856c Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 6 Apr 2016 15:33:39 -0700 Subject: [SPARK-14224] [SPARK-14223] [SPARK-14310] [SQL] fix RowEncoder and parquet reader for wide table ## What changes were proposed in this pull request? 1) fix the RowEncoder for wide table (many columns) by splitting the generate code into multiple functions. 2) Separate DataSourceScan as RowDataSourceScan and BatchedDataSourceScan 3) Disable the returning columnar batch in parquet reader if there are many columns. 4) Added a internal config for maximum number of fields (nested) columns supported by whole stage codegen. Closes #12098 ## How was this patch tested? Add a tests for table with 1000 columns. Author: Davies Liu Closes #12047 from davies/many_columns. --- .../spark/sql/catalyst/expressions/objects.scala | 24 +- .../parquet/VectorizedParquetRecordReader.java | 19 +- .../scala/org/apache/spark/sql/SQLContext.scala | 2 +- .../apache/spark/sql/execution/ExistingRDD.scala | 291 +++++++++++---------- .../spark/sql/execution/WholeStageCodegen.scala | 13 +- .../execution/datasources/DataSourceStrategy.scala | 6 +- .../execution/datasources/FileSourceStrategy.scala | 2 +- .../execution/datasources/SqlNewHadoopRDD.scala | 34 +-- .../datasources/parquet/ParquetRelation.scala | 77 ++++-- .../org/apache/spark/sql/internal/SQLConf.scala | 11 + .../org/apache/spark/sql/sources/interfaces.scala | 9 + .../datasources/FileSourceStrategySuite.scala | 3 +- .../datasources/parquet/ParquetQuerySuite.scala | 10 + 13 files changed, 267 insertions(+), 234 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index a0490e1351..28b6b2adf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -524,22 +524,26 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") - val schemaField = ctx.addReferenceObj("schema", schema) - s""" - boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + - children.zipWithIndex.map { case (e, i) => - val eval = e.gen(ctx) - eval.code + s""" + ctx.addMutableState("Object[]", values, "") + + val childrenCodes = children.zipWithIndex.map { case (e, i) => + val eval = e.gen(ctx) + eval.code + s""" if (${eval.isNull}) { $values[$i] = null; } else { $values[$i] = ${eval.value}; } """ - }.mkString("\n") + - s"final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField);" + } + val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) + val schemaField = ctx.addReferenceObj("schema", schema) + s""" + boolean ${ev.isNull} = false; + $values = new Object[${children.size}]; + $childrenCode + final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); + """ } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index a0b6276ef5..51bdf0f0f2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,7 +31,8 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.ColumnarBatch; -import org.apache.spark.sql.types.*; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.types.StructType; /** * A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the @@ -99,20 +100,6 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private static final MemoryMode DEFAULT_MEMORY_MODE = MemoryMode.ON_HEAP; - /** - * Tries to initialize the reader for this split. Returns true if this reader supports reading - * this split and false otherwise. - */ - public boolean tryInitialize(InputSplit inputSplit, TaskAttemptContext taskAttemptContext) - throws IOException, InterruptedException { - try { - initialize(inputSplit, taskAttemptContext); - return true; - } catch (UnsupportedOperationException e) { - return false; - } - } - /** * Implementation of RecordReader API. */ @@ -222,7 +209,7 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa return columnarBatch; } - /** + /* * Can be called before any rows are returned to enable returning columnar batches directly. */ public void enableReturningBatches() { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1c9cb79ba4..9259ff4062 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -120,7 +120,7 @@ class SQLContext private[sql]( */ @transient protected[sql] lazy val sessionState: SessionState = new SessionState(self) - protected[sql] def conf: SQLConf = sessionState.conf + protected[spark] def conf: SQLConf = sessionState.conf /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index ab575e90c9..392c48fb7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, UnknownPartitioning} +import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partitioning, UnknownPartitioning} import org.apache.spark.sql.catalyst.util.toCommentSafeString import org.apache.spark.sql.execution.datasources.parquet.{DefaultSource => ParquetSource} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation} -import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.types.{AtomicType, DataType} object RDDConversions { def productToRowRdd[A <: Product](data: RDD[A], outputTypes: Seq[DataType]): RDD[InternalRow] = { @@ -123,28 +123,30 @@ private[sql] case class PhysicalRDD( } } -/** Physical plan node for scanning data from a relation. */ -private[sql] case class DataSourceScan( - output: Seq[Attribute], - rdd: RDD[InternalRow], - @transient relation: BaseRelation, - override val metadata: Map[String, String] = Map.empty) - extends LeafNode with CodegenSupport { +private[sql] trait DataSourceScan extends LeafNode { + val rdd: RDD[InternalRow] + val relation: BaseRelation override val nodeName: String = relation.toString // Ignore rdd when checking results - override def sameResult(plan: SparkPlan ): Boolean = plan match { + override def sameResult(plan: SparkPlan): Boolean = plan match { case other: DataSourceScan => relation == other.relation && metadata == other.metadata case _ => false } +} - private[sql] override lazy val metrics = if (canProcessBatches()) { - Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), - "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) - } else { +/** Physical plan node for scanning data from a relation. */ +private[sql] case class RowDataSourceScan( + output: Seq[Attribute], + rdd: RDD[InternalRow], + @transient relation: BaseRelation, + override val outputPartitioning: Partitioning, + override val metadata: Map[String, String] = Map.empty) + extends DataSourceScan with CodegenSupport { + + private[sql] override lazy val metrics = Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) - } val outputUnsafeRows = relation match { case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] => @@ -153,38 +155,6 @@ private[sql] case class DataSourceScan( case _ => false } - override val outputPartitioning = { - val bucketSpec = relation match { - // TODO: this should be closer to bucket planning. - case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec - case _ => None - } - - def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { - throw new AnalysisException(s"bucket column $colName not found in existing columns " + - s"(${output.map(_.name).mkString(", ")})") - } - - bucketSpec.map { spec => - val numBuckets = spec.numBuckets - val bucketColumns = spec.bucketColumnNames.map(toAttribute) - HashPartitioning(bucketColumns, numBuckets) - }.getOrElse { - UnknownPartitioning(0) - } - } - - private def canProcessBatches(): Boolean = { - relation match { - case r: HadoopFsRelation if r.fileFormat.isInstanceOf[ParquetSource] && - SQLContext.getActive().get.conf.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED) && - SQLContext.getActive().get.conf.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED) => - true - case _ => - false - } - } - protected override def doExecute(): RDD[InternalRow] = { val unsafeRow = if (outputUnsafeRows) { rdd @@ -211,6 +181,57 @@ private[sql] case class DataSourceScan( rdd :: Nil } + override protected def doProduce(ctx: CodegenContext): String = { + val numOutputRows = metricTerm(ctx, "numOutputRows") + // PhysicalRDD always just has one input + val input = ctx.freshName("input") + ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + val exprRows = output.zipWithIndex.map{ case (a, i) => + new BoundReference(i, a.dataType, a.nullable) + } + val row = ctx.freshName("row") + ctx.INPUT_ROW = row + ctx.currentVars = null + val columnsRowInput = exprRows.map(_.gen(ctx)) + val inputRow = if (outputUnsafeRows) row else null + s""" + |while ($input.hasNext()) { + | InternalRow $row = (InternalRow) $input.next(); + | $numOutputRows.add(1); + | ${consume(ctx, columnsRowInput, inputRow).trim} + | if (shouldStop()) return; + |} + """.stripMargin + } +} + +/** Physical plan node for scanning data from a batched relation. */ +private[sql] case class BatchedDataSourceScan( + output: Seq[Attribute], + rdd: RDD[InternalRow], + @transient relation: BaseRelation, + override val outputPartitioning: Partitioning, + override val metadata: Map[String, String] = Map.empty) + extends DataSourceScan with CodegenSupport { + + private[sql] override lazy val metrics = + Map("numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"), + "scanTime" -> SQLMetrics.createTimingMetric(sparkContext, "scan time")) + + protected override def doExecute(): RDD[InternalRow] = { + throw new UnsupportedOperationException + } + + override def simpleString: String = { + val metadataEntries = for ((key, value) <- metadata.toSeq.sorted) yield s"$key: $value" + val metadataStr = metadataEntries.mkString(" ", ", ", "") + s"BatchedScan $nodeName${output.mkString("[", ",", "]")}$metadataStr" + } + + override def upstreams(): Seq[RDD[InternalRow]] = { + rdd :: Nil + } + private def genCodeColumnVector(ctx: CodegenContext, columnVar: String, ordinal: String, dataType: DataType, nullable: Boolean): ExprCode = { val javaType = ctx.javaType(dataType) @@ -232,113 +253,65 @@ private[sql] case class DataSourceScan( // Support codegen so that we can avoid the UnsafeRow conversion in all cases. Codegen // never requires UnsafeRow as input. override protected def doProduce(ctx: CodegenContext): String = { - val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" - val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" val input = ctx.freshName("input") - val idx = ctx.freshName("batchIdx") - val rowidx = ctx.freshName("rowIdx") - val batch = ctx.freshName("batch") // PhysicalRDD always just has one input ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") + + // metrics + val numOutputRows = metricTerm(ctx, "numOutputRows") + val scanTimeMetric = metricTerm(ctx, "scanTime") + val scanTimeTotalNs = ctx.freshName("scanTime") + ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") + + val columnarBatchClz = "org.apache.spark.sql.execution.vectorized.ColumnarBatch" + val batch = ctx.freshName("batch") ctx.addMutableState(columnarBatchClz, batch, s"$batch = null;") + + val columnVectorClz = "org.apache.spark.sql.execution.vectorized.ColumnVector" + val idx = ctx.freshName("batchIdx") ctx.addMutableState("int", idx, s"$idx = 0;") val colVars = output.indices.map(i => ctx.freshName("colInstance" + i)) val columnAssigns = colVars.zipWithIndex.map { case (name, i) => ctx.addMutableState(columnVectorClz, name, s"$name = null;") - s"$name = ${batch}.column($i);" } - - val row = ctx.freshName("row") - val numOutputRows = metricTerm(ctx, "numOutputRows") + s"$name = $batch.column($i);" + } - // The input RDD can either return (all) ColumnarBatches or InternalRows. We determine this - // by looking at the first value of the RDD and then calling the function which will process - // the remaining. It is faster to return batches. - // TODO: The abstractions between this class and SqlNewHadoopRDD makes it difficult to know - // here which path to use. Fix this. + val nextBatch = ctx.freshName("nextBatch") + ctx.addNewFunction(nextBatch, + s""" + |private void $nextBatch() throws java.io.IOException { + | long getBatchStart = System.nanoTime(); + | if ($input.hasNext()) { + | $batch = ($columnarBatchClz)$input.next(); + | $numOutputRows.add($batch.numRows()); + | $idx = 0; + | ${columnAssigns.mkString("", "\n", "\n")} + | } + | $scanTimeTotalNs += System.nanoTime() - getBatchStart; + |}""".stripMargin) - val exprRows = - output.zipWithIndex.map(x => new BoundReference(x._2, x._1.dataType, x._1.nullable)) - ctx.INPUT_ROW = row ctx.currentVars = null - val columnsRowInput = exprRows.map(_.gen(ctx)) - val inputRow = if (outputUnsafeRows) row else null - val scanRows = ctx.freshName("processRows") - ctx.addNewFunction(scanRows, - s""" - | private void $scanRows(InternalRow $row) throws java.io.IOException { - | boolean firstRow = true; - | while (!shouldStop() && (firstRow || $input.hasNext())) { - | if (firstRow) { - | firstRow = false; - | } else { - | $row = (InternalRow) $input.next(); - | } - | $numOutputRows.add(1); - | ${consume(ctx, columnsRowInput, inputRow).trim} - | } - | }""".stripMargin) - - // Timers for how long we spent inside the scan. We can only maintain this when using batches, - // otherwise the overhead is too high. - if (canProcessBatches()) { - val scanTimeMetric = metricTerm(ctx, "scanTime") - val getBatchStart = ctx.freshName("scanStart") - val scanTimeTotalNs = ctx.freshName("scanTime") - ctx.currentVars = null - val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => - genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } - val scanBatches = ctx.freshName("processBatches") - ctx.addMutableState("long", scanTimeTotalNs, s"$scanTimeTotalNs = 0;") - - ctx.addNewFunction(scanBatches, - s""" - | private void $scanBatches() throws java.io.IOException { - | while (true) { - | int numRows = $batch.numRows(); - | if ($idx == 0) { - | ${columnAssigns.mkString("", "\n", "\n")} - | $numOutputRows.add(numRows); - | } - | - | while (!shouldStop() && $idx < numRows) { - | int $rowidx = $idx++; - | ${consume(ctx, columnsBatchInput).trim} - | } - | if (shouldStop()) return; - | - | long $getBatchStart = System.nanoTime(); - | if (!$input.hasNext()) { - | $batch = null; - | $scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); - | break; - | } - | $batch = ($columnarBatchClz)$input.next(); - | $scanTimeTotalNs += System.nanoTime() - $getBatchStart; - | $idx = 0; - | } - | }""".stripMargin) - - val value = ctx.freshName("value") - s""" - | if ($batch != null) { - | $scanBatches(); - | } else if ($input.hasNext()) { - | Object $value = $input.next(); - | if ($value instanceof $columnarBatchClz) { - | $batch = ($columnarBatchClz)$value; - | $scanBatches(); - | } else { - | $scanRows((InternalRow) $value); - | } - | } - """.stripMargin - } else { - s""" - |if ($input.hasNext()) { - | $scanRows((InternalRow) $input.next()); - |} - """.stripMargin + val rowidx = ctx.freshName("rowIdx") + val columnsBatchInput = (output zip colVars).map { case (attr, colVar) => + genCodeColumnVector(ctx, colVar, rowidx, attr.dataType, attr.nullable) } + s""" + |if ($batch == null) { + | $nextBatch(); + |} + |while ($batch != null) { + | int numRows = $batch.numRows(); + | while ($idx < numRows) { + | int $rowidx = $idx++; + | ${consume(ctx, columnsBatchInput).trim} + | if (shouldStop()) return; + | } + | $batch = null; + | $nextBatch(); + |} + |$scanTimeMetric.add($scanTimeTotalNs / (1000 * 1000)); + |$scanTimeTotalNs = 0; + """.stripMargin } } @@ -346,4 +319,38 @@ private[sql] object DataSourceScan { // Metadata keys val INPUT_PATHS = "InputPaths" val PUSHED_FILTERS = "PushedFilters" + + def create( + output: Seq[Attribute], + rdd: RDD[InternalRow], + relation: BaseRelation, + metadata: Map[String, String] = Map.empty): DataSourceScan = { + val outputPartitioning = { + val bucketSpec = relation match { + // TODO: this should be closer to bucket planning. + case r: HadoopFsRelation if r.sqlContext.conf.bucketingEnabled => r.bucketSpec + case _ => None + } + + def toAttribute(colName: String): Attribute = output.find(_.name == colName).getOrElse { + throw new AnalysisException(s"bucket column $colName not found in existing columns " + + s"(${output.map(_.name).mkString(", ")})") + } + + bucketSpec.map { spec => + val numBuckets = spec.numBuckets + val bucketColumns = spec.bucketColumnNames.map(toAttribute) + HashPartitioning(bucketColumns, numBuckets) + }.getOrElse { + UnknownPartitioning(0) + } + } + + relation match { + case r: HadoopFsRelation if r.fileFormat.supportBatch(r.sqlContext, relation.schema) => + BatchedDataSourceScan(output, rdd, relation, outputPartitioning, metadata) + case _ => + RowDataSourceScan(output, rdd, relation, outputPartitioning, metadata) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 4e75a3a794..98129d6c52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, SortMergeJoin} import org.apache.spark.sql.execution.metric.{LongSQLMetricValue, SQLMetrics} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ /** * An interface for those physical operators that support codegen. @@ -433,12 +434,20 @@ case class CollapseCodegenStages(conf: SQLConf) extends Rule[SparkPlan] { case _ => true } + private def numOfNestedFields(dataType: DataType): Int = dataType match { + case dt: StructType => dt.fields.map(f => numOfNestedFields(f.dataType)).sum + case m: MapType => numOfNestedFields(m.keyType) + numOfNestedFields(m.valueType) + case a: ArrayType => numOfNestedFields(a.elementType) + case u: UserDefinedType[_] => numOfNestedFields(u.sqlType) + case _ => 1 + } + private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns - val haveManyColumns = plan.output.length > 200 - !willFallback && !haveManyColumns + val haveTooManyFields = numOfNestedFields(plan.schema) > conf.wholeStageMaxNumFields + !willFallback && !haveTooManyFields case _ => false } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 52c8f3ef0b..8c183317f6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -238,7 +238,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { } case l @ LogicalRelation(baseRelation: TableScan, _, _) => - execution.DataSourceScan( + execution.DataSourceScan.create( l.output, toCatalystRDD(l, baseRelation.buildScan()), baseRelation) :: Nil case i @ logical.InsertIntoTable(l @ LogicalRelation(t: InsertableRelation, _, _), @@ -610,7 +610,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { // Don't request columns that are only referenced by pushed filters. .filterNot(handledSet.contains) - val scan = execution.DataSourceScan( + val scan = execution.DataSourceScan.create( projects.map(_.toAttribute), scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, metadata) @@ -620,7 +620,7 @@ private[sql] object DataSourceStrategy extends Strategy with Logging { val requestedColumns = (projectSet ++ filterSet -- handledSet).map(relation.attributeMap).toSeq - val scan = execution.DataSourceScan( + val scan = execution.DataSourceScan.create( requestedColumns, scanBuilder(requestedColumns, candidatePredicates, pushedFilters), relation.relation, metadata) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 618d5a522b..aa1f76450c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -181,7 +181,7 @@ private[sql] object FileSourceStrategy extends Strategy with Logging { } val scan = - DataSourceScan( + DataSourceScan.create( readDataColumns ++ partitionColumns, new FileScanRDD( files.sqlContext, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala index 159fdc99dd..6ddb218a22 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala @@ -97,13 +97,6 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( @transient protected val jobId = new JobID(jobTrackerId, id) - // If true, enable using the custom RecordReader for parquet. This only works for - // a subset of the types (no complex types). - protected val enableVectorizedParquetReader: Boolean = - sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean - protected val enableWholestageCodegen: Boolean = - sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean - override def getPartitions: Array[SparkPartition] = { val conf = getConf(isDriverSide = true) val inputFormat = inputFormatClass.newInstance @@ -165,32 +158,9 @@ private[spark] class SqlNewHadoopRDD[V: ClassTag]( } val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private[this] var reader: RecordReader[Void, V] = null - - /** - * If the format is ParquetInputFormat, try to create the optimized RecordReader. If this - * fails (for example, unsupported schema), try with the normal reader. - * TODO: plumb this through a different way? - */ - if (enableVectorizedParquetReader && - format.getClass.getName == "org.apache.parquet.hadoop.ParquetInputFormat") { - val parquetReader: VectorizedParquetRecordReader = new VectorizedParquetRecordReader() - if (!parquetReader.tryInitialize( - split.serializableHadoopSplit.value, hadoopAttemptContext)) { - parquetReader.close() - } else { - reader = parquetReader.asInstanceOf[RecordReader[Void, V]] - parquetReader.resultBatch() - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - if (enableWholestageCodegen) parquetReader.enableReturningBatches() - } - } - - if (reader == null) { - reader = format.createRecordReader( + private[this] var reader: RecordReader[Void, V] = format.createRecordReader( split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - } + reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) // Register an on-task-completion callback to close the input stream. context.addTaskCompletionListener(context => close()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala index 5b58fa1fc5..a2fd8da782 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRelation.scala @@ -24,7 +24,6 @@ import java.util.logging.{Logger => JLogger} import scala.collection.JavaConverters._ import scala.collection.mutable import scala.util.{Failure, Try} -import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileStatus, Path} @@ -53,7 +52,7 @@ import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.types.{AtomicType, DataType, StructType} import org.apache.spark.util.{SerializableConfiguration, Utils} import org.apache.spark.util.collection.BitSet @@ -276,6 +275,16 @@ private[sql] class DefaultSource file.getName == ParquetFileWriter.PARQUET_METADATA_FILE } + /** + * Returns whether the reader will the rows as batch or not. + */ + override def supportBatch(sqlContext: SQLContext, schema: StructType): Boolean = { + val conf = SQLContext.getActive().get.conf + conf.useFileScan && conf.parquetVectorizedReaderEnabled && + conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && + schema.forall(_.dataType.isInstanceOf[AtomicType]) + } + override def buildReader( sqlContext: SQLContext, dataSchema: StructType, @@ -306,6 +315,10 @@ private[sql] class DefaultSource SQLConf.PARQUET_INT96_AS_TIMESTAMP.key, sqlContext.conf.getConf(SQLConf.PARQUET_INT96_AS_TIMESTAMP)) + // Whole stage codegen (PhysicalRDD) is able to deal with batches directly + val returningBatch = + supportBatch(sqlContext, StructType(partitionSchema.fields ++ dataSchema.fields)) + // Try to push down filters when filter push-down is enabled. val pushed = if (sqlContext.getConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key).toBoolean) { filters @@ -324,10 +337,8 @@ private[sql] class DefaultSource // TODO: if you move this into the closure it reverts to the default values. // If true, enable using the custom RecordReader for parquet. This only works for // a subset of the types (no complex types). - val enableVectorizedParquetReader: Boolean = - sqlContext.getConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key).toBoolean - val enableWholestageCodegen: Boolean = - sqlContext.getConf(SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key).toBoolean + val enableVectorizedParquetReader: Boolean = sqlContext.conf.parquetVectorizedReaderEnabled && + dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) (file: PartitionedFile) => { assert(file.partitionValues.numFields == partitionSchema.size) @@ -347,32 +358,27 @@ private[sql] class DefaultSource val attemptId = new TaskAttemptID(new TaskID(new JobID(), TaskType.MAP, 0), 0) val hadoopAttemptContext = new TaskAttemptContextImpl(broadcastedConf.value.value, attemptId) - val parquetReader = try { - if (!enableVectorizedParquetReader) sys.error("Vectorized reader turned off.") + val parquetReader = if (enableVectorizedParquetReader) { val vectorizedReader = new VectorizedParquetRecordReader() vectorizedReader.initialize(split, hadoopAttemptContext) logDebug(s"Appending $partitionSchema ${file.partitionValues}") vectorizedReader.initBatch(partitionSchema, file.partitionValues) - // Whole stage codegen (PhysicalRDD) is able to deal with batches directly - // TODO: fix column appending - if (enableWholestageCodegen) { - logDebug(s"Enabling batch returning") + if (returningBatch) { vectorizedReader.enableReturningBatches() } vectorizedReader - } catch { - case NonFatal(e) => - logDebug(s"Falling back to parquet-mr: $e", e) - val reader = pushed match { - case Some(filter) => - new ParquetRecordReader[InternalRow]( - new CatalystReadSupport, - FilterCompat.get(filter, null)) - case _ => - new ParquetRecordReader[InternalRow](new CatalystReadSupport) - } - reader.initialize(split, hadoopAttemptContext) - reader + } else { + logDebug(s"Falling back to parquet-mr") + val reader = pushed match { + case Some(filter) => + new ParquetRecordReader[InternalRow]( + new CatalystReadSupport, + FilterCompat.get(filter, null)) + case _ => + new ParquetRecordReader[InternalRow](new CatalystReadSupport) + } + reader.initialize(split, hadoopAttemptContext) + reader } val iter = new RecordReaderIterator(parquetReader) @@ -432,13 +438,21 @@ private[sql] class DefaultSource val setInputPaths = ParquetRelation.initializeDriverSideJobFunc(inputFiles, parquetBlockSize) _ + val allPrimitiveTypes = dataSchema.forall(_.dataType.isInstanceOf[AtomicType]) + val inputFormatCls = if (sqlContext.conf.parquetVectorizedReaderEnabled + && allPrimitiveTypes) { + classOf[VectorizedParquetInputFormat] + } else { + classOf[ParquetInputFormat[InternalRow]] + } + Utils.withDummyCallSite(sqlContext.sparkContext) { new SqlNewHadoopRDD( sqlContext = sqlContext, broadcastedConf = broadcastedConf, initDriverSideJobFuncOpt = Some(setInputPaths), initLocalJobFuncOpt = Some(initLocalJobFuncOpt), - inputFormatClass = classOf[ParquetInputFormat[InternalRow]], + inputFormatClass = inputFormatCls, valueClass = classOf[InternalRow]) { val cacheMetadata = useMetadataCache @@ -481,6 +495,17 @@ private[sql] class DefaultSource } } +/** + * The ParquetInputFormat that create VectorizedParquetRecordReader. + */ +final class VectorizedParquetInputFormat extends ParquetInputFormat[InternalRow] { + override def createRecordReader( + inputSplit: InputSplit, + taskAttemptContext: TaskAttemptContext): ParquetRecordReader[InternalRow] = { + new VectorizedParquetRecordReader().asInstanceOf[ParquetRecordReader[InternalRow]] + } +} + // NOTE: This class is instantiated and used on executor side only, no need to be serializable. private[sql] class ParquetOutputWriter( path: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 927af89949..dc6ba1bcfb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -396,6 +396,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val WHOLESTAGE_MAX_NUM_FIELDS = SQLConfigBuilder("spark.sql.codegen.maxFields") + .internal() + .doc("The maximum number of fields (including nested fields) that will be supported before" + + " deactivating whole-stage codegen.") + .intConf + .createWithDefault(200) + val FILES_MAX_PARTITION_BYTES = SQLConfigBuilder("spark.sql.files.maxPartitionBytes") .doc("The maximum number of bytes to pack into a single partition when reading files.") .longConf @@ -480,6 +487,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def parquetCacheMetadata: Boolean = getConf(PARQUET_CACHE_METADATA) + def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) def numShufflePartitions: Int = getConf(SHUFFLE_PARTITIONS) @@ -504,6 +513,8 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { def wholeStageEnabled: Boolean = getConf(WHOLESTAGE_CODEGEN_ENABLED) + def wholeStageMaxNumFields: Int = getConf(WHOLESTAGE_MAX_NUM_FIELDS) + def exchangeReuseEnabled: Boolean = getConf(EXCHANGE_REUSE_ENABLED) def canonicalView: Boolean = getConf(CANONICAL_NATIVE_VIEW) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala index 14e14710f6..6acb41dd1f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/interfaces.scala @@ -468,6 +468,15 @@ trait FileFormat { broadcastedConf: Broadcast[SerializableConfiguration], options: Map[String, String]): RDD[InternalRow] + /** + * Returns whether this format support returning columnar batch or not. + * + * TODO: we should just have different traits for the different formats. + */ + def supportBatch(sqlContext: SQLContext, dataSchema: StructType): Boolean = { + false + } + /** * Returns a function that can be used to read a single file in as an Iterator of InternalRow. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 4446a6881c..41f536fc37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -279,7 +279,8 @@ class FileSourceStrategySuite extends QueryTest with SharedSQLContext with Predi /** Plans the query and calls the provided validation function with the planned partitioning. */ def checkScan(df: DataFrame)(func: Seq[FilePartition] => Unit): Unit = { val fileScan = df.queryExecution.executedPlan.collect { - case DataSourceScan(_, scan: FileScanRDD, _, _) => scan + case scan: DataSourceScan if scan.rdd.isInstanceOf[FileScanRDD] => + scan.rdd.asInstanceOf[FileScanRDD] }.headOption.getOrElse { fail(s"No FileScan in query\n${df.queryExecution}") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index 2f806ebba6..7d206e7bc4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -579,6 +579,16 @@ class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext assert(CatalystReadSupport.expandUDT(schema) === expected) } + + test("read/write wide table") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = sqlContext.range(1000).select(Seq.tabulate(1000) {i => ('id + i).as(s"c$i")} : _*) + df.write.mode(SaveMode.Overwrite).parquet(path) + checkAnswer(sqlContext.read.parquet(path), df) + } + } } object TestingUDT { -- cgit v1.2.3 From d717ae1fd74d125a9df21350a70e7c2b2d2b4786 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 6 Apr 2016 16:02:55 -0700 Subject: [SPARK-14444][BUILD] Add a new scalastyle `NoScalaDoc` to prevent ScalaDoc-style multiline comments ## What changes were proposed in this pull request? According to the [Spark Code Style Guide](https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide#SparkCodeStyleGuide-Indentation), this PR adds a new scalastyle rule to prevent the followings. ``` /** In Spark, we don't use the ScalaDoc style so this * is not correct. */ ``` ## How was this patch tested? Pass the Jenkins tests (including `lint-scala`). Author: Dongjoon Hyun Closes #12221 from dongjoon-hyun/SPARK-14444. --- core/src/main/scala/org/apache/spark/SparkConf.scala | 6 ++++-- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 12 ++++++------ .../scala/org/apache/spark/partial/BoundedDouble.scala | 4 ++-- .../apache/spark/examples/DriverSubmissionTest.scala | 6 ++++-- .../spark/streaming/flume/FlumeInputDStream.scala | 6 ++++-- .../mllib/stat/distribution/MultivariateGaussian.scala | 10 ++++++---- scalastyle-config.xml | 5 +++++ .../apache/spark/sql/catalyst/ScalaReflection.scala | 18 +++++++++--------- .../catalyst/expressions/codegen/CodeGenerator.scala | 8 ++++---- .../apache/spark/sql/catalyst/plans/QueryPlan.scala | 6 ++++-- .../apache/spark/sql/RelationalGroupedDataset.scala | 4 ++-- .../apache/spark/sql/execution/WholeStageCodegen.scala | 4 ++-- .../apache/spark/sql/execution/ui/SparkPlanGraph.scala | 4 ++-- .../scala/org/apache/spark/sql/jdbc/JdbcDialects.scala | 10 +++++----- 14 files changed, 59 insertions(+), 44 deletions(-) (limited to 'sql/catalyst') diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 5da2e98f1f..e0fd248c43 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -419,8 +419,10 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging { */ private[spark] def getenv(name: String): String = System.getenv(name) - /** Checks for illegal or deprecated config settings. Throws an exception for the former. Not - * idempotent - may mutate this conf object to convert deprecated settings to supported ones. */ + /** + * Checks for illegal or deprecated config settings. Throws an exception for the former. Not + * idempotent - may mutate this conf object to convert deprecated settings to supported ones. + */ private[spark] def validateSettings() { if (contains("spark.local.dir")) { val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " + diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 4e8e363635..41ac308808 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -76,9 +76,9 @@ class SparkHadoopUtil extends Logging { /** - * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop - * configuration. - */ + * Appends S3-specific, spark.hadoop.*, and spark.buffer.size configurations to a Hadoop + * configuration. + */ def appendS3AndSparkHadoopConfigurations(conf: SparkConf, hadoopConf: Configuration): Unit = { // Note: this null check is around more than just access to the "conf" object to maintain // the behavior of the old implementation of this code, for backwards compatibility. @@ -108,9 +108,9 @@ class SparkHadoopUtil extends Logging { } /** - * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop - * subsystems. - */ + * Return an appropriate (subclass) of Configuration. Creating config can initializes some Hadoop + * subsystems. + */ def newConfiguration(conf: SparkConf): Configuration = { val hadoopConf = new Configuration() appendS3AndSparkHadoopConfigurations(conf, hadoopConf) diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index d06b2c67d2..c562c70aba 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -28,8 +28,8 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v this.mean.hashCode ^ this.confidence.hashCode ^ this.low.hashCode ^ this.high.hashCode /** - * Note that consistent with Double, any NaN value will make equality false - */ + * Note that consistent with Double, any NaN value will make equality false + */ override def equals(that: Any): Boolean = that match { case that: BoundedDouble => { diff --git a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala index a2d59a1c95..d12ef642bd 100644 --- a/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/DriverSubmissionTest.scala @@ -22,8 +22,10 @@ import scala.collection.JavaConverters._ import org.apache.spark.util.Utils -/** Prints out environmental information, sleeps, and then exits. Made to - * test driver submission in the standalone scheduler. */ +/** + * Prints out environmental information, sleeps, and then exits. Made to + * test driver submission in the standalone scheduler. + */ object DriverSubmissionTest { def main(args: Array[String]) { if (args.length < 1) { diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala index 6e7c3f358e..13aa817492 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumeInputDStream.scala @@ -130,8 +130,10 @@ class FlumeEventServer(receiver: FlumeReceiver) extends AvroSourceProtocol { } } -/** A NetworkReceiver which listens for events using the - * Flume Avro interface. */ +/** + * A NetworkReceiver which listens for events using the + * Flume Avro interface. + */ private[streaming] class FlumeReceiver( host: String, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala index 052b5b1d65..6c6e9fb7c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/distribution/MultivariateGaussian.scala @@ -61,15 +61,17 @@ class MultivariateGaussian @Since("1.3.0") ( */ private val (rootSigmaInv: DBM[Double], u: Double) = calculateCovarianceConstants - /** Returns density of this multivariate Gaussian at given point, x - */ + /** + * Returns density of this multivariate Gaussian at given point, x + */ @Since("1.3.0") def pdf(x: Vector): Double = { pdf(x.toBreeze) } - /** Returns the log-density of this multivariate Gaussian at given point, x - */ + /** + * Returns the log-density of this multivariate Gaussian at given point, x + */ @Since("1.3.0") def logpdf(x: Vector): Double = { logpdf(x.toBreeze) diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 33c2cbd293..472a8f4084 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -223,6 +223,11 @@ This file is divided into 3 sections: ]]> + + (?m)^(\s*)/[*][*].*$(\r|)\n^\1 [*] + Use Javadoc style indentation for multiline comments + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index d241b8a79b..4795fc2557 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -762,15 +762,15 @@ trait ScalaReflection { } /** - * Returns the full class name for a type. The returned name is the canonical - * Scala name, where each component is separated by a period. It is NOT the - * Java-equivalent runtime name (no dollar signs). - * - * In simple cases, both the Scala and Java names are the same, however when Scala - * generates constructs that do not map to a Java equivalent, such as singleton objects - * or nested classes in package objects, it uses the dollar sign ($) to create - * synthetic classes, emulating behaviour in Java bytecode. - */ + * Returns the full class name for a type. The returned name is the canonical + * Scala name, where each component is separated by a period. It is NOT the + * Java-equivalent runtime name (no dollar signs). + * + * In simple cases, both the Scala and Java names are the same, however when Scala + * generates constructs that do not map to a Java equivalent, such as singleton objects + * or nested classes in package objects, it uses the dollar sign ($) to create + * synthetic classes, emulating behaviour in Java bytecode. + */ def getClassNameFromType(tpe: `Type`): String = { tpe.erasure.typeSymbol.asClass.fullName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 1bebd4e904..ee7f4fadca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -626,15 +626,15 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin object CodeGenerator extends Logging { /** - * Compile the Java source code into a Java class, using Janino. - */ + * Compile the Java source code into a Java class, using Janino. + */ def compile(code: String): GeneratedClass = { cache.get(code) } /** - * Compile the Java source code into a Java class, using Janino. - */ + * Compile the Java source code into a Java class, using Janino. + */ private[this] def doCompile(code: String): GeneratedClass = { val evaluator = new ClassBodyEvaluator() evaluator.setParentClassLoader(Utils.getContextOrSparkClassLoader) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 609a33e2f1..0a11574f44 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -211,8 +211,10 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT if (changed) makeCopy(newArgs).asInstanceOf[this.type] else this } - /** Returns the result of running [[transformExpressions]] on this node - * and all its children. */ + /** + * Returns the result of running [[transformExpressions]] on this node + * and all its children. + */ def transformAllExpressions(rule: PartialFunction[Expression, Expression]): this.type = { transform { case q: QueryPlan[_] => q.transformExpressions(rule).asInstanceOf[PlanType] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 91c02053ae..7dbf2e6c7c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -408,7 +408,7 @@ private[sql] object RelationalGroupedDataset { private[sql] object RollupType extends GroupType /** - * To indicate it's the PIVOT - */ + * To indicate it's the PIVOT + */ private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 98129d6c52..c4594f0480 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -312,8 +312,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup } /** Codegened pipeline for: - * ${toCommentSafeString(child.treeString.trim)} - */ + * ${toCommentSafeString(child.treeString.trim)} + */ final class GeneratedIterator extends org.apache.spark.sql.execution.BufferedRowIterator { private Object[] references; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala index 012b125d6b..c6fcb6956c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala @@ -167,8 +167,8 @@ private[ui] class SparkPlanGraphNode( } /** - * Represent a tree of SparkPlan for WholeStageCodegen. - */ + * Represent a tree of SparkPlan for WholeStageCodegen. + */ private[ui] class SparkPlanGraphCluster( id: Long, name: String, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index cfe4911cb7..948106fd06 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -100,11 +100,11 @@ abstract class JdbcDialect extends Serializable { } /** - * Override connection specific properties to run before a select is made. This is in place to - * allow dialects that need special treatment to optimize behavior. - * @param connection The connection object - * @param properties The connection properties. This is passed through from the relation. - */ + * Override connection specific properties to run before a select is made. This is in place to + * allow dialects that need special treatment to optimize behavior. + * @param connection The connection object + * @param properties The connection properties. This is passed through from the relation. + */ def beforeFetch(connection: Connection, properties: Map[String, String]): Unit = { } -- cgit v1.2.3 From d76592276f9f66fed8012d876595de8717f516a9 Mon Sep 17 00:00:00 2001 From: Herman van Hovell Date: Wed, 6 Apr 2016 19:25:10 -0700 Subject: [SPARK-12610][SQL] Left Anti Join ### What changes were proposed in this pull request? This PR adds support for `LEFT ANTI JOIN` to Spark SQL. A `LEFT ANTI JOIN` is the exact opposite of a `LEFT SEMI JOIN` and can be used to identify rows in one dataset that are not in another dataset. Note that `nulls` on the left side of the join cannot match a row on the right hand side of the join; the result is that left anti join will always select a row with a `null` in one or more of its keys. We currently add support for the following SQL join syntax: SELECT * FROM tbl1 A LEFT ANTI JOIN tbl2 B ON A.Id = B.Id Or using a dataframe: tbl1.as("a").join(tbl2.as("b"), $"a.id" === $"b.id", "left_anti) This PR provides serves as the basis for implementing `NOT EXISTS` and `NOT IN (...)` correlated sub-queries. It would also serve as good basis for implementing an more efficient `EXCEPT` operator. The PR has been (losely) based on PR's by both davies (https://github.com/apache/spark/pull/10706) and chenghao-intel (https://github.com/apache/spark/pull/10563); credit should be given where credit is due. This PR adds supports for `LEFT ANTI JOIN` to `BroadcastHashJoin` (including codegeneration), `ShuffledHashJoin` and `BroadcastNestedLoopJoin`. ### How was this patch tested? Added tests to `JoinSuite` and ported `ExistenceJoinSuite` from https://github.com/apache/spark/pull/10563. cc davies chenghao-intel rxin Author: Herman van Hovell Closes #12214 from hvanhovell/SPARK-12610. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 + .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/optimizer/Optimizer.scala | 8 +- .../spark/sql/catalyst/parser/AstBuilder.scala | 1 + .../spark/sql/catalyst/plans/joinTypes.scala | 17 ++- .../catalyst/plans/logical/basicOperators.scala | 4 +- .../sql/catalyst/parser/PlanParserSuite.scala | 5 +- .../apache/spark/sql/execution/SparkPlanner.scala | 2 +- .../spark/sql/execution/SparkStrategies.scala | 11 +- .../sql/execution/joins/BroadcastHashJoin.scala | 99 +++++++++---- .../execution/joins/BroadcastNestedLoopJoin.scala | 57 +++++--- .../spark/sql/execution/joins/HashJoin.scala | 18 ++- .../sql/execution/joins/ShuffledHashJoin.scala | 1 + .../scala/org/apache/spark/sql/JoinSuite.scala | 36 ++--- .../sql/execution/joins/ExistenceJoinSuite.scala | 159 +++++++++++++++++++++ .../spark/sql/execution/joins/SemiJoinSuite.scala | 129 ----------------- .../apache/spark/sql/hive/HiveSessionState.scala | 2 +- 17 files changed, 338 insertions(+), 215 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 8a45b4f2e1..85cb585919 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -380,6 +380,7 @@ joinType | LEFT SEMI | RIGHT OUTER? | FULL OUTER? + | LEFT? ANTI ; joinCriteria @@ -878,6 +879,7 @@ INDEX: 'INDEX'; INDEXES: 'INDEXES'; LOCKS: 'LOCKS'; OPTION: 'OPTION'; +ANTI: 'ANTI'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 473c91e69e..bc8cf4e78a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1424,7 +1424,7 @@ class Analyzer( val projectList = joinType match { case LeftOuter => leftKeys ++ lUniqueOutput ++ rUniqueOutput.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => leftKeys ++ lUniqueOutput case RightOuter => rightKeys ++ lUniqueOutput.map(_.withNullability(true)) ++ rUniqueOutput diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c085a377ff..f581810c26 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -361,8 +361,8 @@ object ColumnPruning extends Rule[LogicalPlan] { case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) => p.copy(child = g.copy(join = false)) - // Eliminate unneeded attributes from right side of a LeftSemiJoin. - case j @ Join(left, right, LeftSemi, condition) => + // Eliminate unneeded attributes from right side of a Left Existence Join. + case j @ Join(left, right, LeftExistence(_), condition) => j.copy(right = prunedChild(right, j.references)) // all the columns will be used to compare, so we can't prune them @@ -1126,7 +1126,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { (leftFilterConditions ++ commonFilterCondition). reduceLeftOption(And).map(Filter(_, newJoin)).getOrElse(newJoin) - case _ @ (LeftOuter | LeftSemi) => + case LeftOuter | LeftExistence(_) => // push down the left side only `where` condition val newLeft = leftFilterConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) @@ -1147,7 +1147,7 @@ object PushPredicateThroughJoin extends Rule[LogicalPlan] with PredicateHelper { split(joinCondition.map(splitConjunctivePredicates).getOrElse(Nil), left, right) joinType match { - case _ @ (Inner | LeftSemi) => + case Inner | LeftExistence(_) => // push down the single side only join filter for both sides sub queries val newLeft = leftJoinConditions. reduceLeftOption(And).map(Filter(_, left)).getOrElse(left) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 5a3aebff09..aa59f3fb2a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -572,6 +572,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { case null => Inner case jt if jt.FULL != null => FullOuter case jt if jt.SEMI != null => LeftSemi + case jt if jt.ANTI != null => LeftAnti case jt if jt.LEFT != null => LeftOuter case jt if jt.RIGHT != null => RightOuter case _ => Inner diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala index 9ca4f13dd7..13f57c54a5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/joinTypes.scala @@ -26,13 +26,15 @@ object JoinType { case "leftouter" | "left" => LeftOuter case "rightouter" | "right" => RightOuter case "leftsemi" => LeftSemi + case "leftanti" => LeftAnti case _ => val supported = Seq( "inner", "outer", "full", "fullouter", "leftouter", "left", "rightouter", "right", - "leftsemi") + "leftsemi", + "leftanti") throw new IllegalArgumentException(s"Unsupported join type '$typ'. " + "Supported join types include: " + supported.mkString("'", "', '", "'") + ".") @@ -63,6 +65,10 @@ case object LeftSemi extends JoinType { override def sql: String = "LEFT SEMI" } +case object LeftAnti extends JoinType { + override def sql: String = "LEFT ANTI" +} + case class NaturalJoin(tpe: JoinType) extends JoinType { require(Seq(Inner, LeftOuter, RightOuter, FullOuter).contains(tpe), "Unsupported natural join type " + tpe) @@ -70,7 +76,14 @@ case class NaturalJoin(tpe: JoinType) extends JoinType { } case class UsingJoin(tpe: JoinType, usingColumns: Seq[UnresolvedAttribute]) extends JoinType { - require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter).contains(tpe), + require(Seq(Inner, LeftOuter, LeftSemi, RightOuter, FullOuter, LeftAnti).contains(tpe), "Unsupported using join type " + tpe) override def sql: String = "USING " + tpe.sql } + +object LeftExistence { + def unapply(joinType: JoinType): Option[JoinType] = joinType match { + case LeftSemi | LeftAnti => Some(joinType) + case _ => None + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index a18efc90ab..d3353beb09 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -252,7 +252,7 @@ case class Join( override def output: Seq[Attribute] = { joinType match { - case LeftSemi => + case LeftExistence(_) => left.output case LeftOuter => left.output ++ right.output.map(_.withNullability(true)) @@ -276,7 +276,7 @@ case class Join( .union(splitConjunctivePredicates(condition.get).toSet) case Inner => left.constraints.union(right.constraints) - case LeftSemi => + case LeftExistence(_) => left.constraints case LeftOuter => left.constraints diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 262537d9c7..411e2372f2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -334,7 +334,7 @@ class PlanParserSuite extends PlanTest { table("t").join(table("u"), UsingJoin(jt, Seq('a.attr, 'b.attr)), None).select(star())) } val testAll = Seq(testUnconditionalJoin, testConditionalJoin, testNaturalJoin, testUsingJoin) - + val testExistence = Seq(testUnconditionalJoin, testConditionalJoin, testUsingJoin) def test(sql: String, jt: JoinType, tests: Seq[(String, JoinType) => Unit]): Unit = { tests.foreach(_(sql, jt)) } @@ -348,6 +348,9 @@ class PlanParserSuite extends PlanTest { test("right outer join", RightOuter, testAll) test("full join", FullOuter, testAll) test("full outer join", FullOuter, testAll) + test("left semi join", LeftSemi, testExistence) + test("left anti join", LeftAnti, testExistence) + test("anti join", LeftAnti, testExistence) // Test multiple consecutive joins assertEqual( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala index ac8072f3ca..8d05ae470d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanner.scala @@ -38,7 +38,7 @@ class SparkPlanner( DDLStrategy :: SpecialLimits :: Aggregation :: - LeftSemiJoin :: + ExistenceJoin :: EquiJoinSelection :: InMemoryScans :: BasicOperators :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index d77aba7260..eee2b946e3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -62,16 +62,17 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } } - object LeftSemiJoin extends Strategy with PredicateHelper { + object ExistenceJoin extends Strategy with PredicateHelper { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case ExtractEquiJoinKeys( - LeftSemi, leftKeys, rightKeys, condition, left, CanBroadcast(right)) => + LeftExistence(jt), leftKeys, rightKeys, condition, left, CanBroadcast(right)) => Seq(joins.BroadcastHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) // Find left semi joins where at least some predicates can be evaluated by matching join keys - case ExtractEquiJoinKeys(LeftSemi, leftKeys, rightKeys, condition, left, right) => + case ExtractEquiJoinKeys( + LeftExistence(jt), leftKeys, rightKeys, condition, left, right) => Seq(joins.ShuffledHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, condition, planLater(left), planLater(right))) + leftKeys, rightKeys, jt, BuildRight, condition, planLater(left), planLater(right))) case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index 67ac9e94ff..e3d554c2de 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -27,7 +27,6 @@ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{BroadcastDistribution, Distribution, Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.util.collection.CompactBuffer /** * Performs an inner hash join of two child relations. When the output RDD of this operator is @@ -87,6 +86,7 @@ case class BroadcastHashJoin( case Inner => codegenInner(ctx, input) case LeftOuter | RightOuter => codegenOuter(ctx, input) case LeftSemi => codegenSemi(ctx, input) + case LeftAnti => codegenAnti(ctx, input) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") @@ -160,15 +160,14 @@ case class BroadcastHashJoin( } /** - * Generates the code for Inner join. + * Generate the (non-equi) condition used to filter joined rows. This is used in Inner, Left Semi + * and Left Anti joins. */ - private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { - val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) - val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + private def getJoinCondition( + ctx: CodegenContext, + input: Seq[ExprCode]): (String, String, Seq[ExprCode]) = { val matched = ctx.freshName("matched") val buildVars = genBuildSideVars(ctx, matched) - val numOutput = metricTerm(ctx, "numOutputRows") - val checkCondition = if (condition.isDefined) { val expr = condition.get // evaluate the variables from build side that used by condition @@ -184,6 +183,17 @@ case class BroadcastHashJoin( } else { "" } + (matched, checkCondition, buildVars) + } + + /** + * Generates the code for Inner join. + */ + private def codegenInner(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, buildVars) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") val resultVars = buildSide match { case BuildLeft => buildVars ++ input @@ -221,7 +231,6 @@ case class BroadcastHashJoin( } } - /** * Generates the code for left or right outer join. */ @@ -276,7 +285,6 @@ case class BroadcastHashJoin( ctx.copyResult = true val matches = ctx.freshName("matches") val iteratorCls = classOf[Iterator[UnsafeRow]].getName - val i = ctx.freshName("i") val found = ctx.freshName("found") s""" |// generate join key for stream side @@ -304,26 +312,8 @@ case class BroadcastHashJoin( private def codegenSemi(ctx: CodegenContext, input: Seq[ExprCode]): String = { val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) - val matched = ctx.freshName("matched") - val buildVars = genBuildSideVars(ctx, matched) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) val numOutput = metricTerm(ctx, "numOutputRows") - - val checkCondition = if (condition.isDefined) { - val expr = condition.get - // evaluate the variables from build side that used by condition - val eval = evaluateRequiredVariables(buildPlan.output, buildVars, expr.references) - // filter the output via condition - ctx.currentVars = input ++ buildVars - val ev = BindReferences.bindReference(expr, streamedPlan.output ++ buildPlan.output).gen(ctx) - s""" - |$eval - |${ev.code} - |if (${ev.isNull} || !${ev.value}) continue; - """.stripMargin - } else { - "" - } - if (broadcastRelation.value.keyIsUnique) { s""" |// generate join key for stream side @@ -357,4 +347,57 @@ case class BroadcastHashJoin( """.stripMargin } } + + /** + * Generates the code for anti join. + */ + private def codegenAnti(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val (broadcastRelation, relationTerm) = prepareBroadcast(ctx) + val (keyEv, anyNull) = genStreamSideJoinKey(ctx, input) + val (matched, checkCondition, _) = getJoinCondition(ctx, input) + val numOutput = metricTerm(ctx, "numOutputRows") + + if (broadcastRelation.value.keyIsUnique) { + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | UnsafeRow $matched = (UnsafeRow)$relationTerm.getValue(${keyEv.value}); + | if ($matched != null) { + | // Evaluate the condition. + | $checkCondition + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } else { + val matches = ctx.freshName("matches") + val iteratorCls = classOf[Iterator[UnsafeRow]].getName + val found = ctx.freshName("found") + s""" + |// generate join key for stream side + |${keyEv.code} + |// Check if the key has nulls. + |if (!($anyNull)) { + | // Check if the HashedRelation exists. + | $iteratorCls $matches = ($iteratorCls)$relationTerm.get(${keyEv.value}); + | if ($matches != null) { + | // Evaluate the condition. + | boolean $found = false; + | while (!$found && $matches.hasNext()) { + | UnsafeRow $matched = (UnsafeRow) $matches.next(); + | $checkCondition + | $found = true; + | } + | if ($found) continue; + | } + |} + |$numOutput.add(1); + |${consume(ctx, input)} + """.stripMargin + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala index 4143e944e5..4ba710c10a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastNestedLoopJoin.scala @@ -73,7 +73,7 @@ case class BroadcastNestedLoopJoin( left.output.map(_.withNullability(true)) ++ right.output case FullOuter => left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException( @@ -175,8 +175,11 @@ case class BroadcastNestedLoopJoin( * The implementation for these joins: * * LeftSemi with BuildRight + * Anti with BuildRight */ - private def leftSemiJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { + private def leftExistenceJoin( + relation: Broadcast[Array[InternalRow]], + exists: Boolean): RDD[InternalRow] = { assert(buildSide == BuildRight) streamed.execute().mapPartitionsInternal { streamedIter => val buildRows = relation.value @@ -184,10 +187,12 @@ case class BroadcastNestedLoopJoin( if (condition.isDefined) { streamedIter.filter(l => - buildRows.exists(r => boundCondition(joinedRow(l, r))) + buildRows.exists(r => boundCondition(joinedRow(l, r))) == exists ) + } else if (buildRows.nonEmpty == exists) { + streamedIter } else { - streamedIter.filter(r => !buildRows.isEmpty) + Iterator.empty } } } @@ -199,6 +204,7 @@ case class BroadcastNestedLoopJoin( * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ private def defaultJoin(relation: Broadcast[Array[InternalRow]]): RDD[InternalRow] = { /** All rows that either match both-way, or rows from streamed joined with nulls. */ @@ -236,7 +242,27 @@ case class BroadcastNestedLoopJoin( } i += 1 } - return sparkContext.makeRDD(buf.toSeq) + return sparkContext.makeRDD(buf) + } + + val notMatchedBroadcastRows: Seq[InternalRow] = { + val nulls = new GenericMutableRow(streamed.output.size) + val buf: CompactBuffer[InternalRow] = new CompactBuffer() + var i = 0 + val buildRows = relation.value + val joinedRow = new JoinedRow + joinedRow.withLeft(nulls) + while (i < buildRows.length) { + if (!matchedBroadcastRows.get(i)) { + buf += joinedRow.withRight(buildRows(i)).copy() + } + i += 1 + } + buf + } + + if (joinType == LeftAnti) { + return sparkContext.makeRDD(notMatchedBroadcastRows) } val matchedStreamRows = streamRdd.mapPartitionsInternal { streamedIter => @@ -264,22 +290,6 @@ case class BroadcastNestedLoopJoin( } } - val notMatchedBroadcastRows: Seq[InternalRow] = { - val nulls = new GenericMutableRow(streamed.output.size) - val buf: CompactBuffer[InternalRow] = new CompactBuffer() - var i = 0 - val buildRows = relation.value - val joinedRow = new JoinedRow - joinedRow.withLeft(nulls) - while (i < buildRows.length) { - if (!matchedBroadcastRows.get(i)) { - buf += joinedRow.withRight(buildRows(i)).copy() - } - i += 1 - } - buf.toSeq - } - sparkContext.union( matchedStreamRows, sparkContext.makeRDD(notMatchedBroadcastRows) @@ -295,13 +305,16 @@ case class BroadcastNestedLoopJoin( case (LeftOuter, BuildRight) | (RightOuter, BuildLeft) => outerJoin(broadcastedRelation) case (LeftSemi, BuildRight) => - leftSemiJoin(broadcastedRelation) + leftExistenceJoin(broadcastedRelation, exists = true) + case (LeftAnti, BuildRight) => + leftExistenceJoin(broadcastedRelation, exists = false) case _ => /** * LeftOuter with BuildLeft * RightOuter with BuildRight * FullOuter * LeftSemi with BuildLeft + * Anti with BuildLeft */ defaultJoin(broadcastedRelation) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index b7c0f3e7d1..8f45d57126 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -47,7 +47,7 @@ trait HashJoin { left.output ++ right.output.map(_.withNullability(true)) case RightOuter => left.output.map(_.withNullability(true)) ++ right.output - case LeftSemi => + case LeftExistence(_) => left.output case x => throw new IllegalArgumentException(s"HashJoin should not take $x as the JoinType") @@ -197,6 +197,20 @@ trait HashJoin { } } + private def antiJoin( + streamIter: Iterator[InternalRow], + hashedRelation: HashedRelation): Iterator[InternalRow] = { + val joinKeys = streamSideKeyGenerator() + val joinedRow = new JoinedRow + streamIter.filter { current => + val key = joinKeys(current) + lazy val buildIter = hashedRelation.get(key) + key.anyNull || buildIter == null || (condition.isDefined && !buildIter.exists { + row => boundCondition(joinedRow(current, row)) + }) + } + } + protected def join( streamedIter: Iterator[InternalRow], hashed: HashedRelation, @@ -209,6 +223,8 @@ trait HashJoin { outerJoin(streamedIter, hashed) case LeftSemi => semiJoin(streamedIter, hashed) + case LeftAnti => + antiJoin(streamedIter, hashed) case x => throw new IllegalArgumentException( s"BroadcastHashJoin should not take $x as the JoinType") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index c63faacf33..bf86096379 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -45,6 +45,7 @@ case class ShuffledHashJoin( override def outputPartitioning: Partitioning = joinType match { case Inner => PartitioningCollection(Seq(left.outputPartitioning, right.outputPartitioning)) + case LeftAnti => left.outputPartitioning case LeftSemi => left.outputPartitioning case LeftOuter => left.outputPartitioning case RightOuter => right.outputPartitioning diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index a5a4ff13de..a87a41c126 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -41,7 +41,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { assert(planned.size === 1) } - def assertJoin(sqlString: String, c: Class[_]): Any = { + def assertJoin(pair: (String, Class[_])): Any = { + val (sqlString, c) = pair val df = sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { @@ -53,8 +54,8 @@ class JoinSuite extends QueryTest with SharedSQLContext { } assert(operators.size === 1) - if (operators(0).getClass() != c) { - fail(s"$sqlString expected operator: $c, but got ${operators(0)}\n physical: \n$physical") + if (operators.head.getClass != c) { + fail(s"$sqlString expected operator: $c, but got ${operators.head}\n physical: \n$physical") } } @@ -93,8 +94,10 @@ class JoinSuite extends QueryTest with SharedSQLContext { ("SELECT * FROM testData right JOIN testData2 ON (key * a != key + a)", classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 ON (key * a != key + a)", - classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + classOf[BroadcastNestedLoopJoin]), + ("SELECT * FROM testData ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2", classOf[BroadcastNestedLoopJoin]) + ).foreach(assertJoin) } } @@ -114,7 +117,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a where key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -129,7 +132,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastHashJoin]), ("SELECT * FROM testData right join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) sql("UNCACHE TABLE testData") } @@ -419,25 +422,22 @@ class JoinSuite extends QueryTest with SharedSQLContext { Row(null, 10)) } - test("broadcasted left semi join operator selection") { + test("broadcasted existence join operator selection") { sqlContext.cacheManager.clearCache() sql("CACHE TABLE testData") withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", - classOf[BroadcastHashJoin]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastHashJoin]), + ("SELECT * FROM testData ANT JOIN testData2 ON key = a", classOf[BroadcastHashJoin]) + ).foreach(assertJoin) } withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { Seq( - ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) - ).foreach { - case (query, joinClass) => assertJoin(query, joinClass) - } + ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]), + ("SELECT * FROM testData LEFT ANTI JOIN testData2 ON key = a", classOf[ShuffledHashJoin]) + ).foreach(assertJoin) } sql("UNCACHE TABLE testData") @@ -489,7 +489,7 @@ class JoinSuite extends QueryTest with SharedSQLContext { classOf[BroadcastNestedLoopJoin]), ("SELECT * FROM testData full JOIN testData2 WHERE (key * a != key + a)", classOf[BroadcastNestedLoopJoin]) - ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } + ).foreach(assertJoin) checkAnswer( sql( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala new file mode 100644 index 0000000000..8cdfa8afd0 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/ExistenceJoinSuite.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftAnti, LeftSemi} +import org.apache.spark.sql.catalyst.plans.logical.Join +import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} +import org.apache.spark.sql.execution.exchange.EnsureRequirements +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} + +class ExistenceJoinSuite extends SparkPlanTest with SharedSQLContext { + + private lazy val left = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(1, 2.0), + Row(1, 2.0), + Row(2, 1.0), + Row(2, 1.0), + Row(3, 3.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("a", IntegerType).add("b", DoubleType)) + + private lazy val right = sqlContext.createDataFrame( + sparkContext.parallelize(Seq( + Row(2, 3.0), + Row(2, 3.0), + Row(3, 2.0), + Row(4, 1.0), + Row(null, null), + Row(null, 5.0), + Row(6, null) + )), new StructType().add("c", IntegerType).add("d", DoubleType)) + + private lazy val condition = { + And((left.col("a") === right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + private lazy val conditionNEQ = { + And((left.col("a") < right.col("c")).expr, + LessThan(left.col("b").expr, right.col("d").expr)) + } + + // Note: the input dataframes and expression must be evaluated lazily because + // the SQLContext should be used only within a test to keep SQL tests stable + private def testExistenceJoin( + testName: String, + joinType: JoinType, + leftRows: => DataFrame, + rightRows: => DataFrame, + condition: => Expression, + expectedAnswer: Seq[Row]): Unit = { + + def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { + val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) + ExtractEquiJoinKeys.unapply(join) + } + + test(s"$testName using ShuffledHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + ShuffledHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastHashJoin") { + extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) => + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastHashJoin( + leftKeys, rightKeys, joinType, BuildRight, boundCondition, left, right)), + expectedAnswer, + sortAnswers = true) + } + } + } + + test(s"$testName using BroadcastNestedLoopJoin build left") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildLeft, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + + test(s"$testName using BroadcastNestedLoopJoin build right") { + withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { + checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => + EnsureRequirements(left.sqlContext.sessionState.conf).apply( + BroadcastNestedLoopJoin(left, right, BuildRight, joinType, Some(condition))), + expectedAnswer, + sortAnswers = true) + } + } + } + + testExistenceJoin( + "basic test for left semi join", + LeftSemi, + left, + right, + condition, + Seq(Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for left semi non equal join", + LeftSemi, + left, + right, + conditionNEQ, + Seq(Row(1, 2.0), Row(1, 2.0), Row(2, 1.0), Row(2, 1.0))) + + testExistenceJoin( + "basic test for anti join", + LeftAnti, + left, + right, + condition, + Seq(Row(1, 2.0), Row(1, 2.0), Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) + + testExistenceJoin( + "basic test for anti non equal join", + LeftAnti, + left, + right, + conditionNEQ, + Seq(Row(3, 3.0), Row(6, null), Row(null, 5.0), Row(null, null))) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala deleted file mode 100644 index 985a96f684..0000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala +++ /dev/null @@ -1,129 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.joins - -import org.apache.spark.sql.{DataFrame, Row} -import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan} -import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.{Inner, LeftSemi} -import org.apache.spark.sql.catalyst.plans.logical.Join -import org.apache.spark.sql.execution.{SparkPlan, SparkPlanTest} -import org.apache.spark.sql.execution.exchange.EnsureRequirements -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType} - -class SemiJoinSuite extends SparkPlanTest with SharedSQLContext { - - private lazy val left = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(1, 2.0), - Row(1, 2.0), - Row(2, 1.0), - Row(2, 1.0), - Row(3, 3.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("a", IntegerType).add("b", DoubleType)) - - private lazy val right = sqlContext.createDataFrame( - sparkContext.parallelize(Seq( - Row(2, 3.0), - Row(2, 3.0), - Row(3, 2.0), - Row(4, 1.0), - Row(null, null), - Row(null, 5.0), - Row(6, null) - )), new StructType().add("c", IntegerType).add("d", DoubleType)) - - private lazy val condition = { - And((left.col("a") === right.col("c")).expr, - LessThan(left.col("b").expr, right.col("d").expr)) - } - - // Note: the input dataframes and expression must be evaluated lazily because - // the SQLContext should be used only within a test to keep SQL tests stable - private def testLeftSemiJoin( - testName: String, - leftRows: => DataFrame, - rightRows: => DataFrame, - condition: => Expression, - expectedAnswer: Seq[Product]): Unit = { - - def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = { - val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition)) - ExtractEquiJoinKeys.unapply(join) - } - - test(s"$testName using ShuffledHashJoin") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - EnsureRequirements(left.sqlContext.sessionState.conf).apply( - ShuffledHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastHashJoin") { - extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) => - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastHashJoin( - leftKeys, rightKeys, LeftSemi, BuildRight, boundCondition, left, right), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - test(s"$testName using BroadcastNestedLoopJoin build left") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoin(left, right, BuildLeft, LeftSemi, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - - test(s"$testName using BroadcastNestedLoopJoin build right") { - withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") { - checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) => - BroadcastNestedLoopJoin(left, right, BuildRight, LeftSemi, Some(condition)), - expectedAnswer.map(Row.fromTuple), - sortAnswers = true) - } - } - } - - testLeftSemiJoin( - "basic test", - left, - right, - condition, - Seq( - (2, 1.0), - (2, 1.0) - ) - ) -} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index cff24e28fd..b992fda18c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -92,7 +92,7 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) DataSinks, Scripts, Aggregation, - LeftSemiJoin, + ExistenceJoin, EquiJoinSelection, BasicOperators, BroadcastNestedLoop, -- cgit v1.2.3 From 21d5ca128bf3afd5c2d4c7fcc56240e28443474f Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Wed, 6 Apr 2016 19:33:48 -0700 Subject: [SPARK-14134][CORE] Change the package name used for shading classes. The current package name uses a dash, which is a little weird but seemed to work. That is, until a new test tried to mock a class that references one of those shaded types, and then things started failing. Most changes are just noise to fix the logging configs. For reference, SPARK-8815 also raised this issue, although at the time it did not cause any issues in Spark, so it was not addressed. Author: Marcelo Vanzin Closes #11941 from vanzin/SPARK-14134. --- common/network-yarn/pom.xml | 4 ++-- conf/log4j.properties.template | 4 ++-- core/src/main/resources/org/apache/spark/log4j-defaults.properties | 4 ++-- core/src/test/resources/log4j.properties | 3 +-- external/flume-sink/src/test/resources/log4j.properties | 2 +- external/flume/src/test/resources/log4j.properties | 2 +- external/java8-tests/src/test/resources/log4j.properties | 2 +- external/kafka/src/test/resources/log4j.properties | 2 +- external/kinesis-asl/src/main/resources/log4j.properties | 4 ++-- external/kinesis-asl/src/test/resources/log4j.properties | 2 +- graphx/src/test/resources/log4j.properties | 3 +-- launcher/src/test/resources/log4j.properties | 3 +-- mllib/src/test/resources/log4j.properties | 2 +- pom.xml | 7 +++++-- repl/src/test/resources/log4j.properties | 2 +- sql/catalyst/src/test/resources/log4j.properties | 3 +-- streaming/src/test/resources/log4j.properties | 2 +- yarn/src/test/resources/log4j.properties | 2 +- .../scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala | 2 +- 19 files changed, 27 insertions(+), 28 deletions(-) (limited to 'sql/catalyst') diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index 3cb44324f2..bc83ef24c3 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -36,7 +36,7 @@ provided ${project.build.directory}/scala-${scala.binary.version}/spark-${project.version}-yarn-shuffle.jar - org/spark-project/ + org/spark_project/ @@ -91,7 +91,7 @@ com.fasterxml.jackson - org.spark-project.com.fasterxml.jackson + ${spark.shade.packageName}.com.fasterxml.jackson com.fasterxml.jackson.** diff --git a/conf/log4j.properties.template b/conf/log4j.properties.template index 9809b0c828..ec1aa187df 100644 --- a/conf/log4j.properties.template +++ b/conf/log4j.properties.template @@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.org.apache.spark.repl.Main=WARN # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO log4j.logger.org.apache.parquet=ERROR diff --git a/core/src/main/resources/org/apache/spark/log4j-defaults.properties b/core/src/main/resources/org/apache/spark/log4j-defaults.properties index 0750488e4a..89a7963a86 100644 --- a/core/src/main/resources/org/apache/spark/log4j-defaults.properties +++ b/core/src/main/resources/org/apache/spark/log4j-defaults.properties @@ -28,8 +28,8 @@ log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: log4j.logger.org.apache.spark.repl.Main=WARN # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO diff --git a/core/src/test/resources/log4j.properties b/core/src/test/resources/log4j.properties index a54d27de91..fb9d9851cb 100644 --- a/core/src/test/resources/log4j.properties +++ b/core/src/test/resources/log4j.properties @@ -33,5 +33,4 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%t: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/flume-sink/src/test/resources/log4j.properties b/external/flume-sink/src/test/resources/log4j.properties index 42df8792f1..1e3f163f95 100644 --- a/external/flume-sink/src/test/resources/log4j.properties +++ b/external/flume-sink/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/flume/src/test/resources/log4j.properties b/external/flume/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/external/flume/src/test/resources/log4j.properties +++ b/external/flume/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/java8-tests/src/test/resources/log4j.properties b/external/java8-tests/src/test/resources/log4j.properties index edbecdae92..3706a6e361 100644 --- a/external/java8-tests/src/test/resources/log4j.properties +++ b/external/java8-tests/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/kafka/src/test/resources/log4j.properties b/external/kafka/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/external/kafka/src/test/resources/log4j.properties +++ b/external/kafka/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/external/kinesis-asl/src/main/resources/log4j.properties b/external/kinesis-asl/src/main/resources/log4j.properties index 6cdc9286c5..8118d12c5d 100644 --- a/external/kinesis-asl/src/main/resources/log4j.properties +++ b/external/kinesis-asl/src/main/resources/log4j.properties @@ -31,7 +31,7 @@ log4j.appender.console.layout=org.apache.log4j.PatternLayout log4j.appender.console.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss} %p %c{1}: %m%n # Settings to quiet third party logs that are too verbose -log4j.logger.org.spark-project.jetty=WARN -log4j.logger.org.spark-project.jetty.util.component.AbstractLifeCycle=ERROR +log4j.logger.org.spark_project.jetty=WARN +log4j.logger.org.spark_project.jetty.util.component.AbstractLifeCycle=ERROR log4j.logger.org.apache.spark.repl.SparkIMain$exprTyper=INFO log4j.logger.org.apache.spark.repl.SparkILoop$SparkILoopInterpreter=INFO \ No newline at end of file diff --git a/external/kinesis-asl/src/test/resources/log4j.properties b/external/kinesis-asl/src/test/resources/log4j.properties index edbecdae92..3706a6e361 100644 --- a/external/kinesis-asl/src/test/resources/log4j.properties +++ b/external/kinesis-asl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/graphx/src/test/resources/log4j.properties b/graphx/src/test/resources/log4j.properties index eb3b1999eb..3706a6e361 100644 --- a/graphx/src/test/resources/log4j.properties +++ b/graphx/src/test/resources/log4j.properties @@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/launcher/src/test/resources/log4j.properties b/launcher/src/test/resources/log4j.properties index c64b1565e1..744c456cb2 100644 --- a/launcher/src/test/resources/log4j.properties +++ b/launcher/src/test/resources/log4j.properties @@ -30,5 +30,4 @@ log4j.appender.childproc.layout=org.apache.log4j.PatternLayout log4j.appender.childproc.layout.ConversionPattern=%t: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/mllib/src/test/resources/log4j.properties b/mllib/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/mllib/src/test/resources/log4j.properties +++ b/mllib/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/pom.xml b/pom.xml index 984b2859ef..66a34e4bdf 100644 --- a/pom.xml +++ b/pom.xml @@ -182,6 +182,9 @@ ${java.home} + + org.spark_project + ${project.build.directory}/scala-${scala.binary.version}/jars @@ -2204,14 +2207,14 @@ org.eclipse.jetty - org.spark-project.jetty + ${spark.shade.packageName}.jetty org.eclipse.jetty.** com.google.common - org.spark-project.guava + ${spark.shade.packageName}.guava diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index e2ee9c963a..7665bd5e7c 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -24,4 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/sql/catalyst/src/test/resources/log4j.properties b/sql/catalyst/src/test/resources/log4j.properties index eb3b1999eb..3706a6e361 100644 --- a/sql/catalyst/src/test/resources/log4j.properties +++ b/sql/catalyst/src/test/resources/log4j.properties @@ -24,5 +24,4 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN -org.spark-project.jetty.LEVEL=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/streaming/src/test/resources/log4j.properties b/streaming/src/test/resources/log4j.properties index 75e3b53a09..fd51f8faf5 100644 --- a/streaming/src/test/resources/log4j.properties +++ b/streaming/src/test/resources/log4j.properties @@ -24,5 +24,5 @@ log4j.appender.file.layout=org.apache.log4j.PatternLayout log4j.appender.file.layout.ConversionPattern=%d{yy/MM/dd HH:mm:ss.SSS} %t %p %c{1}: %m%n # Ignore messages below warning level from Jetty, because it's a bit verbose -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/yarn/src/test/resources/log4j.properties b/yarn/src/test/resources/log4j.properties index 6b9a799954..d13454d5ae 100644 --- a/yarn/src/test/resources/log4j.properties +++ b/yarn/src/test/resources/log4j.properties @@ -28,4 +28,4 @@ log4j.logger.com.sun.jersey=WARN log4j.logger.org.apache.hadoop=WARN log4j.logger.org.eclipse.jetty=WARN log4j.logger.org.mortbay=WARN -log4j.logger.org.spark-project.jetty=WARN +log4j.logger.org.spark_project.jetty=WARN diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 2f3a31cb04..9c3b18e4ec 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -53,7 +53,7 @@ abstract class BaseYarnClusterSuite |log4j.logger.org.apache.hadoop=WARN |log4j.logger.org.eclipse.jetty=WARN |log4j.logger.org.mortbay=WARN - |log4j.logger.org.spark-project.jetty=WARN + |log4j.logger.org.spark_project.jetty=WARN """.stripMargin private var yarnCluster: MiniYARNCluster = _ -- cgit v1.2.3 From e11aa9ec5c3cdcd8ca08d2486a7208840ad77bf8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 7 Apr 2016 00:46:57 -0700 Subject: [SPARK-14452][SQL] Explicit APIs in Scala for specifying encoders ## What changes were proposed in this pull request? The Scala Dataset public API currently only allows users to specify encoders through SQLContext.implicits. This is OK but sometimes people want to explicitly get encoders without a SQLContext (e.g. Aggregator implementations). This patch adds public APIs to Encoders class for getting Scala encoders. ## How was this patch tested? None - I will update test cases once https://github.com/apache/spark/pull/12231 is merged. Author: Reynold Xin Closes #12232 from rxin/SPARK-14452. --- .../main/scala/org/apache/spark/sql/Encoder.scala | 231 +-------------- .../main/scala/org/apache/spark/sql/Encoders.scala | 314 +++++++++++++++++++++ .../scala/org/apache/spark/sql/SQLImplicits.scala | 18 +- 3 files changed, 327 insertions(+), 236 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala index e0bfe3c32f..ffa694fcdc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala @@ -17,22 +17,20 @@ package org.apache.spark.sql -import java.lang.reflect.Modifier - import scala.annotation.implicitNotFound -import scala.reflect.{classTag, ClassTag} +import scala.reflect.ClassTag import org.apache.spark.annotation.Experimental -import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} -import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} import org.apache.spark.sql.types._ + /** * :: Experimental :: * Used to convert a JVM object of type `T` to and from the internal Spark SQL representation. * * == Scala == - * Encoders are generally created automatically through implicits from a `SQLContext`. + * Encoders are generally created automatically through implicits from a `SQLContext`, or can be + * explicitly created by calling static methods on [[Encoders]]. * * {{{ * import sqlContext.implicits._ @@ -81,224 +79,3 @@ trait Encoder[T] extends Serializable { /** A ClassTag that can be used to construct and Array to contain a collection of `T`. */ def clsTag: ClassTag[T] } - -/** - * :: Experimental :: - * Methods for creating an [[Encoder]]. - * - * @since 1.6.0 - */ -@Experimental -object Encoders { - - /** - * An encoder for nullable boolean type. - * @since 1.6.0 - */ - def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() - - /** - * An encoder for nullable byte type. - * @since 1.6.0 - */ - def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() - - /** - * An encoder for nullable short type. - * @since 1.6.0 - */ - def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() - - /** - * An encoder for nullable int type. - * @since 1.6.0 - */ - def INT: Encoder[java.lang.Integer] = ExpressionEncoder() - - /** - * An encoder for nullable long type. - * @since 1.6.0 - */ - def LONG: Encoder[java.lang.Long] = ExpressionEncoder() - - /** - * An encoder for nullable float type. - * @since 1.6.0 - */ - def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() - - /** - * An encoder for nullable double type. - * @since 1.6.0 - */ - def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() - - /** - * An encoder for nullable string type. - * @since 1.6.0 - */ - def STRING: Encoder[java.lang.String] = ExpressionEncoder() - - /** - * An encoder for nullable decimal type. - * @since 1.6.0 - */ - def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() - - /** - * An encoder for nullable date type. - * @since 1.6.0 - */ - def DATE: Encoder[java.sql.Date] = ExpressionEncoder() - - /** - * An encoder for nullable timestamp type. - * @since 1.6.0 - */ - def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() - - /** - * An encoder for arrays of bytes. - * @since 1.6.1 - */ - def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() - - /** - * Creates an encoder for Java Bean of type T. - * - * T must be publicly accessible. - * - * supported types for java bean field: - * - primitive types: boolean, int, double, etc. - * - boxed types: Boolean, Integer, Double, etc. - * - String - * - java.math.BigDecimal - * - time related: java.sql.Date, java.sql.Timestamp - * - collection types: only array and java.util.List currently, map support is in progress - * - nested java bean. - * - * @since 1.6.0 - */ - def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) - - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) - - /** - * Creates an encoder that serializes objects of type T using Kryo. - * This encoder maps T into a single byte array (binary) field. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) - - /** - * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java - * serialization. This encoder maps T into a single byte array (binary) field. - * - * Note that this is extremely inefficient and should only be used as the last resort. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) - - /** - * Creates an encoder that serializes objects of type T using generic Java serialization. - * This encoder maps T into a single byte array (binary) field. - * - * Note that this is extremely inefficient and should only be used as the last resort. - * - * T must be publicly accessible. - * - * @since 1.6.0 - */ - def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) - - /** Throws an exception if T is not a public class. */ - private def validatePublicClass[T: ClassTag](): Unit = { - if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { - throw new UnsupportedOperationException( - s"${classTag[T].runtimeClass.getName} is not a public class. " + - "Only public classes are supported.") - } - } - - /** A way to construct encoders using generic serializers. */ - private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { - if (classTag[T].runtimeClass.isPrimitive) { - throw new UnsupportedOperationException("Primitive types are not supported.") - } - - validatePublicClass[T]() - - ExpressionEncoder[T]( - schema = new StructType().add("value", BinaryType), - flat = true, - serializer = Seq( - EncodeUsingSerializer( - BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), - deserializer = - DecodeUsingSerializer[T]( - BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), - clsTag = classTag[T] - ) - } - - /** - * An encoder for 2-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2]( - e1: Encoder[T1], - e2: Encoder[T2]): Encoder[(T1, T2)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) - } - - /** - * An encoder for 3-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2, T3]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) - } - - /** - * An encoder for 4-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2, T3, T4]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3], - e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { - ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) - } - - /** - * An encoder for 5-ary tuples. - * @since 1.6.0 - */ - def tuple[T1, T2, T3, T4, T5]( - e1: Encoder[T1], - e2: Encoder[T2], - e3: Encoder[T3], - e4: Encoder[T4], - e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { - ExpressionEncoder.tuple( - encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) - } -} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala new file mode 100644 index 0000000000..3f4df704db --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoders.scala @@ -0,0 +1,314 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import java.lang.reflect.Modifier + +import scala.reflect.{classTag, ClassTag} +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.expressions.{BoundReference, DecodeUsingSerializer, EncodeUsingSerializer} +import org.apache.spark.sql.types._ + +/** + * :: Experimental :: + * Methods for creating an [[Encoder]]. + * + * @since 1.6.0 + */ +@Experimental +object Encoders { + + /** + * An encoder for nullable boolean type. + * The Scala primitive encoder is available as [[scalaBoolean]]. + * @since 1.6.0 + */ + def BOOLEAN: Encoder[java.lang.Boolean] = ExpressionEncoder() + + /** + * An encoder for nullable byte type. + * The Scala primitive encoder is available as [[scalaByte]]. + * @since 1.6.0 + */ + def BYTE: Encoder[java.lang.Byte] = ExpressionEncoder() + + /** + * An encoder for nullable short type. + * The Scala primitive encoder is available as [[scalaShort]]. + * @since 1.6.0 + */ + def SHORT: Encoder[java.lang.Short] = ExpressionEncoder() + + /** + * An encoder for nullable int type. + * The Scala primitive encoder is available as [[scalaInt]]. + * @since 1.6.0 + */ + def INT: Encoder[java.lang.Integer] = ExpressionEncoder() + + /** + * An encoder for nullable long type. + * The Scala primitive encoder is available as [[scalaLong]]. + * @since 1.6.0 + */ + def LONG: Encoder[java.lang.Long] = ExpressionEncoder() + + /** + * An encoder for nullable float type. + * The Scala primitive encoder is available as [[scalaFloat]]. + * @since 1.6.0 + */ + def FLOAT: Encoder[java.lang.Float] = ExpressionEncoder() + + /** + * An encoder for nullable double type. + * The Scala primitive encoder is available as [[scalaDouble]]. + * @since 1.6.0 + */ + def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder() + + /** + * An encoder for nullable string type. + * + * @since 1.6.0 + */ + def STRING: Encoder[java.lang.String] = ExpressionEncoder() + + /** + * An encoder for nullable decimal type. + * + * @since 1.6.0 + */ + def DECIMAL: Encoder[java.math.BigDecimal] = ExpressionEncoder() + + /** + * An encoder for nullable date type. + * + * @since 1.6.0 + */ + def DATE: Encoder[java.sql.Date] = ExpressionEncoder() + + /** + * An encoder for nullable timestamp type. + * + * @since 1.6.0 + */ + def TIMESTAMP: Encoder[java.sql.Timestamp] = ExpressionEncoder() + + /** + * An encoder for arrays of bytes. + * + * @since 1.6.1 + */ + def BINARY: Encoder[Array[Byte]] = ExpressionEncoder() + + /** + * Creates an encoder for Java Bean of type T. + * + * T must be publicly accessible. + * + * supported types for java bean field: + * - primitive types: boolean, int, double, etc. + * - boxed types: Boolean, Integer, Double, etc. + * - String + * - java.math.BigDecimal + * - time related: java.sql.Date, java.sql.Timestamp + * - collection types: only array and java.util.List currently, map support is in progress + * - nested java bean. + * + * @since 1.6.0 + */ + def bean[T](beanClass: Class[T]): Encoder[T] = ExpressionEncoder.javaBean(beanClass) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = true) + + /** + * Creates an encoder that serializes objects of type T using Kryo. + * This encoder maps T into a single byte array (binary) field. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def kryo[T](clazz: Class[T]): Encoder[T] = kryo(ClassTag[T](clazz)) + + /** + * (Scala-specific) Creates an encoder that serializes objects of type T using generic Java + * serialization. This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T: ClassTag]: Encoder[T] = genericSerializer(useKryo = false) + + /** + * Creates an encoder that serializes objects of type T using generic Java serialization. + * This encoder maps T into a single byte array (binary) field. + * + * Note that this is extremely inefficient and should only be used as the last resort. + * + * T must be publicly accessible. + * + * @since 1.6.0 + */ + def javaSerialization[T](clazz: Class[T]): Encoder[T] = javaSerialization(ClassTag[T](clazz)) + + /** Throws an exception if T is not a public class. */ + private def validatePublicClass[T: ClassTag](): Unit = { + if (!Modifier.isPublic(classTag[T].runtimeClass.getModifiers)) { + throw new UnsupportedOperationException( + s"${classTag[T].runtimeClass.getName} is not a public class. " + + "Only public classes are supported.") + } + } + + /** A way to construct encoders using generic serializers. */ + private def genericSerializer[T: ClassTag](useKryo: Boolean): Encoder[T] = { + if (classTag[T].runtimeClass.isPrimitive) { + throw new UnsupportedOperationException("Primitive types are not supported.") + } + + validatePublicClass[T]() + + ExpressionEncoder[T]( + schema = new StructType().add("value", BinaryType), + flat = true, + serializer = Seq( + EncodeUsingSerializer( + BoundReference(0, ObjectType(classOf[AnyRef]), nullable = true), kryo = useKryo)), + deserializer = + DecodeUsingSerializer[T]( + BoundReference(0, BinaryType, nullable = true), classTag[T], kryo = useKryo), + clsTag = classTag[T] + ) + } + + /** + * An encoder for 2-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2]( + e1: Encoder[T1], + e2: Encoder[T2]): Encoder[(T1, T2)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2)) + } + + /** + * An encoder for 3-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3]): Encoder[(T1, T2, T3)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3)) + } + + /** + * An encoder for 4-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = { + ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4)) + } + + /** + * An encoder for 5-ary tuples. + * + * @since 1.6.0 + */ + def tuple[T1, T2, T3, T4, T5]( + e1: Encoder[T1], + e2: Encoder[T2], + e3: Encoder[T3], + e4: Encoder[T4], + e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = { + ExpressionEncoder.tuple( + encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), encoderFor(e5)) + } + + /** + * An encoder for Scala's product type (tuples, case classes, etc). + * @since 2.0.0 + */ + def product[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive int type. + * @since 2.0.0 + */ + def scalaInt: Encoder[Int] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive long type. + * @since 2.0.0 + */ + def scalaLong: Encoder[Long] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive double type. + * @since 2.0.0 + */ + def scalaDouble: Encoder[Double] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive float type. + * @since 2.0.0 + */ + def scalaFloat: Encoder[Float] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive byte type. + * @since 2.0.0 + */ + def scalaByte: Encoder[Byte] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive short type. + * @since 2.0.0 + */ + def scalaShort: Encoder[Short] = ExpressionEncoder() + + /** + * An encoder for Scala's primitive boolean type. + * @since 2.0.0 + */ + def scalaBoolean: Encoder[Boolean] = ExpressionEncoder() + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index c35a969bf0..ad69e23540 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -44,33 +44,33 @@ abstract class SQLImplicits { } /** @since 1.6.0 */ - implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = ExpressionEncoder() + implicit def newProductEncoder[T <: Product : TypeTag]: Encoder[T] = Encoders.product[T] // Primitives /** @since 1.6.0 */ - implicit def newIntEncoder: Encoder[Int] = ExpressionEncoder() + implicit def newIntEncoder: Encoder[Int] = Encoders.scalaInt /** @since 1.6.0 */ - implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder() + implicit def newLongEncoder: Encoder[Long] = Encoders.scalaLong /** @since 1.6.0 */ - implicit def newDoubleEncoder: Encoder[Double] = ExpressionEncoder() + implicit def newDoubleEncoder: Encoder[Double] = Encoders.scalaDouble /** @since 1.6.0 */ - implicit def newFloatEncoder: Encoder[Float] = ExpressionEncoder() + implicit def newFloatEncoder: Encoder[Float] = Encoders.scalaFloat /** @since 1.6.0 */ - implicit def newByteEncoder: Encoder[Byte] = ExpressionEncoder() + implicit def newByteEncoder: Encoder[Byte] = Encoders.scalaByte /** @since 1.6.0 */ - implicit def newShortEncoder: Encoder[Short] = ExpressionEncoder() + implicit def newShortEncoder: Encoder[Short] = Encoders.scalaShort /** @since 1.6.0 */ - implicit def newBooleanEncoder: Encoder[Boolean] = ExpressionEncoder() + implicit def newBooleanEncoder: Encoder[Boolean] = Encoders.scalaBoolean /** @since 1.6.0 */ - implicit def newStringEncoder: Encoder[String] = ExpressionEncoder() + implicit def newStringEncoder: Encoder[String] = Encoders.STRING // Seqs -- cgit v1.2.3 From aa852215f82876977d164f371627e894e86baacc Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 7 Apr 2016 11:51:34 -0700 Subject: [SPARK-12740] [SPARK-13932] support grouping()/grouping_id() in having/order clause ## What changes were proposed in this pull request? This PR brings the support of using grouping()/grouping_id() in HAVING/ORDER BY clause. The resolved grouping()/grouping_id() will be replaced by unresolved "spark_gropuing_id" virtual attribute, then resolved by ResolveMissingAttribute. This PR also fix the HAVING clause that access a grouping column that is not presented in SELECT clause, for example: ```sql select count(1) from (select 1 as a) t group by a having a > 0 ``` ## How was this patch tested? Add new tests. Author: Davies Liu Closes #12235 from davies/grouping_having. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 181 ++++++++++++++------- .../catalyst/expressions/namedExpressions.scala | 4 +- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 82 ++++++++++ 3 files changed, 211 insertions(+), 56 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc8cf4e78a..7bcba421fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -87,7 +87,7 @@ class Analyzer( ResolveGroupingAnalytics :: ResolvePivot :: ResolveOrdinalInOrderByAndGroupBy :: - ResolveSortReferences :: + ResolveMissingReferences :: ResolveGenerate :: ResolveFunctions :: ResolveAliases :: @@ -228,21 +228,56 @@ class Analyzer( Seq.tabulate(1 << c.groupByExprs.length)(i => i) } - private def hasGroupingId(expr: Seq[Expression]): Boolean = { - expr.exists(_.collectFirst { - case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.groupingIdName) => u - }.isDefined) + private def hasGroupingAttribute(expr: Expression): Boolean = { + expr.collectFirst { + case u: UnresolvedAttribute if resolver(u.name, VirtualColumn.hiveGroupingIdName) => u + }.isDefined } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + private def hasGroupingFunction(e: Expression): Boolean = { + e.collectFirst { + case g: Grouping => g + case g: GroupingID => g + }.isDefined + } + + private def replaceGroupingFunc( + expr: Expression, + groupByExprs: Seq[Expression], + gid: Expression): Expression = { + expr transform { + case e: GroupingID => + if (e.groupByExprs.isEmpty || e.groupByExprs == groupByExprs) { + gid + } else { + throw new AnalysisException( + s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + + s"grouping columns (${groupByExprs.mkString(",")})") + } + case Grouping(col: Expression) => + val idx = groupByExprs.indexOf(col) + if (idx >= 0) { + Cast(BitwiseAnd(ShiftRight(gid, Literal(groupByExprs.length - 1 - idx)), + Literal(1)), ByteType) + } else { + throw new AnalysisException(s"Column of grouping ($col) can't be found " + + s"in grouping columns ${groupByExprs.mkString(",")}") + } + } + } + + // This require transformUp to replace grouping()/grouping_id() in resolved Filter/Sort + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { case a if !a.childrenResolved => a // be sure all of the children are resolved. + case p if p.expressions.exists(hasGroupingAttribute) => + failAnalysis( + s"${VirtualColumn.hiveGroupingIdName} is deprecated; use grouping_id() instead") + case Aggregate(Seq(c @ Cube(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(c), groupByExprs, child, aggregateExpressions) case Aggregate(Seq(r @ Rollup(groupByExprs)), aggregateExpressions, child) => GroupingSets(bitmasks(r), groupByExprs, child, aggregateExpressions) - case g: GroupingSets if g.expressions.exists(!_.resolved) && hasGroupingId(g.expressions) => - failAnalysis( - s"${VirtualColumn.groupingIdName} is deprecated; use grouping_id() instead") + // Ensure all the expressions have been resolved. case x: GroupingSets if x.expressions.forall(_.resolved) => val gid = AttributeReference(VirtualColumn.groupingIdName, IntegerType, false)() @@ -270,7 +305,7 @@ class Analyzer( def isPartOfAggregation(e: Expression): Boolean = { aggsBuffer.exists(a => a.find(_ eq e).isDefined) } - expr.transformDown { + replaceGroupingFunc(expr, x.groupByExprs, gid).transformDown { // AggregateExpression should be computed on the unmodified value of its argument // expressions, so we should not replace any references to grouping expression // inside it. @@ -278,23 +313,6 @@ class Analyzer( aggsBuffer += e e case e if isPartOfAggregation(e) => e - case e: GroupingID => - if (e.groupByExprs.isEmpty || e.groupByExprs == x.groupByExprs) { - gid - } else { - throw new AnalysisException( - s"Columns of grouping_id (${e.groupByExprs.mkString(",")}) does not match " + - s"grouping columns (${x.groupByExprs.mkString(",")})") - } - case Grouping(col: Expression) => - val idx = x.groupByExprs.indexOf(col) - if (idx >= 0) { - Cast(BitwiseAnd(ShiftRight(gid, Literal(x.groupByExprs.length - 1 - idx)), - Literal(1)), ByteType) - } else { - throw new AnalysisException(s"Column of grouping ($col) can't be found " + - s"in grouping columns ${x.groupByExprs.mkString(",")}") - } case e => val index = groupByAliases.indexWhere(_.child.semanticEquals(e)) if (index == -1) { @@ -306,9 +324,37 @@ class Analyzer( } Aggregate( - groupByAttributes :+ VirtualColumn.groupingIdAttribute, + groupByAttributes :+ gid, aggregations, Expand(x.bitmasks, groupByAliases, groupByAttributes, gid, x.child)) + + case f @ Filter(cond, child) if hasGroupingFunction(cond) => + val groupingExprs = findGroupingExprs(child) + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newCond = replaceGroupingFunc(cond, groupingExprs, VirtualColumn.groupingIdAttribute) + f.copy(condition = newCond) + + case s @ Sort(order, _, child) if order.exists(hasGroupingFunction) => + val groupingExprs = findGroupingExprs(child) + val gid = VirtualColumn.groupingIdAttribute + // The unresolved grouping id will be resolved by ResolveMissingReferences + val newOrder = order.map(replaceGroupingFunc(_, groupingExprs, gid).asInstanceOf[SortOrder]) + s.copy(order = newOrder) + } + + private def findGroupingExprs(plan: LogicalPlan): Seq[Expression] = { + plan.collectFirst { + case a: Aggregate => + // this Aggregate should have grouping id as the last grouping key. + val gid = a.groupingExpressions.last + if (!gid.isInstanceOf[AttributeReference] + || gid.asInstanceOf[AttributeReference].name != VirtualColumn.groupingIdName) { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } + a.groupingExpressions.take(a.groupingExpressions.length - 1) + }.getOrElse { + failAnalysis(s"grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + } } } @@ -663,13 +709,15 @@ class Analyzer( * clause. This rule detects such queries and adds the required attributes to the original * projection, so that they will be available during sorting. Another projection is added to * remove these attributes after sorting. + * + * The HAVING clause could also used a grouping columns that is not presented in the SELECT. */ - object ResolveSortReferences extends Rule[LogicalPlan] { + object ResolveMissingReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { // Skip sort with aggregate. This will be handled in ResolveAggregateFunctions case sa @ Sort(_, _, child: Aggregate) => sa - case s @ Sort(order, _, child) if !s.resolved && child.resolved => + case s @ Sort(order, _, child) if child.resolved => try { val newOrder = order.map(resolveExpressionRecursively(_, child).asInstanceOf[SortOrder]) val requiredAttrs = AttributeSet(newOrder).filter(_.resolved) @@ -689,6 +737,26 @@ class Analyzer( // in Sort case ae: AnalysisException => s } + + case f @ Filter(cond, child) if child.resolved => + try { + val newCond = resolveExpressionRecursively(cond, child) + val requiredAttrs = newCond.references.filter(_.resolved) + val missingAttrs = requiredAttrs -- child.outputSet + if (missingAttrs.nonEmpty) { + // Add missing attributes and then project them away. + Project(child.output, + Filter(newCond, addMissingAttr(child, missingAttrs))) + } else if (newCond != cond) { + f.copy(condition = newCond) + } else { + f + } + } catch { + // Attempting to resolve it might fail. When this happens, return the original plan. + // Users will see an AnalysisException for resolution failure of missing attributes + case ae: AnalysisException => f + } } /** @@ -843,27 +911,33 @@ class Analyzer( if aggregate.resolved => // Try resolving the condition of the filter as though it is in the aggregate clause - val aggregatedCondition = - Aggregate( - grouping, - Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, - child) - val resolvedOperator = execute(aggregatedCondition) - def resolvedAggregateFilter = - resolvedOperator - .asInstanceOf[Aggregate] - .aggregateExpressions.head - - // If resolution was successful and we see the filter has an aggregate in it, add it to - // the original aggregate operator. - if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { - val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs - - Project(aggregate.output, - Filter(resolvedAggregateFilter.toAttribute, - aggregate.copy(aggregateExpressions = aggExprsWithHaving))) - } else { - filter + try { + val aggregatedCondition = + Aggregate( + grouping, + Alias(havingCondition, "havingCondition")(isGenerated = true) :: Nil, + child) + val resolvedOperator = execute(aggregatedCondition) + def resolvedAggregateFilter = + resolvedOperator + .asInstanceOf[Aggregate] + .aggregateExpressions.head + + // If resolution was successful and we see the filter has an aggregate in it, add it to + // the original aggregate operator. + if (resolvedOperator.resolved && containsAggregate(resolvedAggregateFilter)) { + val aggExprsWithHaving = resolvedAggregateFilter +: originalAggExprs + + Project(aggregate.output, + Filter(resolvedAggregateFilter.toAttribute, + aggregate.copy(aggregateExpressions = aggExprsWithHaving))) + } else { + filter + } + } catch { + // Attempting to resolve in the aggregate can result in ambiguity. When this happens, + // just return the original plan. + case ae: AnalysisException => filter } case sort @ Sort(sortOrder, global, aggregate: Aggregate) if aggregate.resolved => @@ -927,11 +1001,8 @@ class Analyzer( } } - private def isAggregateExpression(e: Expression): Boolean = { - e.isInstanceOf[AggregateExpression] || e.isInstanceOf[Grouping] || e.isInstanceOf[GroupingID] - } def containsAggregate(condition: Expression): Boolean = { - condition.find(isAggregateExpression).isDefined + condition.find(_.isInstanceOf[AggregateExpression]).isDefined } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index 2307122ea1..78310fb2f1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -333,6 +333,8 @@ case class PrettyAttribute( } object VirtualColumn { - val groupingIdName: String = "grouping__id" + // The attribute name used by Hive, which has different result than Spark, deprecated. + val hiveGroupingIdName: String = "grouping__id" + val groupingIdName: String = "spark_grouping_id" val groupingIdAttribute: UnresolvedAttribute = UnresolvedAttribute(groupingIdName) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 2ab7c1581c..dd648cdb81 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -2230,6 +2230,88 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") } + test("grouping and grouping_id in having") { + checkAnswer( + sql("select course, year from courseSales group by cube(course, year)" + + " having grouping(year) = 1 and grouping_id(course, year) > 0"), + Row("Java", null) :: + Row("dotNET", null) :: + Row(null, null) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " having grouping(course) > 0") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " having grouping_id(course, year) > 0") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by cube(course, year)" + + " having grouping__id > 0") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("grouping and grouping_id in sort") { + checkAnswer( + sql("select course, year, grouping(course), grouping(year) from courseSales" + + " group by cube(course, year) order by grouping_id(course, year), course, year"), + Row("Java", 2012, 0, 0) :: + Row("Java", 2013, 0, 0) :: + Row("dotNET", 2012, 0, 0) :: + Row("dotNET", 2013, 0, 0) :: + Row("Java", null, 0, 1) :: + Row("dotNET", null, 0, 1) :: + Row(null, 2012, 1, 0) :: + Row(null, 2013, 1, 0) :: + Row(null, null, 1, 1) :: Nil + ) + + checkAnswer( + sql("select course, year, grouping_id(course, year) from courseSales" + + " group by cube(course, year) order by grouping(course), grouping(year), course, year"), + Row("Java", 2012, 0) :: + Row("Java", 2013, 0) :: + Row("dotNET", 2012, 0) :: + Row("dotNET", 2013, 0) :: + Row("Java", null, 1) :: + Row("dotNET", null, 1) :: + Row(null, 2012, 2) :: + Row(null, 2013, 2) :: + Row(null, null, 3) :: Nil + ) + + var error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " order by grouping(course)") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by course, year" + + " order by grouping_id(course, year)") + } + assert(error.getMessage contains + "grouping()/grouping_id() can only be used with GroupingSets/Cube/Rollup") + error = intercept[AnalysisException] { + sql("select course, year from courseSales group by cube(course, year)" + + " order by grouping__id") + } + assert(error.getMessage contains "grouping__id is deprecated; use grouping_id() instead") + } + + test("filter on a grouping column that is not presented in SELECT") { + checkAnswer( + sql("select count(1) from (select 1 as a) t group by a having a > 0"), + Row(1) :: Nil) + } + test("SPARK-13056: Null in map value causes NPE") { val df = Seq(1 -> Map("abc" -> "somestring", "cba" -> null)).toDF("key", "value") withTempTable("maptest") { -- cgit v1.2.3 From ae1db91d158d1ae62a0ab7ea74467679ca050101 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 7 Apr 2016 16:23:17 -0700 Subject: [SPARK-14410][SQL] Push functions existence check into catalog ## What changes were proposed in this pull request? This is a followup to #12117 and addresses some of the TODOs introduced there. In particular, the resolution of database is now pushed into session catalog, which knows about the current database. Further, the logic for checking whether a function exists is pushed into the external catalog. No change in functionality is expected. ## How was this patch tested? `SessionCatalogSuite`, `DDLSuite` Author: Andrew Or Closes #12198 from andrewor14/function-exists. --- .../sql/catalyst/catalog/InMemoryCatalog.scala | 10 ++-- .../sql/catalyst/catalog/SessionCatalog.scala | 59 ++++++++++------------ .../spark/sql/catalyst/catalog/interface.scala | 2 + .../sql/catalyst/catalog/SessionCatalogSuite.scala | 53 ++++++++++--------- .../spark/sql/execution/command/commands.scala | 2 +- .../apache/spark/sql/execution/command/ddl.scala | 13 ++--- .../spark/sql/execution/command/functions.scala | 31 ++++-------- .../spark/sql/execution/command/DDLSuite.scala | 41 +++++++-------- .../spark/sql/hive/HiveExternalCatalog.scala | 4 ++ .../apache/spark/sql/hive/client/HiveClient.scala | 5 ++ .../spark/sql/hive/client/HiveClientImpl.scala | 10 ++-- .../spark/sql/hive/HiveMetastoreCatalogSuite.scala | 8 +-- .../spark/sql/hive/HiveSparkSubmitSuite.scala | 2 +- 13 files changed, 126 insertions(+), 114 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 5d136b663f..186bbccef1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -47,11 +47,6 @@ class InMemoryCatalog extends ExternalCatalog { // Database name -> description private val catalog = new scala.collection.mutable.HashMap[String, DatabaseDesc] - private def functionExists(db: String, funcName: String): Boolean = { - requireDbExists(db) - catalog(db).functions.contains(funcName) - } - private def partitionExists(db: String, table: String, spec: TablePartitionSpec): Boolean = { requireTableExists(db, table) catalog(db).tables(table).partitions.contains(spec) @@ -315,6 +310,11 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).functions(funcName) } + override def functionExists(db: String, funcName: String): Boolean = { + requireDbExists(db) + catalog(db).functions.contains(funcName) + } + override def listFunctions(db: String, pattern: String): Seq[String] = synchronized { requireDbExists(db) StringUtils.filterPattern(catalog(db).functions.keysIterator.toSeq, pattern) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 2acf584e8f..7db9fd0527 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -95,7 +95,7 @@ class SessionCatalog( externalCatalog.alterDatabase(dbDefinition) } - def getDatabase(db: String): CatalogDatabase = { + def getDatabaseMetadata(db: String): CatalogDatabase = { externalCatalog.getDatabase(db) } @@ -169,7 +169,7 @@ class SessionCatalog( * If no database is specified, assume the table is in the current database. * If the specified table is not found in the database then an [[AnalysisException]] is thrown. */ - def getTable(name: TableIdentifier): CatalogTable = { + def getTableMetadata(name: TableIdentifier): CatalogTable = { val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) externalCatalog.getTable(db, table) @@ -435,28 +435,37 @@ class SessionCatalog( * Create a metastore function in the database specified in `funcDefinition`. * If no such database is specified, create it in the current database. */ - def createFunction(funcDefinition: CatalogFunction): Unit = { + def createFunction(funcDefinition: CatalogFunction, ignoreIfExists: Boolean): Unit = { val db = funcDefinition.identifier.database.getOrElse(currentDb) - val newFuncDefinition = funcDefinition.copy( - identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db))) - externalCatalog.createFunction(db, newFuncDefinition) + val identifier = FunctionIdentifier(funcDefinition.identifier.funcName, Some(db)) + val newFuncDefinition = funcDefinition.copy(identifier = identifier) + if (!functionExists(identifier)) { + externalCatalog.createFunction(db, newFuncDefinition) + } else if (!ignoreIfExists) { + throw new AnalysisException(s"function '$identifier' already exists in database '$db'") + } } /** * Drop a metastore function. * If no database is specified, assume the function is in the current database. */ - def dropFunction(name: FunctionIdentifier): Unit = { + def dropFunction(name: FunctionIdentifier, ignoreIfNotExists: Boolean): Unit = { val db = name.database.getOrElse(currentDb) - val qualified = name.copy(database = Some(db)).unquotedString - if (functionRegistry.functionExists(qualified)) { - // If we have loaded this function into the FunctionRegistry, - // also drop it from there. - // For a permanent function, because we loaded it to the FunctionRegistry - // when it's first used, we also need to drop it from the FunctionRegistry. - functionRegistry.dropFunction(qualified) + val identifier = name.copy(database = Some(db)) + if (functionExists(identifier)) { + // TODO: registry should just take in FunctionIdentifier for type safety + if (functionRegistry.functionExists(identifier.unquotedString)) { + // If we have loaded this function into the FunctionRegistry, + // also drop it from there. + // For a permanent function, because we loaded it to the FunctionRegistry + // when it's first used, we also need to drop it from the FunctionRegistry. + functionRegistry.dropFunction(identifier.unquotedString) + } + externalCatalog.dropFunction(db, name.funcName) + } else if (!ignoreIfNotExists) { + throw new AnalysisException(s"function '$identifier' does not exist in database '$db'") } - externalCatalog.dropFunction(db, name.funcName) } /** @@ -465,8 +474,7 @@ class SessionCatalog( * If a database is specified in `name`, this will return the function in that database. * If no database is specified, this will return the function in the current database. */ - // TODO: have a better name. This method is actually for fetching the metadata of a function. - def getFunction(name: FunctionIdentifier): CatalogFunction = { + def getFunctionMetadata(name: FunctionIdentifier): CatalogFunction = { val db = name.database.getOrElse(currentDb) externalCatalog.getFunction(db, name.funcName) } @@ -475,20 +483,9 @@ class SessionCatalog( * Check if the specified function exists. */ def functionExists(name: FunctionIdentifier): Boolean = { - if (functionRegistry.functionExists(name.unquotedString)) { - // This function exists in the FunctionRegistry. - true - } else { - // Need to check if this function exists in the metastore. - try { - // TODO: It's better to ask external catalog if this function exists. - // So, we can avoid of having this hacky try/catch block. - getFunction(name) != null - } catch { - case _: NoSuchFunctionException => false - case _: AnalysisException => false // HiveExternalCatalog wraps all exceptions with it. - } - } + val db = name.database.getOrElse(currentDb) + functionRegistry.functionExists(name.unquotedString) || + externalCatalog.functionExists(db, name.funcName) } // ---------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 97b9946140..e29d6bd8b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -152,6 +152,8 @@ abstract class ExternalCatalog { def getFunction(db: String, funcName: String): CatalogFunction + def functionExists(db: String, funcName: String): Boolean + def listFunctions(db: String, pattern: String): Seq[String] } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 4d56d001b3..1850dc8156 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -62,7 +62,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("get database when a database exists") { val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabase("db1") + val db1 = catalog.getDatabaseMetadata("db1") assert(db1.name == "db1") assert(db1.description.contains("db1")) } @@ -70,7 +70,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("get database should throw exception when the database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.getDatabase("db_that_does_not_exist") + catalog.getDatabaseMetadata("db_that_does_not_exist") } } @@ -128,10 +128,10 @@ class SessionCatalogSuite extends SparkFunSuite { test("alter database") { val catalog = new SessionCatalog(newBasicCatalog()) - val db1 = catalog.getDatabase("db1") + val db1 = catalog.getDatabaseMetadata("db1") // Note: alter properties here because Hive does not support altering other fields catalog.alterDatabase(db1.copy(properties = Map("k" -> "v3", "good" -> "true"))) - val newDb1 = catalog.getDatabase("db1") + val newDb1 = catalog.getDatabaseMetadata("db1") assert(db1.properties.isEmpty) assert(newDb1.properties.size == 2) assert(newDb1.properties.get("k") == Some("v3")) @@ -346,21 +346,21 @@ class SessionCatalogSuite extends SparkFunSuite { test("get table") { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) - assert(sessionCatalog.getTable(TableIdentifier("tbl1", Some("db2"))) + assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) == externalCatalog.getTable("db2", "tbl1")) // Get table without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") - assert(sessionCatalog.getTable(TableIdentifier("tbl1")) + assert(sessionCatalog.getTableMetadata(TableIdentifier("tbl1")) == externalCatalog.getTable("db2", "tbl1")) } test("get table when database/table does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.getTable(TableIdentifier("tbl1", Some("unknown_db"))) + catalog.getTableMetadata(TableIdentifier("tbl1", Some("unknown_db"))) } intercept[AnalysisException] { - catalog.getTable(TableIdentifier("unknown_table", Some("db2"))) + catalog.getTableMetadata(TableIdentifier("unknown_table", Some("db2"))) } } @@ -386,7 +386,7 @@ class SessionCatalogSuite extends SparkFunSuite { test("lookup table relation with alias") { val catalog = new SessionCatalog(newBasicCatalog()) val alias = "monster" - val tableMetadata = catalog.getTable(TableIdentifier("tbl1", Some("db2"))) + val tableMetadata = catalog.getTableMetadata(TableIdentifier("tbl1", Some("db2"))) val relation = SubqueryAlias("tbl1", CatalogRelation("db2", tableMetadata)) val relationWithAlias = SubqueryAlias(alias, @@ -659,26 +659,28 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newEmptyCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) sessionCatalog.createDatabase(newDb("mydb"), ignoreIfExists = false) - sessionCatalog.createFunction(newFunc("myfunc", Some("mydb"))) + sessionCatalog.createFunction(newFunc("myfunc", Some("mydb")), ignoreIfExists = false) assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc")) // Create function without explicitly specifying database sessionCatalog.setCurrentDatabase("mydb") - sessionCatalog.createFunction(newFunc("myfunc2")) + sessionCatalog.createFunction(newFunc("myfunc2"), ignoreIfExists = false) assert(externalCatalog.listFunctions("mydb", "*").toSet == Set("myfunc", "myfunc2")) } test("create function when database does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.createFunction(newFunc("func5", Some("does_not_exist"))) + catalog.createFunction( + newFunc("func5", Some("does_not_exist")), ignoreIfExists = false) } } test("create function that already exists") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.createFunction(newFunc("func1", Some("db2"))) + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = false) } + catalog.createFunction(newFunc("func1", Some("db2")), ignoreIfExists = true) } test("create temp function") { @@ -711,24 +713,27 @@ class SessionCatalogSuite extends SparkFunSuite { val externalCatalog = newBasicCatalog() val sessionCatalog = new SessionCatalog(externalCatalog) assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func1")) - sessionCatalog.dropFunction(FunctionIdentifier("func1", Some("db2"))) + sessionCatalog.dropFunction( + FunctionIdentifier("func1", Some("db2")), ignoreIfNotExists = false) assert(externalCatalog.listFunctions("db2", "*").isEmpty) // Drop function without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") - sessionCatalog.createFunction(newFunc("func2", Some("db2"))) + sessionCatalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) assert(externalCatalog.listFunctions("db2", "*").toSet == Set("func2")) - sessionCatalog.dropFunction(FunctionIdentifier("func2")) + sessionCatalog.dropFunction(FunctionIdentifier("func2"), ignoreIfNotExists = false) assert(externalCatalog.listFunctions("db2", "*").isEmpty) } test("drop function when database/function does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.dropFunction(FunctionIdentifier("something", Some("does_not_exist"))) + catalog.dropFunction( + FunctionIdentifier("something", Some("does_not_exist")), ignoreIfNotExists = false) } intercept[AnalysisException] { - catalog.dropFunction(FunctionIdentifier("does_not_exist")) + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = false) } + catalog.dropFunction(FunctionIdentifier("does_not_exist"), ignoreIfNotExists = true) } test("drop temp function") { @@ -753,19 +758,19 @@ class SessionCatalogSuite extends SparkFunSuite { val expected = CatalogFunction(FunctionIdentifier("func1", Some("db2")), funcClass, Seq.empty[(String, String)]) - assert(catalog.getFunction(FunctionIdentifier("func1", Some("db2"))) == expected) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("db2"))) == expected) // Get function without explicitly specifying database catalog.setCurrentDatabase("db2") - assert(catalog.getFunction(FunctionIdentifier("func1")) == expected) + assert(catalog.getFunctionMetadata(FunctionIdentifier("func1")) == expected) } test("get function when database/function does not exist") { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { - catalog.getFunction(FunctionIdentifier("func1", Some("does_not_exist"))) + catalog.getFunctionMetadata(FunctionIdentifier("func1", Some("does_not_exist"))) } intercept[AnalysisException] { - catalog.getFunction(FunctionIdentifier("does_not_exist", Some("db2"))) + catalog.getFunctionMetadata(FunctionIdentifier("does_not_exist", Some("db2"))) } } @@ -787,8 +792,8 @@ class SessionCatalogSuite extends SparkFunSuite { val info2 = new ExpressionInfo("tempFunc2", "yes_me") val tempFunc1 = (e: Seq[Expression]) => e.head val tempFunc2 = (e: Seq[Expression]) => e.last - catalog.createFunction(newFunc("func2", Some("db2"))) - catalog.createFunction(newFunc("not_me", Some("db2"))) + catalog.createFunction(newFunc("func2", Some("db2")), ignoreIfExists = false) + catalog.createFunction(newFunc("not_me", Some("db2")), ignoreIfExists = false) catalog.createTempFunction("func1", info1, tempFunc1, ignoreIfExists = false) catalog.createTempFunction("yes_me", info2, tempFunc2, ignoreIfExists = false) assert(catalog.listFunctions("db1", "*").toSet == diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index faa7a2cdb4..3fd2a93d29 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -407,7 +407,7 @@ case class ShowTablePropertiesCommand( if (catalog.isTemporaryTable(table)) { Seq.empty[Row] } else { - val catalogTable = sqlContext.sessionState.catalog.getTable(table) + val catalogTable = sqlContext.sessionState.catalog.getTableMetadata(table) propertyKey match { case Some(p) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 6d56a6fec8..20779d68e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -124,7 +124,7 @@ case class AlterDatabaseProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val db: CatalogDatabase = catalog.getDatabase(databaseName) + val db: CatalogDatabase = catalog.getDatabaseMetadata(databaseName) catalog.alterDatabase(db.copy(properties = db.properties ++ props)) Seq.empty[Row] @@ -149,7 +149,8 @@ case class DescribeDatabase( extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { - val dbMetadata: CatalogDatabase = sqlContext.sessionState.catalog.getDatabase(databaseName) + val dbMetadata: CatalogDatabase = + sqlContext.sessionState.catalog.getDatabaseMetadata(databaseName) val result = Row("Database Name", dbMetadata.name) :: Row("Description", dbMetadata.description) :: @@ -213,7 +214,7 @@ case class AlterTableSetProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) val newProperties = table.properties ++ properties if (DDLUtils.isDatasourceTable(newProperties)) { throw new AnalysisException( @@ -243,7 +244,7 @@ case class AlterTableUnsetProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) if (DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( "alter table properties is not supported for datasource tables") @@ -286,7 +287,7 @@ case class AlterTableSerDeProperties( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) // Do not support setting serde for datasource tables if (serdeClassName.isDefined && DDLUtils.isDatasourceTable(table)) { throw new AnalysisException( @@ -376,7 +377,7 @@ case class AlterTableSetLocation( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - val table = catalog.getTable(tableName) + val table = catalog.getTableMetadata(tableName) partitionSpec match { case Some(spec) => // Partition spec is specified, so we set the location only for this partition diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala index 66d17e322e..c6e601799f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/functions.scala @@ -47,6 +47,7 @@ case class CreateFunction( extends RunnableCommand { override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog if (isTemp) { if (databaseName.isDefined) { throw new AnalysisException( @@ -55,24 +56,18 @@ case class CreateFunction( } // We first load resources and then put the builder in the function registry. // Please note that it is allowed to overwrite an existing temp function. - sqlContext.sessionState.catalog.loadFunctionResources(resources) + catalog.loadFunctionResources(resources) val info = new ExpressionInfo(className, functionName) - val builder = - sqlContext.sessionState.catalog.makeFunctionBuilder(functionName, className) - sqlContext.sessionState.catalog.createTempFunction( - functionName, info, builder, ignoreIfExists = false) + val builder = catalog.makeFunctionBuilder(functionName, className) + catalog.createTempFunction(functionName, info, builder, ignoreIfExists = false) } else { // For a permanent, we will store the metadata into underlying external catalog. // This function will be loaded into the FunctionRegistry when a query uses it. // We do not load it into FunctionRegistry right now. - val dbName = databaseName.getOrElse(sqlContext.sessionState.catalog.getCurrentDatabase) - val func = FunctionIdentifier(functionName, Some(dbName)) - val catalogFunc = CatalogFunction(func, className, resources) - if (sqlContext.sessionState.catalog.functionExists(func)) { - throw new AnalysisException( - s"Function '$functionName' already exists in database '$dbName'.") - } - sqlContext.sessionState.catalog.createFunction(catalogFunc) + // TODO: should we also parse "IF NOT EXISTS"? + catalog.createFunction( + CatalogFunction(FunctionIdentifier(functionName, databaseName), className, resources), + ignoreIfExists = false) } Seq.empty[Row] } @@ -101,13 +96,9 @@ case class DropFunction( catalog.dropTempFunction(functionName, ifExists) } else { // We are dropping a permanent function. - val dbName = databaseName.getOrElse(catalog.getCurrentDatabase) - val func = FunctionIdentifier(functionName, Some(dbName)) - if (!ifExists && !catalog.functionExists(func)) { - throw new AnalysisException( - s"Function '$functionName' does not exist in database '$dbName'.") - } - catalog.dropFunction(func) + catalog.dropFunction( + FunctionIdentifier(functionName, databaseName), + ignoreIfNotExists = ifExists) } Seq.empty[Row] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index a8db4e9923..7084665b3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -88,7 +88,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -110,7 +110,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { try { val dbNameWithoutBackTicks = cleanIdentifier(dbName) sql(s"CREATE DATABASE $dbName") - val db1 = catalog.getDatabase(dbNameWithoutBackTicks) + val db1 = catalog.getDatabaseMetadata(dbNameWithoutBackTicks) assert(db1 == CatalogDatabase( dbNameWithoutBackTicks, "", @@ -233,14 +233,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { val tableIdent = TableIdentifier("tab1", Some("dbx")) createDatabase(catalog, "dbx") createTable(catalog, tableIdent) - assert(catalog.getTable(tableIdent).properties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) // set table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('andrew' = 'or14', 'kor' = 'bel')") - assert(catalog.getTable(tableIdent).properties == Map("andrew" -> "or14", "kor" -> "bel")) + assert(catalog.getTableMetadata(tableIdent).properties == + Map("andrew" -> "or14", "kor" -> "bel")) // set table properties without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET TBLPROPERTIES ('kor' = 'belle', 'kar' = 'bol')") - assert(catalog.getTable(tableIdent).properties == + assert(catalog.getTableMetadata(tableIdent).properties == Map("andrew" -> "or14", "kor" -> "belle", "kar" -> "bol")) // table to alter does not exist intercept[AnalysisException] { @@ -262,11 +263,11 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { // unset table properties sql("ALTER TABLE dbx.tab1 SET TBLPROPERTIES ('j' = 'am', 'p' = 'an', 'c' = 'lan')") sql("ALTER TABLE dbx.tab1 UNSET TBLPROPERTIES ('j')") - assert(catalog.getTable(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) + assert(catalog.getTableMetadata(tableIdent).properties == Map("p" -> "an", "c" -> "lan")) // unset table properties without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 UNSET TBLPROPERTIES ('p')") - assert(catalog.getTable(tableIdent).properties == Map("c" -> "lan")) + assert(catalog.getTableMetadata(tableIdent).properties == Map("c" -> "lan")) // table to alter does not exist intercept[AnalysisException] { sql("ALTER TABLE does_not_exist UNSET TBLPROPERTIES ('c' = 'lan')") @@ -278,7 +279,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e.getMessage.contains("xyz")) // property to unset does not exist, but "IF EXISTS" is specified sql("ALTER TABLE tab1 UNSET TBLPROPERTIES IF EXISTS ('c', 'xyz')") - assert(catalog.getTable(tableIdent).properties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).properties.isEmpty) // throw exception for datasource tables convertToDatasourceTable(catalog, tableIdent) val e1 = intercept[AnalysisException] { @@ -393,7 +394,7 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { private def convertToDatasourceTable( catalog: SessionCatalog, tableIdent: TableIdentifier): Unit = { - catalog.alterTable(catalog.getTable(tableIdent).copy( + catalog.alterTable(catalog.getTableMetadata(tableIdent).copy( properties = Map("spark.sql.sources.provider" -> "csv"))) } @@ -407,15 +408,15 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - assert(catalog.getTable(tableIdent).storage.locationUri.isEmpty) - assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.locationUri.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.locationUri.isEmpty) assert(catalog.getPartition(tableIdent, partSpec).storage.serdeProperties.isEmpty) // Verify that the location is set to the expected string def verifyLocation(expected: String, spec: Option[TablePartitionSpec] = None): Unit = { val storageFormat = spec .map { s => catalog.getPartition(tableIdent, s).storage } - .getOrElse { catalog.getTable(tableIdent).storage } + .getOrElse { catalog.getTableMetadata(tableIdent).storage } if (isDatasourceTable) { if (spec.isDefined) { assert(storageFormat.serdeProperties.isEmpty) @@ -467,8 +468,8 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { if (isDatasourceTable) { convertToDatasourceTable(catalog, tableIdent) } - assert(catalog.getTable(tableIdent).storage.serde.isEmpty) - assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serde.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) // set table serde and/or properties (should fail on datasource tables) if (isDatasourceTable) { val e1 = intercept[AnalysisException] { @@ -482,22 +483,22 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e2.getMessage.contains("datasource")) } else { sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.jadoop'") - assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.jadoop")) - assert(catalog.getTable(tableIdent).storage.serdeProperties.isEmpty) + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.jadoop")) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties.isEmpty) sql("ALTER TABLE dbx.tab1 SET SERDE 'org.apache.madoop' " + "WITH SERDEPROPERTIES ('k' = 'v', 'kay' = 'vee')") - assert(catalog.getTable(tableIdent).storage.serde == Some("org.apache.madoop")) - assert(catalog.getTable(tableIdent).storage.serdeProperties == + assert(catalog.getTableMetadata(tableIdent).storage.serde == Some("org.apache.madoop")) + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == Map("k" -> "v", "kay" -> "vee")) } // set serde properties only sql("ALTER TABLE dbx.tab1 SET SERDEPROPERTIES ('k' = 'vvv', 'kay' = 'vee')") - assert(catalog.getTable(tableIdent).storage.serdeProperties == + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == Map("k" -> "vvv", "kay" -> "vee")) // set things without explicitly specifying database catalog.setCurrentDatabase("dbx") sql("ALTER TABLE tab1 SET SERDEPROPERTIES ('kay' = 'veee')") - assert(catalog.getTable(tableIdent).storage.serdeProperties == + assert(catalog.getTableMetadata(tableIdent).storage.serdeProperties == Map("k" -> "vvv", "kay" -> "veee")) // table to alter does not exist intercept[AnalysisException] { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 98a5998d03..b1156fb3e2 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -292,6 +292,10 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.getFunction(db, funcName) } + override def functionExists(db: String, funcName: String): Boolean = withClient { + client.functionExists(db, funcName) + } + override def listFunctions(db: String, pattern: String): Seq[String] = withClient { client.listFunctions(db, pattern) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index ee56f9d75d..94794b1572 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -232,6 +232,11 @@ private[hive] trait HiveClient { /** Return an existing function in the database, or None if it doesn't exist. */ def getFunctionOption(db: String, name: String): Option[CatalogFunction] + /** Return whether a function exists in the specified database. */ + final def functionExists(db: String, name: String): Boolean = { + getFunctionOption(db, name).isDefined + } + /** Return the names of all functions that match the given pattern in the database. */ def listFunctions(db: String, pattern: String): Seq[String] diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 1f66fbfd85..d0eb9ddf50 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -21,7 +21,6 @@ import java.io.{File, PrintStream} import scala.collection.JavaConverters._ import scala.language.reflectiveCalls -import scala.util.Try import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -30,7 +29,8 @@ import org.apache.hadoop.hive.conf.HiveConf import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver -import org.apache.hadoop.hive.ql.metadata.{Hive, Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} +import org.apache.hadoop.hive.ql.metadata.{Hive, HiveException} import org.apache.hadoop.hive.ql.plan.AddPartitionDesc import org.apache.hadoop.hive.ql.processors._ import org.apache.hadoop.hive.ql.session.SessionState @@ -559,7 +559,11 @@ private[hive] class HiveClientImpl( override def getFunctionOption( db: String, name: String): Option[CatalogFunction] = withHiveState { - Option(client.getFunction(db, name)).map(fromHiveFunction) + try { + Option(client.getFunction(db, name)).map(fromHiveFunction) + } catch { + case he: HiveException => None + } } override def listFunctions(db: String, pattern: String): Seq[String] = withHiveState { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 6967395613..ada8621d07 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -83,7 +83,7 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) + val hiveTable = sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) @@ -114,7 +114,8 @@ class DataSourceWithHiveMetastoreCatalogSuite .saveAsTable("t") } - val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) + val hiveTable = + sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) @@ -144,7 +145,8 @@ class DataSourceWithHiveMetastoreCatalogSuite |AS SELECT 1 AS d1, "val_1" AS d2 """.stripMargin) - val hiveTable = sessionState.catalog.getTable(TableIdentifier("t", Some("default"))) + val hiveTable = + sessionState.catalog.getTableMetadata(TableIdentifier("t", Some("default"))) assert(hiveTable.storage.inputFormat === Some(inputFormat)) assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala index dd2129375d..c5417b06a4 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveSparkSubmitSuite.scala @@ -354,7 +354,7 @@ object PermanentHiveUDFTest2 extends Logging { FunctionIdentifier("example_max"), "org.apache.hadoop.hive.contrib.udaf.example.UDAFExampleMax", ("JAR" -> jar) :: Nil) - hiveContext.sessionState.catalog.createFunction(function) + hiveContext.sessionState.catalog.createFunction(function, ignoreIfExists = false) val source = hiveContext.createDataFrame((1 to 10).map(i => (i, s"str$i"))).toDF("key", "val") source.registerTempTable("sourceTable") -- cgit v1.2.3 From 49fb237081bbca0d811aa48aa06f4728fea62781 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 7 Apr 2016 17:23:34 -0700 Subject: [SPARK-14270][SQL] whole stage codegen support for typed filter ## What changes were proposed in this pull request? We implement typed filter by `MapPartitions`, which doesn't work well with whole stage codegen. This PR use `Filter` to implement typed filter and we can get the whole stage codegen support for free. This PR also introduced `DeserializeToObject` and `SerializeFromObject`, to seperate serialization logic from object operator, so that it's eaiser to write optimization rules for adjacent object operators. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #12061 from cloud-fan/whole-stage-codegen. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 + .../apache/spark/sql/catalyst/dsl/package.scala | 19 ++++++ .../spark/sql/catalyst/optimizer/Optimizer.scala | 39 ++++++++++- .../spark/sql/catalyst/plans/logical/object.scala | 37 ++++++++++- .../optimizer/TypedFilterOptimizationSuite.scala | 74 +++++++++++++++++++++ .../main/scala/org/apache/spark/sql/Dataset.scala | 16 ++++- .../spark/sql/execution/SparkStrategies.scala | 4 ++ .../org/apache/spark/sql/execution/objects.scala | 67 +++++++++++++++++++ .../org/apache/spark/sql/DatasetBenchmark.scala | 76 +++++++++++++++++++--- .../scala/org/apache/spark/sql/QueryTest.scala | 2 + .../sql/execution/WholeStageCodegenSuite.scala | 21 +++++- 11 files changed, 342 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 7bcba421fd..3555a6d7fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -1670,6 +1670,8 @@ object CleanupAliases extends Rule[LogicalPlan] { // Operators that operate on objects should only have expressions from encoders, which should // never have extra aliases. case o: ObjectOperator => o + case d: DeserializeToObject => d + case s: SerializeFromObject => s case other => var stop = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index 105947028d..1e7296664b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -21,6 +21,7 @@ import java.sql.{Date, Timestamp} import scala.language.implicitConversions +import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ @@ -166,6 +167,14 @@ package object dsl { case target => UnresolvedStar(Option(target)) } + def callFunction[T, U]( + func: T => U, + returnType: DataType, + argument: Expression): Expression = { + val function = Literal.create(func, ObjectType(classOf[T => U])) + Invoke(function, "apply", returnType, argument :: Nil) + } + implicit class DslSymbol(sym: Symbol) extends ImplicitAttribute { def s: String = sym.name } // TODO more implicit class for literal? implicit class DslString(val s: String) extends ImplicitOperators { @@ -270,6 +279,16 @@ package object dsl { def where(condition: Expression): LogicalPlan = Filter(condition, logicalPlan) + def filter[T : Encoder](func: T => Boolean): LogicalPlan = { + val deserialized = logicalPlan.deserialize[T] + val condition = expressions.callFunction(func, BooleanType, deserialized.output.head) + Filter(condition, deserialized).serialize[T] + } + + def serialize[T : Encoder]: LogicalPlan = CatalystSerde.serialize[T](logicalPlan) + + def deserialize[T : Encoder]: LogicalPlan = CatalystSerde.deserialize[T](logicalPlan) + def limit(limitExpr: Expression): LogicalPlan = Limit(limitExpr, logicalPlan) def join( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index f581810c26..619514e8aa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -93,6 +93,8 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { EliminateSerialization) :: Batch("Decimal Optimizations", FixedPoint(100), DecimalAggregates) :: + Batch("Typed Filter Optimization", FixedPoint(100), + EmbedSerializerInFilter) :: Batch("LocalRelation", FixedPoint(100), ConvertToLocalRelation) :: Batch("Subquery", Once, @@ -147,12 +149,18 @@ object EliminateSerialization extends Rule[LogicalPlan] { child = childWithoutSerialization) case m @ MapElements(_, deserializer, _, child: ObjectOperator) - if !deserializer.isInstanceOf[Attribute] && - deserializer.dataType == child.outputObject.dataType => + if !deserializer.isInstanceOf[Attribute] && + deserializer.dataType == child.outputObject.dataType => val childWithoutSerialization = child.withObjectOutput m.copy( deserializer = childWithoutSerialization.output.head, child = childWithoutSerialization) + + case d @ DeserializeToObject(_, s: SerializeFromObject) + if d.outputObjectType == s.inputObjectType => + // Adds an extra Project here, to preserve the output expr id of `DeserializeToObject`. + val objAttr = Alias(s.child.output.head, "obj")(exprId = d.output.head.exprId) + Project(objAttr :: Nil, s.child) } } @@ -1329,3 +1337,30 @@ object ComputeCurrentTime extends Rule[LogicalPlan] { } } } + +/** + * Typed [[Filter]] is by default surrounded by a [[DeserializeToObject]] beneath it and a + * [[SerializeFromObject]] above it. If these serializations can't be eliminated, we should embed + * the deserializer in filter condition to save the extra serialization at last. + */ +object EmbedSerializerInFilter extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case s @ SerializeFromObject(_, Filter(condition, d: DeserializeToObject)) => + val numObjects = condition.collect { + case a: Attribute if a == d.output.head => a + }.length + + if (numObjects > 1) { + // If the filter condition references the object more than one times, we should not embed + // deserializer in it as the deserialization will happen many times and slow down the + // execution. + // TODO: we can still embed it if we can make sure subexpression elimination works here. + s + } else { + val newCondition = condition transform { + case a: Attribute if a == d.output.head => d.deserializer.child + } + Filter(newCondition, d.child) + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala index ec33a538a9..6df46189b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/object.scala @@ -21,7 +21,42 @@ import org.apache.spark.sql.Encoder import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.types.{ObjectType, StructType} +import org.apache.spark.sql.types.{DataType, ObjectType, StructType} + +object CatalystSerde { + def deserialize[T : Encoder](child: LogicalPlan): DeserializeToObject = { + val deserializer = UnresolvedDeserializer(encoderFor[T].deserializer) + DeserializeToObject(Alias(deserializer, "obj")(), child) + } + + def serialize[T : Encoder](child: LogicalPlan): SerializeFromObject = { + SerializeFromObject(encoderFor[T].namedExpressions, child) + } +} + +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + def outputObjectType: DataType = deserializer.dataType +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: LogicalPlan) extends UnaryNode { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + def inputObjectType: DataType = child.output.head.dataType +} /** * A trait for logical operators that apply user defined functions to domain objects. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala new file mode 100644 index 0000000000..1fae64e3bc --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/TypedFilterOptimizationSuite.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.reflect.runtime.universe.TypeTag + +import org.apache.spark.sql.catalyst.analysis.UnresolvedDeserializer +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.BooleanType + +class TypedFilterOptimizationSuite extends PlanTest { + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("EliminateSerialization", FixedPoint(50), + EliminateSerialization) :: + Batch("EmbedSerializerInFilter", FixedPoint(50), + EmbedSerializerInFilter) :: Nil + } + + implicit private def productEncoder[T <: Product : TypeTag] = ExpressionEncoder[T]() + + test("back to back filter") { + val input = LocalRelation('_1.int, '_2.int) + val f1 = (i: (Int, Int)) => i._1 > 0 + val f2 = (i: (Int, Int)) => i._2 > 0 + + val query = input.filter(f1).filter(f2).analyze + + val optimized = Optimize.execute(query) + + val expected = input.deserialize[(Int, Int)] + .where(callFunction(f1, BooleanType, 'obj)) + .select('obj.as("obj")) + .where(callFunction(f2, BooleanType, 'obj)) + .serialize[(Int, Int)].analyze + + comparePlans(optimized, expected) + } + + test("embed deserializer in filter condition if there is only one filter") { + val input = LocalRelation('_1.int, '_2.int) + val f = (i: (Int, Int)) => i._1 > 0 + + val query = input.filter(f).analyze + + val optimized = Optimize.execute(query) + + val deserializer = UnresolvedDeserializer(encoderFor[(Int, Int)].deserializer) + val condition = callFunction(f, BooleanType, deserializer) + val expected = input.where(condition).analyze + + comparePlans(optimized, expected) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 2854d5f9da..2f6d8d109f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1879,7 +1879,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: T => Boolean): Dataset[T] = mapPartitions(_.filter(func)) + def filter(func: T => Boolean): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[T => Boolean])) + val condition = Invoke(function, "apply", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: @@ -1890,7 +1896,13 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ @Experimental - def filter(func: FilterFunction[T]): Dataset[T] = filter(t => func.call(t)) + def filter(func: FilterFunction[T]): Dataset[T] = { + val deserialized = CatalystSerde.deserialize[T](logicalPlan) + val function = Literal.create(func, ObjectType(classOf[FilterFunction[T]])) + val condition = Invoke(function, "call", BooleanType, deserialized.output) + val filter = Filter(condition, deserialized) + withTypedPlan(CatalystSerde.serialize[T](filter)) + } /** * :: Experimental :: diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index eee2b946e3..c15aaed365 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -346,6 +346,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { throw new IllegalStateException( "logical intersect operator should have been replaced by semi-join in the optimizer") + case logical.DeserializeToObject(deserializer, child) => + execution.DeserializeToObject(deserializer, planLater(child)) :: Nil + case logical.SerializeFromObject(serializer, child) => + execution.SerializeFromObject(serializer, planLater(child)) :: Nil case logical.MapPartitions(f, in, out, child) => execution.MapPartitions(f, in, out, planLater(child)) :: Nil case logical.MapElements(f, in, out, child) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala index f48f3f09c7..d2ab18ef0e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/objects.scala @@ -27,6 +27,73 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.types.ObjectType +/** + * Takes the input row from child and turns it into object using the given deserializer expression. + * The output of this operator is a single-field safe row containing the deserialized object. + */ +case class DeserializeToObject( + deserializer: Alias, + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = deserializer.toAttribute :: Nil + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = ExpressionCanonicalizer.execute( + BindReferences.bindReference(deserializer, child.output)) + ctx.currentVars = input + val resultVars = bound.gen(ctx) :: Nil + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = GenerateSafeProjection.generate(deserializer :: Nil, child.output) + iter.map(projection) + } + } +} + +/** + * Takes the input object from child and turns in into unsafe row using the given serializer + * expression. The output of its child must be a single-field row containing the input object. + */ +case class SerializeFromObject( + serializer: Seq[NamedExpression], + child: SparkPlan) extends UnaryNode with CodegenSupport { + override def output: Seq[Attribute] = serializer.map(_.toAttribute) + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val bound = serializer.map { expr => + ExpressionCanonicalizer.execute(BindReferences.bindReference(expr, child.output)) + } + ctx.currentVars = input + val resultVars = bound.map(_.gen(ctx)) + consume(ctx, resultVars) + } + + override protected def doExecute(): RDD[InternalRow] = { + child.execute().mapPartitionsInternal { iter => + val projection = UnsafeProjection.create(serializer) + iter.map(projection) + } + } +} + /** * Helper functions for physical operators that work with user defined objects. */ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala index 6eb952445f..5f3dd906fe 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetBenchmark.scala @@ -28,16 +28,10 @@ object DatasetBenchmark { case class Data(l: Long, s: String) - def main(args: Array[String]): Unit = { - val sparkContext = new SparkContext("local[*]", "Dataset benchmark") - val sqlContext = new SQLContext(sparkContext) - + def backToBackMap(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { import sqlContext.implicits._ - val numRows = 10000000 val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) - val numChains = 10 - val benchmark = new Benchmark("back-to-back map", numRows) val func = (d: Data) => Data(d.l + 1, d.s) @@ -61,7 +55,7 @@ object DatasetBenchmark { res.queryExecution.toRdd.foreach(_ => Unit) } - val rdd = sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) benchmark.addCase("RDD") { iter => var res = rdd var i = 0 @@ -72,6 +66,63 @@ object DatasetBenchmark { res.foreach(_ => Unit) } + benchmark + } + + def backToBackFilter(sqlContext: SQLContext, numRows: Long, numChains: Int): Benchmark = { + import sqlContext.implicits._ + + val df = sqlContext.range(1, numRows).select($"id".as("l"), $"id".cast(StringType).as("s")) + val benchmark = new Benchmark("back-to-back filter", numRows) + + val func = (d: Data, i: Int) => d.l % (100L + i) == 0L + val funcs = 0.until(numChains).map { i => + (d: Data) => func(d, i) + } + benchmark.addCase("Dataset") { iter => + var res = df.as[Data] + var i = 0 + while (i < numChains) { + res = res.filter(funcs(i)) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + benchmark.addCase("DataFrame") { iter => + var res = df + var i = 0 + while (i < numChains) { + res = res.filter($"l" % (100L + i) === 0L) + i += 1 + } + res.queryExecution.toRdd.foreach(_ => Unit) + } + + val rdd = sqlContext.sparkContext.range(1, numRows).map(l => Data(l, l.toString)) + benchmark.addCase("RDD") { iter => + var res = rdd + var i = 0 + while (i < numChains) { + res = rdd.filter(funcs(i)) + i += 1 + } + res.foreach(_ => Unit) + } + + benchmark + } + + def main(args: Array[String]): Unit = { + val sparkContext = new SparkContext("local[*]", "Dataset benchmark") + val sqlContext = new SQLContext(sparkContext) + + val numRows = 10000000 + val numChains = 10 + + val benchmark = backToBackMap(sqlContext, numRows, numChains) + val benchmark2 = backToBackFilter(sqlContext, numRows, numChains) + /* Java HotSpot(TM) 64-Bit Server VM 1.8.0_60-b27 on Mac OS X 10.11.4 Intel(R) Core(TM) i7-4960HQ CPU @ 2.60GHz @@ -82,5 +133,14 @@ object DatasetBenchmark { RDD 216 / 237 46.3 21.6 4.2X */ benchmark.run() + + /* + back-to-back filter: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ------------------------------------------------------------------------------------------- + Dataset 585 / 628 17.1 58.5 1.0X + DataFrame 62 / 80 160.7 6.2 9.4X + RDD 205 / 220 48.7 20.5 2.8X + */ + benchmark2.run() } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 48a077d0e5..826862835a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -32,6 +32,7 @@ import org.apache.spark.sql.execution.LogicalRDD import org.apache.spark.sql.execution.columnar.InMemoryRelation import org.apache.spark.sql.execution.datasources.LogicalRelation import org.apache.spark.sql.execution.streaming.MemoryPlan +import org.apache.spark.sql.types.ObjectType abstract class QueryTest extends PlanTest { @@ -204,6 +205,7 @@ abstract class QueryTest extends PlanTest { case _: MemoryPlan => return }.transformAllExpressions { case a: ImperativeAggregate => return + case Literal(_, _: ObjectType) => return } // bypass hive tests before we fix all corner cases in hive module. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index f73ca887f1..4474cfcf6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -17,8 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.api.java.function.MapFunction -import org.apache.spark.sql.{Encoders, Row} +import org.apache.spark.sql.Row import org.apache.spark.sql.execution.aggregate.TungstenAggregate import org.apache.spark.sql.execution.joins.BroadcastHashJoin import org.apache.spark.sql.functions.{avg, broadcast, col, max} @@ -82,4 +81,22 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[MapElements]).isDefined) assert(ds.collect() === 0.until(10).map(_.toString).toArray) } + + test("typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[Filter]).isDefined) + assert(ds.collect() === Array(0, 2, 4, 6, 8)) + } + + test("back-to-back typed filter should be included in WholeStageCodegen") { + val ds = sqlContext.range(10).filter(_ % 2 == 0).filter(_ % 3 == 0) + val plan = ds.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].child.isInstanceOf[SerializeFromObject]).isDefined) + assert(ds.collect() === Array(0, 6)) + } } -- cgit v1.2.3 From 6447098013fad708769423ef108a2b071e0930d8 Mon Sep 17 00:00:00 2001 From: Jacek Laskowski Date: Fri, 8 Apr 2016 11:36:41 +0100 Subject: [SPARK-14402][HOTFIX] Fix ExpressionDescription annotation ## What changes were proposed in this pull request? Fix for the error introduced in https://github.com/apache/spark/commit/c59abad052b7beec4ef550049413e95578e545be: ``` /Users/jacek/dev/oss/spark/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala:626: error: annotation argument needs to be a constant; found: "_FUNC_(str) - ".+("Returns str, with the first letter of each word in uppercase, all other letters in ").+("lowercase. Words are delimited by white space.") "Returns str, with the first letter of each word in uppercase, all other letters in " + ^ ``` ## How was this patch tested? Local build Author: Jacek Laskowski Closes #12192 from jaceklaskowski/SPARK-14402-HOTFIX. --- .../apache/spark/sql/catalyst/expressions/stringExpressions.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b6ea03cd5c..7e0e7a833b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -622,9 +622,9 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC * Words are delimited by whitespace. */ @ExpressionDescription( - usage = "_FUNC_(str) - " + - "Returns str, with the first letter of each word in uppercase, all other letters in " + - "lowercase. Words are delimited by white space.", + usage = + """_FUNC_(str) - Returns str with the first letter of each word in uppercase. + All other letters are in lowercase. Words are delimited by white space.""", extended = "> SELECT initcap('sPark sql');\n 'Spark Sql'") case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { -- cgit v1.2.3 From 10a95781ee8685d1473b156715b905c83af2aab0 Mon Sep 17 00:00:00 2001 From: bomeng Date: Sat, 9 Apr 2016 22:30:54 +0900 Subject: [SPARK-14496][SQL] fix some javadoc typos ## What changes were proposed in this pull request? Minor issues. Found 2 typos while browsing the code. ## How was this patch tested? None. Author: bomeng Closes #12264 from bomeng/SPARK-14496. --- .../src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala | 2 +- .../src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala index 0f65028261..cde8bd5b96 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/StringUtils.scala @@ -57,7 +57,7 @@ object StringUtils { * This utility can be used for filtering pattern in the "Like" of "Show Tables / Functions" DDL * @param names the names list to be filtered * @param pattern the filter pattern, only '*' and '|' are allowed as wildcards, others will - * follows regular expression convention, case insensitive match and white spaces + * follow regular expression convention, case insensitive match and white spaces * on both ends will be ignored * @return the filtered names list in order */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3de8aa0276..bf21c9d524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -475,7 +475,7 @@ class SparkSqlAstBuilder extends AstBuilder { * ALTER VIEW view ADD [IF NOT EXISTS] PARTITION spec * }}} * - * ALTER VIEW ... DROP PARTITION ... is not supported because the concept of partitioning + * ALTER VIEW ... ADD PARTITION ... is not supported because the concept of partitioning * is associated with physical tables */ override def visitAddTablePartition( -- cgit v1.2.3 From cd2fed70129ba601f8c849a93eeb44a5d69c2402 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Sat, 9 Apr 2016 13:54:30 -0700 Subject: [SPARK-14335][SQL] Describe function command returns wrong output MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? …because some of built-in functions are not in function registry. This fix tries to fix issues in `describe function` command where some of the outputs still shows Hive's function because some built-in functions are not in FunctionRegistry. The following built-in functions have been added to FunctionRegistry: ``` - ! * / & % ^ + < <= <=> = == > >= | ~ and in like not or rlike when ``` The following listed functions are not added, but hard coded in `commands.scala` (hvanhovell): ``` != <> between case ``` Below are the existing result of the above functions that have not been added: ``` spark-sql> describe function `!=`; Function: <> Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual Usage: a <> b - Returns TRUE if a is not equal to b ``` ``` spark-sql> describe function `<>`; Function: <> Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFOPNotEqual Usage: a <> b - Returns TRUE if a is not equal to b ``` ``` spark-sql> describe function `between`; Function: between Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFBetween Usage: between a [NOT] BETWEEN b AND c - evaluate if a is [not] in between b and c ``` ``` spark-sql> describe function `case`; Function: case Class: org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - When a = b, returns c; when a = d, return e; else return f ``` ## How was this patch tested? Existing tests passed. Additional test cases added. Author: Yong Tang Closes #12128 from yongtang/SPARK-14335. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 33 +++++++++++++++- .../spark/sql/execution/command/commands.scala | 44 +++++++++++++++------- .../spark/sql/hive/execution/SQLQuerySuite.scala | 30 +++++++++++---- 3 files changed, 86 insertions(+), 21 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f239b33e44..f2abf136da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -171,6 +171,7 @@ object FunctionRegistry { expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), + expression[CaseWhen]("when"), // math functions expression[Acos]("acos"), @@ -217,6 +218,12 @@ object FunctionRegistry { expression[Tan]("tan"), expression[Tanh]("tanh"), + expression[Add]("+"), + expression[Subtract]("-"), + expression[Multiply]("*"), + expression[Divide]("/"), + expression[Remainder]("%"), + // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), expression[Average]("avg"), @@ -257,6 +264,7 @@ object FunctionRegistry { expression[Lower]("lcase"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[Like]("like"), expression[Lower]("lower"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), @@ -267,6 +275,7 @@ object FunctionRegistry { expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), + expression[RLike]("rlike"), expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), @@ -343,7 +352,29 @@ object FunctionRegistry { expression[NTile]("ntile"), expression[Rank]("rank"), expression[DenseRank]("dense_rank"), - expression[PercentRank]("percent_rank") + expression[PercentRank]("percent_rank"), + + // predicates + expression[And]("and"), + expression[In]("in"), + expression[Not]("not"), + expression[Or]("or"), + + expression[EqualNullSafe]("<=>"), + expression[EqualTo]("="), + expression[EqualTo]("=="), + expression[GreaterThan](">"), + expression[GreaterThanOrEqual](">="), + expression[LessThan]("<"), + expression[LessThanOrEqual]("<="), + expression[Not]("!"), + + // bitwise + expression[BitwiseAnd]("&"), + expression[BitwiseNot]("~"), + expression[BitwiseOr]("|"), + expression[BitwiseXor]("^") + ) val builtin: SimpleFunctionRegistry = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala index 3fd2a93d29..5d00c805a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/commands.scala @@ -483,20 +483,38 @@ case class DescribeFunction( } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { - case Some(info) => - val result = - Row(s"Function: ${info.getName}") :: - Row(s"Class: ${info.getClassName}") :: - Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil - - if (isExtended) { - result :+ Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") - } else { - result - } + // Hard code "<>", "!=", "between", and "case" for now as there is no corresponding functions. + functionName.toLowerCase match { + case "<>" => + Row(s"Function: $functionName") :: + Row(s"Usage: a <> b - Returns TRUE if a is not equal to b") :: Nil + case "!=" => + Row(s"Function: $functionName") :: + Row(s"Usage: a != b - Returns TRUE if a is not equal to b") :: Nil + case "between" => + Row(s"Function: between") :: + Row(s"Usage: a [NOT] BETWEEN b AND c - " + + s"evaluate if a is [not] in between b and c") :: Nil + case "case" => + Row(s"Function: case") :: + Row(s"Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + s"When a = b, returns c; when a = d, return e; else return f") :: Nil + case _ => sqlContext.sessionState.functionRegistry.lookupFunction(functionName) match { + case Some(info) => + val result = + Row(s"Function: ${info.getName}") :: + Row(s"Class: ${info.getClassName}") :: + Row(s"Usage: ${replaceFunctionName(info.getUsage(), info.getName)}") :: Nil + + if (isExtended) { + result :+ + Row(s"Extended Usage:\n${replaceFunctionName(info.getExtended, info.getName)}") + } else { + result + } - case None => Seq(Row(s"Function: $functionName not found.")) + case None => Seq(Row(s"Function: $functionName not found.")) + } } } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 14a1d4cd30..d7ec85c15d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -203,8 +203,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkAnswer(sql("SHOW functions abc.abs"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) checkAnswer(sql("SHOW functions `abc`.`abs`"), Row("abs")) - // TODO: Re-enable this test after we fix SPARK-14335. - // checkAnswer(sql("SHOW functions `~`"), Row("~")) + checkAnswer(sql("SHOW functions `~`"), Row("~")) checkAnswer(sql("SHOW functions `a function doens't exist`"), Nil) checkAnswer(sql("SHOW functions `weekofyea*`"), Row("weekofyear")) // this probably will failed if we add more function with `sha` prefixing. @@ -236,11 +235,28 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkExistence(sql("describe functioN abcadf"), true, "Function: abcadf not found.") - // TODO: Re-enable this test after we fix SPARK-14335. - // checkExistence(sql("describe functioN `~`"), true, - // "Function: ~", - // "Class: org.apache.hadoop.hive.ql.udf.UDFOPBitNot", - // "Usage: ~ n - Bitwise not") + checkExistence(sql("describe functioN `~`"), true, + "Function: ~", + "Class: org.apache.spark.sql.catalyst.expressions.BitwiseNot", + "Usage: To be added.") + + // Hard coded describe functions + checkExistence(sql("describe function `<>`"), true, + "Function: <>", + "Usage: a <> b - Returns TRUE if a is not equal to b") + + checkExistence(sql("describe function `!=`"), true, + "Function: !=", + "Usage: a != b - Returns TRUE if a is not equal to b") + + checkExistence(sql("describe function `between`"), true, + "Function: between", + "Usage: a [NOT] BETWEEN b AND c - evaluate if a is [not] in between b and c") + + checkExistence(sql("describe function `case`"), true, + "Function: case", + "Usage: CASE a WHEN b THEN c [WHEN d THEN e]* [ELSE f] END - " + + "When a = b, returns c; when a = d, return e; else return f") } test("SPARK-5371: union with null and sum") { -- cgit v1.2.3 From dfce9665c4b2b29a19e6302216dae2800da68ff9 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sat, 9 Apr 2016 17:40:36 -0700 Subject: [SPARK-14362][SPARK-14406][SQL] DDL Native Support: Drop View and Drop Table #### What changes were proposed in this pull request? This PR is to provide a native support for DDL `DROP VIEW` and `DROP TABLE`. The PR includes native parsing and native analysis. Based on the HIVE DDL document for [DROP_VIEW_WEB_LINK](https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL#LanguageManualDDL- DropView ), `DROP VIEW` is defined as, **Syntax:** ```SQL DROP VIEW [IF EXISTS] [db_name.]view_name; ``` - to remove metadata for the specified view. - illegal to use DROP TABLE on a view. - illegal to use DROP VIEW on a table. - this command only works in `HiveContext`. In `SQLContext`, we will get an exception. This PR also handles `DROP TABLE`. **Syntax:** ```SQL DROP TABLE [IF EXISTS] table_name [PURGE]; ``` - Previously, the `DROP TABLE` command only can drop Hive tables in `HiveContext`. Now, after this PR, this command also can drop temporary table, external table, external data source table in `SQLContext`. - In `HiveContext`, we will not issue an exception if the to-be-dropped table does not exist and users did not specify `IF EXISTS`. Instead, we just log an error message. If `IF EXISTS` is specified, we will not issue any error message/exception. - In `SQLContext`, we will issue an exception if the to-be-dropped table does not exist, unless `IF EXISTS` is specified. - Data will not be deleted if the tables are `external`, unless table type is `managed_table`. #### How was this patch tested? For verifying command parsing, added test cases in `spark/sql/hive/HiveDDLCommandSuite.scala` For verifying command analysis, added test cases in `spark/sql/hive/execution/HiveDDLSuite.scala` Author: gatorsmile Author: xiaoli Author: Xiao Li Closes #12146 from gatorsmile/dropView. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/catalog/InMemoryCatalog.scala | 4 + .../sql/catalyst/catalog/SessionCatalog.scala | 29 +++- .../spark/sql/catalyst/catalog/interface.scala | 2 + .../sql/catalyst/catalog/SessionCatalogSuite.scala | 7 +- .../spark/sql/execution/SparkSqlParser.scala | 16 +++ .../apache/spark/sql/execution/command/ddl.scala | 54 ++++++- .../sql/execution/command/DDLCommandSuite.scala | 56 +++++++- .../spark/sql/execution/command/DDLSuite.scala | 48 +++++++ .../spark/sql/hive/thriftserver/CliSuite.scala | 6 +- .../spark/sql/hive/HiveExternalCatalog.scala | 4 + .../apache/spark/sql/hive/HiveSessionCatalog.scala | 2 + .../spark/sql/hive/execution/HiveSqlParser.scala | 13 -- .../apache/spark/sql/hive/execution/commands.scala | 30 ---- .../spark/sql/hive/HiveDDLCommandSuite.scala | 10 +- .../spark/sql/hive/execution/HiveDDLSuite.scala | 156 +++++++++++++++++++++ 16 files changed, 376 insertions(+), 63 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 85cb585919..2f2e060b38 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -104,6 +104,7 @@ statement REPLACE COLUMNS '(' colTypeList ')' (CASCADE | RESTRICT)? #replaceColumns | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? (FOR METADATA? REPLICATION '(' STRING ')')? #dropTable + | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable | CREATE (OR REPLACE)? VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? (PARTITIONED ON identifierList)? @@ -141,7 +142,6 @@ hiveNativeCommands | DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? - | DROP VIEW (IF EXISTS)? qualifiedName | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? | START TRANSACTION (transactionMode (',' transactionMode)*)? | COMMIT WORK? diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 186bbccef1..1994acd1ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -187,6 +187,10 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables(table).table } + override def getTableOption(db: String, table: String): Option[CatalogTable] = synchronized { + if (!tableExists(db, table)) None else Option(catalog(db).tables(table).table) + } + override def tableExists(db: String, table: String): Boolean = synchronized { requireDbExists(db) catalog(db).tables.contains(table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index 7db9fd0527..c1e5a485e7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -21,6 +21,7 @@ import java.io.File import scala.collection.mutable +import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{CatalystConf, SimpleCatalystConf} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} @@ -41,7 +42,7 @@ class SessionCatalog( externalCatalog: ExternalCatalog, functionResourceLoader: FunctionResourceLoader, functionRegistry: FunctionRegistry, - conf: CatalystConf) { + conf: CatalystConf) extends Logging { import ExternalCatalog._ def this( @@ -175,6 +176,17 @@ class SessionCatalog( externalCatalog.getTable(db, table) } + /** + * Retrieve the metadata of an existing metastore table. + * If no database is specified, assume the table is in the current database. + * If the specified table is not found in the database then return None if it doesn't exist. + */ + def getTableMetadataOption(name: TableIdentifier): Option[CatalogTable] = { + val db = name.database.getOrElse(currentDb) + val table = formatTableName(name.table) + externalCatalog.getTableOption(db, table) + } + // ------------------------------------------------------------- // | Methods that interact with temporary and metastore tables | // ------------------------------------------------------------- @@ -229,7 +241,13 @@ class SessionCatalog( val db = name.database.getOrElse(currentDb) val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { - externalCatalog.dropTable(db, table, ignoreIfNotExists) + // When ignoreIfNotExists is false, no exception is issued when the table does not exist. + // Instead, log it as an error message. This is consistent with Hive. + if (externalCatalog.tableExists(db, table)) { + externalCatalog.dropTable(db, table, ignoreIfNotExists = true) + } else if (!ignoreIfNotExists) { + logError(s"Table '${name.quotedString}' does not exist") + } } else { tempTables.remove(table) } @@ -283,9 +301,14 @@ class SessionCatalog( * explicitly specified. */ def isTemporaryTable(name: TableIdentifier): Boolean = { - !name.database.isDefined && tempTables.contains(formatTableName(name.table)) + name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } + /** + * Return whether View is supported + */ + def isViewSupported: Boolean = false + /** * List all tables in the specified database, including temporary tables. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index e29d6bd8b0..4ef59316ce 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -91,6 +91,8 @@ abstract class ExternalCatalog { def getTable(db: String, table: String): CatalogTable + def getTableOption(db: String, table: String): Option[CatalogTable] + def tableExists(db: String, table: String): Boolean def listTables(db: String): Seq[String] diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 1850dc8156..862fc275ad 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -233,10 +233,9 @@ class SessionCatalogSuite extends SparkFunSuite { intercept[AnalysisException] { catalog.dropTable(TableIdentifier("tbl1", Some("unknown_db")), ignoreIfNotExists = true) } - // Table does not exist - intercept[AnalysisException] { - catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) - } + // If the table does not exist, we do not issue an exception. Instead, we output an error log + // message to console when ignoreIfNotExists is set to false. + catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = false) catalog.dropTable(TableIdentifier("unknown_table", Some("db2")), ignoreIfNotExists = true) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index c8d0f4e3c5..3da715cdb3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -363,6 +363,22 @@ class SparkSqlAstBuilder extends AstBuilder { } } + /** + * Create a [[DropTable]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + if (ctx.PURGE != null) { + throw new ParseException("Unsupported operation: PURGE option", ctx) + } + if (ctx.REPLICATION != null) { + throw new ParseException("Unsupported operation: REPLICATION clause", ctx) + } + DropTable( + visitTableIdentifier(ctx.tableIdentifier), + ctx.EXISTS != null, + ctx.VIEW != null) + } + /** * Create a [[AlterTableRename]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 20779d68e0..e941736f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ @@ -175,13 +175,61 @@ case class DescribeDatabase( } } +/** + * Drops a table/view from the metastore and removes it if it is cached. + * + * The syntax of this command is: + * {{{ + * DROP TABLE [IF EXISTS] table_name; + * DROP VIEW [IF EXISTS] [db_name.]view_name; + * }}} + */ +case class DropTable( + tableName: TableIdentifier, + ifExists: Boolean, + isView: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (isView && !catalog.isViewSupported) { + throw new AnalysisException(s"Not supported object: views") + } + // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view + // issue an exception. + catalog.getTableMetadataOption(tableName).map(_.tableType match { + case CatalogTableType.VIRTUAL_VIEW if !isView => + throw new AnalysisException( + "Cannot drop a view with DROP TABLE. Please use DROP VIEW instead") + case o if o != CatalogTableType.VIRTUAL_VIEW && isView => + throw new AnalysisException( + s"Cannot drop a table with DROP VIEW. Please use DROP TABLE instead") + case _ => + }) + + try { + sqlContext.cacheManager.tryUncacheQuery(sqlContext.table(tableName.quotedString)) + } catch { + // This table's metadata is not in Hive metastore (e.g. the table does not exist). + case e if e.getClass.getName == "org.apache.hadoop.hive.ql.metadata.InvalidTableException" => + case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => + // Other Throwables can be caused by users providing wrong parameters in OPTIONS + // (e.g. invalid paths). We catch it and log a warning message. + // Users should be able to drop such kinds of tables regardless if there is an error. + case e: Throwable => log.warn(s"${e.getMessage}", e) + } + catalog.invalidateTable(tableName) + catalog.dropTable(tableName, ifExists) + Seq.empty[Row] + } +} + /** * A command that renames a table/view. * * The syntax of this command is: * {{{ - * ALTER TABLE table1 RENAME TO table2; - * ALTER VIEW view1 RENAME TO view2; + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; * }}} */ case class AlterTableRename( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index b1c1fd0951..ac69518ddf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -20,9 +20,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser -import org.apache.spark.sql.execution.datasources.BucketSpec import org.apache.spark.sql.types._ class DDLCommandSuite extends PlanTest { @@ -667,7 +666,10 @@ class DDLCommandSuite extends PlanTest { test("unsupported operations") { intercept[ParseException] { - parser.parsePlan("DROP TABLE D1.T1") + parser.parsePlan("DROP TABLE tab PURGE") + } + intercept[ParseException] { + parser.parsePlan("DROP TABLE tab FOR REPLICATION('eventid')") } intercept[ParseException] { parser.parsePlan("CREATE VIEW testView AS SELECT id FROM tab") @@ -700,4 +702,52 @@ class DDLCommandSuite extends PlanTest { val parsed = parser.parsePlan(sql) assert(parsed.isInstanceOf[Project]) } + + test("drop table") { + val tableName1 = "db.tab" + val tableName2 = "tab" + + val parsed1 = parser.parsePlan(s"DROP TABLE $tableName1") + val parsed2 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName1") + val parsed3 = parser.parsePlan(s"DROP TABLE $tableName2") + val parsed4 = parser.parsePlan(s"DROP TABLE IF EXISTS $tableName2") + + val expected1 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = false, isView = false) + val expected2 = + DropTable(TableIdentifier("tab", Option("db")), ifExists = true, isView = false) + val expected3 = + DropTable(TableIdentifier("tab", None), ifExists = false, isView = false) + val expected4 = + DropTable(TableIdentifier("tab", None), ifExists = true, isView = false) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } + + test("drop view") { + val viewName1 = "db.view" + val viewName2 = "view" + + val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1") + val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1") + val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2") + val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") + + val expected1 = + DropTable(TableIdentifier("view", Option("db")), ifExists = false, isView = true) + val expected2 = + DropTable(TableIdentifier("view", Option("db")), ifExists = true, isView = true) + val expected3 = + DropTable(TableIdentifier("view", None), ifExists = false, isView = true) + val expected4 = + DropTable(TableIdentifier("view", None), ifExists = true, isView = true) + + comparePlans(parsed1, expected1) + comparePlans(parsed2, expected2) + comparePlans(parsed3, expected3) + comparePlans(parsed4, expected4) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 7084665b3b..e75e5f5cb2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -391,6 +391,54 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { Nil) } + test("drop table - temporary table") { + val catalog = sqlContext.sessionState.catalog + sql( + """ + |CREATE TEMPORARY TABLE tab1 + |USING org.apache.spark.sql.sources.DDLScanSource + |OPTIONS ( + | From '1', + | To '10', + | Table 'test1' + |) + """.stripMargin) + assert(catalog.listTables("default") == Seq(TableIdentifier("tab1"))) + sql("DROP TABLE tab1") + assert(catalog.listTables("default") == Nil) + } + + test("drop table") { + testDropTable(isDatasourceTable = false) + } + + test("drop table - data source table") { + testDropTable(isDatasourceTable = true) + } + + private def testDropTable(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listTables("dbx") == Seq(tableIdent)) + sql("DROP TABLE dbx.tab1") + assert(catalog.listTables("dbx") == Nil) + sql("DROP TABLE IF EXISTS dbx.tab1") + // no exception will be thrown + sql("DROP TABLE dbx.tab1") + } + + test("drop view") { + val e = intercept[AnalysisException] { + sql("DROP VIEW dbx.tab1") + } + assert(e.getMessage.contains("Not supported object: views")) + } + private def convertToDatasourceTable( catalog: SessionCatalog, tableIdent: TableIdentifier): Unit = { diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index e93b0c145f..9ec8b9a9a6 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -172,7 +172,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SELECT COUNT(*) FROM hive_test;" -> "5", "DROP TABLE hive_test;" - -> "OK" + -> "" ) } @@ -220,9 +220,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "SELECT count(key) FROM t1;" -> "5", "DROP TABLE t1;" - -> "OK", + -> "", "DROP TABLE sourceTable;" - -> "OK" + -> "" ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index b1156fb3e2..a49ce33ba1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -182,6 +182,10 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat client.getTable(db, table) } + override def getTableOption(db: String, table: String): Option[CatalogTable] = withClient { + client.getTableOption(db, table) + } + override def tableExists(db: String, table: String): Boolean = withClient { client.getTableOption(db, table).isDefined } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 0cccc22e5a..875652c226 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -70,6 +70,8 @@ private[sql] class HiveSessionCatalog( } } + override def isViewSupported: Boolean = true + // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 657edb493a..7a435117e7 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -103,19 +103,6 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } - /** - * Create a [[DropTable]] command. - */ - override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { - if (ctx.PURGE != null) { - logWarning("PURGE option is ignored.") - } - if (ctx.REPLICATION != null) { - logWarning("REPLICATION clause is ignored.") - } - DropTable(visitTableIdentifier(ctx.tableIdentifier).toString, ctx.EXISTS != null) - } - /** * Create an [[AnalyzeTable]] command. This currently only implements the NOSCAN option (other * options are passed on to Hive) e.g.: diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala index 64d1341a47..06badff474 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/commands.scala @@ -46,36 +46,6 @@ case class AnalyzeTable(tableName: String) extends RunnableCommand { } } -/** - * Drops a table from the metastore and removes it if it is cached. - */ -private[hive] -case class DropTable( - tableName: String, - ifExists: Boolean) extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val hiveContext = sqlContext.asInstanceOf[HiveContext] - val ifExistsClause = if (ifExists) "IF EXISTS " else "" - try { - hiveContext.cacheManager.tryUncacheQuery(hiveContext.table(tableName)) - } catch { - // This table's metadata is not in Hive metastore (e.g. the table does not exist). - case _: org.apache.hadoop.hive.ql.metadata.InvalidTableException => - case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => - // Other Throwables can be caused by users providing wrong parameters in OPTIONS - // (e.g. invalid paths). We catch it and log a warning message. - // Users should be able to drop such kinds of tables regardless if there is an error. - case e: Throwable => log.warn(s"${e.getMessage}", e) - } - hiveContext.invalidateTable(tableName) - hiveContext.runSqlHive(s"DROP TABLE $ifExistsClause$tableName") - hiveContext.sessionState.catalog.dropTable( - TableIdentifier(tableName), ignoreIfNotExists = true) - Seq.empty[Row] - } -} - private[hive] case class AddJar(path: String) extends RunnableCommand { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index 12a582c10a..a144da4997 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -72,6 +72,7 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) // TODO will be SQLText assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) @@ -118,6 +119,7 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) // TODO will be SQLText assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) @@ -138,6 +140,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == @@ -173,6 +176,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.locationUri == None) assert(desc.schema == Seq.empty[CatalogColumn]) assert(desc.viewText == None) // TODO will be SQLText + assert(desc.viewOriginalText.isEmpty) assert(desc.storage.serdeProperties == Map(("serde_p1" -> "p1"), ("serde_p2" -> "p2"))) assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) @@ -286,7 +290,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use backticks in output of Script Transform") { - val plan = parser.parsePlan( + parser.parsePlan( """SELECT `t`.`thing1` |FROM (SELECT TRANSFORM (`parquet_t1`.`key`, `parquet_t1`.`value`) |USING 'cat' AS (`thing1` int, `thing2` string) FROM `default`.`parquet_t1`) AS t @@ -294,7 +298,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use backticks in output of Generator") { - val plan = parser.parsePlan( + parser.parsePlan( """ |SELECT `gentab2`.`gencol2` |FROM `default`.`src` @@ -304,7 +308,7 @@ class HiveDDLCommandSuite extends PlanTest { } test("use escaped backticks in output of Generator") { - val plan = parser.parsePlan( + parser.parsePlan( """ |SELECT `gen``tab2`.`gen``col2` |FROM `default`.`src` diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala new file mode 100644 index 0000000000..78ccdc7adb --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -0,0 +1,156 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import org.apache.hadoop.fs.Path + +import org.apache.spark.sql.{AnalysisException, QueryTest, SaveMode} +import org.apache.spark.sql.catalyst.catalog.CatalogTableType +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class HiveDDLSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { + import hiveContext.implicits._ + + // check if the directory for recording the data of the table exists. + private def tableDirectoryExists(tableIdentifier: TableIdentifier): Boolean = { + val expectedTablePath = + hiveContext.sessionState.catalog.hiveDefaultTableFilePath(tableIdentifier) + val filesystemPath = new Path(expectedTablePath) + val fs = filesystemPath.getFileSystem(sparkContext.hadoopConfiguration) + fs.exists(filesystemPath) + } + + test("drop tables") { + withTable("tab1") { + val tabName = "tab1" + + assert(!tableDirectoryExists(TableIdentifier(tabName))) + sql(s"CREATE TABLE $tabName(c1 int)") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP TABLE $tabName") + + assert(!tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP TABLE IF EXISTS $tabName") + sql(s"DROP VIEW IF EXISTS $tabName") + } + } + + test("drop managed tables") { + withTempDir { tmpDir => + val tabName = "tab1" + withTable(tabName) { + assert(tmpDir.listFiles.isEmpty) + sql( + s""" + |create table $tabName + |stored as parquet + |location '$tmpDir' + |as select 1, '3' + """.stripMargin) + + val hiveTable = + hiveContext.sessionState.catalog + .getTableMetadata(TableIdentifier(tabName, Some("default"))) + // It is a managed table, although it uses external in SQL + assert(hiveTable.tableType == CatalogTableType.MANAGED_TABLE) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + // The data are deleted since the table type is not EXTERNAL + assert(tmpDir.listFiles == null) + } + } + } + + test("drop external data source table") { + withTempDir { tmpDir => + val tabName = "tab1" + withTable(tabName) { + assert(tmpDir.listFiles.isEmpty) + + withSQLConf(SQLConf.PARQUET_WRITE_LEGACY_FORMAT.key -> "true") { + Seq(1 -> "a").toDF("i", "j") + .write + .mode(SaveMode.Overwrite) + .format("parquet") + .option("path", tmpDir.toString) + .saveAsTable(tabName) + } + + val hiveTable = + hiveContext.sessionState.catalog + .getTableMetadata(TableIdentifier(tabName, Some("default"))) + // This data source table is external table + assert(hiveTable.tableType == CatalogTableType.EXTERNAL_TABLE) + + assert(tmpDir.listFiles.nonEmpty) + sql(s"DROP TABLE $tabName") + // The data are not deleted since the table type is EXTERNAL + assert(tmpDir.listFiles.nonEmpty) + } + } + } + + test("drop views") { + withTable("tab1") { + val tabName = "tab1" + sqlContext.range(10).write.saveAsTable("tab1") + withView("view1") { + val viewName = "view1" + + assert(tableDirectoryExists(TableIdentifier(tabName))) + assert(!tableDirectoryExists(TableIdentifier(viewName))) + sql(s"CREATE VIEW $viewName AS SELECT * FROM tab1") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + assert(!tableDirectoryExists(TableIdentifier(viewName))) + sql(s"DROP VIEW $viewName") + + assert(tableDirectoryExists(TableIdentifier(tabName))) + sql(s"DROP VIEW IF EXISTS $viewName") + } + } + } + + test("drop table using drop view") { + withTable("tab1") { + sql("CREATE TABLE tab1(c1 int)") + val message = intercept[AnalysisException] { + sql("DROP VIEW tab1") + }.getMessage + assert(message.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) + } + } + + test("drop view using drop table") { + withTable("tab1") { + sqlContext.range(10).write.saveAsTable("tab1") + withView("view1") { + sql("CREATE VIEW view1 AS SELECT * FROM tab1") + val message = intercept[AnalysisException] { + sql("DROP TABLE view1") + }.getMessage + assert(message.contains("Cannot drop a view with DROP TABLE. Please use DROP VIEW instead")) + } + } + } +} -- cgit v1.2.3 From 3fb09afd5e55b9a7a0a332273f09f984a78c3645 Mon Sep 17 00:00:00 2001 From: Yin Huai Date: Sat, 9 Apr 2016 23:32:17 -0700 Subject: [SPARK-14506][SQL] HiveClientImpl's toHiveTable misses a table property for external tables ## What changes were proposed in this pull request? For an external table's metadata (in Hive's representation), its table type needs to be EXTERNAL_TABLE. Also, there needs to be a field called EXTERNAL set in the table property with a value of TRUE (for a MANAGED_TABLE it will be FALSE) based on https://github.com/apache/hive/blob/release-1.2.1/metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105. HiveClientImpl's toHiveTable misses to set this table property. ## How was this patch tested? Added a new test. Author: Yin Huai Closes #12275 from yhuai/SPARK-14506. --- .../spark/sql/catalyst/catalog/CatalogTestCases.scala | 9 +++++++++ .../org/apache/spark/sql/hive/client/HiveClientImpl.scala | 13 +++++++++++-- 2 files changed, 20 insertions(+), 2 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index fbcac09ce2..0009438b31 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -149,6 +149,15 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { // Tables // -------------------------------------------------------------------------- + test("the table type of an external table should be EXTERNAL_TABLE") { + val catalog = newBasicCatalog() + val table = + newTable("external_table1", "db2").copy(tableType = CatalogTableType.EXTERNAL_TABLE) + catalog.createTable("db2", table, ignoreIfExists = false) + val actual = catalog.getTable("db2", "external_table1") + assert(actual.tableType === CatalogTableType.EXTERNAL_TABLE) + } + test("drop table") { val catalog = newBasicCatalog() assert(catalog.listTables("db2").toSet == Set("tbl1", "tbl2")) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index d0eb9ddf50..a037671ef0 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -659,9 +659,18 @@ private[hive] class HiveClientImpl( private def toHiveTable(table: CatalogTable): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) + // For EXTERNAL_TABLE/MANAGED_TABLE, we also need to set EXTERNAL field in + // the table properties accodringly. Otherwise, if EXTERNAL_TABLE is the table type + // but EXTERNAL field is not set, Hive metastore will change the type to + // MANAGED_TABLE (see + // metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) hiveTable.setTableType(table.tableType match { - case CatalogTableType.EXTERNAL_TABLE => HiveTableType.EXTERNAL_TABLE - case CatalogTableType.MANAGED_TABLE => HiveTableType.MANAGED_TABLE + case CatalogTableType.EXTERNAL_TABLE => + hiveTable.setProperty("EXTERNAL", "TRUE") + HiveTableType.EXTERNAL_TABLE + case CatalogTableType.MANAGED_TABLE => + hiveTable.setProperty("EXTERNAL", "FALSE") + HiveTableType.MANAGED_TABLE case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW }) -- cgit v1.2.3 From a7ce473bd0520c71154ed028f295dab64a7485fe Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Sun, 10 Apr 2016 11:46:45 -0700 Subject: [SPARK-14415][SQL] All functions should show usages by command `DESC FUNCTION` ## What changes were proposed in this pull request? Currently, many functions do now show usages like the followings. ``` scala> sql("desc function extended `sin`").collect().foreach(println) [Function: sin] [Class: org.apache.spark.sql.catalyst.expressions.Sin] [Usage: To be added.] [Extended Usage: To be added.] ``` This PR adds descriptions for functions and adds a testcase prevent adding function without usage. ``` scala> sql("desc function extended `sin`").collect().foreach(println); [Function: sin] [Class: org.apache.spark.sql.catalyst.expressions.Sin] [Usage: sin(x) - Returns the sine of x.] [Extended Usage: > SELECT sin(0); 0.0] ``` The only exceptions are `cube`, `grouping`, `grouping_id`, `rollup`, `window`. ## How was this patch tested? Pass the Jenkins tests (including new testcases.) Author: Dongjoon Hyun Closes #12185 from dongjoon-hyun/SPARK-14415. --- .../catalyst/expressions/aggregate/Average.scala | 2 + .../expressions/aggregate/CentralMomentAgg.scala | 14 +++ .../sql/catalyst/expressions/aggregate/Corr.scala | 2 + .../sql/catalyst/expressions/aggregate/Count.scala | 6 + .../expressions/aggregate/Covariance.scala | 4 + .../sql/catalyst/expressions/aggregate/First.scala | 5 + .../aggregate/HyperLogLogPlusPlus.scala | 7 +- .../sql/catalyst/expressions/aggregate/Last.scala | 2 + .../sql/catalyst/expressions/aggregate/Max.scala | 2 + .../sql/catalyst/expressions/aggregate/Min.scala | 3 +- .../sql/catalyst/expressions/aggregate/Sum.scala | 2 + .../sql/catalyst/expressions/arithmetic.scala | 23 +++- .../catalyst/expressions/bitwiseExpressions.scala | 12 ++ .../expressions/collectionOperations.scala | 10 ++ .../catalyst/expressions/complexTypeCreator.scala | 10 ++ .../expressions/conditionalExpressions.scala | 13 ++- .../catalyst/expressions/datetimeExpressions.scala | 82 ++++++++++++- .../sql/catalyst/expressions/generators.scala | 4 + .../sql/catalyst/expressions/jsonExpressions.scala | 6 + .../sql/catalyst/expressions/mathExpressions.scala | 129 ++++++++++++++++++++- .../spark/sql/catalyst/expressions/misc.scala | 2 + .../sql/catalyst/expressions/nullExpressions.scala | 11 ++ .../sql/catalyst/expressions/predicates.scala | 29 +++-- .../catalyst/expressions/randomExpressions.scala | 4 + .../catalyst/expressions/regexpExpressions.scala | 14 ++- .../catalyst/expressions/stringExpressions.scala | 106 ++++++++++++++++- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 8 ++ .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 28 files changed, 489 insertions(+), 25 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala index 94ac4bf09b..ff70774847 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala @@ -23,6 +23,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the mean calculated from values of a group.") case class Average(child: Expression) extends DeclarativeAggregate { override def prettyName: String = "avg" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 9d2db45144..17a7c6dce8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -130,6 +130,10 @@ abstract class CentralMomentAgg(child: Expression) extends DeclarativeAggregate } // Compute the population standard deviation of a column +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population standard deviation calculated from values of a group.") +// scalastyle:on line.size.limit case class StddevPop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -143,6 +147,8 @@ case class StddevPop(child: Expression) extends CentralMomentAgg(child) { } // Compute the sample standard deviation of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample standard deviation calculated from values of a group.") case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -157,6 +163,8 @@ case class StddevSamp(child: Expression) extends CentralMomentAgg(child) { } // Compute the population variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the population variance calculated from values of a group.") case class VariancePop(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -170,6 +178,8 @@ case class VariancePop(child: Expression) extends CentralMomentAgg(child) { } // Compute the sample variance of a column +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sample variance calculated from values of a group.") case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 2 @@ -183,6 +193,8 @@ case class VarianceSamp(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "var_samp" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Skewness value calculated from values of a group.") case class Skewness(child: Expression) extends CentralMomentAgg(child) { override def prettyName: String = "skewness" @@ -196,6 +208,8 @@ case class Skewness(child: Expression) extends CentralMomentAgg(child) { } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the Kurtosis value calculated from values of a group.") case class Kurtosis(child: Expression) extends CentralMomentAgg(child) { override protected def momentOrder = 4 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala index e6b8214ef2..e29265e2f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Corr.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * Definition of Pearson correlation can be found at * http://en.wikipedia.org/wiki/Pearson_product-moment_correlation_coefficient */ +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns Pearson coefficient of correlation between a set of number pairs.") case class Corr(x: Expression, y: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = Seq(x, y) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala index 663c69e799..17ae012af7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Count.scala @@ -21,6 +21,12 @@ import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types._ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(*) - Returns the total number of retrieved rows, including rows containing NULL values. + _FUNC_(expr) - Returns the number of rows for which the supplied expression is non-NULL. + _FUNC_(DISTINCT expr[, expr...]) - Returns the number of rows for which the supplied expression(s) are unique and non-NULL.""") +// scalastyle:on line.size.limit case class Count(children: Seq[Expression]) extends DeclarativeAggregate { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala index c175a8c4c7..d80afbebf7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Covariance.scala @@ -76,6 +76,8 @@ abstract class Covariance(x: Expression, y: Expression) extends DeclarativeAggre } } +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the population covariance of a set of number pairs.") case class CovPopulation(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), @@ -85,6 +87,8 @@ case class CovPopulation(left: Expression, right: Expression) extends Covariance } +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the sample covariance of a set of number pairs.") case class CovSample(left: Expression, right: Expression) extends Covariance(left, right) { override val evaluateExpression: Expression = { If(n === Literal(0.0), Literal.create(null, DoubleType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala index 35f57426fe..b8ab0364dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/First.scala @@ -28,6 +28,11 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the first value of `child` for a group of rows. + _FUNC_(expr,isIgnoreNull=false) - Returns the first value of `child` for a group of rows. + If isIgnoreNull is true, returns only non-null values. + """) case class First(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index b6bd56cff6..1d218da6db 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -20,8 +20,6 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import java.lang.{Long => JLong} import java.util -import com.clearspring.analytics.hash.MurmurHash - import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -48,6 +46,11 @@ import org.apache.spark.sql.types._ * @param relativeSD the maximum estimation error allowed. */ // scalastyle:on +@ExpressionDescription( + usage = """_FUNC_(expr) - Returns the estimated cardinality by HyperLogLog++. + _FUNC_(expr, relativeSD=0.05) - Returns the estimated cardinality by HyperLogLog++ + with relativeSD, the maximum estimation error allowed. + """) case class HyperLogLogPlusPlus( child: Expression, relativeSD: Double = 0.05, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala index be7e12d7a2..b05d74b49b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Last.scala @@ -28,6 +28,8 @@ import org.apache.spark.sql.types._ * is used) its result will not be deterministic (unless the input table is sorted and has * a single partition, and we use a single reducer to do the aggregation.). */ +@ExpressionDescription( + usage = "_FUNC_(expr,isIgnoreNull) - Returns the last value of `child` for a group of rows.") case class Last(child: Expression, ignoreNullsExpr: Expression) extends DeclarativeAggregate { def this(child: Expression) = this(child, Literal.create(false, BooleanType)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala index 906003188d..c534fe495f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Max.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the maximum value of expr.") case class Max(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala index 39f7afbd08..35289b4681 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Min.scala @@ -22,7 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the minimum value of expr.") case class Min(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala index 08a67ea3df..ad217f25b5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala @@ -22,6 +22,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sum calculated from values of a group.") case class Sum(child: Expression) extends DeclarativeAggregate { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index b388091538..f3d42fc0b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -23,7 +23,8 @@ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval - +@ExpressionDescription( + usage = "_FUNC_(a) - Returns -a.") case class UnaryMinus(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { @@ -59,6 +60,8 @@ case class UnaryMinus(child: Expression) extends UnaryExpression override def sql: String = s"(-${child.sql})" } +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a.") case class UnaryPositive(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { override def prettyName: String = "positive" @@ -79,8 +82,8 @@ case class UnaryPositive(child: Expression) * A function that get the absolute value of the numeric value. */ @ExpressionDescription( - usage = "_FUNC_(expr) - Returns the absolute value of the numeric value", - extended = "> SELECT _FUNC_('-1');\n1") + usage = "_FUNC_(expr) - Returns the absolute value of the numeric value.", + extended = "> SELECT _FUNC_('-1');\n 1") case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes with NullIntolerant { @@ -126,6 +129,8 @@ private[sql] object BinaryArithmetic { def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right)) } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a+b.") case class Add(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def inputType: AbstractDataType = TypeCollection.NumericAndInterval @@ -155,6 +160,8 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit } } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns a-b.") case class Subtract(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -185,6 +192,8 @@ case class Subtract(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Multiplies a by b.") case class Multiply(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -198,6 +207,9 @@ case class Multiply(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = numeric.times(input1, input2) } +@ExpressionDescription( + usage = "a _FUNC_ b - Divides a by b.", + extended = "> SELECT 3 _FUNC_ 2;\n 1.5") case class Divide(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -275,6 +287,8 @@ case class Divide(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "a _FUNC_ b - Returns the remainder when dividing a by b.") case class Remainder(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { @@ -464,6 +478,9 @@ case class MinOf(left: Expression, right: Expression) override def symbol: String = "min" } +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns the positive modulo", + extended = "> SELECT _FUNC_(10,3);\n 1") case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic with NullIntolerant { override def toString: String = s"pmod($left, $right)" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 4c90b3f7d3..a7e1cd66f2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -26,6 +26,9 @@ import org.apache.spark.sql.types._ * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise AND.", + extended = "> SELECT 3 _FUNC_ 5; 1") case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -51,6 +54,9 @@ case class BitwiseAnd(left: Expression, right: Expression) extends BinaryArithme * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise OR.", + extended = "> SELECT 3 _FUNC_ 5; 7") case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -76,6 +82,9 @@ case class BitwiseOr(left: Expression, right: Expression) extends BinaryArithmet * * Code generation inherited from BinaryArithmetic. */ +@ExpressionDescription( + usage = "a _FUNC_ b - Bitwise exclusive OR.", + extended = "> SELECT 3 _FUNC_ 5; 2") case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithmetic { override def inputType: AbstractDataType = IntegralType @@ -99,6 +108,9 @@ case class BitwiseXor(left: Expression, right: Expression) extends BinaryArithme /** * A function that calculates bitwise not(~) of a number. */ +@ExpressionDescription( + usage = "_FUNC_ b - Bitwise NOT.", + extended = "> SELECT _FUNC_ 0; -1") case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(IntegralType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index e36c985249..ab790cf372 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -26,6 +26,8 @@ import org.apache.spark.sql.types._ /** * Given an array or map, returns its size. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Returns the size of an array or a map.") case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType)) @@ -44,6 +46,11 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType * Sorts the input array in ascending / descending order according to the natural ordering of * the array elements and returns it. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(array(obj1, obj2,...)) - Sorts the input array in ascending order according to the natural ordering of the array elements.", + extended = " > SELECT _FUNC_(array('b', 'd', 'c', 'a'));\n 'a', 'b', 'c', 'd'") +// scalastyle:on line.size.limit case class SortArray(base: Expression, ascendingOrder: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -125,6 +132,9 @@ case class SortArray(base: Expression, ascendingOrder: Expression) /** * Checks if the array (left) has the element (right) */ +@ExpressionDescription( + usage = "_FUNC_(array, value) - Returns TRUE if the array contains value.", + extended = " > SELECT _FUNC_(array(1, 2, 3), 2);\n true") case class ArrayContains(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index c299586dde..74de4a776d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -27,6 +27,8 @@ import org.apache.spark.unsafe.types.UTF8String /** * Returns an Array containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(n0, ...) - Returns an array with the given elements.") case class CreateArray(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -73,6 +75,8 @@ case class CreateArray(children: Seq[Expression]) extends Expression { * Returns a catalyst Map containing the evaluation of all children expressions as keys and values. * The children are a flatted sequence of kv pairs, e.g. (key1, value1, key2, value2, ...) */ +@ExpressionDescription( + usage = "_FUNC_(key0, value0, key1, value1...) - Creates a map with the given key/value pairs.") case class CreateMap(children: Seq[Expression]) extends Expression { private[sql] lazy val keys = children.indices.filter(_ % 2 == 0).map(children) private[sql] lazy val values = children.indices.filter(_ % 2 != 0).map(children) @@ -153,6 +157,8 @@ case class CreateMap(children: Seq[Expression]) extends Expression { /** * Returns a Row containing the evaluation of all children expressions. */ +@ExpressionDescription( + usage = "_FUNC_(col1, col2, col3, ...) - Creates a struct with the given field values.") case class CreateStruct(children: Seq[Expression]) extends Expression { override def foldable: Boolean = children.forall(_.foldable) @@ -204,6 +210,10 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { * * @param children Seq(name1, val1, name2, val2, ...) */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(name1, val1, name2, val2, ...) - Creates a struct with the given field names and values.") +// scalastyle:on line.size.limit case class CreateNamedStruct(children: Seq[Expression]) extends Expression { /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index 35a7b46020..ae6a94842f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -23,7 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util.TypeUtils import org.apache.spark.sql.types._ - +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr1,expr2,expr3) - If expr1 is TRUE then IF() returns expr2; otherwise it returns expr3.") +// scalastyle:on line.size.limit case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) extends Expression { @@ -85,6 +88,10 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi * @param branches seq of (branch condition, branch value) * @param elseValue optional value for the else branch */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "CASE WHEN a THEN b [WHEN c THEN d]* [ELSE e] END - When a = true, returns b; when c = true, return d; else return e.") +// scalastyle:on line.size.limit case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[Expression] = None) extends Expression with CodegenFallback { @@ -256,6 +263,8 @@ object CaseKeyWhen { * A function that returns the least value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the least value of all parameters, skipping null values.") case class Least(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) @@ -315,6 +324,8 @@ case class Least(children: Seq[Expression]) extends Expression { * A function that returns the greatest value of all parameters, skipping null values. * It takes at least 2 parameters, and returns null iff all parameters are null. */ +@ExpressionDescription( + usage = "_FUNC_(n1, ...) - Returns the greatest value of all parameters, skipping null values.") case class Greatest(children: Seq[Expression]) extends Expression { override def nullable: Boolean = children.forall(_.nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 1d0ea68d7a..9135753041 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -35,6 +35,8 @@ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current date at the start of query evaluation.") case class CurrentDate() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -54,6 +56,8 @@ case class CurrentDate() extends LeafExpression with CodegenFallback { * * There is no code generation since this expression should get constant folded by the optimizer. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns the current timestamp at the start of query evaluation.") case class CurrentTimestamp() extends LeafExpression with CodegenFallback { override def foldable: Boolean = true override def nullable: Boolean = false @@ -70,6 +74,9 @@ case class CurrentTimestamp() extends LeafExpression with CodegenFallback { /** * Adds a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days after start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-31'") case class DateAdd(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -96,6 +103,9 @@ case class DateAdd(startDate: Expression, days: Expression) /** * Subtracts a number of days to startdate. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_days) - Returns the date that is num_days before start_date.", + extended = "> SELECT _FUNC_('2016-07-30', 1);\n '2016-07-29'") case class DateSub(startDate: Expression, days: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = startDate @@ -118,6 +128,9 @@ case class DateSub(startDate: Expression, days: Expression) override def prettyName: String = "date_sub" } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the hour component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 12") case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -134,6 +147,9 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the minute component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 58") case class Minute(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -150,6 +166,9 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the second component of the string/timestamp/interval.", + extended = "> SELECT _FUNC_('2009-07-30 12:58:59');\n 59") case class Second(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) @@ -166,6 +185,9 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of year of date/timestamp.", + extended = "> SELECT _FUNC_('2016-04-09');\n 100") case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -182,7 +204,9 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas } } - +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the year component of the date/timestamp/interval.", + extended = "> SELECT _FUNC_('2016-07-30');\n 2016") case class Year(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -199,6 +223,8 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the quarter of the year for date, in the range 1 to 4.") case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -215,6 +241,9 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the month component of the date/timestamp/interval", + extended = "> SELECT _FUNC_('2016-07-30');\n 7") case class Month(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -231,6 +260,9 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the day of month of date/timestamp, or the day of interval.", + extended = "> SELECT _FUNC_('2009-07-30');\n 30") case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -247,6 +279,9 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa } } +@ExpressionDescription( + usage = "_FUNC_(param) - Returns the week of the year of the given date.", + extended = "> SELECT _FUNC_('2008-02-20');\n 8") case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(DateType) @@ -283,6 +318,11 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date/timestamp/string, fmt) - Converts a date/timestamp/string to a value of string in the format specified by the date format fmt.", + extended = "> SELECT _FUNC_('2016-04-08', 'y')\n '2016'") +// scalastyle:on line.size.limit case class DateFormatClass(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -310,6 +350,8 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx * Converts time string with given pattern. * Deterministic version of [[UnixTimestamp]], must have at least one parameter. */ +@ExpressionDescription( + usage = "_FUNC_(date[, pattern]) - Returns the UNIX timestamp of the give time.") case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -331,6 +373,8 @@ case class ToUnixTimestamp(timeExp: Expression, format: Expression) extends Unix * If the first parameter is a Date or Timestamp instead of String, we will ignore the * second parameter. */ +@ExpressionDescription( + usage = "_FUNC_([date[, pattern]]) - Returns the UNIX timestamp of current or specified time.") case class UnixTimestamp(timeExp: Expression, format: Expression) extends UnixTime { override def left: Expression = timeExp override def right: Expression = format @@ -459,6 +503,9 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { * format. If the format is missing, using format like "1970-01-01 00:00:00". * Note that hive Language Manual says it returns 0 if fail, but in fact it returns null. */ +@ExpressionDescription( + usage = "_FUNC_(unix_time, format) - Returns unix_time in the specified format", + extended = "> SELECT _FUNC_(0, 'yyyy-MM-dd HH:mm:ss');\n '1970-01-01 00:00:00'") case class FromUnixTime(sec: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -544,6 +591,9 @@ case class FromUnixTime(sec: Expression, format: Expression) /** * Returns the last day of the month which the date belongs to. */ +@ExpressionDescription( + usage = "_FUNC_(date) - Returns the last day of the month which the date belongs to.", + extended = "> SELECT _FUNC_('2009-01-12');\n '2009-01-31'") case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def child: Expression = startDate @@ -570,6 +620,11 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC * * Allowed "dayOfWeek" is defined in [[DateTimeUtils.getDayOfWeekFromString]]. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(start_date, day_of_week) - Returns the first date which is later than start_date and named as indicated.", + extended = "> SELECT _FUNC_('2015-01-14', 'TU');\n '2015-01-20'") +// scalastyle:on line.size.limit case class NextDay(startDate: Expression, dayOfWeek: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -654,6 +709,10 @@ case class TimeAdd(start: Expression, interval: Expression) /** * Assumes given timestamp is UTC and converts to given timezone. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is UTC and converts to given timezone.") +// scalastyle:on line.size.limit case class FromUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -729,6 +788,9 @@ case class TimeSub(start: Expression, interval: Expression) /** * Returns the date that is num_months after start_date. */ +@ExpressionDescription( + usage = "_FUNC_(start_date, num_months) - Returns the date that is num_months after start_date.", + extended = "> SELECT _FUNC_('2016-08-31', 1);\n '2016-09-30'") case class AddMonths(startDate: Expression, numMonths: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -756,6 +818,9 @@ case class AddMonths(startDate: Expression, numMonths: Expression) /** * Returns number of months between dates date1 and date2. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - returns number of months between dates date1 and date2.", + extended = "> SELECT _FUNC_('1997-02-28 10:30:00', '1996-10-30');\n 3.94959677") case class MonthsBetween(date1: Expression, date2: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -783,6 +848,10 @@ case class MonthsBetween(date1: Expression, date2: Expression) /** * Assumes given timestamp is in given timezone and converts to UTC. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(timestamp, string timezone) - Assumes given timestamp is in given timezone and converts to UTC.") +// scalastyle:on line.size.limit case class ToUTCTimestamp(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -830,6 +899,9 @@ case class ToUTCTimestamp(left: Expression, right: Expression) /** * Returns the date part of a timestamp or string. */ +@ExpressionDescription( + usage = "_FUNC_(expr) - Extracts the date part of the date or datetime expression expr.", + extended = "> SELECT _FUNC_('2009-07-30 04:17:52');\n '2009-07-30'") case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { // Implicit casting of spark will accept string in both date and timestamp format, as @@ -850,6 +922,11 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Returns date truncated to the unit specified by the format. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(date, fmt) - Returns returns date with the time portion of the day truncated to the unit specified by the format model fmt.", + extended = "> SELECT _FUNC_('2009-02-12', 'MM')\n '2009-02-01'\n> SELECT _FUNC_('2015-10-27', 'YEAR');\n '2015-01-01'") +// scalastyle:on line.size.limit case class TruncDate(date: Expression, format: Expression) extends BinaryExpression with ImplicitCastInputTypes { override def left: Expression = date @@ -921,6 +998,9 @@ case class TruncDate(date: Expression, format: Expression) /** * Returns the number of days from startDate to endDate. */ +@ExpressionDescription( + usage = "_FUNC_(date1, date2) - Returns the number of days between date1 and date2.", + extended = "> SELECT _FUNC_('2009-07-30', '2009-07-31');\n 1") case class DateDiff(endDate: Expression, startDate: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala index e7ef21aa85..65d7a1d5a0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala @@ -99,6 +99,10 @@ case class UserDefinedGenerator( /** * Given an input array produces a sequence of rows for each value in the array. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(a) - Separates the elements of array a into multiple rows, or the elements of a map into multiple rows and columns.") +// scalastyle:on line.size.limit case class Explode(child: Expression) extends UnaryExpression with Generator with CodegenFallback { override def children: Seq[Expression] = child :: Nil diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index 72b323587c..ecd09b7083 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -106,6 +106,8 @@ private[this] object SharedFactory { * Extracts json object from a json string based on json path specified, and returns json string * of the extracted json object. It will return null if the input json string is invalid. */ +@ExpressionDescription( + usage = "_FUNC_(json_txt, path) - Extract a json object from path") case class GetJsonObject(json: Expression, path: Expression) extends BinaryExpression with ExpectsInputTypes with CodegenFallback { @@ -319,6 +321,10 @@ case class GetJsonObject(json: Expression, path: Expression) } } +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - like get_json_object, but it takes multiple names and return a tuple. All the input parameters and output column types are string.") +// scalastyle:on line.size.limit case class JsonTuple(children: Seq[Expression]) extends Generator with CodegenFallback { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index e3d1bc127d..c8a28e8477 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -50,6 +50,7 @@ abstract class LeafMathExpression(c: Double, name: String) /** * A unary expression specifically for math functions. Math Functions expect a specific type of * input format, therefore these functions extend `ExpectsInputTypes`. + * * @param f The math function. * @param name The short name of the function */ @@ -103,6 +104,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) /** * A binary expression specifically for math functions that take two `Double`s as input and returns * a `Double`. + * * @param f The math function. * @param name The short name of the function */ @@ -136,12 +138,18 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) * Euler's number. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns Euler's number, E.", + extended = "> SELECT _FUNC_();\n 2.718281828459045") case class EulerNumber() extends LeafMathExpression(math.E, "E") /** * Pi. Note that there is no code generation because this is only * evaluated by the optimizer during constant folding. */ +@ExpressionDescription( + usage = "_FUNC_() - Returns PI.", + extended = "> SELECT _FUNC_();\n 3.141592653589793") case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -150,14 +158,29 @@ case class Pi() extends LeafMathExpression(math.Pi, "PI") //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc cosine of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(1);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Acos(child: Expression) extends UnaryMathExpression(math.acos, "ACOS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc sin of x if -1<=x<=1 or NaN otherwise.", + extended = "> SELECT _FUNC_(0);\n 0.0\n> SELECT _FUNC_(2);\n NaN") case class Asin(child: Expression) extends UnaryMathExpression(math.asin, "ASIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the arc tangent.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Atan(child: Expression) extends UnaryMathExpression(math.atan, "ATAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cube root of a double value.", + extended = "> SELECT _FUNC_(27.0);\n 3.0") case class Cbrt(child: Expression) extends UnaryMathExpression(math.cbrt, "CBRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the smallest integer not smaller than x.", + extended = "> SELECT _FUNC_(-0.1);\n 0\n> SELECT _FUNC_(5);\n 5") case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -184,16 +207,26 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cos(child: Expression) extends UnaryMathExpression(math.cos, "COS") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic cosine of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH") /** * Convert a num from one base to another + * * @param numExpr the number to be converted * @param fromBaseExpr from which base * @param toBaseExpr to which base */ +@ExpressionDescription( + usage = "_FUNC_(num, from_base, to_base) - Convert num from from_base to to_base.", + extended = "> SELECT _FUNC_('100', 2, 10);\n '4'\n> SELECT _FUNC_(-10, 16, -10);\n '16'") case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -222,10 +255,19 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns e to the power of x.", + extended = "> SELECT _FUNC_(0);\n 1.0") case class Exp(child: Expression) extends UnaryMathExpression(math.exp, "EXP") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns exp(x) - 1.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Expm1(child: Expression) extends UnaryMathExpression(math.expm1, "EXPM1") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the largest integer not greater than x.", + extended = "> SELECT _FUNC_(-0.1);\n -1\n> SELECT _FUNC_(5);\n 5") case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLOOR") { override def dataType: DataType = child.dataType match { case dt @ DecimalType.Fixed(_, 0) => dt @@ -283,6 +325,9 @@ object Factorial { ) } +@ExpressionDescription( + usage = "_FUNC_(n) - Returns n factorial for n is [0..20]. Otherwise, NULL.", + extended = "> SELECT _FUNC_(5);\n 120") case class Factorial(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[DataType] = Seq(IntegerType) @@ -315,8 +360,14 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the natural logarithm of x with base e.", + extended = "> SELECT _FUNC_(1);\n 0.0") case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 2.", + extended = "> SELECT _FUNC_(2);\n 1.0") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { override def genCode(ctx: CodegenContext, ev: ExprCode): String = { @@ -332,36 +383,72 @@ case class Log2(child: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the logarithm of x with base 10.", + extended = "> SELECT _FUNC_(10);\n 1.0") case class Log10(child: Expression) extends UnaryLogExpression(math.log10, "LOG10") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns log(1 + x).", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Log1p(child: Expression) extends UnaryLogExpression(math.log1p, "LOG1P") { protected override val yAsymptote: Double = -1.0 } +@ExpressionDescription( + usage = "_FUNC_(x, d) - Return the rounded x at d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Rint(child: Expression) extends UnaryMathExpression(math.rint, "ROUND") { override def funcName: String = "rint" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sign of x.", + extended = "> SELECT _FUNC_(40);\n 1.0") case class Signum(child: Expression) extends UnaryMathExpression(math.signum, "SIGNUM") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sin(child: Expression) extends UnaryMathExpression(math.sin, "SIN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic sine of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Sinh(child: Expression) extends UnaryMathExpression(math.sinh, "SINH") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the square root of x.", + extended = "> SELECT _FUNC_(4);\n 2.0") case class Sqrt(child: Expression) extends UnaryMathExpression(math.sqrt, "SQRT") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tan(child: Expression) extends UnaryMathExpression(math.tan, "TAN") +@ExpressionDescription( + usage = "_FUNC_(x) - Returns the hyperbolic tangent of x.", + extended = "> SELECT _FUNC_(0);\n 0.0") case class Tanh(child: Expression) extends UnaryMathExpression(math.tanh, "TANH") +@ExpressionDescription( + usage = "_FUNC_(x) - Converts radians to degrees.", + extended = "> SELECT _FUNC_(3.141592653589793);\n 180.0") case class ToDegrees(child: Expression) extends UnaryMathExpression(math.toDegrees, "DEGREES") { override def funcName: String = "toDegrees" } +@ExpressionDescription( + usage = "_FUNC_(x) - Converts degrees to radians.", + extended = "> SELECT _FUNC_(180);\n 3.141592653589793") case class ToRadians(child: Expression) extends UnaryMathExpression(math.toRadians, "RADIANS") { override def funcName: String = "toRadians" } +@ExpressionDescription( + usage = "_FUNC_(x) - Returns x in binary.", + extended = "> SELECT _FUNC_(13);\n '1101'") case class Bin(child: Expression) extends UnaryExpression with Serializable with ImplicitCastInputTypes { @@ -453,6 +540,9 @@ object Hex { * Otherwise if the number is a STRING, it converts each character into its hex representation * and returns the resulting STRING. Negative numbers would be treated as two's complement. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Convert the argument to hexadecimal.", + extended = "> SELECT _FUNC_(17);\n '11'\n> SELECT _FUNC_('Spark SQL');\n '537061726B2053514C'") case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = @@ -481,6 +571,9 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput * Performs the inverse operation of HEX. * Resulting characters are returned as a byte array. */ +@ExpressionDescription( + usage = "_FUNC_(x) - Converts hexadecimal argument to binary.", + extended = "> SELECT decode(_FUNC_('537061726B2053514C'),'UTF-8');\n 'Spark SQL'") case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(StringType) @@ -509,7 +602,9 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp //////////////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////////////// - +@ExpressionDescription( + usage = "_FUNC_(x,y) - Returns the arc tangent2.", + extended = "> SELECT _FUNC_(0, 0);\n 0.0") case class Atan2(left: Expression, right: Expression) extends BinaryMathExpression(math.atan2, "ATAN2") { @@ -523,6 +618,9 @@ case class Atan2(left: Expression, right: Expression) } } +@ExpressionDescription( + usage = "_FUNC_(x1, x2) - Raise x1 to the power of x2.", + extended = "> SELECT _FUNC_(2, 3);\n 8.0") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { override def genCode(ctx: CodegenContext, ev: ExprCode): String = { @@ -532,10 +630,14 @@ case class Pow(left: Expression, right: Expression) /** - * Bitwise unsigned left shift. + * Bitwise left shift. + * * @param left the base number to shift. * @param right number of bits to left shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise left shift.", + extended = "> SELECT _FUNC_(2, 1);\n 4") case class ShiftLeft(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -558,10 +660,14 @@ case class ShiftLeft(left: Expression, right: Expression) /** - * Bitwise unsigned left shift. + * Bitwise right shift. + * * @param left the base number to shift. - * @param right number of bits to left shift. + * @param right number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRight(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -585,9 +691,13 @@ case class ShiftRight(left: Expression, right: Expression) /** * Bitwise unsigned right shift, for integer and long data type. + * * @param left the base number. * @param right the number of bits to right shift. */ +@ExpressionDescription( + usage = "_FUNC_(a, b) - Bitwise unsigned right shift.", + extended = "> SELECT _FUNC_(4, 1);\n 2") case class ShiftRightUnsigned(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -608,16 +718,22 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "_FUNC_(a, b) - Returns sqrt(a**2 + b**2).", + extended = "> SELECT _FUNC_(3, 4);\n 5.0") case class Hypot(left: Expression, right: Expression) extends BinaryMathExpression(math.hypot, "HYPOT") /** * Computes the logarithm of a number. + * * @param left the logarithm base, default to e. * @param right the number to compute the logarithm of. */ +@ExpressionDescription( + usage = "_FUNC_(b, x) - Returns the logarithm of x with base b.", + extended = "> SELECT _FUNC_(10, 100);\n 2.0") case class Logarithm(left: Expression, right: Expression) extends BinaryMathExpression((c1, c2) => math.log(c2) / math.log(c1), "LOG") { @@ -674,6 +790,9 @@ case class Logarithm(left: Expression, right: Expression) * @param child expr to be round, all [[NumericType]] is allowed as Input * @param scale new scale to be round to, this should be a constant int at runtime */ +@ExpressionDescription( + usage = "_FUNC_(x, d) - Round x to d decimal places.", + extended = "> SELECT _FUNC_(12.3456, 1);\n 12.3") case class Round(child: Expression, scale: Expression) extends BinaryExpression with ImplicitCastInputTypes { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index eb8dc1423a..4bd918ed01 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -438,6 +438,8 @@ abstract class InterpretedHashFunction { * We should use this hash function for both shuffle and bucket, so that we can guarantee shuffle * and bucketing have same data distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns a hash value of the arguments.") case class Murmur3Hash(children: Seq[Expression], seed: Int) extends HashExpression[Int] { def this(arguments: Seq[Expression]) = this(arguments, 42) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index e22026d584..6a45249943 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -34,6 +34,9 @@ import org.apache.spark.sql.types._ * coalesce(null, null, null) => null * }}} */ +@ExpressionDescription( + usage = "_FUNC_(a1, a2, ...) - Returns the first non-null argument if exists. Otherwise, NULL.", + extended = "> SELECT _FUNC_(NULL, 1, NULL);\n 1") case class Coalesce(children: Seq[Expression]) extends Expression { /** Coalesce is nullable if all of its children are nullable, or if it has no children. */ @@ -89,6 +92,8 @@ case class Coalesce(children: Seq[Expression]) extends Expression { /** * Evaluates to `true` iff it's NaN. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NaN and false otherwise.") case class IsNaN(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes { @@ -126,6 +131,8 @@ case class IsNaN(child: Expression) extends UnaryExpression * An Expression evaluates to `left` iff it's not NaN, or evaluates to `right` otherwise. * This Expression is useful for mapping NaN values to null. */ +@ExpressionDescription( + usage = "_FUNC_(a,b) - Returns a iff it's not NaN, or b otherwise.") case class NaNvl(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -180,6 +187,8 @@ case class NaNvl(left: Expression, right: Expression) /** * An expression that is evaluated to true if the input is null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is NULL and false otherwise.") case class IsNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false @@ -201,6 +210,8 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { /** * An expression that is evaluated to true if the input is not null. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns true if a is not NULL and false otherwise.") case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { override def nullable: Boolean = false diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index 4eb33258ac..38f1210a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -88,7 +88,8 @@ trait PredicateHelper { expr.references.subsetOf(plan.outputSet) } - +@ExpressionDescription( + usage = "_FUNC_ a - Logical not") case class Not(child: Expression) extends UnaryExpression with Predicate with ImplicitCastInputTypes with NullIntolerant { @@ -109,6 +110,8 @@ case class Not(child: Expression) /** * Evaluates to `true` if `list` contains `value`. */ +@ExpressionDescription( + usage = "expr _FUNC_(val1, val2, ...) - Returns true if expr equals to any valN.") case class In(value: Expression, list: Seq[Expression]) extends Predicate with ImplicitCastInputTypes { @@ -243,6 +246,8 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with } } +@ExpressionDescription( + usage = "a _FUNC_ b - Logical AND.") case class And(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType @@ -306,7 +311,8 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Logical OR.") case class Or(left: Expression, right: Expression) extends BinaryOperator with Predicate { override def inputType: AbstractDataType = BooleanType @@ -401,7 +407,8 @@ private[sql] object Equality { } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a equals b and false otherwise.") case class EqualTo(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -426,7 +433,9 @@ case class EqualTo(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = """a _FUNC_ b - Returns same result with EQUAL(=) operator for non-null operands, + but returns TRUE if both are NULL, FALSE if one of the them is NULL.""") case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComparison { override def inputType: AbstractDataType = AnyDataType @@ -467,7 +476,8 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is less than b.") case class LessThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -480,7 +490,8 @@ case class LessThan(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lt(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not greater than b.") case class LessThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -493,7 +504,8 @@ case class LessThanOrEqual(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.lteq(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is greater than b.") case class GreaterThan(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { @@ -506,7 +518,8 @@ case class GreaterThan(left: Expression, right: Expression) protected override def nullSafeEval(input1: Any, input2: Any): Any = ordering.gt(input1, input2) } - +@ExpressionDescription( + usage = "a _FUNC_ b - Returns TRUE if a is not smaller than b.") case class GreaterThanOrEqual(left: Expression, right: Expression) extends BinaryComparison with NullIntolerant { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 6be3cbcae6..1ec092a5be 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -55,6 +55,8 @@ abstract class RDG extends LeafExpression with Nondeterministic { } /** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. uniformly distributed values in [0, 1).") case class Rand(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextDouble() @@ -78,6 +80,8 @@ case class Rand(seed: Long) extends RDG { } /** Generate a random column with i.i.d. gaussian random distribution. */ +@ExpressionDescription( + usage = "_FUNC_(a) - Returns a random column with i.i.d. gaussian random distribution.") case class Randn(seed: Long) extends RDG { override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index b68009331b..85a5429263 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -67,6 +67,8 @@ trait StringRegexExpression extends ImplicitCastInputTypes { /** * Simple RegEx pattern matching function */ +@ExpressionDescription( + usage = "str _FUNC_ pattern - Returns true if str matches pattern and false otherwise.") case class Like(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { @@ -117,7 +119,8 @@ case class Like(left: Expression, right: Expression) } } - +@ExpressionDescription( + usage = "str _FUNC_ regexp - Returns true if str matches regexp and false otherwise.") case class RLike(left: Expression, right: Expression) extends BinaryExpression with StringRegexExpression { @@ -169,6 +172,9 @@ case class RLike(left: Expression, right: Expression) /** * Splits str around pat (pattern is a regular expression). */ +@ExpressionDescription( + usage = "_FUNC_(str, regex) - Splits str around occurrences that match regex", + extended = "> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');\n ['one', 'two', 'three']") case class StringSplit(str: Expression, pattern: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -198,6 +204,9 @@ case class StringSplit(str: Expression, pattern: Expression) * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp, rep) - replace all substrings of str that match regexp with rep.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)', 'num');\n 'num-num'") case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -289,6 +298,9 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ +@ExpressionDescription( + usage = "_FUNC_(str, regexp[, idx]) - extracts a group that matches regexp.", + extended = "> SELECT _FUNC_('100-200', '(\\d+)-(\\d+)', 1);\n '100'") case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 7e0e7a833b..a17482697d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -35,6 +35,9 @@ import org.apache.spark.unsafe.types.{ByteArray, UTF8String} * An expression that concatenates multiple input strings into a single string. * If any input is null, concat returns null. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2, ..., strN) - Returns the concatenation of str1, str2, ..., strN", + extended = "> SELECT _FUNC_('Spark','SQL');\n 'SparkSQL'") case class Concat(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.size)(StringType) @@ -70,6 +73,10 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas * * Returns null if the separator is null. Otherwise, concat_ws skips all null values. */ +@ExpressionDescription( + usage = + "_FUNC_(sep, [str | array(str)]+) - Returns the concatenation of the strings separated by sep.", + extended = "> SELECT _FUNC_(' ', Spark', 'SQL');\n 'Spark SQL'") case class ConcatWs(children: Seq[Expression]) extends Expression with ImplicitCastInputTypes { @@ -188,7 +195,7 @@ case class Upper(child: Expression) */ @ExpressionDescription( usage = "_FUNC_(str) - Returns str with all characters changed to lowercase", - extended = "> SELECT _FUNC_('SparkSql');\n'sparksql'") + extended = "> SELECT _FUNC_('SparkSql');\n 'sparksql'") case class Lower(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.toLowerCase @@ -270,6 +277,11 @@ object StringTranslate { * The translate will happen when any character in the string matching with the character * in the `matchingExpr`. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(input, from, to) - Translates the input string by replacing the characters present in the from string with the corresponding characters in the to string""", + extended = "> SELECT _FUNC_('AaBbCc', 'abc', '123');\n 'A1B2C3'") +// scalastyle:on line.size.limit case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -325,6 +337,12 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac * delimited list (right). Returns 0, if the string wasn't found or if the given * string (left) contains a comma. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, str_array) - Returns the index (1-based) of the given string (left) in the comma-delimited list (right). + Returns 0, if the string wasn't found or if the given string (left) contains a comma.""", + extended = "> SELECT _FUNC_('ab','abc,b,ab,c,def');\n 3") +// scalastyle:on case class FindInSet(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -347,6 +365,9 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi /** * A function that trim the spaces from both ends for the specified string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading and trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL'") case class StringTrim(child: Expression) extends UnaryExpression with String2StringExpression { @@ -362,6 +383,9 @@ case class StringTrim(child: Expression) /** * A function that trim the spaces from left end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the leading space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n 'SparkSQL '") case class StringTrimLeft(child: Expression) extends UnaryExpression with String2StringExpression { @@ -377,6 +401,9 @@ case class StringTrimLeft(child: Expression) /** * A function that trim the spaces from right end for given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Removes the trailing space characters from str.", + extended = "> SELECT _FUNC_(' SparkSQL ');\n ' SparkSQL'") case class StringTrimRight(child: Expression) extends UnaryExpression with String2StringExpression { @@ -396,6 +423,9 @@ case class StringTrimRight(child: Expression) * * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +@ExpressionDescription( + usage = "_FUNC_(str, substr) - Returns the (1-based) index of the first occurrence of substr in str.", + extended = "> SELECT _FUNC_('SparkSQL', 'SQL');\n 6") case class StringInstr(str: Expression, substr: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -422,6 +452,15 @@ case class StringInstr(str: Expression, substr: Expression) * returned. If count is negative, every to the right of the final delimiter (counting from the * right) is returned. substring_index performs a case-sensitive match when searching for delim. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(str, delim, count) - Returns the substring from str before count occurrences of the delimiter delim. + If count is positive, everything to the left of the final delimiter (counting from the + left) is returned. If count is negative, everything to the right of the final delimiter + (counting from the right) is returned. Substring_index performs a case-sensitive match + when searching for delim.""", + extended = "> SELECT _FUNC_('www.apache.org', '.', 2);\n 'www.apache'") +// scalastyle:on line.size.limit case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -445,6 +484,12 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: * A function that returns the position of the first occurrence of substr * in given string after position pos. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = """_FUNC_(substr, str[, pos]) - Returns the position of the first occurrence of substr in str after position pos. + The given pos and return value are 1-based.""", + extended = "> SELECT _FUNC_('bar', 'foobarbar', 5);\n 7") +// scalastyle:on line.size.limit case class StringLocate(substr: Expression, str: Expression, start: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -510,6 +555,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) /** * Returns str, left-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, left-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n '???hi'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringLPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -531,6 +581,11 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) /** * Returns str, right-padded with pad to a length of len. */ +@ExpressionDescription( + usage = """_FUNC_(str, len, pad) - Returns str, right-padded with pad to a length of len. + If str is longer than len, the return value is shortened to len characters.""", + extended = "> SELECT _FUNC_('hi', 5, '??');\n 'hi???'\n" + + "> SELECT _FUNC_('hi', 1, '??');\n 'h'") case class StringRPad(str: Expression, len: Expression, pad: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -552,6 +607,11 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) /** * Returns the input formatted according do printf-style format strings */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(String format, Obj... args) - Returns a formatted string from printf-style format strings.", + extended = "> SELECT _FUNC_(\"Hello World %d %s\", 100, \"days\");\n 'Hello World 100 days'") +// scalastyle:on line.size.limit case class FormatString(children: Expression*) extends Expression with ImplicitCastInputTypes { require(children.nonEmpty, "format_string() should take at least 1 argument") @@ -642,6 +702,9 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI /** * Returns the string which repeat the given string value n times. */ +@ExpressionDescription( + usage = "_FUNC_(str, n) - Returns the string which repeat the given string value n times.", + extended = "> SELECT _FUNC_('123', 2);\n '123123'") case class StringRepeat(str: Expression, times: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -664,6 +727,9 @@ case class StringRepeat(str: Expression, times: Expression) /** * Returns the reversed given string. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the reversed given string.", + extended = "> SELECT _FUNC_('Spark SQL');\n 'LQS krapS'") case class StringReverse(child: Expression) extends UnaryExpression with String2StringExpression { override def convert(v: UTF8String): UTF8String = v.reverse() @@ -677,6 +743,9 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 /** * Returns a n spaces string. */ +@ExpressionDescription( + usage = "_FUNC_(n) - Returns a n spaces string.", + extended = "> SELECT _FUNC_(2);\n ' '") case class StringSpace(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { @@ -699,7 +768,14 @@ case class StringSpace(child: Expression) /** * A function that takes a substring of its first argument starting at a given position. * Defined for String and Binary types. + * + * NOTE: that this is not zero based, but 1-based index. The first character in str has index 1. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(str, pos[, len]) - Returns the substring of str that starts at pos and is of length len or the slice of byte array that starts at pos and is of length len.", + extended = "> SELECT _FUNC_('Spark SQL', 5);\n 'k SQL'\n> SELECT _FUNC_('Spark SQL', -3);\n 'SQL'\n> SELECT _FUNC_('Spark SQL', 5, 1);\n 'k'") +// scalastyle:on line.size.limit case class Substring(str: Expression, pos: Expression, len: Expression) extends TernaryExpression with ImplicitCastInputTypes { @@ -737,6 +813,9 @@ case class Substring(str: Expression, pos: Expression, len: Expression) /** * A function that return the length of the given string or binary expression. */ +@ExpressionDescription( + usage = "_FUNC_(str | binary) - Returns the length of str or number of bytes in binary data.", + extended = "> SELECT _FUNC_('Spark SQL');\n 9") case class Length(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = IntegerType override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, BinaryType)) @@ -757,6 +836,9 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy /** * A function that return the Levenshtein distance between the two given strings. */ +@ExpressionDescription( + usage = "_FUNC_(str1, str2) - Returns the Levenshtein distance between the two given strings.", + extended = "> SELECT _FUNC_('kitten', 'sitting');\n 3") case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -775,6 +857,9 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres /** * A function that return soundex code of the given string expression. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns soundex code of the string.", + extended = "> SELECT _FUNC_('Miller');\n 'M460'") case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def dataType: DataType = StringType @@ -791,6 +876,10 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT /** * Returns the numeric value of the first character of str. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Returns the numeric value of the first character of str.", + extended = "> SELECT _FUNC_('222');\n 50\n" + + "> SELECT _FUNC_(2);\n 50") case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = IntegerType @@ -822,6 +911,8 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp /** * Converts the argument from binary to a base 64 string. */ +@ExpressionDescription( + usage = "_FUNC_(bin) - Convert the argument from binary to a base 64 string.") case class Base64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = StringType @@ -844,6 +935,8 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn /** * Converts the argument from a base 64 string to BINARY. */ +@ExpressionDescription( + usage = "_FUNC_(str) - Convert the argument from a base 64 string to binary.") case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { override def dataType: DataType = BinaryType @@ -865,6 +958,8 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. */ +@ExpressionDescription( + usage = "_FUNC_(bin, str) - Decode the first argument using the second argument character set.") case class Decode(bin: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -894,7 +989,9 @@ case class Decode(bin: Expression, charset: Expression) * Encodes the first argument into a BINARY using the provided character set * (one of 'US-ASCII', 'ISO-8859-1', 'UTF-8', 'UTF-16BE', 'UTF-16LE', 'UTF-16'). * If either argument is null, the result will also be null. -*/ + */ +@ExpressionDescription( + usage = "_FUNC_(str, str) - Encode the first argument using the second argument character set.") case class Encode(value: Expression, charset: Expression) extends BinaryExpression with ImplicitCastInputTypes { @@ -924,6 +1021,11 @@ case class Encode(value: Expression, charset: Expression) * and returns the result as a string. If D is 0, the result has no decimal point or * fractional part. */ +@ExpressionDescription( + usage = """_FUNC_(X, D) - Formats the number X like '#,###,###.##', rounded to D decimal places. + If D is 0, the result has no decimal point or fractional part. + This is supposed to function like MySQL's FORMAT.""", + extended = "> SELECT _FUNC_(12332.123456, 4);\n '12,332.1235'") case class FormatNumber(x: Expression, d: Expression) extends BinaryExpression with ExpectsInputTypes { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index dd648cdb81..695dda269a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -89,6 +89,14 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { "Function: abcadf not found.") } + test("SPARK-14415: All functions should have own descriptions") { + for (f <- sqlContext.sessionState.functionRegistry.listFunction()) { + if (!Seq("cube", "grouping", "grouping_id", "rollup", "window").contains(f)) { + checkExistence(sql(s"describe function `$f`"), false, "To be added.") + } + } + } + test("SPARK-6743: no columns from cache") { Seq( (83, 0, 38), diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index f3796a9966..b4886eba7a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -238,7 +238,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { checkExistence(sql("describe functioN `~`"), true, "Function: ~", "Class: org.apache.spark.sql.catalyst.expressions.BitwiseNot", - "Usage: To be added.") + "Usage: ~ b - Bitwise NOT.") // Hard coded describe functions checkExistence(sql("describe function `<>`"), true, -- cgit v1.2.3 From 9f838bd24242866a687a2655a1b8ac2f5d562526 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 10 Apr 2016 20:46:15 -0700 Subject: [SPARK-14362][SPARK-14406][SQL][FOLLOW-UP] DDL Native Support: Drop View and Drop Table #### What changes were proposed in this pull request? This PR is to address the comment: https://github.com/apache/spark/pull/12146#discussion-diff-59092238. It removes the function `isViewSupported` from `SessionCatalog`. After the removal, we still can capture the user errors if users try to drop a table using `DROP VIEW`. #### How was this patch tested? Modified the existing test cases Author: gatorsmile Closes #12284 from gatorsmile/followupDropTable. --- R/pkg/inst/tests/testthat/test_sparkSQL.R | 2 +- .../org/apache/spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../apache/spark/sql/catalyst/analysis/CheckAnalysis.scala | 2 +- .../spark/sql/catalyst/analysis/NoSuchItemException.scala | 2 +- .../spark/sql/catalyst/catalog/InMemoryCatalog.scala | 4 ++-- .../apache/spark/sql/catalyst/catalog/SessionCatalog.scala | 9 ++------- .../scala/org/apache/spark/sql/execution/command/ddl.scala | 3 --- .../test/scala/org/apache/spark/sql/SQLQuerySuite.scala | 4 ++-- .../org/apache/spark/sql/execution/command/DDLSuite.scala | 14 ++++++++++++-- .../org/apache/spark/sql/hive/thriftserver/CliSuite.scala | 2 +- .../org/apache/spark/sql/hive/HiveSessionCatalog.scala | 2 -- .../apache/spark/sql/hive/execution/HiveCommandSuite.scala | 2 +- 12 files changed, 24 insertions(+), 24 deletions(-) (limited to 'sql/catalyst') diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 22eb3ec984..d747d4f83f 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1853,7 +1853,7 @@ test_that("approxQuantile() on a DataFrame", { test_that("SQL error message is returned from JVM", { retError <- tryCatch(sql(sqlContext, "select * from blah"), error = function(e) e) - expect_equal(grepl("Table not found", retError), TRUE) + expect_equal(grepl("Table or View not found", retError), TRUE) expect_equal(grepl("blah", retError), TRUE) }) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 3555a6d7fa..de40ddde1b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -409,7 +409,7 @@ class Analyzer( catalog.lookupRelation(u.tableIdentifier, u.alias) } catch { case _: NoSuchTableException => - u.failAnalysis(s"Table not found: ${u.tableName}") + u.failAnalysis(s"Table or View not found: ${u.tableName}") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 4880502398..d6a8c3eec8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -52,7 +52,7 @@ trait CheckAnalysis { case p if p.analyzed => // Skip already analyzed sub-plans case u: UnresolvedRelation => - u.failAnalysis(s"Table not found: ${u.tableIdentifier}") + u.failAnalysis(s"Table or View not found: ${u.tableIdentifier}") case operator: LogicalPlan => operator transformExpressionsUp { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index e9f04eecf8..96fd1a027e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -33,7 +33,7 @@ class NoSuchDatabaseException(db: String) extends NoSuchItemException { } class NoSuchTableException(db: String, table: String) extends NoSuchItemException { - override def getMessage: String = s"Table $table not found in database $db" + override def getMessage: String = s"Table or View $table not found in database $db" } class NoSuchPartitionException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala index 1994acd1ad..f8a6fb74cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/InMemoryCatalog.scala @@ -62,7 +62,7 @@ class InMemoryCatalog extends ExternalCatalog { private def requireTableExists(db: String, table: String): Unit = { if (!tableExists(db, table)) { throw new AnalysisException( - s"Table not found: '$table' does not exist in database '$db'") + s"Table or View not found: '$table' does not exist in database '$db'") } } @@ -164,7 +164,7 @@ class InMemoryCatalog extends ExternalCatalog { catalog(db).tables.remove(table) } else { if (!ignoreIfNotExists) { - throw new AnalysisException(s"Table '$table' does not exist in database '$db'") + throw new AnalysisException(s"Table or View '$table' does not exist in database '$db'") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index c1e5a485e7..34e1cb7315 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -242,11 +242,11 @@ class SessionCatalog( val table = formatTableName(name.table) if (name.database.isDefined || !tempTables.contains(table)) { // When ignoreIfNotExists is false, no exception is issued when the table does not exist. - // Instead, log it as an error message. This is consistent with Hive. + // Instead, log it as an error message. if (externalCatalog.tableExists(db, table)) { externalCatalog.dropTable(db, table, ignoreIfNotExists = true) } else if (!ignoreIfNotExists) { - logError(s"Table '${name.quotedString}' does not exist") + logError(s"Table or View '${name.quotedString}' does not exist") } } else { tempTables.remove(table) @@ -304,11 +304,6 @@ class SessionCatalog( name.database.isEmpty && tempTables.contains(formatTableName(name.table)) } - /** - * Return whether View is supported - */ - def isViewSupported: Boolean = false - /** * List all tables in the specified database, including temporary tables. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index e941736f9a..8a37cf8f4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -191,9 +191,6 @@ case class DropTable( override def run(sqlContext: SQLContext): Seq[Row] = { val catalog = sqlContext.sessionState.catalog - if (isView && !catalog.isViewSupported) { - throw new AnalysisException(s"Not supported object: views") - } // If the command DROP VIEW is to drop a table or DROP TABLE is to drop a view // issue an exception. catalog.getTableMetadataOption(tableName).map(_.tableType match { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 695dda269a..cdd404d699 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1827,12 +1827,12 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { val e1 = intercept[AnalysisException] { sql("select * from in_valid_table") } - assert(e1.message.contains("Table not found")) + assert(e1.message.contains("Table or View not found")) val e2 = intercept[AnalysisException] { sql("select * from no_db.no_table").show() } - assert(e2.message.contains("Table not found")) + assert(e2.message.contains("Table or View not found")) val e3 = intercept[AnalysisException] { sql("select * from json.invalid_file") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index e75e5f5cb2..c6479bf33e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -432,11 +432,21 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { sql("DROP TABLE dbx.tab1") } - test("drop view") { + test("drop view in SQLContext") { + // SQLContext does not support create view. Log an error message, if tab1 does not exists + sql("DROP VIEW tab1") + + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + assert(catalog.listTables("dbx") == Seq(tableIdent)) + val e = intercept[AnalysisException] { sql("DROP VIEW dbx.tab1") } - assert(e.getMessage.contains("Not supported object: views")) + assert( + e.getMessage.contains("Cannot drop a table with DROP VIEW. Please use DROP TABLE instead")) } private def convertToDatasourceTable( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index 9ec8b9a9a6..bfc3d195ff 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -230,7 +230,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(timeout = 2.minute, errorResponses = Seq("AnalysisException"))( "select * from nonexistent_table;" - -> "Error in query: Table not found: nonexistent_table;" + -> "Error in query: Table or View not found: nonexistent_table;" ) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala index 875652c226..0cccc22e5a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionCatalog.scala @@ -70,8 +70,6 @@ private[sql] class HiveSessionCatalog( } } - override def isViewSupported: Boolean = true - // ---------------------------------------------------------------- // | Methods and fields for interacting with HiveMetastoreCatalog | // ---------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala index 8de2bdcfc0..061d1512a5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveCommandSuite.scala @@ -96,7 +96,7 @@ class HiveCommandSuite extends QueryTest with SQLTestUtils with TestHiveSingleto val message1 = intercept[AnalysisException] { sql("SHOW TBLPROPERTIES badtable") }.getMessage - assert(message1.contains("Table badtable not found in database default")) + assert(message1.contains("Table or View badtable not found in database default")) // When key is not found, a row containing the error is returned. checkAnswer( -- cgit v1.2.3 From 652c4703099c1e5b17732fa019318358dc99ad50 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 11 Apr 2016 09:43:16 -0700 Subject: [SPARK-14528] [SQL] Fix same result of Union ## What changes were proposed in this pull request? This PR fix resultResult() for Union. ## How was this patch tested? Added regression test. Author: Davies Liu Closes #12295 from davies/fix_sameResult. --- .../main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala | 7 +++---- .../org/apache/spark/sql/catalyst/plans/SameResultSuite.scala | 7 ++++++- 2 files changed, 9 insertions(+), 5 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 0a11574f44..d4447ca32d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -312,18 +312,17 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** Args that have cleaned such that differences in expression id should not affect equality */ protected lazy val cleanArgs: Seq[Any] = { def cleanArg(arg: Any): Any = arg match { + // Children are checked using sameResult above. + case tn: TreeNode[_] if containsChild(tn) => null case e: Expression => cleanExpression(e).canonicalized case other => other } productIterator.map { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanArg(e) case s: Option[_] => s.map(cleanArg) case s: Seq[_] => s.map(cleanArg) case m: Map[_, _] => m.mapValues(cleanArg) - case other => other + case other => cleanArg(other) }.toSeq } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala index 37941cf34e..467f76193c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/SameResultSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.plans import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Union} import org.apache.spark.sql.catalyst.util._ /** @@ -61,4 +61,9 @@ class SameResultSuite extends SparkFunSuite { test("sorts") { assertSameResult(testRelation.orderBy('a.asc), testRelation2.orderBy('a.asc)) } + + test("union") { + assertSameResult(Union(Seq(testRelation, testRelation2)), + Union(Seq(testRelation2, testRelation))) + } } -- cgit v1.2.3 From 5de26194a3aaeab9b7a8323107f614126c90441f Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Mon, 11 Apr 2016 09:52:50 -0700 Subject: [SPARK-14502] [SQL] Add optimization for Binary Comparison Simplification ## What changes were proposed in this pull request? We can simplifies binary comparisons with semantically-equal operands: 1. Replace '<=>' with 'true' literal. 2. Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. 3. Replace '<' and '>' with 'false' literal if both operands are non-nullable. For example, the following example plan ``` scala> sql("SELECT * FROM (SELECT explode(array(1,2,3)) a) T WHERE a BETWEEN a AND a+7").explain() ... : +- Filter ((a#59 >= a#59) && (a#59 <= (a#59 + 7))) ... ``` will be optimized into the following. ``` : +- Filter (a#47 <= (a#47 + 7)) ``` ## How was this patch tested? Pass the Jenkins tests including new `BinaryComparisonSimplificationSuite`. Author: Dongjoon Hyun Closes #12267 from dongjoon-hyun/SPARK-14502. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 24 ++++++ .../BinaryComparisonSimplificationSuite.scala | 95 ++++++++++++++++++++++ 2 files changed, 119 insertions(+) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 619514e8aa..bad115d22f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -86,6 +86,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { BooleanSimplification, SimplifyConditionals, RemoveDispensableExpressions, + BinaryComparisonSimplification, PruneFilters, EliminateSorts, SimplifyCasts, @@ -786,6 +787,29 @@ object BooleanSimplification extends Rule[LogicalPlan] with PredicateHelper { } } +/** + * Simplifies binary comparisons with semantically-equal expressions: + * 1) Replace '<=>' with 'true' literal. + * 2) Replace '=', '<=', and '>=' with 'true' literal if both operands are non-nullable. + * 3) Replace '<' and '>' with 'false' literal if both operands are non-nullable. + */ +object BinaryComparisonSimplification extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsUp { + // True with equality + case a EqualNullSafe b if a.semanticEquals(b) => TrueLiteral + case a EqualTo b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + case a GreaterThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => + TrueLiteral + case a LessThanOrEqual b if !a.nullable && !b.nullable && a.semanticEquals(b) => TrueLiteral + + // False with inequality + case a GreaterThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + case a LessThan b if !a.nullable && !b.nullable && a.semanticEquals(b) => FalseLiteral + } + } +} + /** * Simplifies conditional expressions (if / case). */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala new file mode 100644 index 0000000000..7cd038570b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/BinaryComparisonSimplificationSuite.scala @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.analysis._ +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.Literal.{FalseLiteral, TrueLiteral} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules._ + +class BinaryComparisonSimplificationSuite extends PlanTest with PredicateHelper { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("AnalysisNodes", Once, + EliminateSubqueryAliases) :: + Batch("Constant Folding", FixedPoint(50), + NullPropagation, + ConstantFolding, + BooleanSimplification, + BinaryComparisonSimplification, + PruneFilters) :: Nil + } + + val nullableRelation = LocalRelation('a.int.withNullability(true)) + val nonNullableRelation = LocalRelation('a.int.withNullability(false)) + + test("Preserve nullable exprs in general") { + for (e <- Seq('a === 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a)) { + val plan = nullableRelation.where(e).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + } + + test("Preserve non-deterministic exprs") { + val plan = nonNullableRelation + .where(Rand(0) === Rand(0) && Rand(1) <=> Rand(1)).analyze + val actual = Optimize.execute(plan) + val correctAnswer = plan + comparePlans(actual, correctAnswer) + } + + test("Nullable Simplification Primitive: <=>") { + val plan = nullableRelation.select('a <=> 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nullableRelation.select(Alias(TrueLiteral, "(a <=> a)")()).analyze + comparePlans(actual, correctAnswer) + } + + test("Non-Nullable Simplification Primitive") { + val plan = nonNullableRelation + .select('a === 'a, 'a <=> 'a, 'a <= 'a, 'a >= 'a, 'a < 'a, 'a > 'a).analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation + .select( + Alias(TrueLiteral, "(a = a)")(), + Alias(TrueLiteral, "(a <=> a)")(), + Alias(TrueLiteral, "(a <= a)")(), + Alias(TrueLiteral, "(a >= a)")(), + Alias(FalseLiteral, "(a < a)")(), + Alias(FalseLiteral, "(a > a)")()) + .analyze + comparePlans(actual, correctAnswer) + } + + test("Expression Normalization") { + val plan = nonNullableRelation.where( + 'a * Literal(100) + Pi() === Pi() + Literal(100) * 'a && + DateAdd(CurrentDate(), 'a + Literal(2)) <= DateAdd(CurrentDate(), Literal(2) + 'a)) + .analyze + val actual = Optimize.execute(plan) + val correctAnswer = nonNullableRelation.analyze + comparePlans(actual, correctAnswer) + } +} -- cgit v1.2.3 From 83fb96403bcfb1566e9d765690744824724737ac Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Mon, 11 Apr 2016 20:59:45 -0700 Subject: [SPARK-14132][SPARK-14133][SQL] Alter table partition DDLs ## What changes were proposed in this pull request? This implements a few alter table partition commands using the `SessionCatalog`. In particular: ``` ALTER TABLE ... ADD PARTITION ... ALTER TABLE ... DROP PARTITION ... ALTER TABLE ... RENAME PARTITION ... TO ... ``` The following operations are not supported, and an `AnalysisException` with a helpful error message will be thrown if the user tries to use them: ``` ALTER TABLE ... EXCHANGE PARTITION ... ALTER TABLE ... ARCHIVE PARTITION ... ALTER TABLE ... UNARCHIVE PARTITION ... ALTER TABLE ... TOUCH ... ALTER TABLE ... COMPACT ... ALTER TABLE ... CONCATENATE MSCK REPAIR TABLE ... ``` ## How was this patch tested? `DDLSuite`, `DDLCommandSuite` and `HiveDDLCommandSuite` Author: Andrew Or Closes #12220 from andrewor14/alter-partition-ddl. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 6 +- .../sql/catalyst/catalog/CatalogTestCases.scala | 18 ++- .../sql/catalyst/catalog/SessionCatalogSuite.scala | 28 +++- .../spark/sql/execution/SparkSqlParser.scala | 56 +++---- .../apache/spark/sql/execution/command/ddl.scala | 104 ++++++++----- .../sql/execution/command/DDLCommandSuite.scala | 118 +++++---------- .../spark/sql/execution/command/DDLSuite.scala | 162 ++++++++++++++++++++- .../hive/execution/HiveCompatibilitySuite.scala | 10 +- .../spark/sql/hive/HiveExternalCatalog.scala | 21 +-- .../apache/spark/sql/hive/client/HiveClient.scala | 9 +- .../spark/sql/hive/client/HiveClientImpl.scala | 22 ++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 12 ++ 12 files changed, 360 insertions(+), 206 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 2f2e060b38..0e2cd39448 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -148,7 +148,7 @@ hiveNativeCommands | ROLLBACK WORK? | SHOW PARTITIONS tableIdentifier partitionSpec? | DFS .*? - | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | MSCK | LOAD) .*? + | (CREATE | ALTER | DROP | SHOW | DESC | DESCRIBE | LOAD) .*? ; unsupportedHiveNativeCommands @@ -177,6 +177,7 @@ unsupportedHiveNativeCommands | kw1=UNLOCK kw2=DATABASE | kw1=CREATE kw2=TEMPORARY kw3=MACRO | kw1=DROP kw2=TEMPORARY kw3=MACRO + | kw1=MSCK kw2=REPAIR kw3=TABLE ; createTableHeader @@ -651,7 +652,7 @@ nonReserved | AFTER | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT | INPUTDRIVER | OUTPUTDRIVER | DBPROPERTIES | DFS | TRUNCATE | METADATA | REPLICATION | COMPUTE | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE + | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION ; @@ -867,6 +868,7 @@ GRANT: 'GRANT'; LOCK: 'LOCK'; UNLOCK: 'UNLOCK'; MSCK: 'MSCK'; +REPAIR: 'REPAIR'; EXPORT: 'EXPORT'; IMPORT: 'IMPORT'; LOAD: 'LOAD'; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 0009438b31..0d9b0851fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -281,31 +281,37 @@ abstract class CatalogTestCases extends SparkFunSuite with BeforeAndAfterEach { test("drop partitions") { val catalog = newBasicCatalog() assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part1, part2))) - catalog.dropPartitions("db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "tbl2", Seq(part1.spec), ignoreIfNotExists = false) assert(catalogPartitionsEqual(catalog, "db2", "tbl2", Seq(part2))) resetState() val catalog2 = newBasicCatalog() assert(catalogPartitionsEqual(catalog2, "db2", "tbl2", Seq(part1, part2))) - catalog2.dropPartitions("db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + catalog2.dropPartitions( + "db2", "tbl2", Seq(part1.spec, part2.spec), ignoreIfNotExists = false) assert(catalog2.listPartitions("db2", "tbl2").isEmpty) } test("drop partitions when database/table does not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.dropPartitions("does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) + catalog.dropPartitions( + "does_not_exist", "tbl1", Seq(), ignoreIfNotExists = false) } intercept[AnalysisException] { - catalog.dropPartitions("db2", "does_not_exist", Seq(), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "does_not_exist", Seq(), ignoreIfNotExists = false) } } test("drop partitions that do not exist") { val catalog = newBasicCatalog() intercept[AnalysisException] { - catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = false) } - catalog.dropPartitions("db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) + catalog.dropPartitions( + "db2", "tbl2", Seq(part3.spec), ignoreIfNotExists = true) } test("get partition") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 862fc275ad..426273e1e3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -496,19 +496,25 @@ class SessionCatalogSuite extends SparkFunSuite { val sessionCatalog = new SessionCatalog(externalCatalog) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec), + ignoreIfNotExists = false) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part2))) // Drop partitions without explicitly specifying database sessionCatalog.setCurrentDatabase("db2") sessionCatalog.dropPartitions( - TableIdentifier("tbl2"), Seq(part2.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2"), + Seq(part2.spec), + ignoreIfNotExists = false) assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) // Drop multiple partitions at once sessionCatalog.createPartitions( TableIdentifier("tbl2", Some("db2")), Seq(part1, part2), ignoreIfExists = false) assert(catalogPartitionsEqual(externalCatalog, "db2", "tbl2", Seq(part1, part2))) sessionCatalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part1.spec, part2.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part1.spec, part2.spec), + ignoreIfNotExists = false) assert(externalCatalog.listPartitions("db2", "tbl2").isEmpty) } @@ -516,11 +522,15 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("tbl1", Some("does_not_exist")), Seq(), ignoreIfNotExists = false) + TableIdentifier("tbl1", Some("does_not_exist")), + Seq(), + ignoreIfNotExists = false) } intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("does_not_exist", Some("db2")), Seq(), ignoreIfNotExists = false) + TableIdentifier("does_not_exist", Some("db2")), + Seq(), + ignoreIfNotExists = false) } } @@ -528,10 +538,14 @@ class SessionCatalogSuite extends SparkFunSuite { val catalog = new SessionCatalog(newBasicCatalog()) intercept[AnalysisException] { catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = false) + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = false) } catalog.dropPartitions( - TableIdentifier("tbl2", Some("db2")), Seq(part3.spec), ignoreIfNotExists = true) + TableIdentifier("tbl2", Some("db2")), + Seq(part3.spec), + ignoreIfNotExists = true) } test("get partition") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 3da715cdb3..73d9640c35 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -493,7 +493,9 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitAddTablePartition( ctx: AddTablePartitionContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } // Create partition spec to location mapping. val specsAndLocs = if (ctx.partitionSpec.isEmpty) { ctx.partitionSpecLocation.asScala.map { @@ -509,8 +511,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableAddPartition( visitTableIdentifier(ctx.tableIdentifier), specsAndLocs, - ctx.EXISTS != null)( - command(ctx)) + ctx.EXISTS != null) } /** @@ -523,11 +524,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitExchangeTablePartition( ctx: ExchangeTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableExchangePartition( - visitTableIdentifier(ctx.from), - visitTableIdentifier(ctx.to), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... EXCHANGE PARTITION ...") } /** @@ -543,8 +541,7 @@ class SparkSqlAstBuilder extends AstBuilder { AlterTableRenamePartition( visitTableIdentifier(ctx.tableIdentifier), visitNonOptionalPartitionSpec(ctx.from), - visitNonOptionalPartitionSpec(ctx.to))( - command(ctx)) + visitNonOptionalPartitionSpec(ctx.to)) } /** @@ -561,13 +558,16 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitDropTablePartitions( ctx: DropTablePartitionsContext): LogicalPlan = withOrigin(ctx) { - if (ctx.VIEW != null) throw new ParseException(s"Operation not allowed: partitioned views", ctx) + if (ctx.VIEW != null) { + throw new AnalysisException(s"Operation not allowed: partitioned views") + } + if (ctx.PURGE != null) { + throw new AnalysisException(s"Operation not allowed: PURGE") + } AlterTableDropPartition( visitTableIdentifier(ctx.tableIdentifier), ctx.partitionSpec.asScala.map(visitNonOptionalPartitionSpec), - ctx.EXISTS != null, - ctx.PURGE != null)( - command(ctx)) + ctx.EXISTS != null) } /** @@ -580,10 +580,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitArchiveTablePartition( ctx: ArchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableArchivePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... ARCHIVE PARTITION ...") } /** @@ -596,10 +594,8 @@ class SparkSqlAstBuilder extends AstBuilder { */ override def visitUnarchiveTablePartition( ctx: UnarchiveTablePartitionContext): LogicalPlan = withOrigin(ctx) { - AlterTableUnarchivePartition( - visitTableIdentifier(ctx.tableIdentifier), - visitNonOptionalPartitionSpec(ctx.partitionSpec))( - command(ctx)) + throw new AnalysisException( + "Operation not allowed: ALTER TABLE ... UNARCHIVE PARTITION ...") } /** @@ -658,10 +654,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitTouchTable(ctx: TouchTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableTouch( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... TOUCH ...") } /** @@ -673,11 +666,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitCompactTable(ctx: CompactTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableCompact( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), - string(ctx.STRING))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... COMPACT ...") } /** @@ -689,10 +678,7 @@ class SparkSqlAstBuilder extends AstBuilder { * }}} */ override def visitConcatenateTable(ctx: ConcatenateTableContext): LogicalPlan = withOrigin(ctx) { - AlterTableMerge( - visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec))( - command(ctx)) + throw new AnalysisException("Operation not allowed: ALTER TABLE ... CONCATENATE") } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 8a37cf8f4c..c55b1a690e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.execution.command import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable} +import org.apache.spark.sql.catalyst.catalog.{CatalogTablePartition, CatalogTableType} import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ @@ -348,53 +349,94 @@ case class AlterTableSerDeProperties( } /** - * Add Partition in ALTER TABLE/VIEW: add the table/view partitions. + * Add Partition in ALTER TABLE: add the table partitions. + * * 'partitionSpecsAndLocs': the syntax of ALTER VIEW is identical to ALTER TABLE, * EXCEPT that it is ILLEGAL to specify a LOCATION clause. * An error message will be issued if the partition exists, unless 'ifNotExists' is true. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table ADD [IF NOT EXISTS] PARTITION spec [LOCATION 'loc1'] + * }}} */ case class AlterTableAddPartition( tableName: TableIdentifier, partitionSpecsAndLocs: Seq[(TablePartitionSpec, Option[String])], - ifNotExists: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + ifNotExists: Boolean) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table add partition is not allowed for tables defined using the datasource API") + } + val parts = partitionSpecsAndLocs.map { case (spec, location) => + // inherit table storage format (possibly except for location) + CatalogTablePartition(spec, table.storage.copy(locationUri = location)) + } + catalog.createPartitions(tableName, parts, ignoreIfExists = ifNotExists) + Seq.empty[Row] + } +} + +/** + * Alter a table partition's spec. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table PARTITION spec1 RENAME TO PARTITION spec2; + * }}} + */ case class AlterTableRenamePartition( tableName: TableIdentifier, oldPartition: TablePartitionSpec, - newPartition: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + newPartition: TablePartitionSpec) + extends RunnableCommand { -case class AlterTableExchangePartition( - fromTableName: TableIdentifier, - toTableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.renamePartitions( + tableName, Seq(oldPartition), Seq(newPartition)) + Seq.empty[Row] + } + +} /** - * Drop Partition in ALTER TABLE/VIEW: to drop a particular partition for a table/view. + * Drop Partition in ALTER TABLE: to drop a particular partition for a table. + * * This removes the data and metadata for this partition. * The data is actually moved to the .Trash/Current directory if Trash is configured, * unless 'purge' is true, but the metadata is completely lost. * An error message will be issued if the partition does not exist, unless 'ifExists' is true. * Note: purge is always false when the target is a view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE]; + * }}} */ case class AlterTableDropPartition( tableName: TableIdentifier, specs: Seq[TablePartitionSpec], - ifExists: Boolean, - purge: Boolean)(sql: String) - extends NativeDDLCommand(sql) with Logging + ifExists: Boolean) + extends RunnableCommand { -case class AlterTableArchivePartition( - tableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + val table = catalog.getTableMetadata(tableName) + if (DDLUtils.isDatasourceTable(table)) { + throw new AnalysisException( + "alter table drop partition is not allowed for tables defined using the datasource API") + } + catalog.dropPartitions(tableName, specs, ignoreIfNotExists = ifExists) + Seq.empty[Row] + } -case class AlterTableUnarchivePartition( - tableName: TableIdentifier, - spec: TablePartitionSpec)(sql: String) - extends NativeDDLCommand(sql) with Logging +} case class AlterTableSetFileFormat( tableName: TableIdentifier, @@ -453,22 +495,6 @@ case class AlterTableSetLocation( } -case class AlterTableTouch( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec])(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableCompact( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec], - compactType: String)(sql: String) - extends NativeDDLCommand(sql) with Logging - -case class AlterTableMerge( - tableName: TableIdentifier, - partitionSpec: Option[TablePartitionSpec])(sql: String) - extends NativeDDLCommand(sql) with Logging - case class AlterTableChangeCol( tableName: TableIdentifier, partitionSpec: Option[TablePartitionSpec], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index ac69518ddf..1c8dd68286 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.command +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest @@ -24,9 +25,17 @@ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.types._ +// TODO: merge this with DDLSuite (SPARK-14441) class DDLCommandSuite extends PlanTest { private val parser = SparkSqlParser + private def assertUnsupported(sql: String): Unit = { + val e = intercept[AnalysisException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("operation not allowed")) + } + test("create database") { val sql = """ @@ -326,11 +335,11 @@ class DDLCommandSuite extends PlanTest { Seq( (Map("dt" -> "2008-08-08", "country" -> "us"), Some("location1")), (Map("dt" -> "2009-09-09", "country" -> "uk"), None)), - ifNotExists = true)(sql1) + ifNotExists = true) val expected2 = AlterTableAddPartition( TableIdentifier("table_name", None), Seq((Map("dt" -> "2008-08-08"), Some("loc"))), - ifNotExists = false)(sql2) + ifNotExists = false) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) @@ -369,22 +378,16 @@ class DDLCommandSuite extends PlanTest { val expected = AlterTableRenamePartition( TableIdentifier("table_name", None), Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2008-09-09", "country" -> "uk"))(sql) + Map("dt" -> "2008-09-09", "country" -> "uk")) comparePlans(parsed, expected) } - test("alter table: exchange partition") { - val sql = + test("alter table: exchange partition (not supported)") { + assertUnsupported( """ |ALTER TABLE table_name_1 EXCHANGE PARTITION |(dt='2008-08-08', country='us') WITH TABLE table_name_2 - """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = AlterTableExchangePartition( - TableIdentifier("table_name_1", None), - TableIdentifier("table_name_2", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + """.stripMargin) } // ALTER TABLE table_name DROP [IF EXISTS] PARTITION spec1[, PARTITION spec2, ...] [PURGE] @@ -405,7 +408,10 @@ class DDLCommandSuite extends PlanTest { val sql2_view = sql2_table.replace("TABLE", "VIEW").replace("PURGE", "") val parsed1_table = parser.parsePlan(sql1_table) - val parsed2_table = parser.parsePlan(sql2_table) + val e = intercept[ParseException] { + parser.parsePlan(sql2_table) + } + assert(e.getMessage.contains("Operation not allowed")) intercept[ParseException] { parser.parsePlan(sql1_view) @@ -420,36 +426,17 @@ class DDLCommandSuite extends PlanTest { Seq( Map("dt" -> "2008-08-08", "country" -> "us"), Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = true, - purge = false)(sql1_table) - val expected2_table = AlterTableDropPartition( - tableIdent, - Seq( - Map("dt" -> "2008-08-08", "country" -> "us"), - Map("dt" -> "2009-09-09", "country" -> "uk")), - ifExists = false, - purge = true)(sql2_table) + ifExists = true) comparePlans(parsed1_table, expected1_table) - comparePlans(parsed2_table, expected2_table) } - test("alter table: archive partition") { - val sql = "ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')" - val parsed = parser.parsePlan(sql) - val expected = AlterTableArchivePartition( - TableIdentifier("table_name", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + test("alter table: archive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name ARCHIVE PARTITION (dt='2008-08-08', country='us')") } - test("alter table: unarchive partition") { - val sql = "ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')" - val parsed = parser.parsePlan(sql) - val expected = AlterTableUnarchivePartition( - TableIdentifier("table_name", None), - Map("dt" -> "2008-08-08", "country" -> "us"))(sql) - comparePlans(parsed, expected) + test("alter table: unarchive partition (not supported)") { + assertUnsupported("ALTER TABLE table_name UNARCHIVE PARTITION (dt='2008-08-08', country='us')") } test("alter table: set file format") { @@ -505,55 +492,24 @@ class DDLCommandSuite extends PlanTest { comparePlans(parsed2, expected2) } - test("alter table: touch") { - val sql1 = "ALTER TABLE table_name TOUCH" - val sql2 = "ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableTouch( - tableIdent, - None)(sql1) - val expected2 = AlterTableTouch( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + test("alter table: touch (not supported)") { + assertUnsupported("ALTER TABLE table_name TOUCH") + assertUnsupported("ALTER TABLE table_name TOUCH PARTITION (dt='2008-08-08', country='us')") } - test("alter table: compact") { - val sql1 = "ALTER TABLE table_name COMPACT 'compaction_type'" - val sql2 = + test("alter table: compact (not supported)") { + assertUnsupported("ALTER TABLE table_name COMPACT 'compaction_type'") + assertUnsupported( """ - |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') - |COMPACT 'MAJOR' - """.stripMargin - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableCompact( - tableIdent, - None, - "compaction_type")(sql1) - val expected2 = AlterTableCompact( - tableIdent, - Some(Map("dt" -> "2008-08-08", "country" -> "us")), - "MAJOR")(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + |ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') + |COMPACT 'MAJOR' + """.stripMargin) } - test("alter table: concatenate") { - val sql1 = "ALTER TABLE table_name CONCATENATE" - val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE" - val parsed1 = parser.parsePlan(sql1) - val parsed2 = parser.parsePlan(sql2) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableMerge(tableIdent, None)(sql1) - val expected2 = AlterTableMerge( - tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")))(sql2) - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) + test("alter table: concatenate (not supported)") { + assertUnsupported("ALTER TABLE table_name CONCATENATE") + assertUnsupported( + "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') CONCATENATE") } test("alter table: change column name/type/position/comment") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index c6479bf33e..40a8b0e614 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -58,6 +58,10 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(e.getMessage.toLowerCase.contains("operation not allowed")) } + private def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { + if (expectException) intercept[AnalysisException] { body } else body + } + private def createDatabase(catalog: SessionCatalog, name: String): Unit = { catalog.createDatabase(CatalogDatabase(name, "", "", Map()), ignoreIfExists = false) } @@ -320,6 +324,62 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assertUnsupported("ALTER TABLE dbx.tab1 NOT STORED AS DIRECTORIES") } + test("alter table: add partition") { + testAddPartitions(isDatasourceTable = false) + } + + test("alter table: add partition (datasource table)") { + testAddPartitions(isDatasourceTable = true) + } + + test("alter table: add partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 ADD IF NOT EXISTS PARTITION (b='2')") + } + + test("alter table: drop partition") { + testDropPartitions(isDatasourceTable = false) + } + + test("alter table: drop partition (datasource table)") { + testDropPartitions(isDatasourceTable = true) + } + + test("alter table: drop partition is not supported for views") { + assertUnsupported("ALTER VIEW dbx.tab1 DROP IF EXISTS PARTITION (b='2')") + } + + test("alter table: rename partition") { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3)) + sql("ALTER TABLE dbx.tab1 PARTITION (a='1') RENAME TO PARTITION (a='100')") + sql("ALTER TABLE dbx.tab1 PARTITION (b='2') RENAME TO PARTITION (b='200')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "100"), Map("b" -> "200"), part3)) + // rename without explicitly specifying database + catalog.setCurrentDatabase("dbx") + sql("ALTER TABLE tab1 PARTITION (a='100') RENAME TO PARTITION (a='10')") + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(Map("a" -> "10"), Map("b" -> "200"), part3)) + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist PARTITION (c='3') RENAME TO PARTITION (c='333')") + } + // partition to rename does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 PARTITION (x='300') RENAME TO PARTITION (x='333')") + } + } + // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext test("show tables") { @@ -487,10 +547,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { assert(storageFormat.locationUri === Some(expected)) } } - // Optionally expect AnalysisException - def maybeWrapException[T](expectException: Boolean)(body: => T): Unit = { - if (expectException) intercept[AnalysisException] { body } else body - } // set table location sql("ALTER TABLE dbx.tab1 SET LOCATION '/path/to/your/lovely/heart'") verifyLocation("/path/to/your/lovely/heart") @@ -564,4 +620,102 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + private def testAddPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1)) + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 ADD IF NOT EXISTS " + + "PARTITION (b='2') LOCATION 'paris' PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2, part3)) + assert(catalog.getPartition(tableIdent, part1).storage.locationUri.isEmpty) + assert(catalog.getPartition(tableIdent, part2).storage.locationUri == Some("paris")) + assert(catalog.getPartition(tableIdent, part3).storage.locationUri.isEmpty) + } + // add partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist ADD IF NOT EXISTS PARTITION (d='4')") + } + // partition to add already exists + intercept[AnalysisException] { + sql("ALTER TABLE tab1 ADD PARTITION (d='4')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 ADD IF NOT EXISTS PARTITION (d='4')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + } + } + + private def testDropPartitions(isDatasourceTable: Boolean): Unit = { + val catalog = sqlContext.sessionState.catalog + val tableIdent = TableIdentifier("tab1", Some("dbx")) + val part1 = Map("a" -> "1") + val part2 = Map("b" -> "2") + val part3 = Map("c" -> "3") + val part4 = Map("d" -> "4") + createDatabase(catalog, "dbx") + createTable(catalog, tableIdent) + createTablePartition(catalog, part1, tableIdent) + createTablePartition(catalog, part2, tableIdent) + createTablePartition(catalog, part3, tableIdent) + createTablePartition(catalog, part4, tableIdent) + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == + Set(part1, part2, part3, part4)) + if (isDatasourceTable) { + convertToDatasourceTable(catalog, tableIdent) + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE dbx.tab1 DROP IF EXISTS PARTITION (d='4'), PARTITION (c='3')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec).toSet == Set(part1, part2)) + } + // drop partitions without explicitly specifying database + catalog.setCurrentDatabase("dbx") + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (b='2')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + // table to alter does not exist + intercept[AnalysisException] { + sql("ALTER TABLE does_not_exist DROP IF EXISTS PARTITION (b='2')") + } + // partition to drop does not exist + intercept[AnalysisException] { + sql("ALTER TABLE tab1 DROP PARTITION (x='300')") + } + maybeWrapException(isDatasourceTable) { + sql("ALTER TABLE tab1 DROP IF EXISTS PARTITION (x='300')") + } + if (!isDatasourceTable) { + assert(catalog.listPartitions(tableIdent).map(_.spec) == Seq(part1)) + } + } + } diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index 9e3cb18d45..f0eeda09db 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -376,7 +376,13 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { // Create partitioned view is not supported "create_like_view", - "describe_formatted_view_partitioned" + "describe_formatted_view_partitioned", + + // This uses CONCATENATE, which we don't support + "alter_merge_2", + + // TOUCH is not supported + "touch" ) /** @@ -392,7 +398,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "alter2", "alter3", "alter5", - "alter_merge_2", "alter_partition_format_loc", "alter_partition_with_whitelist", "alter_rename_partition", @@ -897,7 +902,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "timestamp_comparison", "timestamp_lazy", "timestamp_null", - "touch", "transform_ppr1", "transform_ppr2", "truncate_table", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index a49ce33ba1..482f47428d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -219,26 +219,7 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat parts: Seq[TablePartitionSpec], ignoreIfNotExists: Boolean): Unit = withClient { requireTableExists(db, table) - // Note: Unfortunately Hive does not currently support `ignoreIfNotExists` so we - // need to implement it here ourselves. This is currently somewhat expensive because - // we make multiple synchronous calls to Hive for each partition we want to drop. - val partsToDrop = - if (ignoreIfNotExists) { - parts.filter { spec => - try { - getPartition(db, table, spec) - true - } catch { - // Filter out the partitions that do not actually exist - case _: AnalysisException => false - } - } - } else { - parts - } - if (partsToDrop.nonEmpty) { - client.dropPartitions(db, table, partsToDrop) - } + client.dropPartitions(db, table, parts, ignoreIfNotExists) } override def renamePartitions( diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala index 94794b1572..6f7e7bf451 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClient.scala @@ -120,16 +120,13 @@ private[hive] trait HiveClient { ignoreIfExists: Boolean): Unit /** - * Drop one or many partitions in the given table. - * - * Note: Unfortunately, Hive does not currently provide a way to ignore this call if the - * partitions do not already exist. The seemingly relevant flag `ifExists` in - * [[org.apache.hadoop.hive.metastore.PartitionDropOptions]] is not read anywhere. + * Drop one or many partitions in the given table, assuming they exist. */ def dropPartitions( db: String, table: String, - specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit /** * Rename one or many existing table partitions, assuming they exist. diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index a037671ef0..39e26acd7f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -26,7 +26,7 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path import org.apache.hadoop.hive.cli.CliSessionState import org.apache.hadoop.hive.conf.HiveConf -import org.apache.hadoop.hive.metastore.{TableType => HiveTableType} +import org.apache.hadoop.hive.metastore.{PartitionDropOptions, TableType => HiveTableType} import org.apache.hadoop.hive.metastore.api.{Database => HiveDatabase, FieldSchema, Function => HiveFunction, FunctionType, PrincipalType, ResourceType, ResourceUri} import org.apache.hadoop.hive.ql.Driver import org.apache.hadoop.hive.ql.metadata.{Partition => HivePartition, Table => HiveTable} @@ -367,9 +367,25 @@ private[hive] class HiveClientImpl( override def dropPartitions( db: String, table: String, - specs: Seq[ExternalCatalog.TablePartitionSpec]): Unit = withHiveState { + specs: Seq[ExternalCatalog.TablePartitionSpec], + ignoreIfNotExists: Boolean): Unit = withHiveState { // TODO: figure out how to drop multiple partitions in one call - specs.foreach { s => client.dropPartition(db, table, s.values.toList.asJava, true) } + val hiveTable = client.getTable(db, table, true /* throw exception */) + specs.foreach { s => + // The provided spec here can be a partial spec, i.e. it will match all partitions + // whose specs are supersets of this partial spec. E.g. If a table has partitions + // (b='1', c='1') and (b='1', c='2'), a partial spec of (b='1') will match both. + val matchingParts = client.getPartitions(hiveTable, s.asJava).asScala + if (matchingParts.isEmpty && !ignoreIfNotExists) { + throw new AnalysisException( + s"partition to drop '$s' does not exist in table '$table' database '$db'") + } + matchingParts.foreach { hivePartition => + val dropOptions = new PartitionDropOptions + dropOptions.ifExists = ignoreIfNotExists + client.dropPartition(db, table, hivePartition.getValues, dropOptions) + } + } } override def renamePartitions( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index a144da4997..e8086aec32 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -41,6 +41,13 @@ class HiveDDLCommandSuite extends PlanTest { }.head } + private def assertUnsupported(sql: String): Unit = { + val e = intercept[ParseException] { + parser.parsePlan(sql) + } + assert(e.getMessage.toLowerCase.contains("unsupported")) + } + test("Test CTAS #1") { val s1 = """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view @@ -367,4 +374,9 @@ class HiveDDLCommandSuite extends PlanTest { parser.parsePlan(v1).isInstanceOf[HiveNativeCommand] } } + + test("MSCK repair table (not supported)") { + assertUnsupported("MSCK REPAIR TABLE tab1") + } + } -- cgit v1.2.3 From b0f5497e9520575e5082fa8ce8be5569f43abe74 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Tue, 12 Apr 2016 00:43:28 -0700 Subject: [SPARK-14508][BUILD] Add a new ScalaStyle Rule `OmitBracesInCase` ## What changes were proposed in this pull request? According to the [Spark Code Style Guide](https://cwiki.apache.org/confluence/display/SPARK/Spark+Code+Style+Guide) and [Scala Style Guide](http://docs.scala-lang.org/style/control-structures.html#curlybraces), we had better enforce the following rule. ``` case: Always omit braces in case clauses. ``` This PR makes a new ScalaStyle rule, 'OmitBracesInCase', and enforces it to the code. ## How was this patch tested? Pass the Jenkins tests (including Scala style checking) Author: Dongjoon Hyun Closes #12280 from dongjoon-hyun/SPARK-14508. --- .../main/scala/org/apache/spark/SparkContext.scala | 12 ++---- .../src/main/scala/org/apache/spark/SparkEnv.scala | 3 +- .../apache/spark/api/python/PythonHadoopUtil.scala | 3 +- .../org/apache/spark/api/python/PythonRDD.scala | 2 +- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 12 ++---- .../org/apache/spark/deploy/master/Master.scala | 48 ++++++++-------------- .../spark/deploy/master/MasterArguments.scala | 2 +- .../deploy/master/ZooKeeperPersistenceEngine.scala | 3 +- .../mesos/MesosClusterDispatcherArguments.scala | 3 +- .../spark/deploy/worker/ExecutorRunner.scala | 6 +-- .../org/apache/spark/deploy/worker/Worker.scala | 12 ++---- .../spark/deploy/worker/WorkerArguments.scala | 3 +- .../executor/CoarseGrainedExecutorBackend.scala | 3 +- .../org/apache/spark/metrics/MetricsSystem.scala | 3 +- .../org/apache/spark/partial/BoundedDouble.scala | 3 +- .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 +- .../org/apache/spark/rdd/OrderedRDDFunctions.scala | 3 +- .../apache/spark/rdd/ParallelCollectionRDD.scala | 9 ++-- .../spark/rdd/PartitionerAwareUnionRDD.scala | 3 +- .../apache/spark/scheduler/InputFormatInfo.scala | 6 +-- .../org/apache/spark/scheduler/SplitInfo.scala | 3 +- .../apache/spark/scheduler/TaskResultGetter.scala | 2 +- .../apache/spark/scheduler/TaskSetManager.scala | 9 ++-- .../mesos/MesosClusterPersistenceEngine.scala | 3 +- .../cluster/mesos/MesosSchedulerBackendUtil.scala | 6 +-- .../cluster/mesos/MesosSchedulerUtils.scala | 6 +-- .../org/apache/spark/serializer/Serializer.scala | 3 +- .../storage/ShuffleBlockFetcherIterator.scala | 6 +-- .../spark/ui/exec/ExecutorThreadDumpPage.scala | 3 +- .../scala/org/apache/spark/util/EventLoop.scala | 3 +- .../org/apache/spark/util/SizeEstimator.scala | 3 +- .../scala/org/apache/spark/DistributedSuite.scala | 2 +- .../org/apache/spark/SparkContextInfoSuite.scala | 6 +-- .../scala/org/apache/spark/UnpersistSuite.scala | 2 +- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 15 +++---- .../apache/spark/examples/CassandraCQLTest.scala | 6 +-- .../org/apache/spark/examples/CassandraTest.scala | 6 +-- .../scala/org/apache/spark/examples/LocalALS.scala | 6 +-- .../spark/examples/ml/OneVsRestExample.scala | 6 +-- .../spark/examples/mllib/DecisionTreeRunner.scala | 6 +-- .../streaming/kinesis/KinesisRecordProcessor.scala | 25 +++++------ .../org/apache/spark/ml/r/SparkRWrappers.scala | 6 +-- .../spark/mllib/clustering/GaussianMixture.scala | 3 +- .../mllib/clustering/GaussianMixtureModel.scala | 3 +- .../org/apache/spark/mllib/clustering/KMeans.scala | 6 +-- .../mllib/stat/test/KolmogorovSmirnovTest.scala | 3 +- .../apache/spark/repl/ExecutorClassLoader.scala | 36 ++++++++-------- scalastyle-config.xml | 5 +++ .../spark/sql/catalyst/expressions/Cast.scala | 3 +- .../expressions/EquivalentExpressions.scala | 4 +- .../org/apache/spark/sql/RandomDataGenerator.scala | 18 +++----- .../org/apache/spark/sql/DataFrameSuite.scala | 9 ++-- .../execution/vectorized/ColumnarBatchSuite.scala | 8 +--- .../apache/spark/streaming/dstream/DStream.scala | 3 +- .../streaming/dstream/DStreamCheckpointData.scala | 3 +- .../spark/streaming/dstream/FileInputDStream.scala | 3 +- .../spark/streaming/dstream/StateDStream.scala | 26 ++++-------- .../spark/streaming/BasicOperationsSuite.scala | 3 +- .../apache/spark/streaming/CheckpointSuite.scala | 6 +-- .../apache/spark/streaming/MasterFailureTest.scala | 3 +- .../spark/deploy/yarn/ApplicationMaster.scala | 3 +- .../apache/spark/deploy/yarn/YarnAllocator.scala | 3 +- .../scheduler/cluster/YarnSchedulerBackend.scala | 6 +-- .../deploy/yarn/YarnSparkHadoopUtilSuite.scala | 24 ++++------- 64 files changed, 164 insertions(+), 293 deletions(-) (limited to 'sql/catalyst') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index f0d152f05a..966198dd5e 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -2397,9 +2397,8 @@ object SparkContext extends Logging { } catch { // TODO: Enumerate the exact reasons why it can fail // But irrespective of it, it means we cannot proceed ! - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } val backend = try { val clazz = @@ -2407,9 +2406,8 @@ object SparkContext extends Logging { val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } scheduler.initialize(backend) (backend, scheduler) @@ -2421,9 +2419,8 @@ object SparkContext extends Logging { cons.newInstance(sc).asInstanceOf[TaskSchedulerImpl] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } val backend = try { @@ -2432,9 +2429,8 @@ object SparkContext extends Logging { val cons = clazz.getConstructor(classOf[TaskSchedulerImpl], classOf[SparkContext]) cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] } catch { - case e: Exception => { + case e: Exception => throw new SparkException("YARN mode not available ?", e) - } } scheduler.initialize(backend) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ab89f4c4e4..3d11db7461 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -101,14 +101,13 @@ class SparkEnv ( // We only need to delete the tmp dir create by driver, because sparkFilesDir is point to the // current working dir in executor which we do not need to delete. driverTmpDirToDelete match { - case Some(path) => { + case Some(path) => try { Utils.deleteRecursively(new File(path)) } catch { case e: Exception => logWarning(s"Exception while deleting Spark temp dir: $path", e) } - } case None => // We just need to delete tmp dir created by driver, so do nothing on executor } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala index 6f6730690f..6259bead3e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala @@ -134,11 +134,10 @@ private[python] class JavaToWritableConverter extends Converter[Any, Writable] { mapWritable.put(convertToWritable(k), convertToWritable(v)) } mapWritable - case array: Array[Any] => { + case array: Array[Any] => val arrayWriteable = new ArrayWritable(classOf[Writable]) arrayWriteable.set(array.map(convertToWritable(_))) arrayWriteable - } case other => throw new SparkException( s"Data of type ${other.getClass.getName} cannot be used") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 4bca16a234..ab5b6c8380 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -470,7 +470,7 @@ private[spark] object PythonRDD extends Logging { objs.append(obj) } } catch { - case eof: EOFException => {} + case eof: EOFException => // No-op } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } finally { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 41ac308808..cda9d38c6a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -152,10 +152,9 @@ class SparkHadoopUtil extends Logging { val baselineBytesRead = f() Some(() => f() - baselineBytesRead) } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => logDebug("Couldn't find method for retrieving thread-level FileSystem input data", e) None - } } } @@ -174,10 +173,9 @@ class SparkHadoopUtil extends Logging { val baselineBytesWritten = f() Some(() => f() - baselineBytesWritten) } catch { - case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => { + case e @ (_: NoSuchMethodException | _: ClassNotFoundException) => logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e) None - } } } @@ -315,7 +313,7 @@ class SparkHadoopUtil extends Logging { */ def substituteHadoopVariables(text: String, hadoopConf: Configuration): String = { text match { - case HADOOP_CONF_PATTERN(matched) => { + case HADOOP_CONF_PATTERN(matched) => logDebug(text + " matched " + HADOOP_CONF_PATTERN) val key = matched.substring(13, matched.length() - 1) // remove ${hadoopconf- .. } val eval = Option[String](hadoopConf.get(key)) @@ -330,11 +328,9 @@ class SparkHadoopUtil extends Logging { // Continue to substitute more variables. substituteHadoopVariables(eval.get, hadoopConf) } - } - case _ => { + case _ => logDebug(text + " didn't match " + HADOOP_CONF_PATTERN) text - } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 01901bbf85..9bd3fc1033 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -217,7 +217,7 @@ private[deploy] class Master( } override def receive: PartialFunction[Any, Unit] = { - case ElectedLeader => { + case ElectedLeader => val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData(rpcEnv) state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { RecoveryState.ALIVE @@ -233,16 +233,14 @@ private[deploy] class Master( } }, WORKER_TIMEOUT_MS, TimeUnit.MILLISECONDS) } - } case CompleteRecovery => completeRecovery() - case RevokedLeadership => { + case RevokedLeadership => logError("Leadership has been revoked -- master shutting down.") System.exit(0) - } - case RegisterApplication(description, driver) => { + case RegisterApplication(description, driver) => // TODO Prevent repeated registrations from some driver if (state == RecoveryState.STANDBY) { // ignore, don't send response @@ -255,12 +253,11 @@ private[deploy] class Master( driver.send(RegisteredApplication(app.id, self)) schedule() } - } - case ExecutorStateChanged(appId, execId, state, message, exitStatus) => { + case ExecutorStateChanged(appId, execId, state, message, exitStatus) => val execOption = idToApp.get(appId).flatMap(app => app.executors.get(execId)) execOption match { - case Some(exec) => { + case Some(exec) => val appInfo = idToApp(appId) val oldState = exec.state exec.state = state @@ -298,22 +295,19 @@ private[deploy] class Master( } } } - } case None => logWarning(s"Got status update for unknown executor $appId/$execId") } - } - case DriverStateChanged(driverId, state, exception) => { + case DriverStateChanged(driverId, state, exception) => state match { case DriverState.ERROR | DriverState.FINISHED | DriverState.KILLED | DriverState.FAILED => removeDriver(driverId, state, exception) case _ => throw new Exception(s"Received unexpected state update for driver $driverId: $state") } - } - case Heartbeat(workerId, worker) => { + case Heartbeat(workerId, worker) => idToWorker.get(workerId) match { case Some(workerInfo) => workerInfo.lastHeartbeat = System.currentTimeMillis() @@ -327,9 +321,8 @@ private[deploy] class Master( " This worker was never registered, so ignoring the heartbeat.") } } - } - case MasterChangeAcknowledged(appId) => { + case MasterChangeAcknowledged(appId) => idToApp.get(appId) match { case Some(app) => logInfo("Application has been re-registered: " + appId) @@ -339,9 +332,8 @@ private[deploy] class Master( } if (canCompleteRecovery) { completeRecovery() } - } - case WorkerSchedulerStateResponse(workerId, executors, driverIds) => { + case WorkerSchedulerStateResponse(workerId, executors, driverIds) => idToWorker.get(workerId) match { case Some(worker) => logInfo("Worker has been re-registered: " + workerId) @@ -367,7 +359,6 @@ private[deploy] class Master( } if (canCompleteRecovery) { completeRecovery() } - } case WorkerLatestState(workerId, executors, driverIds) => idToWorker.get(workerId) match { @@ -397,9 +388,8 @@ private[deploy] class Master( logInfo(s"Received unregister request from application $applicationId") idToApp.get(applicationId).foreach(finishApplication) - case CheckForWorkerTimeOut => { + case CheckForWorkerTimeOut => timeOutDeadWorkers() - } case AttachCompletedRebuildUI(appId) => // An asyncRebuildSparkUI has completed, so need to attach to master webUi @@ -408,7 +398,7 @@ private[deploy] class Master( override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case RegisterWorker( - id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => { + id, workerHost, workerPort, workerRef, cores, memory, workerWebUiUrl) => logInfo("Registering worker %s:%d with %d cores, %s RAM".format( workerHost, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -430,9 +420,8 @@ private[deploy] class Master( + workerAddress)) } } - } - case RequestSubmitDriver(description) => { + case RequestSubmitDriver(description) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + "Can only accept driver submissions in ALIVE state." @@ -451,9 +440,8 @@ private[deploy] class Master( context.reply(SubmitDriverResponse(self, true, Some(driver.id), s"Driver successfully submitted as ${driver.id}")) } - } - case RequestKillDriver(driverId) => { + case RequestKillDriver(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + s"Can only kill drivers in ALIVE state." @@ -484,9 +472,8 @@ private[deploy] class Master( context.reply(KillDriverResponse(self, driverId, success = false, msg)) } } - } - case RequestDriverStatus(driverId) => { + case RequestDriverStatus(driverId) => if (state != RecoveryState.ALIVE) { val msg = s"${Utils.BACKUP_STANDALONE_MASTER_PREFIX}: $state. " + "Can only request driver status in ALIVE state." @@ -501,18 +488,15 @@ private[deploy] class Master( context.reply(DriverStatusResponse(found = false, None, None, None, None)) } } - } - case RequestMasterState => { + case RequestMasterState => context.reply(MasterStateResponse( address.host, address.port, restServerBoundPort, workers.toArray, apps.toArray, completedApps.toArray, drivers.toArray, completedDrivers.toArray, state)) - } - case BoundPortsRequest => { + case BoundPortsRequest => context.reply(BoundPortsResponse(address.port, webUi.boundPort, restServerBoundPort)) - } case RequestExecutors(appId, requestedTotal) => context.reply(handleRequestExecutors(appId, requestedTotal)) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala index 9cd7458ba0..585e0839d0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/MasterArguments.scala @@ -78,7 +78,7 @@ private[master] class MasterArguments(args: Array[String], conf: SparkConf) { case ("--help") :: tail => printUsageAndExit(0) - case Nil => {} + case Nil => // No-op case _ => printUsageAndExit(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index 79f77212fe..af850e4871 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -70,11 +70,10 @@ private[master] class ZooKeeperPersistenceEngine(conf: SparkConf, val serializer try { Some(serializer.newInstance().deserialize[T](ByteBuffer.wrap(fileData))) } catch { - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(WORKING_DIR + "/" + filename) None - } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala index b97805a28b..11e13441ee 100644 --- a/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/mesos/MesosClusterDispatcherArguments.scala @@ -76,14 +76,13 @@ private[mesos] class MesosClusterDispatcherArguments(args: Array[String], conf: case ("--help") :: tail => printUsageAndExit(0) - case Nil => { + case Nil => if (masterUrl == null) { // scalastyle:off println System.err.println("--master is required") // scalastyle:on println printUsageAndExit(1) } - } case _ => printUsageAndExit(1) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index f9c92c3bb9..06066248ea 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -179,16 +179,14 @@ private[deploy] class ExecutorRunner( val message = "Command exited with code " + exitCode worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { - case interrupted: InterruptedException => { + case interrupted: InterruptedException => logInfo("Runner thread for executor " + fullId + " interrupted") state = ExecutorState.KILLED killProcess(None) - } - case e: Exception => { + case e: Exception => logError("Error running executor", e) state = ExecutorState.FAILED killProcess(Some(e.toString)) - } } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 1b7637a39c..449beb0811 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -480,7 +480,7 @@ private[deploy] class Worker( memoryUsed += memory_ sendToMaster(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { - case e: Exception => { + case e: Exception => logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) if (executors.contains(appId + "/" + execId)) { executors(appId + "/" + execId).kill() @@ -488,7 +488,6 @@ private[deploy] class Worker( } sendToMaster(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, Some(e.toString), None)) - } } } @@ -509,7 +508,7 @@ private[deploy] class Worker( } } - case LaunchDriver(driverId, driverDesc) => { + case LaunchDriver(driverId, driverDesc) => logInfo(s"Asked to launch driver $driverId") val driver = new DriverRunner( conf, @@ -525,9 +524,8 @@ private[deploy] class Worker( coresUsed += driverDesc.cores memoryUsed += driverDesc.mem - } - case KillDriver(driverId) => { + case KillDriver(driverId) => logInfo(s"Asked to kill driver $driverId") drivers.get(driverId) match { case Some(runner) => @@ -535,11 +533,9 @@ private[deploy] class Worker( case None => logError(s"Asked to kill unknown driver $driverId") } - } - case driverStateChanged @ DriverStateChanged(driverId, state, exception) => { + case driverStateChanged @ DriverStateChanged(driverId, state, exception) => handleDriverStateChanged(driverStateChanged) - } case ReregisterWithMaster => reregisterWithMaster() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala index 391eb41190..777020d4d5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerArguments.scala @@ -165,12 +165,11 @@ private[worker] class WorkerArguments(args: Array[String], conf: SparkConf) { } // scalastyle:on classforname } catch { - case e: Exception => { + case e: Exception => totalMb = 2*1024 // scalastyle:off println System.out.println("Failed to get total physical memory. Using " + totalMb + " MB") // scalastyle:on println - } } // Leave out 1 GB for the operating system, but don't return a negative memory size math.max(totalMb - 1024, Utils.DEFAULT_DRIVER_MEM_MB) diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index d4ed5845e7..71b4ad160d 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -62,10 +62,9 @@ private[spark] class CoarseGrainedExecutorBackend( // This is a very fast action so we can use "ThreadUtils.sameThread" case Success(msg) => // Always receive `true`. Just ignore it - case Failure(e) => { + case Failure(e) => logError(s"Cannot register with driver: $driverUrl", e) System.exit(1) - } }(ThreadUtils.sameThread) } diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala index 4da1017d28..0fed991049 100644 --- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala +++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala @@ -196,10 +196,9 @@ private[spark] class MetricsSystem private ( sinks += sink.asInstanceOf[Sink] } } catch { - case e: Exception => { + case e: Exception => logError("Sink class " + classPath + " cannot be instantiated") throw e - } } } } diff --git a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala index c562c70aba..ab6aba6fc7 100644 --- a/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala +++ b/core/src/main/scala/org/apache/spark/partial/BoundedDouble.scala @@ -32,12 +32,11 @@ class BoundedDouble(val mean: Double, val confidence: Double, val low: Double, v */ override def equals(that: Any): Boolean = that match { - case that: BoundedDouble => { + case that: BoundedDouble => this.mean == that.mean && this.confidence == that.confidence && this.low == that.low && this.high == that.high - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 5e9230e733..368916a39e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -166,8 +166,8 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val counters = new Array[Long](buckets.length - 1) while (iter.hasNext) { bucketFunction(iter.next()) match { - case Some(x: Int) => {counters(x) += 1} - case _ => {} + case Some(x: Int) => counters(x) += 1 + case _ => // No-Op } } Iterator(counters) diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 363004e587..a5992022d0 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -86,12 +86,11 @@ class OrderedRDDFunctions[K : Ordering : ClassTag, def inRange(k: K): Boolean = ordering.gteq(k, lower) && ordering.lteq(k, upper) val rddToFilter: RDD[P] = self.partitioner match { - case Some(rp: RangePartitioner[K, V]) => { + case Some(rp: RangePartitioner[K, V]) => val partitionIndicies = (rp.getPartition(lower), rp.getPartition(upper)) match { case (l, u) => Math.min(l, u) to Math.max(l, u) } PartitionPruningRDD.create(self, partitionIndicies.contains) - } case _ => self } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 582fa93afe..462fb39ea2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -128,7 +128,7 @@ private object ParallelCollectionRDD { }) } seq match { - case r: Range => { + case r: Range => positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) => // If the range is inclusive, use inclusive range for the last slice if (r.isInclusive && index == numSlices - 1) { @@ -138,8 +138,7 @@ private object ParallelCollectionRDD { new Range(r.start + start * r.step, r.start + end * r.step, r.step) } }).toSeq.asInstanceOf[Seq[Seq[T]]] - } - case nr: NumericRange[_] => { + case nr: NumericRange[_] => // For ranges of Long, Double, BigInteger, etc val slices = new ArrayBuffer[Seq[T]](numSlices) var r = nr @@ -149,14 +148,12 @@ private object ParallelCollectionRDD { r = r.drop(sliceSize) } slices - } - case _ => { + case _ => val array = seq.toArray // To prevent O(n^2) operations for List etc positions(array.length, numSlices).map({ case (start, end) => array.slice(start, end).toSeq }).toSeq - } } } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index 9e3880714a..c3579d761d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -78,11 +78,10 @@ class PartitionerAwareUnionRDD[T: ClassTag]( logDebug("Finding preferred location for " + this + ", partition " + s.index) val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents val locations = rdds.zip(parentPartitions).flatMap { - case (rdd, part) => { + case (rdd, part) => val parentLocations = currPrefLocs(rdd, part) logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) parentLocations - } } val location = if (locations.isEmpty) { None diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 0640f26051..a6b032cc00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -57,11 +57,10 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl // Since we are not doing canonicalization of path, this can be wrong : like relative vs // absolute path .. which is fine, this is best case effort to remove duplicates - right ? override def equals(other: Any): Boolean = other match { - case that: InputFormatInfo => { + case that: InputFormatInfo => // not checking config - that should be fine, right ? this.inputFormatClazz == that.inputFormatClazz && this.path == that.path - } case _ => false } @@ -86,10 +85,9 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl } } catch { - case e: ClassNotFoundException => { + case e: ClassNotFoundException => throw new IllegalArgumentException("Specified inputformat " + inputFormatClazz + " cannot be found ?", e) - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala index 6e9337bb90..bc1431835e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SplitInfo.scala @@ -49,14 +49,13 @@ class SplitInfo( // So unless there is identity equality between underlyingSplits, it will always fail even if it // is pointing to same block. override def equals(other: Any): Boolean = other match { - case that: SplitInfo => { + case that: SplitInfo => this.hostLocation == that.hostLocation && this.inputFormatClazz == that.inputFormatClazz && this.path == that.path && this.length == that.length && // other split specific checks (like start for FileSplit) this.underlyingSplit == that.underlyingSplit - } case _ => false } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala index 873f1b56bd..ae7ef46abb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskResultGetter.scala @@ -133,7 +133,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: TaskSchedul // if we can't deserialize the reason. logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex: Exception => {} + case ex: Exception => // No-op } scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 15d3515a02..6e08cdd87a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -188,20 +188,18 @@ private[spark] class TaskSetManager( loc match { case e: ExecutorCacheTaskLocation => pendingTasksForExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index - case e: HDFSCacheTaskLocation => { + case e: HDFSCacheTaskLocation => val exe = sched.getExecutorsAliveOnHost(loc.host) exe match { - case Some(set) => { + case Some(set) => for (e <- set) { pendingTasksForExecutor.getOrElseUpdate(e, new ArrayBuffer) += index } logInfo(s"Pending task $index has a cached location at ${e.host} " + ", where there are executors " + set.mkString(",")) - } case None => logDebug(s"Pending task $index has a cached location at ${e.host} " + ", but there are no executors alive there.") } - } case _ => } pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer) += index @@ -437,7 +435,7 @@ private[spark] class TaskSetManager( } dequeueTask(execId, host, allowedLocality) match { - case Some((index, taskLocality, speculative)) => { + case Some((index, taskLocality, speculative)) => // Found a task; do some bookkeeping and return a task description val task = tasks(index) val taskId = sched.newTaskId() @@ -486,7 +484,6 @@ private[spark] class TaskSetManager( sched.dagScheduler.taskStarted(task, info) return Some(new TaskDescription(taskId = taskId, attemptNumber = attemptNum, execId, taskName, index, serializedTask)) - } case _ => } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala index 3971e6c382..61ab3e87c5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterPersistenceEngine.scala @@ -121,11 +121,10 @@ private[spark] class ZookeeperMesosClusterPersistenceEngine( Some(Utils.deserialize[T](fileData)) } catch { case e: NoNodeException => None - case e: Exception => { + case e: Exception => logWarning("Exception while reading persisted file, deleting", e) zk.delete().forPath(zkPath) None - } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala index 374c79a7e5..1b7ac172de 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackendUtil.scala @@ -55,11 +55,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(vol.setContainerPath(container_path) .setHostPath(host_path) .setMode(Volume.Mode.RO)) - case spec => { + case spec => logWarning(s"Unable to parse volume specs: $volumes. " + "Expected form: \"[host-dir:]container-dir[:rw|:ro](, ...)\"") None - } } } .map { _.build() } @@ -90,11 +89,10 @@ private[mesos] object MesosSchedulerBackendUtil extends Logging { Some(portmap.setHostPort(host_port.toInt) .setContainerPort(container_port.toInt) .setProtocol(protocol)) - case spec => { + case spec => logWarning(s"Unable to parse port mapping specs: $portmaps. " + "Expected form: \"host_port:container_port[:udp|:tcp](, ...)\"") None - } } } .map { _.build() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 233bdc23e6..7295d50682 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -124,11 +124,10 @@ private[mesos] trait MesosSchedulerUtils extends Logging { markErr() } } catch { - case e: Exception => { + case e: Exception => logError("driver.run() failed", e) error = Some(e) markErr() - } } } }.start() @@ -184,7 +183,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { var remain = amountToUse var requestedResources = new ArrayBuffer[Resource] val remainingResources = resources.asScala.map { - case r => { + case r => if (remain > 0 && r.getType == Value.Type.SCALAR && r.getScalar.getValue > 0.0 && @@ -196,7 +195,6 @@ private[mesos] trait MesosSchedulerUtils extends Logging { } else { r } - } } // Filter any resource that has depleted. diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala index 5ead40e89e..cb95246d5b 100644 --- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala @@ -188,10 +188,9 @@ abstract class DeserializationStream { try { (readKey[Any](), readValue[Any]()) } catch { - case eof: EOFException => { + case eof: EOFException => finished = true null - } } } diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala index 25edb9f1e4..4ec5b4bbb0 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala @@ -143,13 +143,12 @@ final class ShuffleBlockFetcherIterator( while (iter.hasNext) { val result = iter.next() result match { - case SuccessFetchResult(_, address, _, buf, _) => { + case SuccessFetchResult(_, address, _, buf, _) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) } buf.release() - } case _ => } } @@ -313,7 +312,7 @@ final class ShuffleBlockFetcherIterator( shuffleMetrics.incFetchWaitTime(stopFetchWait - startFetchWait) result match { - case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => { + case SuccessFetchResult(_, address, size, buf, isNetworkReqDone) => if (address != blockManager.blockManagerId) { shuffleMetrics.incRemoteBytesRead(buf.size) shuffleMetrics.incRemoteBlocksFetched(1) @@ -323,7 +322,6 @@ final class ShuffleBlockFetcherIterator( reqsInFlight -= 1 logDebug("Number of requests in flight " + reqsInFlight) } - } case _ => } // Send fetch requests up to maxBytesInFlight diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala index cc476d61b5..a0ef80d9bd 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala @@ -38,7 +38,7 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage val content = maybeThreadDump.map { threadDump => val dumpRows = threadDump.sortWith { - case (threadTrace1, threadTrace2) => { + case (threadTrace1, threadTrace2) => val v1 = if (threadTrace1.threadName.contains("Executor task launch")) 1 else 0 val v2 = if (threadTrace2.threadName.contains("Executor task launch")) 1 else 0 if (v1 == v2) { @@ -46,7 +46,6 @@ private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage } else { v1 > v2 } - } }.map { thread => val threadId = thread.threadId { + case NonFatal(e) => try { onError(e) } catch { case NonFatal(e) => logError("Unexpected error in " + name, e) } - } } } } catch { diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 3f627a0145..6861a75612 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -151,13 +151,12 @@ object SizeEstimator extends Logging { // TODO: We could use reflection on the VMOption returned ? getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") } catch { - case e: Exception => { + case e: Exception => // Guess whether they've enabled UseCompressedOops based on whether maxMemory < 32 GB val guess = Runtime.getRuntime.maxMemory < (32L*1024*1024*1024) val guessInWords = if (guess) "yes" else "not" logWarning("Failed to check whether UseCompressedOops is set; assuming " + guessInWords) return guess - } } } diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 67d722c1dc..2110d3d770 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -320,7 +320,7 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala index 3706455c3f..8feb3dee05 100644 --- a/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala +++ b/core/src/test/scala/org/apache/spark/SparkContextInfoSuite.scala @@ -82,20 +82,18 @@ package object testPackage extends Assertions { val curCallSite = sc.getCallSite().shortForm // note: 2 lines after definition of "rdd" val rddCreationLine = rddCreationSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "makeRDD") assert(file === "SparkContextInfoSuite.scala") line.toInt - } case _ => fail("Did not match expected call site format") } curCallSite match { - case CALL_SITE_REGEX(func, file, line) => { + case CALL_SITE_REGEX(func, file, line) => assert(func === "getCallSite") // this is correct because we called it from outside of Spark assert(file === "SparkContextInfoSuite.scala") assert(line.toInt === rddCreationLine.toInt + 2) - } case _ => fail("Did not match expected call site format") } } diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index f7a13ab399..09e21646ee 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -35,7 +35,7 @@ class UnpersistSuite extends SparkFunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case _: Throwable => { Thread.sleep(10) } + case _: Throwable => Thread.sleep(10) // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 43e61241b6..cebac2097f 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -127,9 +127,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val reply = rpcEndpointRef.askWithRetry[String]("hello") @@ -141,9 +140,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) @@ -164,10 +162,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => Thread.sleep(100) context.reply(msg) - } } }) @@ -317,10 +314,9 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receive: PartialFunction[Any, Unit] = { - case m => { + case m => self callSelfSuccessfully = true - } } }) @@ -682,9 +678,8 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { override val rpcEnv = localEnv override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { - case msg: String => { + case msg: String => context.reply(msg) - } } }) val rpcEndpointRef = remoteEnv.setupEndpointRef(localEnv.address, "ask-authentication") diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala index 973b005f91..ca4eea2356 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraCQLTest.scala @@ -106,9 +106,8 @@ object CassandraCQLTest { println("Count: " + casRdd.count) val productSaleRDD = casRdd.map { - case (key, value) => { + case (key, value) => (ByteBufferUtil.string(value.get("prod_id")), ByteBufferUtil.toInt(value.get("quantity"))) - } } val aggregatedRDD = productSaleRDD.reduceByKey(_ + _) aggregatedRDD.collect().foreach { @@ -116,11 +115,10 @@ object CassandraCQLTest { } val casoutputCF = aggregatedRDD.map { - case (productId, saleCount) => { + case (productId, saleCount) => val outKey = Collections.singletonMap("prod_id", ByteBufferUtil.bytes(productId)) val outVal = Collections.singletonList(ByteBufferUtil.bytes(saleCount)) (outKey, outVal) - } } casoutputCF.saveAsNewAPIHadoopFile( diff --git a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala index 6a8f73ad00..eff840d36e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/CassandraTest.scala @@ -90,9 +90,8 @@ object CassandraTest { // Let us first get all the paragraphs from the retrieved rows val paraRdd = casRdd.map { - case (key, value) => { + case (key, value) => ByteBufferUtil.string(value.get(ByteBufferUtil.bytes("para")).value()) - } } // Lets get the word count in paras @@ -103,7 +102,7 @@ object CassandraTest { } counts.map { - case (word, count) => { + case (word, count) => val colWord = new org.apache.cassandra.thrift.Column() colWord.setName(ByteBufferUtil.bytes("word")) colWord.setValue(ByteBufferUtil.bytes(word)) @@ -122,7 +121,6 @@ object CassandraTest { mutations.get(1).setColumn_or_supercolumn(new ColumnOrSuperColumn()) mutations.get(1).column_or_supercolumn.setColumn(colCount) (outputkey, mutations) - } }.saveAsNewAPIHadoopFile("casDemo", classOf[ByteBuffer], classOf[List[Mutation]], classOf[ColumnFamilyOutputFormat], job.getConfiguration) diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index af5f216f28..fa10101955 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -104,16 +104,14 @@ object LocalALS { def main(args: Array[String]) { args match { - case Array(m, u, f, iters) => { + case Array(m, u, f, iters) => M = m.toInt U = u.toInt F = f.toInt ITERATIONS = iters.toInt - } - case _ => { + case _ => System.err.println("Usage: LocalALS ") System.exit(1) - } } showWarning() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala index a0bb5dabf4..0b5d31c0ff 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/OneVsRestExample.scala @@ -118,17 +118,15 @@ object OneVsRestExample { val inputData = sqlContext.read.format("libsvm").load(params.input) // compute the train/test split: if testInput is not provided use part of input. val data = params.testInput match { - case Some(t) => { + case Some(t) => // compute the number of features in the training set. val numFeatures = inputData.first().getAs[Vector](1).size val testData = sqlContext.read.option("numFeatures", numFeatures.toString) .format("libsvm").load(t) Array[DataFrame](inputData, testData) - } - case None => { + case None => val f = params.fracTest inputData.randomSplit(Array(1 - f, f), seed = 12345) - } } val Array(train, test) = data.map(_.cache()) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala index c263f4f595..ee811d3aa1 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala @@ -180,7 +180,7 @@ object DecisionTreeRunner { } // For classification, re-index classes if needed. val (examples, classIndexMap, numClasses) = algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val classCounts = origExamples.map(_.label).countByValue() val sortedClasses = classCounts.keys.toList.sorted @@ -209,7 +209,6 @@ object DecisionTreeRunner { println(s"$c\t$frac\t${classCounts(c)}") } (examples, classIndexMap, numClasses) - } case Regression => (origExamples, null, 0) case _ => @@ -225,7 +224,7 @@ object DecisionTreeRunner { case "libsvm" => MLUtils.loadLibSVMFile(sc, testInput, numFeatures) } algo match { - case Classification => { + case Classification => // classCounts: class --> # examples in class val testExamples = { if (classIndexMap.isEmpty) { @@ -235,7 +234,6 @@ object DecisionTreeRunner { } } Array(examples, testExamples) - } case Regression => Array(examples, origTestExamples) } diff --git a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala index 41c6ab123b..80e0cce055 100644 --- a/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala +++ b/external/kinesis-asl/src/main/scala/org/apache/spark/streaming/kinesis/KinesisRecordProcessor.scala @@ -73,7 +73,7 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w logDebug(s"Stored: Worker $workerId stored ${batch.size} records for shardId $shardId") receiver.setCheckpointer(shardId, checkpointer) } catch { - case NonFatal(e) => { + case NonFatal(e) => /* * If there is a failure within the batch, the batch will not be checkpointed. * This will potentially cause records since the last checkpoint to be processed @@ -84,7 +84,6 @@ private[kinesis] class KinesisRecordProcessor[T](receiver: KinesisReceiver[T], w /* Rethrow the exception to the Kinesis Worker that is managing this RecordProcessor. */ throw e - } } } else { /* RecordProcessor has been stopped. */ @@ -148,29 +147,25 @@ private[kinesis] object KinesisRecordProcessor extends Logging { /* If the function failed, either retry or throw the exception */ case util.Failure(e) => e match { /* Retry: Throttling or other Retryable exception has occurred */ - case _: ThrottlingException | _: KinesisClientLibDependencyException if numRetriesLeft > 1 - => { - val backOffMillis = Random.nextInt(maxBackOffMillis) - Thread.sleep(backOffMillis) - logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) - retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) - } + case _: ThrottlingException | _: KinesisClientLibDependencyException + if numRetriesLeft > 1 => + val backOffMillis = Random.nextInt(maxBackOffMillis) + Thread.sleep(backOffMillis) + logError(s"Retryable Exception: Random backOffMillis=${backOffMillis}", e) + retryRandom(expression, numRetriesLeft - 1, maxBackOffMillis) /* Throw: Shutdown has been requested by the Kinesis Client Library. */ - case _: ShutdownException => { + case _: ShutdownException => logError(s"ShutdownException: Caught shutdown exception, skipping checkpoint.", e) throw e - } /* Throw: Non-retryable exception has occurred with the Kinesis Client Library */ - case _: InvalidStateException => { + case _: InvalidStateException => logError(s"InvalidStateException: Cannot save checkpoint to the DynamoDB table used" + s" by the Amazon Kinesis Client Library. Table likely doesn't exist.", e) throw e - } /* Throw: Unexpected exception has occurred */ - case _ => { + case _ => logError(s"Unexpected, non-retryable exception.", e) throw e - } } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala index 551e75dc0a..fa143715be 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -53,7 +53,7 @@ private[r] object SparkRWrappers { def getModelCoefficients(model: PipelineModel): Array[Double] = { model.stages.last match { - case m: LinearRegressionModel => { + case m: LinearRegressionModel => val coefficientStandardErrorsR = Array(m.summary.coefficientStandardErrors.last) ++ m.summary.coefficientStandardErrors.dropRight(1) val tValuesR = Array(m.summary.tValues.last) ++ m.summary.tValues.dropRight(1) @@ -64,14 +64,12 @@ private[r] object SparkRWrappers { } else { m.coefficients.toArray ++ coefficientStandardErrorsR ++ tValuesR ++ pValuesR } - } - case m: LogisticRegressionModel => { + case m: LogisticRegressionModel => if (m.getFitIntercept) { Array(m.intercept) ++ m.coefficients.toArray } else { m.coefficients.toArray } - } } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala index 03eb903bb8..f04c87259c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixture.scala @@ -181,13 +181,12 @@ class GaussianMixture private ( val (weights, gaussians) = initialModel match { case Some(gmm) => (gmm.weights, gmm.gaussians) - case None => { + case None => val samples = breezeData.takeSample(withReplacement = true, k * nSamples, seed) (Array.fill(k)(1.0 / k), Array.tabulate(k) { i => val slice = samples.view(i * nSamples, (i + 1) * nSamples) new MultivariateGaussian(vectorMean(slice), initCovariance(slice)) }) - } } var llh = Double.MinValue // current log-likelihood diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala index 02417b1124..f87613cc72 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/GaussianMixtureModel.scala @@ -183,7 +183,7 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { val k = (metadata \ "k").extract[Int] val classNameV1_0 = SaveLoadV1_0.classNameV1_0 (loadedClassName, version) match { - case (classNameV1_0, "1.0") => { + case (classNameV1_0, "1.0") => val model = SaveLoadV1_0.load(sc, path) require(model.weights.length == k, s"GaussianMixtureModel requires weights of length $k " + @@ -192,7 +192,6 @@ object GaussianMixtureModel extends Loader[GaussianMixtureModel] { s"GaussianMixtureModel requires gaussians of length $k" + s"got gaussians of length ${model.gaussians.length}") model - } case _ => throw new Exception( s"GaussianMixtureModel.load did not recognize model with (className, format version):" + s"($loadedClassName, $version). Supported:\n" + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 37a21cd879..8ff0b83e8b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -253,16 +253,14 @@ class KMeans private ( } val centers = initialModel match { - case Some(kMeansCenters) => { + case Some(kMeansCenters) => Array(kMeansCenters.clusterCenters.map(s => new VectorWithNorm(s))) - } - case None => { + case None => if (initializationMode == KMeans.RANDOM) { initRandom(data) } else { initKMeansParallel(data) } - } } val initTimeInSeconds = (System.nanoTime() - initStartTime) / 1e9 logInfo(s"Initialization with $initializationMode took " + "%.3f".format(initTimeInSeconds) + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala index 0ec8975fed..ef284531c9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/test/KolmogorovSmirnovTest.scala @@ -97,7 +97,7 @@ private[stat] object KolmogorovSmirnovTest extends Logging { : KolmogorovSmirnovTestResult = { val distObj = distName match { - case "norm" => { + case "norm" => if (params.nonEmpty) { // parameters are passed, then can only be 2 require(params.length == 2, "Normal distribution requires mean and standard " + @@ -109,7 +109,6 @@ private[stat] object KolmogorovSmirnovTest extends Logging { "initialized to standard normal (i.e. N(0, 1))") new NormalDistribution(0, 1) } - } case _ => throw new UnsupportedOperationException(s"$distName not yet supported through" + s" convenience method. Current options are:['norm'].") } diff --git a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala index 928aaa5629..4a15d52b57 100644 --- a/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala +++ b/repl/src/main/scala/org/apache/spark/repl/ExecutorClassLoader.scala @@ -70,26 +70,24 @@ class ExecutorClassLoader( } override def findClass(name: String): Class[_] = { - userClassPathFirst match { - case true => findClassLocally(name).getOrElse(parentLoader.loadClass(name)) - case false => { - try { - parentLoader.loadClass(name) - } catch { - case e: ClassNotFoundException => { - val classOption = findClassLocally(name) - classOption match { - case None => - // If this class has a cause, it will break the internal assumption of Janino - // (the compiler used for Spark SQL code-gen). - // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see - // its behavior will be changed if there is a cause and the compilation - // of generated class will fail. - throw new ClassNotFoundException(name) - case Some(a) => a - } + if (userClassPathFirst) { + findClassLocally(name).getOrElse(parentLoader.loadClass(name)) + } else { + try { + parentLoader.loadClass(name) + } catch { + case e: ClassNotFoundException => + val classOption = findClassLocally(name) + classOption match { + case None => + // If this class has a cause, it will break the internal assumption of Janino + // (the compiler used for Spark SQL code-gen). + // See org.codehaus.janino.ClassLoaderIClassLoader's findIClass, you will see + // its behavior will be changed if there is a cause and the compilation + // of generated class will fail. + throw new ClassNotFoundException(name) + case Some(a) => a } - } } } } diff --git a/scalastyle-config.xml b/scalastyle-config.xml index 472a8f4084..a14e3e583f 100644 --- a/scalastyle-config.xml +++ b/scalastyle-config.xml @@ -228,6 +228,11 @@ This file is divided into 3 sections: Use Javadoc style indentation for multiline comments + + case[^\n>]*=>\s*\{ + Omit braces in case clauses. + + diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d842ffdc66..0f8876a9e6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -898,7 +898,7 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w val result = ctx.freshName("result") val tmpRow = ctx.freshName("tmpRow") - val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => { + val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) => val fromFieldPrim = ctx.freshName("ffp") val fromFieldNull = ctx.freshName("ffn") val toFieldPrim = ctx.freshName("tfp") @@ -920,7 +920,6 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w } } """ - } }.mkString("\n") (c, evPrim, evNull) => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala index affd1bdb32..8d8cc152ff 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/EquivalentExpressions.scala @@ -97,11 +97,11 @@ class EquivalentExpressions { def debugString(all: Boolean = false): String = { val sb: mutable.StringBuilder = new StringBuilder() sb.append("Equivalent expressions:\n") - equivalenceMap.foreach { case (k, v) => { + equivalenceMap.foreach { case (k, v) => if (all || v.length > 1) { sb.append(" " + v.mkString(", ")).append("\n") } - }} + } sb.toString() } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala index 8207d64798..711e870711 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/RandomDataGenerator.scala @@ -196,12 +196,11 @@ object RandomDataGenerator { case ShortType => randomNumeric[Short]( rand, _.nextInt().toShort, Seq(Short.MinValue, Short.MaxValue, 0.toShort)) case NullType => Some(() => null) - case ArrayType(elementType, containsNull) => { + case ArrayType(elementType, containsNull) => forType(elementType, nullable = containsNull, rand).map { elementGenerator => () => Seq.fill(rand.nextInt(MAX_ARR_SIZE))(elementGenerator()) } - } - case MapType(keyType, valueType, valueContainsNull) => { + case MapType(keyType, valueType, valueContainsNull) => for ( keyGenerator <- forType(keyType, nullable = false, rand); valueGenerator <- @@ -221,8 +220,7 @@ object RandomDataGenerator { keys.zip(values).toMap } } - } - case StructType(fields) => { + case StructType(fields) => val maybeFieldGenerators: Seq[Option[() => Any]] = fields.map { field => forType(field.dataType, nullable = field.nullable, rand) } @@ -232,8 +230,7 @@ object RandomDataGenerator { } else { None } - } - case udt: UserDefinedType[_] => { + case udt: UserDefinedType[_] => val maybeSqlTypeGenerator = forType(udt.sqlType, nullable, rand) // Because random data generator at here returns scala value, we need to // convert it to catalyst value to call udt's deserialize. @@ -253,7 +250,6 @@ object RandomDataGenerator { } else { None } - } case unsupportedType => None } // Handle nullability by wrapping the non-null value generator: @@ -277,7 +273,7 @@ object RandomDataGenerator { val fields = mutable.ArrayBuffer.empty[Any] schema.fields.foreach { f => f.dataType match { - case ArrayType(childType, nullable) => { + case ArrayType(childType, nullable) => val data = if (f.nullable && rand.nextFloat() <= PROBABILITY_OF_NULL) { null } else { @@ -294,10 +290,8 @@ object RandomDataGenerator { arr } fields += data - } - case StructType(children) => { + case StructType(children) => fields += randomRow(rand, StructType(children)) - } case _ => val generator = RandomDataGenerator.forType(f.dataType, f.nullable, rand) assert(generator.isDefined, "Unsupported type") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 86c6405522..e953a6e8ef 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -1153,14 +1153,12 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyNonExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => atFirstAgg = !atFirstAgg - } - case _ => { + case _ => if (atFirstAgg) { fail("Should not have operators between the two aggregations") } - } } } @@ -1170,12 +1168,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { private def verifyExchangingAgg(df: DataFrame) = { var atFirstAgg: Boolean = false df.queryExecution.executedPlan.foreach { - case agg: TungstenAggregate => { + case agg: TungstenAggregate => if (atFirstAgg) { fail("Should not have back to back Aggregates") } atFirstAgg = true - } case e: ShuffleExchange => atFirstAgg = false case _ => } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 8a551cd78c..31b63f2ce1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -612,23 +612,20 @@ class ColumnarBatchSuite extends SparkFunSuite { val a2 = r2.getList(v._2).toArray assert(a1.length == a2.length, "Seed = " + seed) childType match { - case DoubleType => { + case DoubleType => var i = 0 while (i < a1.length) { assert(doubleEquals(a1(i).asInstanceOf[Double], a2(i).asInstanceOf[Double]), "Seed = " + seed) i += 1 } - } - case FloatType => { + case FloatType => var i = 0 while (i < a1.length) { assert(doubleEquals(a1(i).asInstanceOf[Float], a2(i).asInstanceOf[Float]), "Seed = " + seed) i += 1 } - } - case t: DecimalType => var i = 0 while (i < a1.length) { @@ -640,7 +637,6 @@ class ColumnarBatchSuite extends SparkFunSuite { } i += 1 } - case _ => assert(a1 === a2, "Seed = " + seed) } case StructType(childFields) => diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala index c40beeff97..58842f9c2f 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStream.scala @@ -429,13 +429,12 @@ abstract class DStream[T: ClassTag] ( */ private[streaming] def generateJob(time: Time): Option[Job] = { getOrCompute(time) match { - case Some(rdd) => { + case Some(rdd) => val jobFunc = () => { val emptyFunc = { (iterator: Iterator[T]) => {} } context.sparkContext.runJob(rdd, emptyFunc) } Some(new Job(time, jobFunc)) - } case None => None } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala index 431c9dbe2c..e73837eb96 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/DStreamCheckpointData.scala @@ -109,10 +109,9 @@ class DStreamCheckpointData[T: ClassTag](dstream: DStream[T]) def restore() { // Create RDDs from the checkpoint data currentCheckpointFiles.foreach { - case(time, file) => { + case(time, file) => logInfo("Restoring checkpointed RDD for time " + time + " from file '" + file + "'") dstream.generatedRDDs += ((time, dstream.context.sparkContext.checkpointFile[T](file))) - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala index 7fba2e8ec0..36f50e04db 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FileInputDStream.scala @@ -333,14 +333,13 @@ class FileInputDStream[K, V, F <: NewInputFormat[K, V]]( override def restore() { hadoopFiles.toSeq.sortBy(_._1)(Time.ordering).foreach { - case (t, f) => { + case (t, f) => // Restore the metadata in both files and generatedRDDs logInfo("Restoring files for time " + t + " - " + f.mkString("[", ", ", "]") ) batchTimeToSelectedFiles.synchronized { batchTimeToSelectedFiles += ((t, f)) } recentlySelectedFiles ++= f generatedRDDs += ((t, filesToRDD(f))) - } } } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 0379957e58..28aed0ca45 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -65,14 +65,12 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // Try to get the previous state RDD getOrCompute(validTime - slideDuration) match { - case Some(prevStateRDD) => { // If previous state RDD exists - + case Some(prevStateRDD) => // If previous state RDD exists // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual + case Some(parentRDD) => // If parent RDD exists, then compute as usual computeUsingPreviousRDD(parentRDD, prevStateRDD) - } - case None => { // If parent RDD does not exist + case None => // If parent RDD does not exist // Re-apply the update function to the old state RDD val updateFuncLocal = updateFunc @@ -82,17 +80,14 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( } val stateRDD = prevStateRDD.mapPartitions(finalFunc, preservePartitioning) Some(stateRDD) - } } - } - - case None => { // If previous session RDD does not exist (first input data) + case None => // If previous session RDD does not exist (first input data) // Try to get the parent RDD parent.getOrCompute(validTime) match { - case Some(parentRDD) => { // If parent RDD exists, then compute as usual + case Some(parentRDD) => // If parent RDD exists, then compute as usual initialRDD match { - case None => { + case None => // Define the function for the mapPartition operation on grouped RDD; // first map the grouped tuple to tuples of required type, // and then apply the update function @@ -105,18 +100,13 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( val sessionRDD = groupedRDD.mapPartitions(finalFunc, preservePartitioning) // logDebug("Generating state RDD for time " + validTime + " (first)") Some(sessionRDD) - } - case Some(initialStateRDD) => { + case Some(initialStateRDD) => computeUsingPreviousRDD(parentRDD, initialStateRDD) - } } - } - case None => { // If parent RDD does not exist, then nothing to do! + case None => // If parent RDD does not exist, then nothing to do! // logDebug("Not generating state RDD (no previous state, no parent)") None - } } - } } } } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala index bd60059b18..cfcbdc7c38 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala @@ -538,10 +538,9 @@ class BasicOperationsSuite extends TestSuiteBase { val stateObj = state.getOrElse(new StateObject) values.sum match { case 0 => stateObj.expireCounter += 1 // no new values - case n => { // has new values, increment and reset expireCounter + case n => // has new values, increment and reset expireCounter stateObj.counter += n stateObj.expireCounter = 0 - } } stateObj.expireCounter match { case 2 => None // seen twice with no new values, give it the boot diff --git a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala index fbb25d4c59..bdbac64b9b 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala @@ -267,10 +267,9 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before first failure") stateStream.checkpointData.currentCheckpointFiles.foreach { - case (time, file) => { + case (time, file) => assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before first failure does not exist") - } } // Run till a further time such that previous checkpoint files in the stream would be deleted @@ -297,10 +296,9 @@ class CheckpointSuite extends TestSuiteBase with DStreamCheckpointTester assert(!stateStream.checkpointData.currentCheckpointFiles.isEmpty, "No checkpointed RDDs in state stream before second failure") stateStream.checkpointData.currentCheckpointFiles.foreach { - case (time, file) => { + case (time, file) => assert(fs.exists(new Path(file)), "Checkpoint file '" + file +"' for time " + time + " for state stream before seconds failure does not exist") - } } ssc.stop() diff --git a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala index 29bee4adf2..60c8e70235 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/MasterFailureTest.scala @@ -382,11 +382,10 @@ class FileGeneratingThread(input: Seq[String], testDir: Path, interval: Long) fs.rename(tempHadoopFile, hadoopFile) done = true } catch { - case ioe: IOException => { + case ioe: IOException => fs = testDir.getFileSystem(new Configuration()) logWarning("Attempt " + tries + " at generating file " + hadoopFile + " failed.", ioe) - } } } if (!done) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9e8453429c..d447a59937 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -374,7 +374,7 @@ private[spark] class ApplicationMaster( failureCount = 0 } catch { case i: InterruptedException => - case e: Throwable => { + case e: Throwable => failureCount += 1 // this exception was introduced in hadoop 2.4 and this code would not compile // with earlier versions if we refer it directly. @@ -390,7 +390,6 @@ private[spark] class ApplicationMaster( } else { logWarning(s"Reporter thread fails $failureCount time(s) in a row.", e) } - } } try { val numPendingAllocate = allocator.getPendingAllocate.size diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala index b0bfe855e9..23742eab62 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocator.scala @@ -148,11 +148,10 @@ private[yarn] class YarnAllocator( classOf[Array[String]], classOf[Array[String]], classOf[Priority], classOf[Boolean], classOf[String])) } catch { - case e: NoSuchMethodException => { + case e: NoSuchMethodException => logWarning(s"Node label expression $expr will be ignored because YARN version on" + " classpath does not support it.") None - } } } diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index 8720ee57fe..6b3c831e60 100644 --- a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -223,17 +223,15 @@ private[spark] abstract class YarnSchedulerBackend( val lossReasonRequest = GetExecutorLossReason(executorId) val future = am.ask[ExecutorLossReason](lossReasonRequest, askTimeout) future onSuccess { - case reason: ExecutorLossReason => { + case reason: ExecutorLossReason => driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, reason)) - } } future onFailure { - case NonFatal(e) => { + case NonFatal(e) => logWarning(s"Attempted to get executor loss reason" + s" for executor id ${executorId} at RPC address ${executorRpcAddress}," + s" but got no response. Marking as slave lost.", e) driverEndpoint.askWithRetry[Boolean](RemoveExecutor(executorId, SlaveLost())) - } case t => throw t } case None => diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala index de14e36f4e..fe09808ae5 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtilSuite.scala @@ -101,22 +101,18 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) viewAcls match { - case Some(vacls) => { + case Some(vacls) => val aclSet = vacls.split(',').map(_.trim).toSet assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } modifyAcls match { - case Some(macls) => { + case Some(macls) => val aclSet = macls.split(',').map(_.trim).toSet assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } } @@ -135,26 +131,22 @@ class YarnSparkHadoopUtilSuite extends SparkFunSuite with Matchers with Logging val modifyAcls = acls.get(ApplicationAccessType.MODIFY_APP) viewAcls match { - case Some(vacls) => { + case Some(vacls) => val aclSet = vacls.split(',').map(_.trim).toSet assert(aclSet.contains("user1")) assert(aclSet.contains("user2")) assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } modifyAcls match { - case Some(macls) => { + case Some(macls) => val aclSet = macls.split(',').map(_.trim).toSet assert(aclSet.contains("user3")) assert(aclSet.contains("user4")) assert(aclSet.contains(System.getProperty("user.name", "invalid"))) - } - case None => { + case None => fail() - } } } -- cgit v1.2.3 From 85e68b4bea3e4ad2e4063334dbf5b11af197d2ce Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Apr 2016 12:29:54 -0700 Subject: [SPARK-14562] [SQL] improve constraints propagation in Union ## What changes were proposed in this pull request? Currently, Union only takes intersect of the constraints from it's children, all others are dropped, we should try to merge them together. This PR try to merge the constraints that have the same reference but came from different children, for example: `a > 10` and `a < 100` could be merged as `a > 10 || a < 100`. ## How was this patch tested? Added more cases in existing test. Author: Davies Liu Closes #12328 from davies/union_const. --- .../sql/catalyst/plans/logical/basicOperators.scala | 16 +++++++++++++++- .../sql/catalyst/plans/ConstraintPropagationSuite.scala | 14 ++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index d3353beb09..d4fc9e4da9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -236,10 +236,24 @@ case class Union(children: Seq[LogicalPlan]) extends LogicalPlan { }) } + private def merge(a: Set[Expression], b: Set[Expression]): Set[Expression] = { + val common = a.intersect(b) + // The constraint with only one reference could be easily inferred as predicate + // Grouping the constraints by it's references so we can combine the constraints with same + // reference together + val othera = a.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + val otherb = b.diff(common).filter(_.references.size == 1).groupBy(_.references.head) + // loose the constraints by: A1 && B1 || A2 && B2 -> (A1 || A2) && (B1 || B2) + val others = (othera.keySet intersect otherb.keySet).map { attr => + Or(othera(attr).reduceLeft(And), otherb(attr).reduceLeft(And)) + } + common ++ others + } + override protected def validConstraints: Set[Expression] = { children .map(child => rewriteConstraints(children.head.output, child.output, child.constraints)) - .reduce(_ intersect _) + .reduce(merge(_, _)) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala index 49c1353efb..81cc6b123c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/ConstraintPropagationSuite.scala @@ -148,6 +148,20 @@ class ConstraintPropagationSuite extends SparkFunSuite { .analyze.constraints, ExpressionSet(Seq(resolveColumn(tr1, "a") > 10, IsNotNull(resolveColumn(tr1, "a"))))) + + val a = resolveColumn(tr1, "a") + verifyConstraints(tr1 + .where('a.attr > 10) + .union(tr2.where('d.attr > 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, IsNotNull(a)))) + + val b = resolveColumn(tr1, "b") + verifyConstraints(tr1 + .where('a.attr > 10 && 'b.attr < 10) + .union(tr2.where('d.attr > 11 && 'e.attr < 11)) + .analyze.constraints, + ExpressionSet(Seq(a > 10 || a > 11, b < 10 || b < 11, IsNotNull(a), IsNotNull(b)))) } test("propagating constraints in intersect") { -- cgit v1.2.3 From bcd2076274b1a95f74616d0ceacb0696e38b5f4c Mon Sep 17 00:00:00 2001 From: bomeng Date: Tue, 12 Apr 2016 13:43:39 -0700 Subject: [SPARK-14414][SQL] improve the error message class hierarchy ## What changes were proposed in this pull request? Before we are using `AnalysisException`, `ParseException`, `NoSuchFunctionException` etc when a parsing error encounters. I am trying to make it consistent and also **minimum** code impact to the current implementation by changing the class hierarchy. 1. `NoSuchItemException` is removed, since it is an abstract class and it just simply takes a message string. 2. `NoSuchDatabaseException`, `NoSuchTableException`, `NoSuchPartitionException` and `NoSuchFunctionException` now extends `AnalysisException`, as well as `ParseException`, they are all under `AnalysisException` umbrella, but you can also determine how to use them in a granular way. ## How was this patch tested? The existing test cases should cover this patch. Author: bomeng Closes #12314 from bomeng/SPARK-14414. --- .../catalyst/analysis/NoSuchItemException.scala | 31 ++++++---------------- .../apache/spark/sql/execution/command/ddl.scala | 1 + .../spark/sql/hive/HiveExternalCatalog.scala | 3 --- .../spark/sql/hive/execution/HiveQuerySuite.scala | 1 - 4 files changed, 9 insertions(+), 27 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 96fd1a027e..5e18316c94 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec @@ -24,29 +25,13 @@ import org.apache.spark.sql.catalyst.catalog.ExternalCatalog.TablePartitionSpec * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -abstract class NoSuchItemException extends Exception { - override def getMessage: String -} +class NoSuchDatabaseException(db: String) extends AnalysisException(s"Database $db not found") -class NoSuchDatabaseException(db: String) extends NoSuchItemException { - override def getMessage: String = s"Database $db not found" -} +class NoSuchTableException(db: String, table: String) + extends AnalysisException(s"Table or View $table not found in database $db") -class NoSuchTableException(db: String, table: String) extends NoSuchItemException { - override def getMessage: String = s"Table or View $table not found in database $db" -} +class NoSuchPartitionException(db: String, table: String, spec: TablePartitionSpec) extends + AnalysisException(s"Partition not found in table $table database $db:\n" + spec.mkString("\n")) -class NoSuchPartitionException( - db: String, - table: String, - spec: TablePartitionSpec) - extends NoSuchItemException { - - override def getMessage: String = { - s"Partition not found in table $table database $db:\n" + spec.mkString("\n") - } -} - -class NoSuchFunctionException(db: String, func: String) extends NoSuchItemException { - override def getMessage: String = s"Function $func not found in database $db" -} +class NoSuchFunctionException(db: String, func: String) + extends AnalysisException(s"Function $func not found in database $db") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 758a7e45d2..5137bd11d8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.types._ + // Note: The definition of these commands are based on the ones described in // https://cwiki.apache.org/confluence/display/Hive/LanguageManual+DDL diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala index 482f47428d..f627384253 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveExternalCatalog.scala @@ -25,7 +25,6 @@ import org.apache.thrift.TException import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.NoSuchItemException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.hive.client.HiveClient @@ -66,8 +65,6 @@ private[spark] class HiveExternalCatalog(client: HiveClient) extends ExternalCat try { body } catch { - case e: NoSuchItemException => - throw new AnalysisException(e.getMessage) case NonFatal(e) if isClientException(e) => throw new AnalysisException(e.getClass.getCanonicalName + ": " + e.getMessage) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala index 0c57ede9ed..af73baa1f3 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala @@ -28,7 +28,6 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.{SparkException, SparkFiles} import org.apache.spark.sql.{AnalysisException, DataFrame, Row} -import org.apache.spark.sql.catalyst.analysis.NoSuchDatabaseException import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.joins.BroadcastNestedLoopJoin -- cgit v1.2.3 From 372baf0479840695388515170e6eae0b3fc4125e Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Tue, 12 Apr 2016 17:26:37 -0700 Subject: [SPARK-14578] [SQL] Fix codegen for CreateExternalRow with nested wide schema ## What changes were proposed in this pull request? The wide schema, the expression of fields will be splitted into multiple functions, but the variable for loopVar can't be accessed in splitted functions, this PR change them as class member. ## How was this patch tested? Added regression test. Author: Davies Liu Closes #12338 from davies/nested_row. --- .../apache/spark/sql/catalyst/expressions/objects.scala | 8 +++++--- .../spark/sql/execution/datasources/json/JsonSuite.scala | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index 28b6b2adf8..26b1ff39b3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -446,6 +446,8 @@ case class MapObjects private( override def genCode(ctx: CodegenContext, ev: ExprCode): String = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) + ctx.addMutableState("boolean", loopVar.isNull, "") + ctx.addMutableState(elementJavaType, loopVar.value, "") val genInputData = inputData.gen(ctx) val genFunction = lambdaFunction.gen(ctx) val dataLength = ctx.freshName("dataLength") @@ -466,9 +468,9 @@ case class MapObjects private( } val loopNullCheck = if (primitiveElement) { - s"boolean ${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" + s"${loopVar.isNull} = ${genInputData.value}.isNullAt($loopIndex);" } else { - s"boolean ${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" + s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } s""" @@ -484,7 +486,7 @@ case class MapObjects private( int $loopIndex = 0; while ($loopIndex < $dataLength) { - $elementJavaType ${loopVar.value} = + ${loopVar.value} = ($elementJavaType)${genInputData.value}${itemAccessor(loopIndex)}; $loopNullCheck diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala index 2a18acb95b..e17340c70b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala @@ -1664,4 +1664,19 @@ class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData { ) } } + + test("wide nested json table") { + val nested = (1 to 100).map { i => + s""" + |"c$i": $i + """.stripMargin + }.mkString(", ") + val json = s""" + |{"a": [{$nested}], "b": [{$nested}]} + """.stripMargin + val rdd = sqlContext.sparkContext.makeRDD(Seq(json)) + val df = sqlContext.read.json(rdd) + assert(df.schema.size === 2) + df.collect() + } } -- cgit v1.2.3 From 7d2ed8cc030f3d84fea47fded072c320c3d87ca7 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Wed, 13 Apr 2016 11:08:34 -0700 Subject: [SPARK-14388][SQL] Implement CREATE TABLE ## What changes were proposed in this pull request? This patch implements the `CREATE TABLE` command using the `SessionCatalog`. Previously we handled only `CTAS` and `CREATE TABLE ... USING`. This requires us to refactor `CatalogTable` to accept various fields (e.g. bucket and skew columns) and pass them to Hive. WIP: Note that I haven't verified whether this actually works yet! But I believe it does. ## How was this patch tested? Tests will come in a future commit. Author: Andrew Or Author: Yin Huai Closes #12271 from andrewor14/create-table-ddl. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 3 +- .../spark/sql/catalyst/catalog/interface.scala | 26 ++- .../sql/catalyst/catalog/CatalogTestCases.scala | 8 +- .../spark/sql/execution/SparkSqlParser.scala | 25 +-- .../apache/spark/sql/execution/command/ddl.scala | 23 --- .../spark/sql/execution/command/tables.scala | 80 ++++++++ .../sql/execution/command/DDLCommandSuite.scala | 20 +- .../spark/sql/execution/command/DDLSuite.scala | 2 - .../spark/sql/hive/thriftserver/CliSuite.scala | 8 +- .../hive/execution/HiveCompatibilitySuite.scala | 128 ++++++------ .../spark/sql/hive/HiveMetastoreCatalog.scala | 16 +- .../spark/sql/hive/client/HiveClientImpl.scala | 44 ++-- .../spark/sql/hive/execution/HiveSqlParser.scala | 223 +++++++++++++-------- .../spark/sql/hive/HiveDDLCommandSuite.scala | 217 ++++++++++++++++++-- .../spark/sql/hive/HiveMetastoreCatalogSuite.scala | 4 +- .../spark/sql/hive/InsertIntoHiveTableSuite.scala | 2 +- .../spark/sql/hive/execution/PruningSuite.scala | 2 +- .../spark/sql/hive/execution/SQLQuerySuite.scala | 2 +- 18 files changed, 571 insertions(+), 262 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0e2cd39448..a937ad1eb7 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -272,8 +272,7 @@ createFileFormat ; fileFormat - : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? - (INPUTDRIVER inDriver=STRING OUTPUTDRIVER outDriver=STRING)? #tableFileFormat + : INPUTFORMAT inFmt=STRING OUTPUTFORMAT outFmt=STRING (SERDE serdeCls=STRING)? #tableFileFormat | identifier #genericFileFormat ; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 4ef59316ce..ad989a97e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -220,14 +220,30 @@ case class CatalogTable( tableType: CatalogTableType, storage: CatalogStorageFormat, schema: Seq[CatalogColumn], - partitionColumns: Seq[CatalogColumn] = Seq.empty, - sortColumns: Seq[CatalogColumn] = Seq.empty, - numBuckets: Int = 0, + partitionColumnNames: Seq[String] = Seq.empty, + sortColumnNames: Seq[String] = Seq.empty, + bucketColumnNames: Seq[String] = Seq.empty, + numBuckets: Int = -1, createTime: Long = System.currentTimeMillis, - lastAccessTime: Long = System.currentTimeMillis, + lastAccessTime: Long = -1, properties: Map[String, String] = Map.empty, viewOriginalText: Option[String] = None, - viewText: Option[String] = None) { + viewText: Option[String] = None, + comment: Option[String] = None) { + + // Verify that the provided columns are part of the schema + private val colNames = schema.map(_.name).toSet + private def requireSubsetOfSchema(cols: Seq[String], colType: String): Unit = { + require(cols.toSet.subsetOf(colNames), s"$colType columns (${cols.mkString(", ")}) " + + s"must be a subset of schema (${colNames.mkString(", ")}) in table '$identifier'") + } + requireSubsetOfSchema(partitionColumnNames, "partition") + requireSubsetOfSchema(sortColumnNames, "sort") + requireSubsetOfSchema(bucketColumnNames, "bucket") + + /** Columns this table is partitioned by. */ + def partitionColumns: Seq[CatalogColumn] = + schema.filter { c => partitionColumnNames.contains(c.name) } /** Return the database this table was specified to belong to, assuming it exists. */ def database: String = identifier.database.getOrElse { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala index 0d9b0851fa..f961fe3292 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/CatalogTestCases.scala @@ -553,8 +553,12 @@ abstract class CatalogTestUtils { identifier = TableIdentifier(name, database), tableType = CatalogTableType.EXTERNAL_TABLE, storage = storageFormat, - schema = Seq(CatalogColumn("col1", "int"), CatalogColumn("col2", "string")), - partitionColumns = Seq(CatalogColumn("a", "int"), CatalogColumn("b", "string"))) + schema = Seq( + CatalogColumn("col1", "int"), + CatalogColumn("col2", "string"), + CatalogColumn("a", "int"), + CatalogColumn("b", "string")), + partitionColumnNames = Seq("a", "b")) } def newFunc(name: String, database: Option[String] = None): CatalogFunction = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 73d9640c35..af92cecee5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -179,7 +179,9 @@ class SparkSqlAstBuilder extends AstBuilder { } } - /** Type to keep track of a table header. */ + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) /** @@ -616,10 +618,7 @@ class SparkSqlAstBuilder extends AstBuilder { case s: GenericFileFormatContext => (Seq.empty[String], Option(s.identifier.getText)) case s: TableFileFormatContext => - val elements = Seq(s.inFmt, s.outFmt) ++ - Option(s.serdeCls).toSeq ++ - Option(s.inDriver).toSeq ++ - Option(s.outDriver).toSeq + val elements = Seq(s.inFmt, s.outFmt) ++ Option(s.serdeCls).toSeq (elements.map(string), None) } AlterTableSetFileFormat( @@ -773,22 +772,6 @@ class SparkSqlAstBuilder extends AstBuilder { .map(_.identifier.getText)) } - /** - * Create a skew specification. This contains three components: - * - The Skewed Columns - * - Values for which are skewed. The size of each entry must match the number of skewed columns. - * - A store in directory flag. - */ - override def visitSkewSpec( - ctx: SkewSpecContext): (Seq[String], Seq[Seq[String]], Boolean) = withOrigin(ctx) { - val skewedValues = if (ctx.constantList != null) { - Seq(visitConstantList(ctx.constantList)) - } else { - visitNestedConstantList(ctx.nestedConstantList) - } - (visitIdentifierList(ctx.identifierList), skewedValues, ctx.DIRECTORIES != null) - } - /** * Convert a nested constants list into a sequence of string sequences. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index 5137bd11d8..234099ad15 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -224,29 +224,6 @@ case class DropTable( } } -/** - * A command that renames a table/view. - * - * The syntax of this command is: - * {{{ - * ALTER TABLE table1 RENAME TO table2; - * ALTER VIEW view1 RENAME TO view2; - * }}} - */ -case class AlterTableRename( - oldName: TableIdentifier, - newName: TableIdentifier) - extends RunnableCommand { - - override def run(sqlContext: SQLContext): Seq[Row] = { - val catalog = sqlContext.sessionState.catalog - catalog.invalidateTable(oldName) - catalog.renameTable(oldName, newName) - Seq.empty[Row] - } - -} - /** * A command that sets table/view properties. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala new file mode 100644 index 0000000000..9c6030502d --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -0,0 +1,80 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.CatalogTable + + +// TODO: move the rest of the table commands from ddl.scala to this file + +/** + * A command to create a table. + * + * Note: This is currently used only for creating Hive tables. + * This is not intended for temporary tables. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) + * [STORED AS DIRECTORIES] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} + */ +case class CreateTable(table: CatalogTable, ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + sqlContext.sessionState.catalog.createTable(table, ifNotExists) + Seq.empty[Row] + } + +} + + +/** + * A command that renames a table/view. + * + * The syntax of this command is: + * {{{ + * ALTER TABLE table1 RENAME TO table2; + * ALTER VIEW view1 RENAME TO view2; + * }}} + */ +case class AlterTableRename( + oldName: TableIdentifier, + newName: TableIdentifier) + extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + catalog.invalidateTable(oldName) + catalog.renameTable(oldName, newName) + Seq.empty[Row] + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala index 1c8dd68286..6e6475ee29 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLCommandSuite.scala @@ -440,37 +440,25 @@ class DDLCommandSuite extends PlanTest { } test("alter table: set file format") { - val sql1 = - """ - |ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' - |OUTPUTFORMAT 'test' SERDE 'test' INPUTDRIVER 'test' OUTPUTDRIVER 'test' - """.stripMargin - val sql2 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + + val sql1 = "ALTER TABLE table_name SET FILEFORMAT INPUTFORMAT 'test' " + "OUTPUTFORMAT 'test' SERDE 'test'" - val sql3 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + + val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + "SET FILEFORMAT PARQUET" val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) - val parsed3 = parser.parsePlan(sql3) val tableIdent = TableIdentifier("table_name", None) val expected1 = AlterTableSetFileFormat( tableIdent, None, - List("test", "test", "test", "test", "test"), + List("test", "test", "test"), None)(sql1) val expected2 = AlterTableSetFileFormat( - tableIdent, - None, - List("test", "test", "test"), - None)(sql2) - val expected3 = AlterTableSetFileFormat( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), Seq(), - Some("PARQUET"))(sql3) + Some("PARQUET"))(sql2) comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) } test("alter table: set location") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 40a8b0e614..9ffffa0bdd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -380,8 +380,6 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } - // TODO: ADD a testcase for Drop Database in Restric when we can create tables in SQLContext - test("show tables") { withTempTable("show1a", "show2b") { sql( diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala index bfc3d195ff..eb49eabcb1 100644 --- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala +++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/CliSuite.scala @@ -162,7 +162,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { runCliWithin(3.minute)( "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE hive_test;" @@ -187,7 +187,7 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { "USE hive_test_db;" -> "", "CREATE TABLE hive_test(key INT, val STRING);" - -> "OK", + -> "", "SHOW TABLES;" -> "hive_test" ) @@ -210,9 +210,9 @@ class CliSuite extends SparkFunSuite with BeforeAndAfterAll with Logging { """CREATE TABLE t1(key string, val string) |ROW FORMAT SERDE 'org.apache.hive.hcatalog.data.JsonSerDe'; """.stripMargin - -> "OK", + -> "", "CREATE TABLE sourceTable (key INT, val STRING);" - -> "OK", + -> "", s"LOAD DATA LOCAL INPATH '$dataFilePath' OVERWRITE INTO TABLE sourceTable;" -> "OK", "INSERT INTO TABLE t1 SELECT key, val FROM sourceTable;" diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index f0eeda09db..a45d180464 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -366,10 +366,76 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "sort_merge_join_desc_6", "sort_merge_join_desc_7", + // These tests try to create a table with bucketed columns, which we don't support + "auto_join32", + "auto_join_filters", + "auto_smb_mapjoin_14", + "ct_case_insensitive", + "explain_rearrange", + "groupby_sort_10", + "groupby_sort_2", + "groupby_sort_3", + "groupby_sort_4", + "groupby_sort_5", + "groupby_sort_7", + "groupby_sort_8", + "groupby_sort_9", + "groupby_sort_test_1", + "inputddl4", + "join_filters", + "join_nulls", + "join_nullsafe", + "load_dyn_part2", + "orc_empty_files", + "reduce_deduplicate", + "smb_mapjoin9", + "smb_mapjoin_1", + "smb_mapjoin_10", + "smb_mapjoin_13", + "smb_mapjoin_14", + "smb_mapjoin_15", + "smb_mapjoin_16", + "smb_mapjoin_17", + "smb_mapjoin_2", + "smb_mapjoin_21", + "smb_mapjoin_25", + "smb_mapjoin_3", + "smb_mapjoin_4", + "smb_mapjoin_5", + "smb_mapjoin_6", + "smb_mapjoin_7", + "smb_mapjoin_8", + "sort_merge_join_desc_1", + "sort_merge_join_desc_2", + "sort_merge_join_desc_3", + "sort_merge_join_desc_4", + + // These tests try to create a table with skewed columns, which we don't support + "create_skewed_table1", + "skewjoinopt13", + "skewjoinopt18", + "skewjoinopt9", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", "alter_index", + "auto_sortmerge_join_1", + "auto_sortmerge_join_10", + "auto_sortmerge_join_11", + "auto_sortmerge_join_12", + "auto_sortmerge_join_13", + "auto_sortmerge_join_14", + "auto_sortmerge_join_15", + "auto_sortmerge_join_16", + "auto_sortmerge_join_2", + "auto_sortmerge_join_3", + "auto_sortmerge_join_4", + "auto_sortmerge_join_5", + "auto_sortmerge_join_6", + "auto_sortmerge_join_7", + "auto_sortmerge_join_8", + "auto_sortmerge_join_9", // Macro commands are not supported "macro", @@ -435,33 +501,14 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "auto_join3", "auto_join30", "auto_join31", - "auto_join32", "auto_join4", "auto_join5", "auto_join6", "auto_join7", "auto_join8", "auto_join9", - "auto_join_filters", "auto_join_nulls", "auto_join_reordering_values", - "auto_smb_mapjoin_14", - "auto_sortmerge_join_1", - "auto_sortmerge_join_10", - "auto_sortmerge_join_11", - "auto_sortmerge_join_12", - "auto_sortmerge_join_13", - "auto_sortmerge_join_14", - "auto_sortmerge_join_15", - "auto_sortmerge_join_16", - "auto_sortmerge_join_2", - "auto_sortmerge_join_3", - "auto_sortmerge_join_4", - "auto_sortmerge_join_5", - "auto_sortmerge_join_6", - "auto_sortmerge_join_7", - "auto_sortmerge_join_8", - "auto_sortmerge_join_9", "binary_constant", "binarysortable_1", "cast1", @@ -492,13 +539,11 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "create_insert_outputformat", "create_like_tbl_props", "create_nested_type", - "create_skewed_table1", "create_struct_table", "create_view_translate", "cross_join", "cross_product_check_1", "cross_product_check_2", - "ct_case_insensitive", "database_drop", "database_location", "database_properties", @@ -534,7 +579,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "escape_distributeby1", "escape_orderby1", "escape_sortby1", - "explain_rearrange", "fileformat_mix", "fileformat_sequencefile", "fileformat_text", @@ -589,16 +633,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "groupby_neg_float", "groupby_ppd", "groupby_ppr", - "groupby_sort_10", - "groupby_sort_2", - "groupby_sort_3", - "groupby_sort_4", - "groupby_sort_5", "groupby_sort_6", - "groupby_sort_7", - "groupby_sort_8", - "groupby_sort_9", - "groupby_sort_test_1", "having", "implicit_cast1", "index_serde", @@ -653,7 +688,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "inputddl1", "inputddl2", "inputddl3", - "inputddl4", "inputddl6", "inputddl7", "inputddl8", @@ -709,11 +743,8 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "join_array", "join_casesensitive", "join_empty", - "join_filters", "join_hive_626", "join_map_ppr", - "join_nulls", - "join_nullsafe", "join_rc", "join_reorder2", "join_reorder3", @@ -737,7 +768,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "load_dyn_part13", "load_dyn_part14", "load_dyn_part14_win", - "load_dyn_part2", "load_dyn_part3", "load_dyn_part4", "load_dyn_part5", @@ -790,7 +820,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "nullscript", "optional_outer", "orc_dictionary_threshold", - "orc_empty_files", "order", "order2", "outer_join_ppr", @@ -846,7 +875,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "rcfile_null_value", "rcfile_toleratecorruptions", "rcfile_union", - "reduce_deduplicate", "reduce_deduplicate_exclude_gby", "reduce_deduplicate_exclude_join", "reduce_deduplicate_extended", @@ -867,31 +895,7 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "show_functions", "show_partitions", "show_tblproperties", - "skewjoinopt13", - "skewjoinopt18", - "skewjoinopt9", - "smb_mapjoin9", - "smb_mapjoin_1", - "smb_mapjoin_10", - "smb_mapjoin_13", - "smb_mapjoin_14", - "smb_mapjoin_15", - "smb_mapjoin_16", - "smb_mapjoin_17", - "smb_mapjoin_2", - "smb_mapjoin_21", - "smb_mapjoin_25", - "smb_mapjoin_3", - "smb_mapjoin_4", - "smb_mapjoin_5", - "smb_mapjoin_6", - "smb_mapjoin_7", - "smb_mapjoin_8", "sort", - "sort_merge_join_desc_1", - "sort_merge_join_desc_2", - "sort_merge_join_desc_3", - "sort_merge_join_desc_4", "stats0", "stats_aggregator_error_1", "stats_empty_partition", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala index 14f331961e..ccc8345d73 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveMetastoreCatalog.scala @@ -91,7 +91,7 @@ private[hive] object HiveSerDe { "textfile" -> HiveSerDe( inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")), + outputFormat = Option("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), "avro" -> HiveSerDe( @@ -905,8 +905,13 @@ private[hive] case class MetastoreRelation( val sd = new org.apache.hadoop.hive.metastore.api.StorageDescriptor() tTable.setSd(sd) - sd.setCols(table.schema.map(toHiveColumn).asJava) - tTable.setPartitionKeys(table.partitionColumns.map(toHiveColumn).asJava) + + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + sd.setCols(schema.asJava) + tTable.setPartitionKeys(partCols.asJava) table.storage.locationUri.foreach(sd.setLocation) table.storage.inputFormat.foreach(sd.setInputFormat) @@ -1013,7 +1018,10 @@ private[hive] case class MetastoreRelation( val partitionKeys = table.partitionColumns.map(_.toAttribute) /** Non-partitionKey attributes */ - val attributes = table.schema.map(_.toAttribute) + // TODO: just make this hold the schema itself, not just non-partition columns + val attributes = table.schema + .filter { c => !table.partitionColumnNames.contains(c.name) } + .map(_.toAttribute) val output = attributes ++ partitionKeys diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala index 39e26acd7f..2a1fff92b5 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/client/HiveClientImpl.scala @@ -299,6 +299,10 @@ private[hive] class HiveClientImpl( tableName: String): Option[CatalogTable] = withHiveState { logDebug(s"Looking up $dbName.$tableName") Option(client.getTable(dbName, tableName, false)).map { h => + // Note: Hive separates partition columns and the schema, but for us the + // partition columns are part of the schema + val partCols = h.getPartCols.asScala.map(fromHiveColumn) + val schema = h.getCols.asScala.map(fromHiveColumn) ++ partCols CatalogTable( identifier = TableIdentifier(h.getTableName, Option(h.getDbName)), tableType = h.getTableType match { @@ -307,9 +311,10 @@ private[hive] class HiveClientImpl( case HiveTableType.INDEX_TABLE => CatalogTableType.INDEX_TABLE case HiveTableType.VIRTUAL_VIEW => CatalogTableType.VIRTUAL_VIEW }, - schema = h.getCols.asScala.map(fromHiveColumn), - partitionColumns = h.getPartCols.asScala.map(fromHiveColumn), - sortColumns = Seq(), + schema = schema, + partitionColumnNames = partCols.map(_.name), + sortColumnNames = Seq(), // TODO: populate this + bucketColumnNames = h.getBucketCols.asScala, numBuckets = h.getNumBuckets, createTime = h.getTTable.getCreateTime.toLong * 1000, lastAccessTime = h.getLastAccessTime.toLong * 1000, @@ -675,24 +680,37 @@ private[hive] class HiveClientImpl( private def toHiveTable(table: CatalogTable): HiveTable = { val hiveTable = new HiveTable(table.database, table.identifier.table) - // For EXTERNAL_TABLE/MANAGED_TABLE, we also need to set EXTERNAL field in - // the table properties accodringly. Otherwise, if EXTERNAL_TABLE is the table type - // but EXTERNAL field is not set, Hive metastore will change the type to - // MANAGED_TABLE (see - // metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) + // For EXTERNAL_TABLE, we also need to set EXTERNAL field in the table properties. + // Otherwise, Hive metastore will change the table to a MANAGED_TABLE. + // (metastore/src/java/org/apache/hadoop/hive/metastore/ObjectStore.java#L1095-L1105) hiveTable.setTableType(table.tableType match { case CatalogTableType.EXTERNAL_TABLE => hiveTable.setProperty("EXTERNAL", "TRUE") HiveTableType.EXTERNAL_TABLE case CatalogTableType.MANAGED_TABLE => - hiveTable.setProperty("EXTERNAL", "FALSE") HiveTableType.MANAGED_TABLE case CatalogTableType.INDEX_TABLE => HiveTableType.INDEX_TABLE case CatalogTableType.VIRTUAL_VIEW => HiveTableType.VIRTUAL_VIEW }) - hiveTable.setFields(table.schema.map(toHiveColumn).asJava) - hiveTable.setPartCols(table.partitionColumns.map(toHiveColumn).asJava) + // Note: In Hive the schema and partition columns must be disjoint sets + val (partCols, schema) = table.schema.map(toHiveColumn).partition { c => + table.partitionColumnNames.contains(c.getName) + } + if (table.schema.isEmpty) { + // This is a hack to preserve existing behavior. Before Spark 2.0, we do not + // set a default serde here (this was done in Hive), and so if the user provides + // an empty schema Hive would automatically populate the schema with a single + // field "col". However, after SPARK-14388, we set the default serde to + // LazySimpleSerde so this implicit behavior no longer happens. Therefore, + // we need to do it in Spark ourselves. + hiveTable.setFields( + Seq(new FieldSchema("col", "array", "from deserializer")).asJava) + } else { + hiveTable.setFields(schema.asJava) + } + hiveTable.setPartCols(partCols.asJava) // TODO: set sort columns here too + hiveTable.setBucketCols(table.bucketColumnNames.asJava) hiveTable.setOwner(conf.getUser) hiveTable.setNumBuckets(table.numBuckets) hiveTable.setCreateTime((table.createTime / 1000).toInt) @@ -700,9 +718,11 @@ private[hive] class HiveClientImpl( table.storage.locationUri.foreach { loc => shim.setDataLocation(hiveTable, loc) } table.storage.inputFormat.map(toInputFormat).foreach(hiveTable.setInputFormatClass) table.storage.outputFormat.map(toOutputFormat).foreach(hiveTable.setOutputFormatClass) - table.storage.serde.foreach(hiveTable.setSerializationLib) + hiveTable.setSerializationLib( + table.storage.serde.getOrElse("org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe")) table.storage.serdeProperties.foreach { case (k, v) => hiveTable.setSerdeParam(k, v) } table.properties.foreach { case (k, v) => hiveTable.setProperty(k, v) } + table.comment.foreach { c => hiveTable.setProperty("comment", c) } table.viewOriginalText.foreach { t => hiveTable.setViewOriginalText(t) } table.viewText.foreach { t => hiveTable.setViewExpandedText(t) } hiveTable diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 7a435117e7..b14db7fe71 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.hive.ql.session.SessionState import org.apache.hadoop.hive.serde.serdeConstants import org.apache.hadoop.hive.serde2.`lazy`.LazySimpleSerDe +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.UnresolvedGenerator import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions._ @@ -33,8 +34,9 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder +import org.apache.spark.sql.execution.command.CreateTable import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView} -import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveSerDe} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** @@ -121,84 +123,116 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } /** - * Create a [[CatalogStorageFormat]]. This is part of the [[CreateTableAsSelect]] command. + * Create a [[CatalogStorageFormat]] for creating tables. */ override def visitCreateFileFormat( ctx: CreateFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - if (ctx.storageHandler == null) { - typedVisit[CatalogStorageFormat](ctx.fileFormat) - } else { - visitStorageHandler(ctx.storageHandler) + (ctx.fileFormat, ctx.storageHandler) match { + // Expected format: INPUTFORMAT input_format OUTPUTFORMAT output_format + case (c: TableFileFormatContext, null) => + visitTableFileFormat(c) + // Expected format: SEQUENCEFILE | TEXTFILE | RCFILE | ORC | PARQUET | AVRO + case (c: GenericFileFormatContext, null) => + visitGenericFileFormat(c) + case (null, storageHandler) => + throw new ParseException("Operation not allowed: ... STORED BY storage_handler ...", ctx) + case _ => + throw new ParseException("expected either STORED AS or STORED BY, not both", ctx) } } /** - * Create a [[CreateTableAsSelect]] command. + * Create a table, returning either a [[CreateTable]] or a [[CreateTableAsSelect]]. + * + * This is not used to create datasource tables, which is handled through + * "CREATE TABLE ... USING ...". + * + * Note: several features are currently not supported - temporary tables, bucketing, + * skewed columns and storage handlers (STORED BY). + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name + * [(col1 data_type [COMMENT col_comment], ...)] + * [COMMENT table_comment] + * [PARTITIONED BY (col3 data_type [COMMENT col_comment], ...)] + * [CLUSTERED BY (col1, ...) [SORTED BY (col1 [ASC|DESC], ...)] INTO num_buckets BUCKETS] + * [SKEWED BY (col1, col2, ...) ON ((col_value, col_value, ...), ...) [STORED AS DIRECTORIES]] + * [ROW FORMAT row_format] + * [STORED AS file_format | STORED BY storage_handler_class [WITH SERDEPROPERTIES (...)]] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] + * [AS select_statement]; + * }}} */ - override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = { - if (ctx.query == null) { - HiveNativeCommand(command(ctx)) + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + // TODO: implement temporary tables + if (temp) { + throw new ParseException( + "CREATE TEMPORARY TABLE is not supported yet. " + + "Please use registerTempTable as an alternative.", ctx) + } + if (ctx.skewSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... SKEWED BY ...", ctx) + } + if (ctx.bucketSpec != null) { + throw new ParseException("Operation not allowed: CREATE TABLE ... CLUSTERED BY ...", ctx) + } + val tableType = if (external) { + CatalogTableType.EXTERNAL_TABLE } else { - // Get the table header. - val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - val tableType = if (external) { - CatalogTableType.EXTERNAL_TABLE - } else { - CatalogTableType.MANAGED_TABLE - } - - // Unsupported clauses. - if (temp) { - throw new ParseException(s"Unsupported operation: TEMPORARY clause.", ctx) - } - if (ctx.bucketSpec != null) { - // TODO add this - we need cluster columns in the CatalogTable for this to work. - throw new ParseException("Unsupported operation: " + - "CLUSTERED BY ... [ORDERED BY ...] INTO ... BUCKETS clause.", ctx) - } - if (ctx.skewSpec != null) { - throw new ParseException("Operation not allowed: " + - "SKEWED BY ... ON ... [STORED AS DIRECTORIES] clause.", ctx) - } - - // Create the schema. - val schema = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns(_, _.toLowerCase)) - - // Get the column by which the table is partitioned. - val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns(_)) - - // Create the storage. - def format(fmt: ParserRuleContext): CatalogStorageFormat = { - Option(fmt).map(typedVisit[CatalogStorageFormat]).getOrElse(EmptyStorageFormat) - } - // Default storage. + CatalogTableType.MANAGED_TABLE + } + val comment = Option(ctx.STRING).map(string) + val partitionCols = Option(ctx.partitionColumns).toSeq.flatMap(visitCatalogColumns) + val cols = Option(ctx.columns).toSeq.flatMap(visitCatalogColumns) + val properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty) + val selectQuery = Option(ctx.query).map(plan) + + // Note: Hive requires partition columns to be distinct from the schema, so we need + // to include the partition columns here explicitly + val schema = cols ++ partitionCols + + // Storage format + val defaultStorage: CatalogStorageFormat = { val defaultStorageType = hiveConf.getVar(HiveConf.ConfVars.HIVEDEFAULTFILEFORMAT) - val hiveSerDe = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf).getOrElse { - HiveSerDe( - inputFormat = Option("org.apache.hadoop.mapred.TextInputFormat"), - outputFormat = Option("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) - } - // Defined storage. - val fileStorage = format(ctx.createFileFormat) - val rowStorage = format(ctx.rowFormat) - val storage = CatalogStorageFormat( - Option(ctx.locationSpec).map(visitLocationSpec), - fileStorage.inputFormat.orElse(hiveSerDe.inputFormat), - fileStorage.outputFormat.orElse(hiveSerDe.outputFormat), - rowStorage.serde.orElse(hiveSerDe.serde).orElse(fileStorage.serde), - rowStorage.serdeProperties ++ fileStorage.serdeProperties - ) + val defaultHiveSerde = HiveSerDe.sourceToSerDe(defaultStorageType, hiveConf) + CatalogStorageFormat( + locationUri = None, + inputFormat = defaultHiveSerde.flatMap(_.inputFormat) + .orElse(Some("org.apache.hadoop.mapred.TextInputFormat")), + outputFormat = defaultHiveSerde.flatMap(_.outputFormat) + .orElse(Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")), + // Note: Keep this unspecified because we use the presence of the serde to decide + // whether to convert a table created by CTAS to a datasource table. + serde = None, + serdeProperties = Map()) + } + val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + .getOrElse(EmptyStorageFormat) + val rowStorage = Option(ctx.rowFormat).map(visitRowFormat).getOrElse(EmptyStorageFormat) + val location = Option(ctx.locationSpec).map(visitLocationSpec) + val storage = CatalogStorageFormat( + locationUri = location, + inputFormat = fileStorage.inputFormat.orElse(defaultStorage.inputFormat), + outputFormat = fileStorage.outputFormat.orElse(defaultStorage.outputFormat), + serde = rowStorage.serde.orElse(fileStorage.serde).orElse(defaultStorage.serde), + serdeProperties = rowStorage.serdeProperties ++ fileStorage.serdeProperties) + + // TODO support the sql text - have a proper location for this! + val tableDesc = CatalogTable( + identifier = name, + tableType = tableType, + storage = storage, + schema = schema, + partitionColumnNames = partitionCols.map(_.name), + properties = properties, + comment = comment) - val tableDesc = CatalogTable( - identifier = table, - tableType = tableType, - schema = schema, - partitionColumns = partitionCols, - storage = storage, - properties = Option(ctx.tablePropertyList).map(visitTablePropertyList).getOrElse(Map.empty), - // TODO support the sql text - have a proper location for this! - viewText = Option(ctx.STRING).map(string)) - CTAS(tableDesc, plan(ctx.query), ifNotExists) + selectQuery match { + case Some(q) => CTAS(tableDesc, q, ifNotExists) + case None => CreateTable(tableDesc, ifNotExists) } } @@ -353,25 +387,19 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { private val EmptyStorageFormat = CatalogStorageFormat(None, None, None, None, Map.empty) /** - * Create a [[CatalogStorageFormat]]. The INPUTDRIVER and OUTPUTDRIVER clauses are currently - * ignored. + * Create a [[CatalogStorageFormat]]. */ override def visitTableFileFormat( ctx: TableFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { - import ctx._ - if (inDriver != null || outDriver != null) { - throw new ParseException( - s"Operation not allowed: INPUTDRIVER ... OUTPUTDRIVER ... clauses", ctx) - } EmptyStorageFormat.copy( - inputFormat = Option(string(inFmt)), - outputFormat = Option(string(outFmt)), - serde = Option(serdeCls).map(string) + inputFormat = Option(string(ctx.inFmt)), + outputFormat = Option(string(ctx.outFmt)), + serde = Option(ctx.serdeCls).map(string) ) } /** - * Resolve a [[HiveSerDe]] based on the format name given. + * Resolve a [[HiveSerDe]] based on the name given and return it as a [[CatalogStorageFormat]]. */ override def visitGenericFileFormat( ctx: GenericFileFormatContext): CatalogStorageFormat = withOrigin(ctx) { @@ -388,11 +416,28 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } /** - * Storage Handlers are currently not supported in the statements we support (CTAS). + * Create a [[RowFormat]] used for creating tables. + * + * Example format: + * {{{ + * SERDE serde_name [WITH SERDEPROPERTIES (k1=v1, k2=v2, ...)] + * }}} + * + * OR + * + * {{{ + * DELIMITED [FIELDS TERMINATED BY char [ESCAPED BY char]] + * [COLLECTION ITEMS TERMINATED BY char] + * [MAP KEYS TERMINATED BY char] + * [LINES TERMINATED BY char] + * [NULL DEFINED AS char] + * }}} */ - override def visitStorageHandler( - ctx: StorageHandlerContext): CatalogStorageFormat = withOrigin(ctx) { - throw new ParseException("Storage Handlers are currently unsupported.", ctx) + private def visitRowFormat(ctx: RowFormatContext): CatalogStorageFormat = withOrigin(ctx) { + ctx match { + case serde: RowFormatSerdeContext => visitRowFormatSerde(serde) + case delimited: RowFormatDelimitedContext => visitRowFormatDelimited(delimited) + } } /** @@ -435,13 +480,15 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { /** * Create a sequence of [[CatalogColumn]]s from a column list */ - private def visitCatalogColumns( - ctx: ColTypeListContext, - formatter: String => String = identity): Seq[CatalogColumn] = withOrigin(ctx) { + private def visitCatalogColumns(ctx: ColTypeListContext): Seq[CatalogColumn] = withOrigin(ctx) { ctx.colType.asScala.map { col => CatalogColumn( - formatter(col.identifier.getText), - col.dataType.getText.toLowerCase, // TODO validate this? + col.identifier.getText.toLowerCase, + // Note: for types like "STRUCT" we can't + // just convert the whole type string to lower case, otherwise the struct field names + // will no longer be case sensitive. Instead, we rely on our parser to get the proper + // case before passing it to Hive. + CatalystSqlParser.parseDataType(col.dataType.getText).simpleString, nullable = true, Option(col.STRING).map(string)) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index e8086aec32..68d3ea6ed9 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} +import org.apache.spark.sql.execution.command.CreateTable import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} class HiveDDLCommandSuite extends PlanTest { @@ -36,6 +37,7 @@ class HiveDDLCommandSuite extends PlanTest { private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parser.parsePlan(sql).collect { + case CreateTable(desc, allowExisting) => (desc, allowExisting) case CreateTableAsSelect(desc, _, allowExisting) => (desc, allowExisting) case CreateViewAsSelect(desc, _, allowExisting, _, _) => (desc, allowExisting) }.head @@ -76,9 +78,12 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("page_url", "string") :: CatalogColumn("referrer_url", "string") :: CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) + assert(desc.comment == Some("This is the staging page view table")) // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.viewText.isEmpty) assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: @@ -123,9 +128,12 @@ class HiveDDLCommandSuite extends PlanTest { CatalogColumn("page_url", "string") :: CatalogColumn("referrer_url", "string") :: CatalogColumn("ip", "string", comment = Some("IP Address of the User")) :: - CatalogColumn("country", "string", comment = Some("country of origination")) :: Nil) + CatalogColumn("country", "string", comment = Some("country of origination")) :: + CatalogColumn("dt", "string", comment = Some("date type")) :: + CatalogColumn("hour", "string", comment = Some("hour of the day")) :: Nil) // TODO will be SQLText - assert(desc.viewText == Option("This is the staging page view table")) + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) assert(desc.viewOriginalText.isEmpty) assert(desc.partitionColumns == CatalogColumn("dt", "string", comment = Some("date type")) :: @@ -151,7 +159,7 @@ class HiveDDLCommandSuite extends PlanTest { assert(desc.storage.serdeProperties == Map()) assert(desc.storage.inputFormat == Some("org.apache.hadoop.mapred.TextInputFormat")) assert(desc.storage.outputFormat == - Some("org.apache.hadoop.hive.ql.io.IgnoreKeyTextOutputFormat")) + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) assert(desc.storage.serde.isEmpty) assert(desc.properties == Map()) } @@ -203,17 +211,6 @@ class HiveDDLCommandSuite extends PlanTest { |AS SELECT key, value FROM src ORDER BY key, value """.stripMargin) } - intercept[ParseException] { - parser.parsePlan( - """CREATE TABLE ctas2 - |STORED AS - |INPUTFORMAT "org.apache.hadoop.mapred.TextInputFormat" - |OUTPUTFORMAT "org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat" - |INPUTDRIVER "org.apache.hadoop.hive.howl.rcfile.RCFileInputDriver" - |OUTPUTDRIVER "org.apache.hadoop.hive.howl.rcfile.RCFileOutputDriver" - |AS SELECT key, value FROM src ORDER BY key, value - """.stripMargin) - } intercept[ParseException] { parser.parsePlan( """ @@ -324,6 +321,194 @@ class HiveDDLCommandSuite extends PlanTest { """.stripMargin) } + test("create table - basic") { + val query = "CREATE TABLE my_table (id int, name string)" + val (desc, allowExisting) = extractTableDesc(query) + assert(!allowExisting) + assert(desc.identifier.database.isEmpty) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.MANAGED_TABLE) + assert(desc.schema == Seq(CatalogColumn("id", "int"), CatalogColumn("name", "string"))) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri.isEmpty) + assert(desc.storage.inputFormat == + Some("org.apache.hadoop.mapred.TextInputFormat")) + assert(desc.storage.outputFormat == + Some("org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat")) + assert(desc.storage.serde.isEmpty) + assert(desc.storage.serdeProperties.isEmpty) + assert(desc.properties.isEmpty) + assert(desc.comment.isEmpty) + } + + test("create table - with database name") { + val query = "CREATE TABLE dbx.my_table (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + } + + test("create table - temporary") { + val query = "CREATE TEMPORARY TABLE tab1 (id int, name string)" + val e = intercept[ParseException] { parser.parsePlan(query) } + assert(e.message.contains("registerTempTable")) + } + + test("create table - external") { + val query = "CREATE EXTERNAL TABLE tab1 (id int, name string)" + val (desc, _) = extractTableDesc(query) + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + } + + test("create table - if not exists") { + val query = "CREATE TABLE IF NOT EXISTS tab1 (id int, name string)" + val (_, allowExisting) = extractTableDesc(query) + assert(allowExisting) + } + + test("create table - comment") { + val query = "CREATE TABLE my_table (id int, name string) COMMENT 'its hot as hell below'" + val (desc, _) = extractTableDesc(query) + assert(desc.comment == Some("its hot as hell below")) + } + + test("create table - partitioned columns") { + val query = "CREATE TABLE my_table (id int, name string) PARTITIONED BY (month int)" + val (desc, _) = extractTableDesc(query) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + } + + test("create table - clustered by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) CLUSTERED BY(id)" + val query1 = s"$baseQuery INTO 10 BUCKETS" + val query2 = s"$baseQuery SORTED BY(id) INTO 10 BUCKETS" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - skewed by") { + val baseQuery = "CREATE TABLE my_table (id int, name string) SKEWED BY" + val query1 = s"$baseQuery(id) ON (1, 10, 100)" + val query2 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z'))" + val query3 = s"$baseQuery(id, name) ON ((1, 'x'), (2, 'y'), (3, 'z')) STORED AS DIRECTORIES" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + val e3 = intercept[ParseException] { parser.parsePlan(query3) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + assert(e3.getMessage.contains("Operation not allowed")) + } + + test("create table - row format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) ROW FORMAT" + val query1 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff'" + val query2 = s"$baseQuery SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1')" + val query3 = + s""" + |$baseQuery DELIMITED FIELDS TERMINATED BY 'x' ESCAPED BY 'y' + |COLLECTION ITEMS TERMINATED BY 'a' + |MAP KEYS TERMINATED BY 'b' + |LINES TERMINATED BY '\n' + |NULL DEFINED AS 'c' + """.stripMargin + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + val (desc3, _) = extractTableDesc(query3) + assert(desc1.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc1.storage.serdeProperties.isEmpty) + assert(desc2.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc2.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc3.storage.serdeProperties == Map( + "field.delim" -> "x", + "escape.delim" -> "y", + "serialization.format" -> "x", + "line.delim" -> "\n", + "colelction.delim" -> "a", // yes, it's a typo from Hive :) + "mapkey.delim" -> "b")) + } + + test("create table - file format") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED AS" + val query1 = s"$baseQuery INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput'" + val query2 = s"$baseQuery ORC" + val (desc1, _) = extractTableDesc(query1) + val (desc2, _) = extractTableDesc(query2) + assert(desc1.storage.inputFormat == Some("winput")) + assert(desc1.storage.outputFormat == Some("wowput")) + assert(desc1.storage.serde.isEmpty) + assert(desc2.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcInputFormat")) + assert(desc2.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.orc.OrcOutputFormat")) + assert(desc2.storage.serde == Some("org.apache.hadoop.hive.ql.io.orc.OrcSerde")) + } + + test("create table - storage handler") { + val baseQuery = "CREATE TABLE my_table (id int, name string) STORED BY" + val query1 = s"$baseQuery 'org.papachi.StorageHandler'" + val query2 = s"$baseQuery 'org.mamachi.StorageHandler' WITH SERDEPROPERTIES ('k1'='v1')" + val e1 = intercept[ParseException] { parser.parsePlan(query1) } + val e2 = intercept[ParseException] { parser.parsePlan(query2) } + assert(e1.getMessage.contains("Operation not allowed")) + assert(e2.getMessage.contains("Operation not allowed")) + } + + test("create table - location") { + val query = "CREATE TABLE my_table (id int, name string) LOCATION '/path/to/mars'" + val (desc, _) = extractTableDesc(query) + assert(desc.storage.locationUri == Some("/path/to/mars")) + } + + test("create table - properties") { + val query = "CREATE TABLE my_table (id int, name string) TBLPROPERTIES ('k1'='v1', 'k2'='v2')" + val (desc, _) = extractTableDesc(query) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + } + + test("create table - everything!") { + val query = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS dbx.my_table (id int, name string) + |COMMENT 'no comment' + |PARTITIONED BY (month int) + |ROW FORMAT SERDE 'org.apache.poof.serde.Baff' WITH SERDEPROPERTIES ('k1'='v1') + |STORED AS INPUTFORMAT 'winput' OUTPUTFORMAT 'wowput' + |LOCATION '/path/to/mercury' + |TBLPROPERTIES ('k1'='v1', 'k2'='v2') + """.stripMargin + val (desc, allowExisting) = extractTableDesc(query) + assert(allowExisting) + assert(desc.identifier.database == Some("dbx")) + assert(desc.identifier.table == "my_table") + assert(desc.tableType == CatalogTableType.EXTERNAL_TABLE) + assert(desc.schema == Seq( + CatalogColumn("id", "int"), + CatalogColumn("name", "string"), + CatalogColumn("month", "int"))) + assert(desc.partitionColumnNames == Seq("month")) + assert(desc.sortColumnNames.isEmpty) + assert(desc.bucketColumnNames.isEmpty) + assert(desc.numBuckets == -1) + assert(desc.viewText.isEmpty) + assert(desc.viewOriginalText.isEmpty) + assert(desc.storage.locationUri == Some("/path/to/mercury")) + assert(desc.storage.inputFormat == Some("winput")) + assert(desc.storage.outputFormat == Some("wowput")) + assert(desc.storage.serde == Some("org.apache.poof.serde.Baff")) + assert(desc.storage.serdeProperties == Map("k1" -> "v1")) + assert(desc.properties == Map("k1" -> "v1", "k2" -> "v2")) + assert(desc.comment == Some("no comment")) + } + test("create view -- basic") { val v1 = "CREATE VIEW view1 AS SELECT * FROM tab1" val (desc, exists) = extractTableDesc(v1) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index ada8621d07..8648834f0d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -88,7 +88,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.partitionColumnNames.isEmpty) assert(hiveTable.tableType === CatalogTableType.MANAGED_TABLE) val columns = hiveTable.schema @@ -151,7 +151,7 @@ class DataSourceWithHiveMetastoreCatalogSuite assert(hiveTable.storage.outputFormat === Some(outputFormat)) assert(hiveTable.storage.serde === Some(serde)) - assert(hiveTable.partitionColumns.isEmpty) + assert(hiveTable.partitionColumnNames.isEmpty) assert(hiveTable.tableType === CatalogTableType.EXTERNAL_TABLE) val columns = hiveTable.schema diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala index 40e9c9362c..4db95636e7 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/InsertIntoHiveTableSuite.scala @@ -81,7 +81,7 @@ class InsertIntoHiveTableSuite extends QueryTest with TestHiveSingleton with Bef test("Double create fails when allowExisting = false") { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") - intercept[QueryExecutionException] { + intercept[AnalysisException] { sql("CREATE TABLE doubleCreateAndInsertTest (key int, value string)") } } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala index 37c01792d9..97cb9d9720 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/PruningSuite.scala @@ -149,7 +149,7 @@ class PruningSuite extends HiveComparisonTest with BeforeAndAfter { val (actualScannedColumns, actualPartValues) = plan.collect { case p @ HiveTableScan(columns, relation, _) => val columnNames = columns.map(_.name) - val partValues = if (relation.table.partitionColumns.nonEmpty) { + val partValues = if (relation.table.partitionColumnNames.nonEmpty) { p.prunePartitions(relation.getHiveQlPartitions()).map(_.getValues) } else { Seq.empty diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 7eaf19dfe9..5ce16be4dc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -360,7 +360,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { var message = intercept[AnalysisException] { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") }.getMessage - assert(message.contains("ctas1 already exists")) + assert(message.contains("already exists")) checkRelation("ctas1", true) sql("DROP TABLE ctas1") -- cgit v1.2.3 From dbbe149070052af5cda04f7b110d65de73766ded Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Wed, 13 Apr 2016 13:01:13 -0700 Subject: [SPARK-14581] [SQL] push predicatese through more logical plans ## What changes were proposed in this pull request? Right now, filter push down only works with Project, Aggregate, Generate and Join, they can't be pushed through many other plans. This PR added support for Union, Intersect, Except and all unary plans. ## How was this patch tested? Added tests. Author: Davies Liu Closes #12342 from davies/filter_hint. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 111 +++++++++++++-------- .../spark/sql/catalyst/planning/patterns.scala | 3 + .../catalyst/optimizer/ColumnPruningSuite.scala | 2 +- .../catalyst/optimizer/FilterPushdownSuite.scala | 76 ++++++++++++-- .../catalyst/optimizer/JoinOptimizationSuite.scala | 4 +- .../sql/catalyst/optimizer/PruneFiltersSuite.scala | 2 +- 6 files changed, 146 insertions(+), 52 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index bad115d22f..438cbabdbb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -66,9 +66,7 @@ abstract class Optimizer extends RuleExecutor[LogicalPlan] { ReorderJoin, OuterJoinElimination, PushPredicateThroughJoin, - PushPredicateThroughProject, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, + PushDownPredicate, LimitPushDown, ColumnPruning, InferFiltersFromConstraints, @@ -917,12 +915,13 @@ object PruneFilters extends Rule[LogicalPlan] with PredicateHelper { } /** - * Pushes [[Filter]] operators through [[Project]] operators, in-lining any [[Alias Aliases]] - * that were defined in the projection. + * Pushes [[Filter]] operators through many operators iff: + * 1) the operator is deterministic + * 2) the predicate is deterministic and the operator will not change any of rows. * * This heuristic is valid assuming the expression evaluation cost is minimal. */ -object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelper { +object PushDownPredicate extends Rule[LogicalPlan] with PredicateHelper { def apply(plan: LogicalPlan): LogicalPlan = plan transform { // SPARK-13473: We can't push the predicate down when the underlying projection output non- // deterministic field(s). Non-deterministic expressions are essentially stateful. This @@ -939,41 +938,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] with PredicateHelpe }) project.copy(child = Filter(replaceAlias(condition, aliasMap), grandChild)) - } - -} - -/** - * Push [[Filter]] operators through [[Generate]] operators. Parts of the predicate that reference - * attributes generated in [[Generate]] will remain above, and the rest should be pushed beneath. - */ -object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - case filter @ Filter(condition, g: Generate) => - // Predicates that reference attributes produced by the `Generate` operator cannot - // be pushed below the operator. - val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => - cond.references.subsetOf(g.child.outputSet) && cond.deterministic - } - if (pushDown.nonEmpty) { - val pushDownPredicate = pushDown.reduce(And) - val newGenerate = Generate(g.generator, join = g.join, outer = g.outer, - g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child)) - if (stayUp.isEmpty) newGenerate else Filter(stayUp.reduce(And), newGenerate) - } else { - filter - } - } -} -/** - * Push [[Filter]] operators through [[Aggregate]] operators, iff the filters reference only - * non-aggregate attributes (typically literals or grouping expressions). - */ -object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHelper { - - def apply(plan: LogicalPlan): LogicalPlan = plan transform { case filter @ Filter(condition, aggregate: Aggregate) => // Find all the aliased expressions in the aggregate list that don't include any actual // AggregateExpression, and create a map from the alias to the expression @@ -999,6 +964,72 @@ object PushPredicateThroughAggregate extends Rule[LogicalPlan] with PredicateHel } else { filter } + + case filter @ Filter(condition, child) + if child.isInstanceOf[Union] || child.isInstanceOf[Intersect] => + // Union/Intersect could change the rows, so non-deterministic predicate can't be pushed down + val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition { cond => + cond.deterministic + } + if (pushDown.nonEmpty) { + val pushDownCond = pushDown.reduceLeft(And) + val output = child.output + val newGrandChildren = child.children.map { grandchild => + val newCond = pushDownCond transform { + case e if output.exists(_.semanticEquals(e)) => + grandchild.output(output.indexWhere(_.semanticEquals(e))) + } + assert(newCond.references.subsetOf(grandchild.outputSet)) + Filter(newCond, grandchild) + } + val newChild = child.withNewChildren(newGrandChildren) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } + + case filter @ Filter(condition, e @ Except(left, _)) => + pushDownPredicate(filter, e.left) { predicate => + e.copy(left = Filter(predicate, left)) + } + + // two filters should be combine together by other rules + case filter @ Filter(_, f: Filter) => filter + // should not push predicates through sample, or will generate different results. + case filter @ Filter(_, s: Sample) => filter + // TODO: push predicates through expand + case filter @ Filter(_, e: Expand) => filter + + case filter @ Filter(condition, u: UnaryNode) if u.expressions.forall(_.deterministic) => + pushDownPredicate(filter, u.child) { predicate => + u.withNewChildren(Seq(Filter(predicate, u.child))) + } + } + + private def pushDownPredicate( + filter: Filter, + grandchild: LogicalPlan)(insertFilter: Expression => LogicalPlan): LogicalPlan = { + // Only push down the predicates that is deterministic and all the referenced attributes + // come from grandchild. + // TODO: non-deterministic predicates could be pushed through some operators that do not change + // the rows. + val (pushDown, stayUp) = splitConjunctivePredicates(filter.condition).partition { cond => + cond.deterministic && cond.references.subsetOf(grandchild.outputSet) + } + if (pushDown.nonEmpty) { + val newChild = insertFilter(pushDown.reduceLeft(And)) + if (stayUp.nonEmpty) { + Filter(stayUp.reduceLeft(And), newChild) + } else { + newChild + } + } else { + filter + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala index 6f35d87ebb..0065619135 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala @@ -69,6 +69,9 @@ object PhysicalOperation extends PredicateHelper { val substitutedCondition = substitute(aliases)(condition) (fields, filters ++ splitConjunctivePredicates(substitutedCondition), other, aliases) + case BroadcastHint(child) => + collectProjectsAndFilters(child) + case other => (None, Nil, other, Map.empty) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 2248e03b2f..52b574c0e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -34,7 +34,7 @@ class ColumnPruningSuite extends PlanTest { object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Column pruning", FixedPoint(100), - PushPredicateThroughProject, + PushDownPredicate, ColumnPruning, CollapseProject) :: Nil } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala index b84ae7c5bb..df7529d83f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala @@ -33,14 +33,12 @@ class FilterPushdownSuite extends PlanTest { val batches = Batch("Subqueries", Once, EliminateSubqueryAliases) :: - Batch("Filter Pushdown", Once, + Batch("Filter Pushdown", FixedPoint(10), SamplePushDown, CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, CollapseProject) :: Nil } @@ -620,8 +618,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a === 3) + .select('a, 'b) .groupBy('a)('a, count('b) as 'c) .where('c === 2L) .analyze @@ -638,8 +636,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where('a + 1 < 3) + .select('a, 'b) .groupBy('a)(('a + 1) as 'aa, count('b) as 'c) .where('c === 2L || 'aa > 4) .analyze @@ -656,8 +654,8 @@ class FilterPushdownSuite extends PlanTest { val optimized = Optimize.execute(originalQuery.analyze) val correctAnswer = testRelation - .select('a, 'b) .where("s" === "s") + .select('a, 'b) .groupBy('a)('a, count('b) as 'c, "s" as 'd) .where('c === 2L) .analyze @@ -681,4 +679,68 @@ class FilterPushdownSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + + test("broadcast hint") { + val originalQuery = BroadcastHint(testRelation) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = BroadcastHint(testRelation.where('a === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("union") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Union(Seq(testRelation, testRelation2)) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Union(Seq( + testRelation.where('a === 2L), + testRelation2.where('d === 2L))) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("intersect") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Intersect(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Intersect( + testRelation.where('a === 2L), + testRelation2.where('d === 2L)) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("except") { + val testRelation2 = LocalRelation('d.int, 'e.int, 'f.int) + + val originalQuery = Except(testRelation, testRelation2) + .where('a === 2L && 'b + Rand(10).as("rnd") === 3) + + val optimized = Optimize.execute(originalQuery.analyze) + + val correctAnswer = Except( + testRelation.where('a === 2L), + testRelation2) + .where('b + Rand(10).as("rnd") === 3) + .analyze + + comparePlans(optimized, correctAnswer) + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index e2f8146bee..c1ebf8b09e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -36,12 +36,10 @@ class JoinOptimizationSuite extends PlanTest { EliminateSubqueryAliases) :: Batch("Filter Pushdown", FixedPoint(100), CombineFilters, - PushPredicateThroughProject, + PushDownPredicate, BooleanSimplification, ReorderJoin, PushPredicateThroughJoin, - PushPredicateThroughGenerate, - PushPredicateThroughAggregate, ColumnPruning, CollapseProject) :: Nil diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala index 14fb72a8a3..d8cfec5391 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/PruneFiltersSuite.scala @@ -34,7 +34,7 @@ class PruneFiltersSuite extends PlanTest { Batch("Filter Pushdown and Pruning", Once, CombineFilters, PruneFilters, - PushPredicateThroughProject, + PushDownPredicate, PushPredicateThroughJoin) :: Nil } -- cgit v1.2.3 From b4819404a65f9b97c1f8deb1fcb8419969831574 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 14 Apr 2016 15:43:44 +0800 Subject: [SPARK-14596][SQL] Remove not used SqlNewHadoopRDD and some more unused imports ## What changes were proposed in this pull request? Old `HadoopFsRelation` API includes `buildInternalScan()` which uses `SqlNewHadoopRDD` in `ParquetRelation`. Because now the old API is removed, `SqlNewHadoopRDD` is not used anymore. So, this PR removes `SqlNewHadoopRDD` and several unused imports. This was discussed in https://github.com/apache/spark/pull/12326. ## How was this patch tested? Several related existing unit tests and `sbt scalastyle`. Author: hyukjinkwon Closes #12354 from HyukjinKwon/SPARK-14596. --- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 8 +- .../org/apache/spark/rdd/InputFileNameHolder.scala | 41 +++ .../apache/spark/rdd/SqlNewHadoopRDDState.scala | 41 --- project/MimaExcludes.scala | 5 - .../sql/catalyst/expressions/InputFileName.scala | 8 +- .../sql/execution/datasources/FileScanRDD.scala | 11 +- .../execution/datasources/SqlNewHadoopRDD.scala | 282 --------------------- .../org/apache/spark/sql/internal/SQLConf.scala | 1 - .../datasources/FileSourceStrategySuite.scala | 5 +- 9 files changed, 54 insertions(+), 348 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala delete mode 100644 core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala (limited to 'sql/catalyst') diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index 08db96edd6..ac5ba9e79f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -213,15 +213,13 @@ class HadoopRDD[K, V]( logInfo("Input split: " + split.inputSplit) val jobConf = getJobConf() - // TODO: there is a lot of duplicate code between this and NewHadoopRDD and SqlNewHadoopRDD - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) val existingBytesRead = inputMetrics.bytesRead // Sets the thread local variable for the file's name split.inputSplit.value match { - case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDDState.unsetInputFileName() + case fs: FileSplit => InputFileNameHolder.setInputFileName(fs.getPath.toString) + case _ => InputFileNameHolder.unsetInputFileName() } // Find a function that will return the FileSystem bytes read by this thread. Do this before @@ -271,7 +269,7 @@ class HadoopRDD[K, V]( override def close() { if (reader != null) { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() // Close the reader and release it. Note: it's very important that we don't close the // reader more than once, since that exposes us to MAPREDUCE-5918 when running against // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic diff --git a/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala new file mode 100644 index 0000000000..108e9d2558 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/InputFileNameHolder.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.apache.spark.unsafe.types.UTF8String + +/** + * This holds file names of the current Spark task. This is used in HadoopRDD, + * FileScanRDD and InputFileName function in Spark SQL. + */ +private[spark] object InputFileNameHolder { + /** + * The thread variable for the name of the current file being read. This is used by + * the InputFileName function in Spark SQL. + */ + private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { + override protected def initialValue(): UTF8String = UTF8String.fromString("") + } + + def getInputFileName(): UTF8String = inputFileName.get() + + private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) + + private[spark] def unsetInputFileName(): Unit = inputFileName.remove() + +} diff --git a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala b/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala deleted file mode 100644 index 3f15fff793..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/SqlNewHadoopRDDState.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.unsafe.types.UTF8String - -/** - * State for SqlNewHadoopRDD objects. This is split this way because of the package splits. - * TODO: Move/Combine this with org.apache.spark.sql.datasources.SqlNewHadoopRDD - */ -private[spark] object SqlNewHadoopRDDState { - /** - * The thread variable for the name of the current file being read. This is used by - * the InputFileName function in Spark SQL. - */ - private[this] val inputFileName: ThreadLocal[UTF8String] = new ThreadLocal[UTF8String] { - override protected def initialValue(): UTF8String = UTF8String.fromString("") - } - - def getInputFileName(): UTF8String = inputFileName.get() - - private[spark] def setInputFileName(file: String) = inputFileName.set(UTF8String.fromString(file)) - - private[spark] def unsetInputFileName(): Unit = inputFileName.remove() - -} diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index a30581eb48..313bf93b5d 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -847,7 +847,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoDataSource$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopPartition"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DefaultWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$PartitionValues"), @@ -856,10 +855,8 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DynamicPartitionWriterContainer"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsingAsSelect"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitioningUtils$"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreInsertCastAndRename"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Partition$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LogicalRelation$"), @@ -870,7 +867,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PreWriteCheck"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.RefreshTable"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsing"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CreateTempTableUsingAsSelect$"), @@ -884,7 +880,6 @@ object MimaExcludes { ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.CaseInsensitiveMap"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.InsertIntoHadoopFsRelation$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DataSourceStrategy"), - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.SqlNewHadoopRDD$NewHadoopMapPartitionsWithSplitRDD$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.PartitionSpec$"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DescribeCommand"), ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.DDLException"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index dbd0acf06c..2ed6fc0d38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -17,14 +17,14 @@ package org.apache.spark.sql.catalyst.expressions -import org.apache.spark.rdd.SqlNewHadoopRDDState +import org.apache.spark.rdd.InputFileNameHolder import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.types.{DataType, StringType} import org.apache.spark.unsafe.types.UTF8String /** - * Expression that returns the name of the current file being read in using [[SqlNewHadoopRDD]] + * Expression that returns the name of the current file being read. */ @ExpressionDescription( usage = "_FUNC_() - Returns the name of the current file being read if available", @@ -40,12 +40,12 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override protected def initInternal(): Unit = {} override protected def evalInternal(input: InternalRow): UTF8String = { - SqlNewHadoopRDDState.getInputFileName() + InputFileNameHolder.getInputFileName() } override def genCode(ctx: CodegenContext, ev: ExprCode): String = { ev.isNull = "false" s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();" } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 988c785dbe..468e101fed 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.execution.datasources import org.apache.spark.{Partition, TaskContext} -import org.apache.spark.rdd.{RDD, SqlNewHadoopRDDState} +import org.apache.spark.rdd.{InputFileNameHolder, RDD} import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow @@ -37,7 +37,6 @@ case class PartitionedFile( } } - /** * A collection of files that should be read as a single task possibly from multiple partitioned * directories. @@ -50,7 +49,7 @@ class FileScanRDD( @transient val sqlContext: SQLContext, readFunction: (PartitionedFile) => Iterator[InternalRow], @transient val filePartitions: Seq[FilePartition]) - extends RDD[InternalRow](sqlContext.sparkContext, Nil) { + extends RDD[InternalRow](sqlContext.sparkContext, Nil) { override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = { val iterator = new Iterator[Object] with AutoCloseable { @@ -65,17 +64,17 @@ class FileScanRDD( if (files.hasNext) { val nextFile = files.next() logInfo(s"Reading File $nextFile") - SqlNewHadoopRDDState.setInputFileName(nextFile.filePath) + InputFileNameHolder.setInputFileName(nextFile.filePath) currentIterator = readFunction(nextFile) hasNext } else { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() false } } override def close() = { - SqlNewHadoopRDDState.unsetInputFileName() + InputFileNameHolder.unsetInputFileName() } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala deleted file mode 100644 index 4d6864d8ba..0000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SqlNewHadoopRDD.scala +++ /dev/null @@ -1,282 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import java.text.SimpleDateFormat -import java.util.Date - -import scala.reflect.ClassTag - -import org.apache.hadoop.conf.{Configurable, Configuration} -import org.apache.hadoop.io.Writable -import org.apache.hadoop.mapreduce._ -import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, FileSplit} -import org.apache.hadoop.mapreduce.task.{JobContextImpl, TaskAttemptContextImpl} - -import org.apache.spark.{Partition => SparkPartition, _} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.spark.executor.DataReadMethod -import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{SerializableConfiguration, ShutdownHookManager} - -private[spark] class SqlNewHadoopPartition( - rddId: Int, - val index: Int, - rawSplit: InputSplit with Writable) - extends SparkPartition { - - val serializableHadoopSplit = new SerializableWritable(rawSplit) - - override def hashCode(): Int = 41 * (41 + rddId) + index -} - -/** - * An RDD that provides core functionality for reading data stored in Hadoop (e.g., files in HDFS, - * sources in HBase, or S3), using the new MapReduce API (`org.apache.hadoop.mapreduce`). - * It is based on [[org.apache.spark.rdd.NewHadoopRDD]]. It has three additions. - * 1. A shared broadcast Hadoop Configuration. - * 2. An optional closure `initDriverSideJobFuncOpt` that set configurations at the driver side - * to the shared Hadoop Configuration. - * 3. An optional closure `initLocalJobFuncOpt` that set configurations at both the driver side - * and the executor side to the shared Hadoop Configuration. - * - * Note: This is RDD is basically a cloned version of [[org.apache.spark.rdd.NewHadoopRDD]] with - * changes based on [[org.apache.spark.rdd.HadoopRDD]]. - */ -private[spark] class SqlNewHadoopRDD[V: ClassTag]( - sqlContext: SQLContext, - broadcastedConf: Broadcast[SerializableConfiguration], - @transient private val initDriverSideJobFuncOpt: Option[Job => Unit], - initLocalJobFuncOpt: Option[Job => Unit], - inputFormatClass: Class[_ <: InputFormat[Void, V]], - valueClass: Class[V]) - extends RDD[V](sqlContext.sparkContext, Nil) with Logging { - - protected def getJob(): Job = { - val conf = broadcastedConf.value.value - // "new Job" will make a copy of the conf. Then, it is - // safe to mutate conf properties with initLocalJobFuncOpt - // and initDriverSideJobFuncOpt. - val newJob = Job.getInstance(conf) - initLocalJobFuncOpt.map(f => f(newJob)) - newJob - } - - def getConf(isDriverSide: Boolean): Configuration = { - val job = getJob() - if (isDriverSide) { - initDriverSideJobFuncOpt.map(f => f(job)) - } - job.getConfiguration - } - - private val jobTrackerId: String = { - val formatter = new SimpleDateFormat("yyyyMMddHHmm") - formatter.format(new Date()) - } - - @transient protected val jobId = new JobID(jobTrackerId, id) - - override def getPartitions: Array[SparkPartition] = { - val conf = getConf(isDriverSide = true) - val inputFormat = inputFormatClass.newInstance - inputFormat match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val jobContext = new JobContextImpl(conf, jobId) - val rawSplits = inputFormat.getSplits(jobContext).toArray - val result = new Array[SparkPartition](rawSplits.size) - for (i <- 0 until rawSplits.size) { - result(i) = - new SqlNewHadoopPartition(id, i, rawSplits(i).asInstanceOf[InputSplit with Writable]) - } - result - } - - override def compute( - theSplit: SparkPartition, - context: TaskContext): Iterator[V] = { - val iter = new Iterator[V] { - val split = theSplit.asInstanceOf[SqlNewHadoopPartition] - logInfo("Input split: " + split.serializableHadoopSplit) - val conf = getConf(isDriverSide = false) - - val inputMetrics = context.taskMetrics().registerInputMetrics(DataReadMethod.Hadoop) - val existingBytesRead = inputMetrics.bytesRead - - // Sets the thread local variable for the file's name - split.serializableHadoopSplit.value match { - case fs: FileSplit => SqlNewHadoopRDDState.setInputFileName(fs.getPath.toString) - case _ => SqlNewHadoopRDDState.unsetInputFileName() - } - - // Find a function that will return the FileSystem bytes read by this thread. Do this before - // creating RecordReader, because RecordReader's constructor might read some bytes - val getBytesReadCallback: Option[() => Long] = split.serializableHadoopSplit.value match { - case _: FileSplit | _: CombineFileSplit => - SparkHadoopUtil.get.getFSBytesReadOnThreadCallback() - case _ => None - } - - // For Hadoop 2.5+, we get our input bytes from thread-local Hadoop FileSystem statistics. - // If we do a coalesce, however, we are likely to compute multiple partitions in the same - // task and in the same thread, in which case we need to avoid override values written by - // previous partitions (SPARK-13071). - def updateBytesRead(): Unit = { - getBytesReadCallback.foreach { getBytesRead => - inputMetrics.setBytesRead(existingBytesRead + getBytesRead()) - } - } - - val format = inputFormatClass.newInstance - format match { - case configurable: Configurable => - configurable.setConf(conf) - case _ => - } - val attemptId = new TaskAttemptID(jobTrackerId, id, TaskType.MAP, split.index, 0) - val hadoopAttemptContext = new TaskAttemptContextImpl(conf, attemptId) - private[this] var reader: RecordReader[Void, V] = format.createRecordReader( - split.serializableHadoopSplit.value, hadoopAttemptContext) - reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext) - - // Register an on-task-completion callback to close the input stream. - context.addTaskCompletionListener(context => close()) - - private[this] var havePair = false - private[this] var finished = false - - override def hasNext: Boolean = { - if (context.isInterrupted()) { - throw new TaskKilledException - } - if (!finished && !havePair) { - finished = !reader.nextKeyValue - if (finished) { - // Close and release the reader here; close() will also be called when the task - // completes, but for tasks that read from many files, it helps to release the - // resources early. - close() - } - havePair = !finished - } - !finished - } - - override def next(): V = { - if (!hasNext) { - throw new java.util.NoSuchElementException("End of stream") - } - havePair = false - if (!finished) { - inputMetrics.incRecordsReadInternal(1) - } - if (inputMetrics.recordsRead % SparkHadoopUtil.UPDATE_INPUT_METRICS_INTERVAL_RECORDS == 0) { - updateBytesRead() - } - reader.getCurrentValue - } - - private def close() { - if (reader != null) { - SqlNewHadoopRDDState.unsetInputFileName() - // Close the reader and release it. Note: it's very important that we don't close the - // reader more than once, since that exposes us to MAPREDUCE-5918 when running against - // Hadoop 1.x and older Hadoop 2.x releases. That bug can lead to non-deterministic - // corruption issues when reading compressed input. - try { - reader.close() - } catch { - case e: Exception => - if (!ShutdownHookManager.inShutdown()) { - logWarning("Exception in RecordReader.close()", e) - } - } finally { - reader = null - } - if (getBytesReadCallback.isDefined) { - updateBytesRead() - } else if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit] || - split.serializableHadoopSplit.value.isInstanceOf[CombineFileSplit]) { - // If we can't get the bytes read from the FS stats, fall back to the split size, - // which may be inaccurate. - try { - inputMetrics.incBytesReadInternal(split.serializableHadoopSplit.value.getLength) - } catch { - case e: java.io.IOException => - logWarning("Unable to get input size to set InputMetrics for task", e) - } - } - } - } - } - iter - } - - override def getPreferredLocations(hsplit: SparkPartition): Seq[String] = { - val split = hsplit.asInstanceOf[SqlNewHadoopPartition].serializableHadoopSplit.value - val locs = HadoopRDD.SPLIT_INFO_REFLECTIONS match { - case Some(c) => - try { - val infos = c.newGetLocationInfo.invoke(split).asInstanceOf[Array[AnyRef]] - Some(HadoopRDD.convertSplitLocationInfo(infos)) - } catch { - case e : Exception => - logDebug("Failed to use InputSplit#getLocationInfo.", e) - None - } - case None => None - } - locs.getOrElse(split.getLocations.filter(_ != "localhost")) - } - - override def persist(storageLevel: StorageLevel): this.type = { - if (storageLevel.deserialized) { - logWarning("Caching NewHadoopRDDs as deserialized objects usually leads to undesired" + - " behavior because Hadoop's RecordReader reuses the same Writable object for all records." + - " Use a map transformation to make copies of the records.") - } - super.persist(storageLevel) - } - - /** - * Analogous to [[org.apache.spark.rdd.MapPartitionsRDD]], but passes in an InputSplit to - * the given function rather than the index of the partition. - */ - private[spark] class NewHadoopMapPartitionsWithSplitRDD[U: ClassTag, T: ClassTag]( - prev: RDD[T], - f: (InputSplit, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean = false) - extends RDD[U](prev) { - - override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None - - override def getPartitions: Array[SparkPartition] = firstParent[T].partitions - - override def compute(split: SparkPartition, context: TaskContext): Iterator[U] = { - val partition = split.asInstanceOf[SqlNewHadoopPartition] - val inputSplit = partition.serializableHadoopSplit.value - f(inputSplit, firstParent[T].iterator(split, context)) - } - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index e74fb00cb2..2f9d63c2e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -28,7 +28,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 0b74f07540..dac56d3936 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -22,8 +22,6 @@ import java.io.File import org.apache.hadoop.fs.FileStatus import org.apache.hadoop.mapreduce.Job -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionSet, PredicateHelper} @@ -34,8 +32,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} -import org.apache.spark.util.{SerializableConfiguration, Utils} -import org.apache.spark.util.collection.BitSet +import org.apache.spark.util.Utils class FileSourceStrategySuite extends QueryTest with SharedSQLContext with PredicateHelper { import testImplicits._ -- cgit v1.2.3 From 6fc3dc8839eaed673c64ec87af6dfe24f8cebe0c Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Thu, 14 Apr 2016 09:43:41 +0100 Subject: [MINOR][SQL] Remove extra anonymous closure within functional transformations ## What changes were proposed in this pull request? This PR removes extra anonymous closure within functional transformations. For example, ```scala .map(item => { ... }) ``` which can be just simply as below: ```scala .map { item => ... } ``` ## How was this patch tested? Related unit tests and `sbt scalastyle`. Author: hyukjinkwon Closes #12382 from HyukjinKwon/minor-extra-closers. --- .../main/scala/org/apache/spark/SparkContext.scala | 4 ++-- .../org/apache/spark/deploy/master/Master.scala | 4 ++-- .../main/scala/org/apache/spark/rdd/BlockRDD.scala | 4 ++-- .../scala/org/apache/spark/rdd/HadoopRDD.scala | 4 ++-- .../apache/spark/rdd/ParallelCollectionRDD.scala | 4 ++-- .../spark/rdd/PartitionerAwareUnionRDD.scala | 4 ++-- .../cluster/mesos/MesosSchedulerUtils.scala | 4 ++-- .../spark/shuffle/BlockStoreShuffleReader.scala | 4 ++-- .../org/apache/spark/ui/jobs/ExecutorTable.scala | 4 ++-- .../streaming/RecoverableNetworkWordCount.scala | 4 ++-- .../examples/streaming/SqlNetworkWordCount.scala | 4 ++-- .../flume/sink/SparkAvroCallbackHandler.scala | 4 ++-- .../spark/streaming/flume/sink/SparkSink.scala | 12 +++++----- .../flume/sink/TransactionProcessor.scala | 12 +++++----- .../streaming/flume/FlumePollingInputDStream.scala | 4 ++-- .../streaming/flume/PollingFlumeTestUtils.scala | 4 ++-- .../expressions/codegen/CodeGenerator.scala | 4 ++-- .../spark/sql/catalyst/optimizer/Optimizer.scala | 4 ++-- .../spark/sql/execution/basicOperators.scala | 4 ++-- .../sql/execution/datasources/jdbc/JdbcUtils.scala | 4 ++-- .../org/apache/spark/sql/hive/HiveInspectors.scala | 27 +++++++++------------- .../org/apache/spark/streaming/Checkpoint.scala | 8 +++---- .../spark/streaming/dstream/StateDStream.scala | 4 ++-- .../streaming/scheduler/ReceiverTracker.scala | 4 ++-- 24 files changed, 67 insertions(+), 72 deletions(-) (limited to 'sql/catalyst') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 966198dd5e..e41088f7c8 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -723,7 +723,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli (safeEnd - safeStart) / step + 1 } } - parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex((i, _) => { + parallelize(0 until numSlices, numSlices).mapPartitionsWithIndex { (i, _) => val partitionStart = (i * numElements) / numSlices * step + start val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start def getSafeMargin(bi: BigInt): Long = @@ -762,7 +762,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli ret } } - }) + } } /** Distribute a local Scala collection to form an RDD. diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 9bd3fc1033..b443e8f051 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -843,10 +843,10 @@ private[deploy] class Master( addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) - completedApps.take(toRemove).foreach( a => { + completedApps.take(toRemove).foreach { a => Option(appIdToUI.remove(a.id)).foreach { ui => webUi.detachSparkUI(ui) } applicationMetricsSystem.removeSource(a.appSource) - }) + } completedApps.trimStart(toRemove) } completedApps += app // Remember it in our history diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 8358244987..63d1d1767a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -35,9 +35,9 @@ class BlockRDD[T: ClassTag](sc: SparkContext, @transient val blockIds: Array[Blo override def getPartitions: Array[Partition] = { assertValid() - (0 until blockIds.length).map(i => { + (0 until blockIds.length).map { i => new BlockRDDPartition(blockIds(i), i).asInstanceOf[Partition] - }).toArray + }.toArray } override def compute(split: Partition, context: TaskContext): Iterator[T] = { diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala index ac5ba9e79f..f7c646c668 100644 --- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala @@ -422,7 +422,7 @@ private[spark] object HadoopRDD extends Logging { private[spark] def convertSplitLocationInfo(infos: Array[AnyRef]): Seq[String] = { val out = ListBuffer[String]() - infos.foreach { loc => { + infos.foreach { loc => val locationStr = HadoopRDD.SPLIT_INFO_REFLECTIONS.get. getLocation.invoke(loc).asInstanceOf[String] if (locationStr != "localhost") { @@ -434,7 +434,7 @@ private[spark] object HadoopRDD extends Logging { out += new HostTaskLocation(locationStr).toString } } - }} + } out.seq } } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index 462fb39ea2..bb84e4af15 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -121,11 +121,11 @@ private object ParallelCollectionRDD { // Sequences need to be sliced at the same set of index positions for operations // like RDD.zip() to behave as expected def positions(length: Long, numSlices: Int): Iterator[(Int, Int)] = { - (0 until numSlices).iterator.map(i => { + (0 until numSlices).iterator.map { i => val start = ((i * length) / numSlices).toInt val end = (((i + 1) * length) / numSlices).toInt (start, end) - }) + } } seq match { case r: Range => diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala index c3579d761d..0abba15bec 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -68,9 +68,9 @@ class PartitionerAwareUnionRDD[T: ClassTag]( override def getPartitions: Array[Partition] = { val numPartitions = partitioner.get.numPartitions - (0 until numPartitions).map(index => { + (0 until numPartitions).map { index => new PartitionerAwareUnionRDDPartition(rdds, index) - }).toArray + }.toArray } // Get the location where most of the partitions of parent RDDs are located diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 7295d50682..1e322ac679 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -226,7 +226,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { * @return */ protected def toAttributeMap(offerAttributes: JList[Attribute]): Map[String, GeneratedMessage] = { - offerAttributes.asScala.map(attr => { + offerAttributes.asScala.map { attr => val attrValue = attr.getType match { case Value.Type.SCALAR => attr.getScalar case Value.Type.RANGES => attr.getRanges @@ -234,7 +234,7 @@ private[mesos] trait MesosSchedulerUtils extends Logging { case Value.Type.TEXT => attr.getText } (attr.getName, attrValue) - }).toMap + }.toMap } diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala index 637b2dfc19..876cdfaa87 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala @@ -69,10 +69,10 @@ private[spark] class BlockStoreShuffleReader[K, C]( // Update the context task metrics for each record read. val readMetrics = context.taskMetrics.registerTempShuffleReadMetrics() val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]]( - recordIter.map(record => { + recordIter.map { record => readMetrics.incRecordsRead(1) record - }), + }, context.taskMetrics().mergeShuffleReadMetrics()) // An interruptible iterator must be used here in order to support task cancellation diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala index 1304efd8f2..f609fb4cd2 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala @@ -42,13 +42,13 @@ private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: Stage var hasShuffleWrite = false var hasShuffleRead = false var hasBytesSpilled = false - stageData.foreach(data => { + stageData.foreach { data => hasInput = data.hasInput hasOutput = data.hasOutput hasShuffleRead = data.hasShuffleRead hasShuffleWrite = data.hasShuffleWrite hasBytesSpilled = data.hasBytesSpilled - }) + }
diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index b6b8bc33f7..bb2af9cd72 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -116,7 +116,7 @@ object RecoverableNetworkWordCount { val lines = ssc.socketTextStream(ip, port) val words = lines.flatMap(_.split(" ")) val wordCounts = words.map(x => (x, 1)).reduceByKey(_ + _) - wordCounts.foreachRDD((rdd: RDD[(String, Int)], time: Time) => { + wordCounts.foreachRDD { (rdd: RDD[(String, Int)], time: Time) => // Get or register the blacklist Broadcast val blacklist = WordBlacklist.getInstance(rdd.sparkContext) // Get or register the droppedWordsCounter Accumulator @@ -135,7 +135,7 @@ object RecoverableNetworkWordCount { println("Dropped " + droppedWordsCounter.value + " word(s) totally") println("Appending to " + outputFile.getAbsolutePath) Files.append(output + "\n", outputFile, Charset.defaultCharset()) - }) + } ssc } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala index 3727f8fe6a..918e124065 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/SqlNetworkWordCount.scala @@ -59,7 +59,7 @@ object SqlNetworkWordCount { val words = lines.flatMap(_.split(" ")) // Convert RDDs of the words DStream to DataFrame and run SQL query - words.foreachRDD((rdd: RDD[String], time: Time) => { + words.foreachRDD { (rdd: RDD[String], time: Time) => // Get the singleton instance of SQLContext val sqlContext = SQLContextSingleton.getInstance(rdd.sparkContext) import sqlContext.implicits._ @@ -75,7 +75,7 @@ object SqlNetworkWordCount { sqlContext.sql("select word, count(*) as total from words group by word") println(s"========= $time =========") wordCountsDataFrame.show() - }) + } ssc.start() ssc.awaitTermination() diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala index 719fca0938..8050ec357e 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkAvroCallbackHandler.scala @@ -129,9 +129,9 @@ private[flume] class SparkAvroCallbackHandler(val threads: Int, val channel: Cha * @param success Whether the batch was successful or not. */ private def completeTransaction(sequenceNumber: CharSequence, success: Boolean) { - removeAndGetProcessor(sequenceNumber).foreach(processor => { + removeAndGetProcessor(sequenceNumber).foreach { processor => processor.batchProcessed(success) - }) + } } /** diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala index 14dffb15fe..41f27e9376 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/SparkSink.scala @@ -88,23 +88,23 @@ class SparkSink extends AbstractSink with Logging with Configurable { // dependencies which are being excluded in the build. In practice, // Netty dependencies are already available on the JVM as Flume would have pulled them in. serverOpt = Option(new NettyServer(responder, new InetSocketAddress(hostname, port))) - serverOpt.foreach(server => { + serverOpt.foreach { server => logInfo("Starting Avro server for sink: " + getName) server.start() - }) + } super.start() } override def stop() { logInfo("Stopping Spark Sink: " + getName) - handler.foreach(callbackHandler => { + handler.foreach { callbackHandler => callbackHandler.shutdown() - }) - serverOpt.foreach(server => { + } + serverOpt.foreach { server => logInfo("Stopping Avro Server for sink: " + getName) server.close() server.join() - }) + } blockingLatch.countDown() super.stop() } diff --git a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala index b15c2097e5..19e736f016 100644 --- a/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala +++ b/external/flume-sink/src/main/scala/org/apache/spark/streaming/flume/sink/TransactionProcessor.scala @@ -110,7 +110,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, eventBatch.setErrorMsg("Something went wrong. Channel was " + "unable to create a transaction!") } - txOpt.foreach(tx => { + txOpt.foreach { tx => tx.begin() val events = new util.ArrayList[SparkSinkEvent](maxBatchSize) val loop = new Breaks @@ -145,7 +145,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // At this point, the events are available, so fill them into the event batch eventBatch = new EventBatch("", seqNum, events) } - }) + } } catch { case interrupted: InterruptedException => // Don't pollute logs if the InterruptedException came from this being stopped @@ -156,9 +156,9 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, logWarning("Error while processing transaction.", e) eventBatch.setErrorMsg(e.getMessage) try { - txOpt.foreach(tx => { + txOpt.foreach { tx => rollbackAndClose(tx, close = true) - }) + } } finally { txOpt = None } @@ -174,7 +174,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, */ private def processAckOrNack() { batchAckLatch.await(transactionTimeout, TimeUnit.SECONDS) - txOpt.foreach(tx => { + txOpt.foreach { tx => if (batchSuccess) { try { logDebug("Committing transaction") @@ -197,7 +197,7 @@ private class TransactionProcessor(val channel: Channel, val seqNum: String, // cause issues. This is required to ensure the TransactionProcessor instance is not leaked parent.removeAndGetProcessor(seqNum) } - }) + } } /** diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala index 250bfc1718..54565840fa 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/FlumePollingInputDStream.scala @@ -79,11 +79,11 @@ private[streaming] class FlumePollingReceiver( override def onStart(): Unit = { // Create the connections to each Flume agent. - addresses.foreach(host => { + addresses.foreach { host => val transceiver = new NettyTransceiver(host, channelFactory) val client = SpecificRequestor.getClient(classOf[SparkFlumeProtocol.Callback], transceiver) connections.add(new FlumeConnection(transceiver, client)) - }) + } for (i <- 0 until parallelism) { logInfo("Starting Flume Polling Receiver worker threads..") // Threads that pull data from Flume. diff --git a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala index 1a96df6e94..6a4dafb8ed 100644 --- a/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala +++ b/external/flume/src/main/scala/org/apache/spark/streaming/flume/PollingFlumeTestUtils.scala @@ -123,9 +123,9 @@ private[flume] class PollingFlumeTestUtils { val latch = new CountDownLatch(batchCount * channels.size) sinks.foreach(_.countdownWhenBatchReceived(latch)) - channels.foreach(channel => { + channels.foreach { channel => executorCompletion.submit(new TxnSubmitter(channel)) - }) + } for (i <- 0 until channels.size) { executorCompletion.take() diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index ee7f4fadca..f43626ca81 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -519,7 +519,7 @@ class CodegenContext { // Get all the expressions that appear at least twice and set up the state for subexpression // elimination. val commonExprs = equivalentExpressions.getAllEquivalentExprs.filter(_.size > 1) - commonExprs.foreach(e => { + commonExprs.foreach { e => val expr = e.head val fnName = freshName("evalExpr") val isNull = s"${fnName}IsNull" @@ -561,7 +561,7 @@ class CodegenContext { subexprFunctions += s"$fnName($INPUT_ROW);" val state = SubExprEliminationState(isNull, value) e.foreach(subExprEliminationExprs.put(_, state)) - }) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 438cbabdbb..aeb1842677 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -286,10 +286,10 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper { assert(children.nonEmpty) if (projectList.forall(_.deterministic)) { val newFirstChild = Project(projectList, children.head) - val newOtherChildren = children.tail.map ( child => { + val newOtherChildren = children.tail.map { child => val rewrites = buildRewrites(children.head, child) Project(projectList.map(pushToRight(_, rewrites)), child) - } ) + } Union(newFirstChild +: newOtherChildren) } else { p diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index aba500ad8d..344aaff348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -400,7 +400,7 @@ case class Range( sqlContext .sparkContext .parallelize(0 until numSlices, numSlices) - .mapPartitionsWithIndex((i, _) => { + .mapPartitionsWithIndex { (i, _) => val partitionStart = (i * numElements) / numSlices * step + start val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start def getSafeMargin(bi: BigInt): Long = @@ -444,7 +444,7 @@ case class Range( unsafeRow } } - }) + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala index b7ff5f7242..065c8572b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtils.scala @@ -251,12 +251,12 @@ object JdbcUtils extends Logging { def schemaString(df: DataFrame, url: String): String = { val sb = new StringBuilder() val dialect = JdbcDialects.get(url) - df.schema.fields foreach { field => { + df.schema.fields foreach { field => val name = field.name val typ: String = getJdbcType(field.dataType, dialect).databaseTypeDefinition val nullable = if (field.nullable) "" else "NOT NULL" sb.append(s", $name $typ $nullable") - }} + } if (sb.length < 2) "" else sb.substring(2) } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala index 589862c7c0..585befe378 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala @@ -450,9 +450,7 @@ private[hive] trait HiveInspectors { if (o != null) { val array = o.asInstanceOf[ArrayData] val values = new java.util.ArrayList[Any](array.numElements()) - array.foreach(elementType, (_, e) => { - values.add(wrapper(e)) - }) + array.foreach(elementType, (_, e) => values.add(wrapper(e))) values } else { null @@ -468,9 +466,8 @@ private[hive] trait HiveInspectors { if (o != null) { val map = o.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(mt.keyType, mt.valueType, (k, v) => { - jmap.put(keyWrapper(k), valueWrapper(v)) - }) + map.foreach(mt.keyType, mt.valueType, (k, v) => + jmap.put(keyWrapper(k), valueWrapper(v))) jmap } else { null @@ -587,9 +584,9 @@ private[hive] trait HiveInspectors { case x: ListObjectInspector => val list = new java.util.ArrayList[Object] val tpe = dataType.asInstanceOf[ArrayType].elementType - a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => { + a.asInstanceOf[ArrayData].foreach(tpe, (_, e) => list.add(wrap(e, x.getListElementObjectInspector, tpe)) - }) + ) list case x: MapObjectInspector => val keyType = dataType.asInstanceOf[MapType].keyType @@ -599,10 +596,10 @@ private[hive] trait HiveInspectors { // Some UDFs seem to assume we pass in a HashMap. val hashMap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { + map.foreach(keyType, valueType, (k, v) => hashMap.put(wrap(k, x.getMapKeyObjectInspector, keyType), wrap(v, x.getMapValueObjectInspector, valueType)) - }) + ) hashMap } @@ -704,9 +701,8 @@ private[hive] trait HiveInspectors { ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, null) } else { val list = new java.util.ArrayList[Object]() - value.asInstanceOf[ArrayData].foreach(dt, (_, e) => { - list.add(wrap(e, listObjectInspector, dt)) - }) + value.asInstanceOf[ArrayData].foreach(dt, (_, e) => + list.add(wrap(e, listObjectInspector, dt))) ObjectInspectorFactory.getStandardConstantListObjectInspector(listObjectInspector, list) } case Literal(value, MapType(keyType, valueType, _)) => @@ -718,9 +714,8 @@ private[hive] trait HiveInspectors { val map = value.asInstanceOf[MapData] val jmap = new java.util.HashMap[Any, Any](map.numElements()) - map.foreach(keyType, valueType, (k, v) => { - jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType)) - }) + map.foreach(keyType, valueType, (k, v) => + jmap.put(wrap(k, keyOI, keyType), wrap(v, valueOI, valueType))) ObjectInspectorFactory.getStandardConstantMapObjectInspector(keyOI, valueOI, jmap) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 5cc677d085..0395600954 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -247,10 +247,10 @@ class CheckpointWriter( // Delete old checkpoint files val allCheckpointFiles = Checkpoint.getCheckpointFiles(checkpointDir, Some(fs)) if (allCheckpointFiles.size > 10) { - allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach(file => { + allCheckpointFiles.take(allCheckpointFiles.size - 10).foreach { file => logInfo("Deleting " + file) fs.delete(file, true) - }) + } } // All done, print success @@ -345,7 +345,7 @@ object CheckpointReader extends Logging { // Try to read the checkpoint files in the order logInfo("Checkpoint files found: " + checkpointFiles.mkString(",")) var readError: Exception = null - checkpointFiles.foreach(file => { + checkpointFiles.foreach { file => logInfo("Attempting to load checkpoint from file " + file) try { val fis = fs.open(file) @@ -358,7 +358,7 @@ object CheckpointReader extends Logging { readError = e logWarning("Error reading checkpoint from file " + file, e) } - }) + } // If none of checkpoint files could be read, then throw exception if (!ignoreReadError) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala index 28aed0ca45..8efb09a8ce 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/StateDStream.scala @@ -48,11 +48,11 @@ class StateDStream[K: ClassTag, V: ClassTag, S: ClassTag]( // and then apply the update function val updateFuncLocal = updateFunc val finalFunc = (iterator: Iterator[(K, (Iterable[V], Iterable[S]))]) => { - val i = iterator.map(t => { + val i = iterator.map { t => val itr = t._2._2.iterator val headOption = if (itr.hasNext) Some(itr.next()) else None (t._1, t._2._1.toSeq, headOption) - }) + } updateFuncLocal(i) } val cogroupedRDD = parentRDD.cogroup(prevStateRDD, partitioner) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala index 3b33a979df..9aa2f0bbb9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/scheduler/ReceiverTracker.scala @@ -434,11 +434,11 @@ class ReceiverTracker(ssc: StreamingContext, skipReceiverLaunch: Boolean = false * worker nodes as a parallel collection, and runs them. */ private def launchReceivers(): Unit = { - val receivers = receiverInputStreams.map(nis => { + val receivers = receiverInputStreams.map { nis => val rcvr = nis.getReceiver() rcvr.setReceiverId(nis.id) rcvr - }) + } runDummySparkJob() -- cgit v1.2.3 From 3e27940a19e7bab448f1af11d2065ecd1ec66197 Mon Sep 17 00:00:00 2001 From: Liwei Lin Date: Thu, 14 Apr 2016 10:14:38 -0700 Subject: [SPARK-14630][BUILD][CORE][SQL][STREAMING] Code style: public abstract methods should have explicit return types ## What changes were proposed in this pull request? Currently many public abstract methods (in abstract classes as well as traits) don't declare return types explicitly, such as in [o.a.s.streaming.dstream.InputDStream](https://github.com/apache/spark/blob/master/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala#L110): ```scala def start() // should be: def start(): Unit def stop() // should be: def stop(): Unit ``` These methods exist in core, sql, streaming; this PR fixes them. ## How was this patch tested? N/A ## Which piece of scala style rule led to the changes? the rule was added separately in https://github.com/apache/spark/pull/12396 Author: Liwei Lin Closes #12389 from lw-lin/public-abstract-methods. --- core/src/main/scala/org/apache/spark/ContextCleaner.scala | 10 +++++----- core/src/main/scala/org/apache/spark/FutureAction.scala | 4 ++-- .../org/apache/spark/deploy/client/AppClientListener.scala | 3 ++- .../org/apache/spark/deploy/master/LeaderElectionAgent.scala | 4 ++-- .../org/apache/spark/deploy/master/PersistenceEngine.scala | 4 ++-- .../scala/org/apache/spark/deploy/worker/DriverRunner.scala | 2 +- .../main/scala/org/apache/spark/executor/ExecutorBackend.scala | 2 +- .../scala/org/apache/spark/network/BlockTransferService.scala | 2 +- .../main/scala/org/apache/spark/scheduler/JobListener.scala | 4 ++-- .../scala/org/apache/spark/scheduler/SchedulableBuilder.scala | 4 ++-- .../main/scala/org/apache/spark/scheduler/TaskScheduler.scala | 2 +- .../scala/org/apache/spark/serializer/KryoSerializer.scala | 2 +- .../org/apache/spark/shuffle/FileShuffleBlockResolver.scala | 2 +- core/src/main/scala/org/apache/spark/ui/WebUI.scala | 2 +- .../scala/org/apache/spark/util/logging/RollingPolicy.scala | 4 ++-- .../main/scala/org/apache/spark/util/random/Pseudorandom.scala | 2 +- .../spark/sql/catalyst/expressions/SpecificMutableRow.scala | 2 +- .../scala/org/apache/spark/sql/catalyst/expressions/rows.scala | 2 +- .../apache/spark/sql/execution/columnar/ColumnAccessor.scala | 2 +- .../apache/spark/sql/execution/columnar/ColumnBuilder.scala | 4 ++-- .../spark/sql/execution/streaming/state/StateStore.scala | 2 +- .../org/apache/spark/sql/util/ContinuousQueryListener.scala | 6 +++--- .../org/apache/spark/streaming/dstream/InputDStream.scala | 4 ++-- .../org/apache/spark/streaming/receiver/BlockGenerator.scala | 8 ++++---- .../apache/spark/streaming/receiver/ReceivedBlockHandler.scala | 2 +- .../scala/org/apache/spark/streaming/receiver/Receiver.scala | 4 ++-- .../apache/spark/streaming/receiver/ReceiverSupervisor.scala | 10 +++++----- 27 files changed, 50 insertions(+), 49 deletions(-) (limited to 'sql/catalyst') diff --git a/core/src/main/scala/org/apache/spark/ContextCleaner.scala b/core/src/main/scala/org/apache/spark/ContextCleaner.scala index 8fc657c5eb..76692ccec8 100644 --- a/core/src/main/scala/org/apache/spark/ContextCleaner.scala +++ b/core/src/main/scala/org/apache/spark/ContextCleaner.scala @@ -278,9 +278,9 @@ private object ContextCleaner { * Listener class used for testing when any item has been cleaned by the Cleaner class. */ private[spark] trait CleanerListener { - def rddCleaned(rddId: Int) - def shuffleCleaned(shuffleId: Int) - def broadcastCleaned(broadcastId: Long) - def accumCleaned(accId: Long) - def checkpointCleaned(rddId: Long) + def rddCleaned(rddId: Int): Unit + def shuffleCleaned(shuffleId: Int): Unit + def broadcastCleaned(broadcastId: Long): Unit + def accumCleaned(accId: Long): Unit + def checkpointCleaned(rddId: Long): Unit } diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index ce11772a6d..339266a5d4 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -41,7 +41,7 @@ trait FutureAction[T] extends Future[T] { /** * Cancels the execution of this action. */ - def cancel() + def cancel(): Unit /** * Blocks until this action completes. @@ -65,7 +65,7 @@ trait FutureAction[T] extends Future[T] { * When this action is completed, either through an exception, or a value, applies the provided * function. */ - def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext) + def onComplete[U](func: (Try[T]) => U)(implicit executor: ExecutionContext): Unit /** * Returns whether the action has already been completed with a value or an exception. diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala index e584952a9a..94506a0cbb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClientListener.scala @@ -33,7 +33,8 @@ private[spark] trait AppClientListener { /** An application death is an unrecoverable failure condition. */ def dead(reason: String): Unit - def executorAdded(fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int) + def executorAdded( + fullId: String, workerId: String, hostPort: String, cores: Int, memory: Int): Unit def executorRemoved(fullId: String, message: String, exitStatus: Option[Int]): Unit } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala index 70f21fbe0d..52e2854961 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/LeaderElectionAgent.scala @@ -32,8 +32,8 @@ trait LeaderElectionAgent { @DeveloperApi trait LeaderElectable { - def electedLeader() - def revokedLeadership() + def electedLeader(): Unit + def revokedLeadership(): Unit } /** Single-node implementation of LeaderElectionAgent -- we're initially and always the leader. */ diff --git a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala index dddf2be57e..b30bc821b7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/PersistenceEngine.scala @@ -40,12 +40,12 @@ abstract class PersistenceEngine { * Defines how the object is serialized and persisted. Implementation will * depend on the store used. */ - def persist(name: String, obj: Object) + def persist(name: String, obj: Object): Unit /** * Defines how the object referred by its name is removed from the store. */ - def unpersist(name: String) + def unpersist(name: String): Unit /** * Gives all objects, matching a prefix. This defines how objects are diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 9c6bc5c62f..aad2e91b25 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -218,7 +218,7 @@ private[deploy] class DriverRunner( } private[deploy] trait Sleeper { - def sleep(seconds: Int) + def sleep(seconds: Int): Unit } // Needed because ProcessBuilder is a final class and cannot be mocked diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala index e07cb31cbe..7153323d01 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorBackend.scala @@ -25,6 +25,6 @@ import org.apache.spark.TaskState.TaskState * A pluggable interface used by the Executor to send updates to the cluster scheduler. */ private[spark] trait ExecutorBackend { - def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) + def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer): Unit } diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala index e43e3a2de2..09ce012e4e 100644 --- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala +++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala @@ -36,7 +36,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo * Initialize the transfer service by giving it the BlockDataManager that can be used to fetch * local blocks or put local blocks. */ - def init(blockDataManager: BlockDataManager) + def init(blockDataManager: BlockDataManager): Unit /** * Tear down the transfer service. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala index 50c2b9acd6..e0f7c8f021 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobListener.scala @@ -23,6 +23,6 @@ package org.apache.spark.scheduler * job fails (and no further taskSucceeded events will happen). */ private[spark] trait JobListener { - def taskSucceeded(index: Int, result: Any) - def jobFailed(exception: Exception) + def taskSucceeded(index: Int, result: Any): Unit + def jobFailed(exception: Exception): Unit } diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 5baebe8c1f..100ed76ecb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -34,9 +34,9 @@ import org.apache.spark.util.Utils private[spark] trait SchedulableBuilder { def rootPool: Pool - def buildPools() + def buildPools(): Unit - def addTaskSetManager(manager: Schedulable, properties: Properties) + def addTaskSetManager(manager: Schedulable, properties: Properties): Unit } private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 8477a66b39..647d44a0f0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -51,7 +51,7 @@ private[spark] trait TaskScheduler { def submitTasks(taskSet: TaskSet): Unit // Cancel a stage. - def cancelTasks(stageId: Int, interruptThread: Boolean) + def cancelTasks(stageId: Int, interruptThread: Boolean): Unit // Set the DAG scheduler for upcalls. This is guaranteed to be set before submitTasks is called. def setDAGScheduler(dagScheduler: DAGScheduler): Unit diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 3d090a4353..918ae376f6 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -357,7 +357,7 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ * serialization. */ trait KryoRegistrator { - def registerClasses(kryo: Kryo) + def registerClasses(kryo: Kryo): Unit } private[serializer] object KryoSerializer { diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala index 6cd7d69518..be1e84a2ba 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala @@ -35,7 +35,7 @@ private[spark] trait ShuffleWriterGroup { val writers: Array[DiskBlockObjectWriter] /** @param success Indicates all writes were successful. If false, no blocks will be recorded. */ - def releaseWriters(success: Boolean) + def releaseWriters(success: Boolean): Unit } /** diff --git a/core/src/main/scala/org/apache/spark/ui/WebUI.scala b/core/src/main/scala/org/apache/spark/ui/WebUI.scala index 3939b111b5..2b0bc32cf6 100644 --- a/core/src/main/scala/org/apache/spark/ui/WebUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/WebUI.scala @@ -129,7 +129,7 @@ private[spark] abstract class WebUI( } /** Initialize all components of the server. */ - def initialize() + def initialize(): Unit /** Bind to the HTTP server behind this web interface. */ def bind() { diff --git a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala index b34880d3a7..6e80db2f51 100644 --- a/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala +++ b/core/src/main/scala/org/apache/spark/util/logging/RollingPolicy.scala @@ -32,10 +32,10 @@ private[spark] trait RollingPolicy { def shouldRollover(bytesToBeWritten: Long): Boolean /** Notify that rollover has occurred */ - def rolledOver() + def rolledOver(): Unit /** Notify that bytes have been written */ - def bytesWritten(bytes: Long) + def bytesWritten(bytes: Long): Unit /** Get the desired name of the rollover file */ def generateRolledOverFileSuffix(): String diff --git a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala index 70f3dd62b9..41f28f6e51 100644 --- a/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala +++ b/core/src/main/scala/org/apache/spark/util/random/Pseudorandom.scala @@ -26,5 +26,5 @@ import org.apache.spark.annotation.DeveloperApi @DeveloperApi trait Pseudorandom { /** Set random seed. */ - def setSeed(seed: Long) + def setSeed(seed: Long): Unit } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala index 4615c55d67..61ca7272df 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SpecificMutableRow.scala @@ -62,7 +62,7 @@ import org.apache.spark.sql.types._ abstract class MutableValue extends Serializable { var isNull: Boolean = true def boxed: Any - def update(v: Any) + def update(v: Any): Unit def copy(): MutableValue } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala index be6b2530ef..93a8278528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/rows.scala @@ -164,7 +164,7 @@ trait BaseGenericInternalRow extends InternalRow { abstract class MutableRow extends InternalRow { def setNullAt(i: Int): Unit - def update(i: Int, value: Any) + def update(i: Int, value: Any): Unit // default implementation (slow) def setBoolean(i: Int, value: Boolean): Unit = { update(i, value) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala index 78664baa56..7cde04b626 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnAccessor.scala @@ -38,7 +38,7 @@ private[columnar] trait ColumnAccessor { def hasNext: Boolean - def extractTo(row: MutableRow, ordinal: Int) + def extractTo(row: MutableRow, ordinal: Int): Unit protected def underlyingBuffer: ByteBuffer } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala index 9a173367f4..d30655e0c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/ColumnBuilder.scala @@ -28,12 +28,12 @@ private[columnar] trait ColumnBuilder { /** * Initializes with an approximate lower bound on the expected number of elements in this column. */ - def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false) + def initialize(initialSize: Int, columnName: String = "", useCompression: Boolean = false): Unit /** * Appends `row(ordinal)` to the column builder. */ - def appendFrom(row: InternalRow, ordinal: Int) + def appendFrom(row: InternalRow, ordinal: Int): Unit /** * Column statistics information diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala index cc5327e0e2..9521506325 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/state/StateStore.scala @@ -50,7 +50,7 @@ trait StateStore { def get(key: UnsafeRow): Option[UnsafeRow] /** Put a new value for a key. */ - def put(key: UnsafeRow, value: UnsafeRow) + def put(key: UnsafeRow, value: UnsafeRow): Unit /** * Remove keys that match the following condition. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala index bf78be9d9f..ba1facf11b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/ContinuousQueryListener.scala @@ -37,7 +37,7 @@ abstract class ContinuousQueryListener { * `DataFrameWriter.startStream()` returns the corresponding [[ContinuousQuery]]. Please * don't block this method as it will block your query. */ - def onQueryStarted(queryStarted: QueryStarted) + def onQueryStarted(queryStarted: QueryStarted): Unit /** * Called when there is some status update (ingestion rate updated, etc.) @@ -47,10 +47,10 @@ abstract class ContinuousQueryListener { * may be changed before/when you process the event. E.g., you may find [[ContinuousQuery]] * is terminated when you are processing [[QueryProgress]]. */ - def onQueryProgress(queryProgress: QueryProgress) + def onQueryProgress(queryProgress: QueryProgress): Unit /** Called when a query is stopped, with or without error */ - def onQueryTerminated(queryTerminated: QueryTerminated) + def onQueryTerminated(queryTerminated: QueryTerminated): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala index dc88349db5..a3c125c306 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/InputDStream.scala @@ -107,8 +107,8 @@ abstract class InputDStream[T: ClassTag](_ssc: StreamingContext) } /** Method called to start receiving data. Subclasses must implement this method. */ - def start() + def start(): Unit /** Method called to stop receiving data. Subclasses must implement this method. */ - def stop() + def stop(): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala index e42bea6ec6..4592e015ed 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/BlockGenerator.scala @@ -37,7 +37,7 @@ private[streaming] trait BlockGeneratorListener { * that will be useful when a block is generated. Any long blocking operation in this callback * will hurt the throughput. */ - def onAddData(data: Any, metadata: Any) + def onAddData(data: Any, metadata: Any): Unit /** * Called when a new block of data is generated by the block generator. The block generation @@ -47,7 +47,7 @@ private[streaming] trait BlockGeneratorListener { * be useful when the block has been successfully stored. Any long blocking operation in this * callback will hurt the throughput. */ - def onGenerateBlock(blockId: StreamBlockId) + def onGenerateBlock(blockId: StreamBlockId): Unit /** * Called when a new block is ready to be pushed. Callers are supposed to store the block into @@ -55,13 +55,13 @@ private[streaming] trait BlockGeneratorListener { * thread, that is not synchronized with any other callbacks. Hence it is okay to do long * blocking operation in this callback. */ - def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]) + def onPushBlock(blockId: StreamBlockId, arrayBuffer: ArrayBuffer[_]): Unit /** * Called when an error has occurred in the BlockGenerator. Can be called form many places * so better to not do any long block operation in this callback. */ - def onError(message: String, throwable: Throwable) + def onError(message: String, throwable: Throwable): Unit } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala index 85350ff658..7aea1c9b64 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceivedBlockHandler.scala @@ -48,7 +48,7 @@ private[streaming] trait ReceivedBlockHandler { def storeBlock(blockId: StreamBlockId, receivedBlock: ReceivedBlock): ReceivedBlockStoreResult /** Cleanup old blocks older than the given threshold time */ - def cleanupOldBlocks(threshTime: Long) + def cleanupOldBlocks(threshTime: Long): Unit } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala index 3376cd557d..5157ca62dc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/Receiver.scala @@ -99,13 +99,13 @@ abstract class Receiver[T](val storageLevel: StorageLevel) extends Serializable * (iii) `restart(...)` can be called to restart the receiver. This will call `onStop()` * immediately, and then `onStart()` after a delay. */ - def onStart() + def onStart(): Unit /** * This method is called by the system when the receiver is stopped. All resources * (threads, buffers, etc.) set up in `onStart()` must be cleaned up in this method. */ - def onStop() + def onStop(): Unit /** Override this to specify a preferred location (hostname). */ def preferredLocation: Option[String] = None diff --git a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala index e0fe8d2206..42fc84c19b 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/receiver/ReceiverSupervisor.scala @@ -70,28 +70,28 @@ private[streaming] abstract class ReceiverSupervisor( @volatile private[streaming] var receiverState = Initialized /** Push a single data item to backend data store. */ - def pushSingle(data: Any) + def pushSingle(data: Any): Unit /** Store the bytes of received data as a data block into Spark's memory. */ def pushBytes( bytes: ByteBuffer, optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** Store a iterator of received data as a data block into Spark's memory. */ def pushIterator( iterator: Iterator[_], optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** Store an ArrayBuffer of received data as a data block into Spark's memory. */ def pushArrayBuffer( arrayBuffer: ArrayBuffer[_], optionalMetadata: Option[Any], optionalBlockId: Option[StreamBlockId] - ) + ): Unit /** * Create a custom [[BlockGenerator]] that the receiver implementation can directly control @@ -103,7 +103,7 @@ private[streaming] abstract class ReceiverSupervisor( def createBlockGenerator(blockGeneratorListener: BlockGeneratorListener): BlockGenerator /** Report errors. */ - def reportError(message: String, throwable: Throwable) + def reportError(message: String, throwable: Throwable): Unit /** * Called when supervisor is started. -- cgit v1.2.3 From 28efdd3fd789fa2ebed5be03b36ca0f682e37669 Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 14 Apr 2016 11:08:08 -0700 Subject: [SPARK-14592][SQL] Native support for CREATE TABLE LIKE DDL command ## What changes were proposed in this pull request? JIRA: https://issues.apache.org/jira/browse/SPARK-14592 This patch adds native support for DDL command `CREATE TABLE LIKE`. The SQL syntax is like: CREATE TABLE table_name LIKE existing_table CREATE TABLE IF NOT EXISTS table_name LIKE existing_table ## How was this patch tested? `HiveDDLCommandSuite`. `HiveQuerySuite` already tests `CREATE TABLE LIKE`. Author: Liang-Chi Hsieh This patch had conflicts when merged, resolved by Committer: Andrew Or Closes #12362 from viirya/create-table-like. --- .../apache/spark/sql/catalyst/parser/SqlBase.g4 | 7 ++-- .../spark/sql/execution/command/tables.scala | 40 ++++++++++++++++++++-- .../hive/execution/HiveCompatibilitySuite.scala | 4 ++- .../spark/sql/hive/execution/HiveSqlParser.scala | 13 ++++++- .../spark/sql/hive/HiveDDLCommandSuite.scala | 24 ++++++++++++- 5 files changed, 79 insertions(+), 9 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index a937ad1eb7..9cf2dd257e 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -55,6 +55,8 @@ statement rowFormat? createFileFormat? locationSpec? (TBLPROPERTIES tablePropertyList)? (AS? query)? #createTable + | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier + LIKE source=tableIdentifier #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq?)? #analyze | ALTER (TABLE | VIEW) from=tableIdentifier @@ -136,10 +138,7 @@ statement ; hiveNativeCommands - : createTableHeader LIKE tableIdentifier - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? - | DELETE FROM tableIdentifier (WHERE booleanExpression)? + : DELETE FROM tableIdentifier (WHERE booleanExpression)? | TRUNCATE TABLE tableIdentifier partitionSpec? (COLUMNS identifierList)? | SHOW COLUMNS (FROM | IN) tableIdentifier ((FROM|IN) identifier)? diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index e315598daa..0b41985174 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -17,9 +17,45 @@ package org.apache.spark.sql.execution.command -import org.apache.spark.sql.{Row, SQLContext} +import org.apache.spark.sql.{AnalysisException, Row, SQLContext} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogTableType} + +/** + * A command to create a table with the same definition of the given existing table. + * + * The syntax of using this command in SQL is: + * {{{ + * CREATE TABLE [IF NOT EXISTS] [db_name.]table_name + * LIKE [other_db_name.]existing_table_name + * }}} + */ +case class CreateTableLike( + targetTable: TableIdentifier, + sourceTable: TableIdentifier, + ifNotExists: Boolean) extends RunnableCommand { + + override def run(sqlContext: SQLContext): Seq[Row] = { + val catalog = sqlContext.sessionState.catalog + if (!catalog.tableExists(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE does not exist: '$sourceTable'") + } + if (catalog.isTemporaryTable(sourceTable)) { + throw new AnalysisException( + s"Source table in CREATE TABLE LIKE cannot be temporary: '$sourceTable'") + } + + val tableToCreate = catalog.getTableMetadata(sourceTable).copy( + identifier = targetTable, + tableType = CatalogTableType.MANAGED_TABLE, + createTime = System.currentTimeMillis, + lastAccessTime = -1).withNewStorage(locationUri = None) + + catalog.createTable(tableToCreate, ifNotExists) + Seq.empty[Row] + } +} // TODO: move the rest of the table commands from ddl.scala to this file diff --git a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala index a45d180464..989e68aebe 100644 --- a/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala +++ b/sql/hive/compatibility/src/test/scala/org/apache/spark/sql/hive/execution/HiveCompatibilitySuite.scala @@ -416,6 +416,9 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "skewjoinopt18", "skewjoinopt9", + // This test tries to create a table like with TBLPROPERTIES clause, which we don't support. + "create_like_tbl_props", + // Index commands are not supported "drop_index", "drop_index_removes_partition_dirs", @@ -537,7 +540,6 @@ class HiveCompatibilitySuite extends HiveQueryFileTest with BeforeAndAfter { "count", "cp_mj_rc", "create_insert_outputformat", - "create_like_tbl_props", "create_nested_type", "create_struct_table", "create_view_translate", diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala index 8c707079a1..a97b65e27b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveSqlParser.scala @@ -31,8 +31,10 @@ import org.apache.spark.sql.catalyst.parser._ import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkSqlAstBuilder -import org.apache.spark.sql.execution.command.CreateTable +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} import org.apache.spark.sql.hive.{CreateTableAsSelect => CTAS, CreateViewAsSelect => CreateView, HiveSerDe} +import org.apache.spark.sql.hive.{HiveGenericUDTF, HiveMetastoreTypes, HiveSerDe} +import org.apache.spark.sql.hive.HiveShim.HiveFunctionWrapper /** * Concrete parser for HiveQl statements. @@ -231,6 +233,15 @@ class HiveSqlAstBuilder extends SparkSqlAstBuilder { } } + /** + * Create a [[CreateTableLike]] command. + */ + override def visitCreateTableLike(ctx: CreateTableLikeContext): LogicalPlan = withOrigin(ctx) { + val targetTable = visitTableIdentifier(ctx.target) + val sourceTable = visitTableIdentifier(ctx.source) + CreateTableLike(targetTable, sourceTable, ctx.EXISTS != null) + } + /** * Create or replace a view. This creates a [[CreateViewAsSelect]] command. * diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala index b87f035cd1..110c6d19d8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveDDLCommandSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Generate, ScriptTransformation} -import org.apache.spark.sql.execution.command.CreateTable +import org.apache.spark.sql.execution.command.{CreateTable, CreateTableLike} import org.apache.spark.sql.hive.execution.{HiveNativeCommand, HiveSqlParser} class HiveDDLCommandSuite extends PlanTest { @@ -557,4 +557,26 @@ class HiveDDLCommandSuite extends PlanTest { assertUnsupported("MSCK REPAIR TABLE tab1") } + test("create table like") { + val v1 = "CREATE TABLE table1 LIKE table2" + val (target, source, exists) = parser.parsePlan(v1).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists == false) + assert(target.database.isEmpty) + assert(target.table == "table1") + assert(source.database.isEmpty) + assert(source.table == "table2") + + val v2 = "CREATE TABLE IF NOT EXISTS table1 LIKE table2" + val (target2, source2, exists2) = parser.parsePlan(v2).collect { + case CreateTableLike(t, s, allowExisting) => (t, s, allowExisting) + }.head + assert(exists2) + assert(target2.database.isEmpty) + assert(target2.table == "table1") + assert(source2.database.isEmpty) + assert(source2.table == "table2") + } + } -- cgit v1.2.3 From d7e124edfe2578ecdf8e816a4dda3ce430a09172 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Thu, 14 Apr 2016 13:34:29 -0700 Subject: [SPARK-14545][SQL] Improve `LikeSimplification` by adding `a%b` rule ## What changes were proposed in this pull request? Current `LikeSimplification` handles the following four rules. - 'a%' => expr.StartsWith("a") - '%b' => expr.EndsWith("b") - '%a%' => expr.Contains("a") - 'a' => EqualTo("a") This PR adds the following rule. - 'a%b' => expr.Length() >= 2 && expr.StartsWith("a") && expr.EndsWith("b") Here, 2 is statically calculated from "a".size + "b".size. **Before** ``` scala> sql("select a from (select explode(array('abc','adc')) a) T where a like 'a%c'").explain() == Physical Plan == WholeStageCodegen : +- Filter a#5 LIKE a%c : +- INPUT +- Generate explode([abc,adc]), false, false, [a#5] +- Scan OneRowRelation[] ``` **After** ``` scala> sql("select a from (select explode(array('abc','adc')) a) T where a like 'a%c'").explain() == Physical Plan == WholeStageCodegen : +- Filter ((length(a#5) >= 2) && (StartsWith(a#5, a) && EndsWith(a#5, c))) : +- INPUT +- Generate explode([abc,adc]), false, false, [a#5] +- Scan OneRowRelation[] ``` ## How was this patch tested? Pass the Jenkins tests (including new testcase). Author: Dongjoon Hyun Closes #12312 from dongjoon-hyun/SPARK-14545. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 28 +++++++++++++--------- .../optimizer/LikeSimplificationSuite.scala | 14 +++++++++++ 2 files changed, 31 insertions(+), 11 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index aeb1842677..f5172b213a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -517,22 +517,28 @@ object LikeSimplification extends Rule[LogicalPlan] { // Cases like "something\%" are not optimized, but this does not affect correctness. private val startsWith = "([^_%]+)%".r private val endsWith = "%([^_%]+)".r + private val startsAndEndsWith = "([^_%]+)%([^_%]+)".r private val contains = "%([^_%]+)%".r private val equalTo = "([^_%]*)".r def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case Like(l, Literal(utf, StringType)) => - utf.toString match { - case startsWith(pattern) if !pattern.endsWith("\\") => - StartsWith(l, Literal(pattern)) - case endsWith(pattern) => - EndsWith(l, Literal(pattern)) - case contains(pattern) if !pattern.endsWith("\\") => - Contains(l, Literal(pattern)) - case equalTo(pattern) => - EqualTo(l, Literal(pattern)) + case Like(input, Literal(pattern, StringType)) => + pattern.toString match { + case startsWith(prefix) if !prefix.endsWith("\\") => + StartsWith(input, Literal(prefix)) + case endsWith(postfix) => + EndsWith(input, Literal(postfix)) + // 'a%a' pattern is basically same with 'a%' && '%a'. + // However, the additional `Length` condition is required to prevent 'a' match 'a%a'. + case startsAndEndsWith(prefix, postfix) if !prefix.endsWith("\\") => + And(GreaterThanOrEqual(Length(input), Literal(prefix.size + postfix.size)), + And(StartsWith(input, Literal(prefix)), EndsWith(input, Literal(postfix)))) + case contains(infix) if !infix.endsWith("\\") => + Contains(input, Literal(infix)) + case equalTo(str) => + EqualTo(input, Literal(str)) case _ => - Like(l, Literal.create(utf, StringType)) + Like(input, Literal.create(pattern, StringType)) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala index 741bc113cf..fdde89d079 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/LikeSimplificationSuite.scala @@ -61,6 +61,20 @@ class LikeSimplificationSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("simplify Like into startsWith and EndsWith") { + val originalQuery = + testRelation + .where(('a like "abc\\%def") || ('a like "abc%def")) + + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .where(('a like "abc\\%def") || + (Length('a) >= 6 && (StartsWith('a, "abc") && EndsWith('a, "def")))) + .analyze + + comparePlans(optimized, correctAnswer) + } + test("simplify Like into Contains") { val originalQuery = testRelation -- cgit v1.2.3