aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-02-22 22:17:56 -0800
committerReynold Xin <rxin@databricks.com>2016-02-22 22:17:56 -0800
commit9dd5399d78d74a8ba2326db25704ba7cb7aa353d (patch)
tree5aa9be89454d316b9142f497b792e325300e0a70 /sql/hive
parent5d80fac58f837933b5359a8057676f45539e53af (diff)
downloadspark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.tar.gz
spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.tar.bz2
spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.zip
[SPARK-12723][SQL] Comprehensive Verification and Fixing of SQL Generation Support for Expressions
#### What changes were proposed in this pull request? Ensure that all built-in expressions can be mapped to its SQL representation if there is one (e.g. ScalaUDF doesn't have a SQL representation). The function lists are from the expression list in `FunctionRegistry`. window functions, grouping sets functions (`cube`, `rollup`, `grouping`, `grouping_id`), generator functions (`explode` and `json_tuple`) are covered by separate JIRA and PRs. Thus, this PR does not cover them. Except these functions, all the built-in expressions are covered. For details, see the list in `ExpressionToSQLSuite`. Fixed a few issues. For example, the `prettyName` of `approx_count_distinct` is not right. The `sql` of `hash` function is not right, since the `hash` function does not accept `seed`. Additionally, also correct the order of expressions in `FunctionRegistry` so that people are easier to find which functions are missing. cc liancheng #### How was the this patch tested? Added two test cases in LogicalPlanToSQLSuite for covering `not like` and `not in`. Added a new test suite `ExpressionToSQLSuite` to cover the functions: 1. misc non-aggregate functions + complex type creators + null expressions 2. math functions 3. aggregate functions 4. string functions 5. date time functions + calendar interval 6. collection functions 7. misc functions Author: gatorsmile <gatorsmile@gmail.com> Closes #11314 from gatorsmile/expressionToSQL.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala271
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala8
2 files changed, 279 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
new file mode 100644
index 0000000000..d68c602a88
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SQLTestUtils
+
+class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+
+ val bytes = Array[Byte](1, 2, 3, 4)
+ Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0")
+
+ sqlContext
+ .range(10)
+ .select('id as 'key, concat(lit("val_"), 'id) as 'value)
+ .write
+ .saveAsTable("t1")
+
+ sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
+ }
+
+ override protected def afterAll(): Unit = {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+ }
+
+ private def checkSqlGeneration(hiveQl: String): Unit = {
+ val df = sql(hiveQl)
+
+ val convertedSQL = try new SQLBuilder(df).toSQL catch {
+ case NonFatal(e) =>
+ fail(
+ s"""Cannot convert the following HiveQL query plan back to SQL query string:
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin)
+ }
+
+ try {
+ checkAnswer(sql(convertedSQL), df)
+ } catch { case cause: Throwable =>
+ fail(
+ s"""Failed to execute converted SQL string or got wrong answer:
+ |
+ |# Converted SQL query string:
+ |$convertedSQL
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin,
+ cause)
+ }
+ }
+
+ test("misc non-aggregate functions") {
+ checkSqlGeneration("SELECT abs(15), abs(-15)")
+ checkSqlGeneration("SELECT array(1,2,3)")
+ checkSqlGeneration("SELECT coalesce(null, 1, 2)")
+ // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators
+ // checkSqlGeneration("SELECT explode(array(1,2,3))")
+ checkSqlGeneration("SELECT greatest(1,null,3)")
+ checkSqlGeneration("SELECT if(1==2, 'yes', 'no')")
+ checkSqlGeneration("SELECT isnan(15), isnan('invalid')")
+ checkSqlGeneration("SELECT isnull(null), isnull('a')")
+ checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
+ checkSqlGeneration("SELECT least(1,null,3)")
+ checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
+ checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
+ checkSqlGeneration("SELECT nvl(null, 1, 2)")
+ checkSqlGeneration("SELECT rand(1)")
+ checkSqlGeneration("SELECT randn(3)")
+ checkSqlGeneration("SELECT struct(1,2,3)")
+ }
+
+ test("math functions") {
+ checkSqlGeneration("SELECT acos(-1)")
+ checkSqlGeneration("SELECT asin(-1)")
+ checkSqlGeneration("SELECT atan(1)")
+ checkSqlGeneration("SELECT atan2(1, 1)")
+ checkSqlGeneration("SELECT bin(10)")
+ checkSqlGeneration("SELECT cbrt(1000.0)")
+ checkSqlGeneration("SELECT ceil(2.333)")
+ checkSqlGeneration("SELECT ceiling(2.333)")
+ checkSqlGeneration("SELECT cos(1.0)")
+ checkSqlGeneration("SELECT cosh(1.0)")
+ checkSqlGeneration("SELECT conv(15, 10, 16)")
+ checkSqlGeneration("SELECT degrees(pi())")
+ checkSqlGeneration("SELECT e()")
+ checkSqlGeneration("SELECT exp(1.0)")
+ checkSqlGeneration("SELECT expm1(1.0)")
+ checkSqlGeneration("SELECT floor(-2.333)")
+ checkSqlGeneration("SELECT factorial(5)")
+ checkSqlGeneration("SELECT hex(10)")
+ checkSqlGeneration("SELECT hypot(3, 4)")
+ checkSqlGeneration("SELECT log(10.0)")
+ checkSqlGeneration("SELECT log10(1000.0)")
+ checkSqlGeneration("SELECT log1p(0.0)")
+ checkSqlGeneration("SELECT log2(8.0)")
+ checkSqlGeneration("SELECT ln(10.0)")
+ checkSqlGeneration("SELECT negative(-1)")
+ checkSqlGeneration("SELECT pi()")
+ checkSqlGeneration("SELECT pmod(3, 2)")
+ checkSqlGeneration("SELECT positive(3)")
+ checkSqlGeneration("SELECT pow(2, 3)")
+ checkSqlGeneration("SELECT power(2, 3)")
+ checkSqlGeneration("SELECT radians(180.0)")
+ checkSqlGeneration("SELECT rint(1.63)")
+ checkSqlGeneration("SELECT round(31.415, -1)")
+ checkSqlGeneration("SELECT shiftleft(2, 3)")
+ checkSqlGeneration("SELECT shiftright(16, 3)")
+ checkSqlGeneration("SELECT shiftrightunsigned(16, 3)")
+ checkSqlGeneration("SELECT sign(-2.63)")
+ checkSqlGeneration("SELECT signum(-2.63)")
+ checkSqlGeneration("SELECT sin(1.0)")
+ checkSqlGeneration("SELECT sinh(1.0)")
+ checkSqlGeneration("SELECT sqrt(100.0)")
+ checkSqlGeneration("SELECT tan(1.0)")
+ checkSqlGeneration("SELECT tanh(1.0)")
+ }
+
+ test("aggregate functions") {
+ checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT covar_pop(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT covar_samp(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT first(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT first_value(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT kurtosis(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT last(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT last_value(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev_samp(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT sum(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT variance(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT var_pop(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT var_samp(value) FROM t1 GROUP BY key")
+ }
+
+ test("string functions") {
+ checkSqlGeneration("SELECT ascii('SparkSql')")
+ checkSqlGeneration("SELECT base64(a) FROM t0")
+ checkSqlGeneration("SELECT concat('This ', 'is ', 'a ', 'test')")
+ checkSqlGeneration("SELECT concat_ws(' ', 'This', 'is', 'a', 'test')")
+ checkSqlGeneration("SELECT decode(a, 'UTF-8') FROM t0")
+ checkSqlGeneration("SELECT encode('SparkSql', 'UTF-8')")
+ checkSqlGeneration("SELECT find_in_set('ab', 'abc,b,ab,c,def')")
+ checkSqlGeneration("SELECT format_number(1234567.890, 2)")
+ checkSqlGeneration("SELECT format_string('aa%d%s',123, 'cc')")
+ checkSqlGeneration("SELECT get_json_object('{\"a\":\"bc\"}','$.a')")
+ checkSqlGeneration("SELECT initcap('This is a test')")
+ checkSqlGeneration("SELECT instr('This is a test', 'is')")
+ checkSqlGeneration("SELECT lcase('SparkSql')")
+ checkSqlGeneration("SELECT length('This is a test')")
+ checkSqlGeneration("SELECT levenshtein('This is a test', 'Another test')")
+ checkSqlGeneration("SELECT lower('SparkSql')")
+ checkSqlGeneration("SELECT locate('is', 'This is a test', 3)")
+ checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')")
+ checkSqlGeneration("SELECT ltrim(' SparkSql ')")
+ // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators
+ // checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')")
+ checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')")
+ checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)")
+ checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')")
+ checkSqlGeneration("SELECT repeat('SparkSql', 3)")
+ checkSqlGeneration("SELECT reverse('SparkSql')")
+ checkSqlGeneration("SELECT rpad('SparkSql', 16, ' is Cool')")
+ checkSqlGeneration("SELECT rtrim(' SparkSql ')")
+ checkSqlGeneration("SELECT soundex('SparkSql')")
+ checkSqlGeneration("SELECT space(2)")
+ checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')")
+ checkSqlGeneration("SELECT space(2)")
+ checkSqlGeneration("SELECT substr('This is a test', 'is')")
+ checkSqlGeneration("SELECT substring('This is a test', 'is')")
+ checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)")
+ checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')")
+ checkSqlGeneration("SELECT trim(' SparkSql ')")
+ checkSqlGeneration("SELECT ucase('SparkSql')")
+ checkSqlGeneration("SELECT unbase64('SparkSql')")
+ checkSqlGeneration("SELECT unhex(41)")
+ checkSqlGeneration("SELECT upper('SparkSql')")
+ }
+
+ test("datetime functions") {
+ checkSqlGeneration("SELECT add_months('2001-03-31', 1)")
+ checkSqlGeneration("SELECT count(current_date())")
+ checkSqlGeneration("SELECT count(current_timestamp())")
+ checkSqlGeneration("SELECT datediff('2001-01-02', '2001-01-01')")
+ checkSqlGeneration("SELECT date_add('2001-01-02', 1)")
+ checkSqlGeneration("SELECT date_format('2001-05-02', 'yyyy-dd')")
+ checkSqlGeneration("SELECT date_sub('2001-01-02', 1)")
+ checkSqlGeneration("SELECT day('2001-05-02')")
+ checkSqlGeneration("SELECT dayofyear('2001-05-02')")
+ checkSqlGeneration("SELECT dayofmonth('2001-05-02')")
+ checkSqlGeneration("SELECT from_unixtime(1000, 'yyyy-MM-dd HH:mm:ss')")
+ checkSqlGeneration("SELECT from_utc_timestamp('2015-07-24 00:00:00', 'PST')")
+ checkSqlGeneration("SELECT hour('11:35:55')")
+ checkSqlGeneration("SELECT last_day('2001-01-01')")
+ checkSqlGeneration("SELECT minute('11:35:55')")
+ checkSqlGeneration("SELECT month('2001-05-02')")
+ checkSqlGeneration("SELECT months_between('2001-10-30 10:30:00', '1996-10-30')")
+ checkSqlGeneration("SELECT next_day('2001-05-02', 'TU')")
+ checkSqlGeneration("SELECT count(now())")
+ checkSqlGeneration("SELECT quarter('2001-05-02')")
+ checkSqlGeneration("SELECT second('11:35:55')")
+ checkSqlGeneration("SELECT to_date('2001-10-30 10:30:00')")
+ checkSqlGeneration("SELECT to_unix_timestamp('2015-07-24 00:00:00', 'yyyy-MM-dd HH:mm:ss')")
+ checkSqlGeneration("SELECT to_utc_timestamp('2015-07-24 00:00:00', 'PST')")
+ checkSqlGeneration("SELECT trunc('2001-10-30 10:30:00', 'YEAR')")
+ checkSqlGeneration("SELECT unix_timestamp('2001-10-30 10:30:00')")
+ checkSqlGeneration("SELECT weekofyear('2001-05-02')")
+ checkSqlGeneration("SELECT year('2001-05-02')")
+
+ checkSqlGeneration("SELECT interval 3 years - 3 month 7 week 123 microseconds as i")
+ }
+
+ test("collection functions") {
+ checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)")
+ checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))")
+ checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))")
+ }
+
+ test("misc functions") {
+ checkSqlGeneration("SELECT crc32('Spark')")
+ checkSqlGeneration("SELECT md5('Spark')")
+ checkSqlGeneration("SELECT hash('Spark')")
+ checkSqlGeneration("SELECT sha('Spark')")
+ checkSqlGeneration("SELECT sha1('Spark')")
+ checkSqlGeneration("SELECT sha2('Spark', 0)")
+ checkSqlGeneration("SELECT spark_partition_id()")
+ checkSqlGeneration("SELECT input_file_name()")
+ checkSqlGeneration("SELECT monotonically_increasing_id()")
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index 5255b150aa..b162adf215 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -86,6 +86,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)")
}
+ test("not in") {
+ checkHiveQl("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)")
+ }
+
+ test("not like") {
+ checkHiveQl("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'")
+ }
+
test("aggregate function in having clause") {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0")
}