diff options
author | gatorsmile <gatorsmile@gmail.com> | 2016-02-22 22:17:56 -0800 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-02-22 22:17:56 -0800 |
commit | 9dd5399d78d74a8ba2326db25704ba7cb7aa353d (patch) | |
tree | 5aa9be89454d316b9142f497b792e325300e0a70 | |
parent | 5d80fac58f837933b5359a8057676f45539e53af (diff) | |
download | spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.tar.gz spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.tar.bz2 spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.zip |
[SPARK-12723][SQL] Comprehensive Verification and Fixing of SQL Generation Support for Expressions
#### What changes were proposed in this pull request?
Ensure that all built-in expressions can be mapped to its SQL representation if there is one (e.g. ScalaUDF doesn't have a SQL representation). The function lists are from the expression list in `FunctionRegistry`.
window functions, grouping sets functions (`cube`, `rollup`, `grouping`, `grouping_id`), generator functions (`explode` and `json_tuple`) are covered by separate JIRA and PRs. Thus, this PR does not cover them. Except these functions, all the built-in expressions are covered. For details, see the list in `ExpressionToSQLSuite`.
Fixed a few issues. For example, the `prettyName` of `approx_count_distinct` is not right. The `sql` of `hash` function is not right, since the `hash` function does not accept `seed`.
Additionally, also correct the order of expressions in `FunctionRegistry` so that people are easier to find which functions are missing.
cc liancheng
#### How was the this patch tested?
Added two test cases in LogicalPlanToSQLSuite for covering `not like` and `not in`.
Added a new test suite `ExpressionToSQLSuite` to cover the functions:
1. misc non-aggregate functions + complex type creators + null expressions
2. math functions
3. aggregate functions
4. string functions
5. date time functions + calendar interval
6. collection functions
7. misc functions
Author: gatorsmile <gatorsmile@gmail.com>
Closes #11314 from gatorsmile/expressionToSQL.
8 files changed, 306 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 1be97c7b81..26bb96eb08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -113,6 +113,7 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression + // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map( // misc non-aggregate functions expression[Abs]("abs"), @@ -125,13 +126,12 @@ object FunctionRegistry { expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), expression[Least]("least"), + expression[CreateNamedStruct]("named_struct"), + expression[NaNvl]("nanvl"), expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), - expression[CreateNamedStruct]("named_struct"), - expression[Sqrt]("sqrt"), - expression[NaNvl]("nanvl"), // math functions expression[Acos]("acos"), @@ -145,24 +145,26 @@ object FunctionRegistry { expression[Cos]("cos"), expression[Cosh]("cosh"), expression[Conv]("conv"), + expression[ToDegrees]("degrees"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Factorial]("factorial"), - expression[Hypot]("hypot"), expression[Hex]("hex"), + expression[Hypot]("hypot"), expression[Logarithm]("log"), - expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), expression[Log2]("log2"), + expression[Log]("ln"), expression[UnaryMinus]("negative"), expression[Pi]("pi"), - expression[Pow]("pow"), - expression[Pow]("power"), expression[Pmod]("pmod"), expression[UnaryPositive]("positive"), + expression[Pow]("pow"), + expression[Pow]("power"), + expression[ToRadians]("radians"), expression[Rint]("rint"), expression[Round]("round"), expression[ShiftLeft]("shiftleft"), @@ -172,10 +174,9 @@ object FunctionRegistry { expression[Signum]("signum"), expression[Sin]("sin"), expression[Sinh]("sinh"), + expression[Sqrt]("sqrt"), expression[Tan]("tan"), expression[Tanh]("tanh"), - expression[ToDegrees]("degrees"), - expression[ToRadians]("radians"), // aggregate functions expression[HyperLogLogPlusPlus]("approx_count_distinct"), @@ -186,11 +187,13 @@ object FunctionRegistry { expression[CovSample]("covar_samp"), expression[First]("first"), expression[First]("first_value"), + expression[Kurtosis]("kurtosis"), expression[Last]("last"), expression[Last]("last_value"), expression[Max]("max"), expression[Average]("mean"), expression[Min]("min"), + expression[Skewness]("skewness"), expression[StddevSamp]("stddev"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), @@ -198,36 +201,34 @@ object FunctionRegistry { expression[VarianceSamp]("variance"), expression[VariancePop]("var_pop"), expression[VarianceSamp]("var_samp"), - expression[Skewness]("skewness"), - expression[Kurtosis]("kurtosis"), // string functions expression[Ascii]("ascii"), expression[Base64]("base64"), expression[Concat]("concat"), expression[ConcatWs]("concat_ws"), - expression[Encode]("encode"), expression[Decode]("decode"), + expression[Encode]("encode"), expression[FindInSet]("find_in_set"), expression[FormatNumber]("format_number"), + expression[FormatString]("format_string"), expression[GetJsonObject]("get_json_object"), expression[InitCap]("initcap"), - expression[JsonTuple]("json_tuple"), + expression[StringInstr]("instr"), expression[Lower]("lcase"), - expression[Lower]("lower"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), - expression[RegExpExtract]("regexp_extract"), - expression[RegExpReplace]("regexp_replace"), - expression[StringInstr]("instr"), + expression[Lower]("lower"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), expression[StringTrimLeft]("ltrim"), - expression[FormatString]("format_string"), + expression[JsonTuple]("json_tuple"), expression[FormatString]("printf"), - expression[StringRPad]("rpad"), + expression[RegExpExtract]("regexp_extract"), + expression[RegExpReplace]("regexp_replace"), expression[StringRepeat]("repeat"), expression[StringReverse]("reverse"), + expression[StringRPad]("rpad"), expression[StringTrimRight]("rtrim"), expression[SoundEx]("soundex"), expression[StringSpace]("space"), @@ -237,8 +238,8 @@ object FunctionRegistry { expression[SubstringIndex]("substring_index"), expression[StringTranslate]("translate"), expression[StringTrim]("trim"), - expression[UnBase64]("unbase64"), expression[Upper]("ucase"), + expression[UnBase64]("unbase64"), expression[Unhex]("unhex"), expression[Upper]("upper"), @@ -246,7 +247,6 @@ object FunctionRegistry { expression[AddMonths]("add_months"), expression[CurrentDate]("current_date"), expression[CurrentTimestamp]("current_timestamp"), - expression[CurrentTimestamp]("now"), expression[DateDiff]("datediff"), expression[DateAdd]("date_add"), expression[DateFormatClass]("date_format"), @@ -262,6 +262,7 @@ object FunctionRegistry { expression[Month]("month"), expression[MonthsBetween]("months_between"), expression[NextDay]("next_day"), + expression[CurrentTimestamp]("now"), expression[Quarter]("quarter"), expression[Second]("second"), expression[ToDate]("to_date"), @@ -273,9 +274,9 @@ object FunctionRegistry { expression[Year]("year"), // collection functions + expression[ArrayContains]("array_contains"), expression[Size]("size"), expression[SortArray]("sort_array"), - expression[ArrayContains]("array_contains"), // misc functions expression[Crc32]("crc32"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index c49c601c30..dbd0acf06c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -35,7 +35,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic { override def dataType: DataType = StringType - override val prettyName = "INPUT_FILE_NAME" + override def prettyName: String = "input_file_name" override protected def initInternal(): Unit = {} @@ -48,6 +48,4 @@ case class InputFileName() extends LeafExpression with Nondeterministic { s"final ${ctx.javaType(dataType)} ${ev.value} = " + "org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();" } - - override def sql: String = prettyName } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala index a474017221..32bae13360 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala @@ -68,6 +68,8 @@ case class HyperLogLogPlusPlus( inputAggBufferOffset = 0) } + override def prettyName: String = "approx_count_distinct" + override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate = copy(mutableAggBufferOffset = newMutableAggBufferOffset) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 5af234609d..ed812e0679 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -95,8 +95,6 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes } protected override def nullSafeEval(input: Any): Any = numeric.abs(input) - - override def sql: String = s"$prettyName(${child.sql})" } abstract class BinaryArithmetic extends BinaryOperator { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index dcbb594afd..33bd3f2095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -234,8 +234,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression override def prettyName: String = "hash" - override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)" - override def eval(input: InternalRow): Any = { var hash = seed var i = 0 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 510894afac..b9873d38a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1972,7 +1972,7 @@ object functions extends LegacyFunctions { def crc32(e: Column): Column = withExpr { Crc32(e.expr) } /** - * Calculates the hash code of given columns, and returns the result as a int column. + * Calculates the hash code of given columns, and returns the result as an int column. * * @group misc_funcs * @since 2.0 diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala new file mode 100644 index 0000000000..d68c602a88 --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.util.control.NonFatal + +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.test.SQLTestUtils + +class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils { + import testImplicits._ + + protected override def beforeAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + + val bytes = Array[Byte](1, 2, 3, 4) + Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0") + + sqlContext + .range(10) + .select('id as 'key, concat(lit("val_"), 'id) as 'value) + .write + .saveAsTable("t1") + + sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2") + } + + override protected def afterAll(): Unit = { + sql("DROP TABLE IF EXISTS t0") + sql("DROP TABLE IF EXISTS t1") + sql("DROP TABLE IF EXISTS t2") + } + + private def checkSqlGeneration(hiveQl: String): Unit = { + val df = sql(hiveQl) + + val convertedSQL = try new SQLBuilder(df).toSQL catch { + case NonFatal(e) => + fail( + s"""Cannot convert the following HiveQL query plan back to SQL query string: + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin) + } + + try { + checkAnswer(sql(convertedSQL), df) + } catch { case cause: Throwable => + fail( + s"""Failed to execute converted SQL string or got wrong answer: + | + |# Converted SQL query string: + |$convertedSQL + | + |# Original HiveQL query string: + |$hiveQl + | + |# Resolved query plan: + |${df.queryExecution.analyzed.treeString} + """.stripMargin, + cause) + } + } + + test("misc non-aggregate functions") { + checkSqlGeneration("SELECT abs(15), abs(-15)") + checkSqlGeneration("SELECT array(1,2,3)") + checkSqlGeneration("SELECT coalesce(null, 1, 2)") + // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators + // checkSqlGeneration("SELECT explode(array(1,2,3))") + checkSqlGeneration("SELECT greatest(1,null,3)") + checkSqlGeneration("SELECT if(1==2, 'yes', 'no')") + checkSqlGeneration("SELECT isnan(15), isnan('invalid')") + checkSqlGeneration("SELECT isnull(null), isnull('a')") + checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')") + checkSqlGeneration("SELECT least(1,null,3)") + checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)") + checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2") + checkSqlGeneration("SELECT nvl(null, 1, 2)") + checkSqlGeneration("SELECT rand(1)") + checkSqlGeneration("SELECT randn(3)") + checkSqlGeneration("SELECT struct(1,2,3)") + } + + test("math functions") { + checkSqlGeneration("SELECT acos(-1)") + checkSqlGeneration("SELECT asin(-1)") + checkSqlGeneration("SELECT atan(1)") + checkSqlGeneration("SELECT atan2(1, 1)") + checkSqlGeneration("SELECT bin(10)") + checkSqlGeneration("SELECT cbrt(1000.0)") + checkSqlGeneration("SELECT ceil(2.333)") + checkSqlGeneration("SELECT ceiling(2.333)") + checkSqlGeneration("SELECT cos(1.0)") + checkSqlGeneration("SELECT cosh(1.0)") + checkSqlGeneration("SELECT conv(15, 10, 16)") + checkSqlGeneration("SELECT degrees(pi())") + checkSqlGeneration("SELECT e()") + checkSqlGeneration("SELECT exp(1.0)") + checkSqlGeneration("SELECT expm1(1.0)") + checkSqlGeneration("SELECT floor(-2.333)") + checkSqlGeneration("SELECT factorial(5)") + checkSqlGeneration("SELECT hex(10)") + checkSqlGeneration("SELECT hypot(3, 4)") + checkSqlGeneration("SELECT log(10.0)") + checkSqlGeneration("SELECT log10(1000.0)") + checkSqlGeneration("SELECT log1p(0.0)") + checkSqlGeneration("SELECT log2(8.0)") + checkSqlGeneration("SELECT ln(10.0)") + checkSqlGeneration("SELECT negative(-1)") + checkSqlGeneration("SELECT pi()") + checkSqlGeneration("SELECT pmod(3, 2)") + checkSqlGeneration("SELECT positive(3)") + checkSqlGeneration("SELECT pow(2, 3)") + checkSqlGeneration("SELECT power(2, 3)") + checkSqlGeneration("SELECT radians(180.0)") + checkSqlGeneration("SELECT rint(1.63)") + checkSqlGeneration("SELECT round(31.415, -1)") + checkSqlGeneration("SELECT shiftleft(2, 3)") + checkSqlGeneration("SELECT shiftright(16, 3)") + checkSqlGeneration("SELECT shiftrightunsigned(16, 3)") + checkSqlGeneration("SELECT sign(-2.63)") + checkSqlGeneration("SELECT signum(-2.63)") + checkSqlGeneration("SELECT sin(1.0)") + checkSqlGeneration("SELECT sinh(1.0)") + checkSqlGeneration("SELECT sqrt(100.0)") + checkSqlGeneration("SELECT tan(1.0)") + checkSqlGeneration("SELECT tanh(1.0)") + } + + test("aggregate functions") { + checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT covar_pop(value, key) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT covar_samp(value, key) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT first(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT first_value(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT kurtosis(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT last(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT last_value(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT stddev_samp(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT sum(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT variance(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT var_pop(value) FROM t1 GROUP BY key") + checkSqlGeneration("SELECT var_samp(value) FROM t1 GROUP BY key") + } + + test("string functions") { + checkSqlGeneration("SELECT ascii('SparkSql')") + checkSqlGeneration("SELECT base64(a) FROM t0") + checkSqlGeneration("SELECT concat('This ', 'is ', 'a ', 'test')") + checkSqlGeneration("SELECT concat_ws(' ', 'This', 'is', 'a', 'test')") + checkSqlGeneration("SELECT decode(a, 'UTF-8') FROM t0") + checkSqlGeneration("SELECT encode('SparkSql', 'UTF-8')") + checkSqlGeneration("SELECT find_in_set('ab', 'abc,b,ab,c,def')") + checkSqlGeneration("SELECT format_number(1234567.890, 2)") + checkSqlGeneration("SELECT format_string('aa%d%s',123, 'cc')") + checkSqlGeneration("SELECT get_json_object('{\"a\":\"bc\"}','$.a')") + checkSqlGeneration("SELECT initcap('This is a test')") + checkSqlGeneration("SELECT instr('This is a test', 'is')") + checkSqlGeneration("SELECT lcase('SparkSql')") + checkSqlGeneration("SELECT length('This is a test')") + checkSqlGeneration("SELECT levenshtein('This is a test', 'Another test')") + checkSqlGeneration("SELECT lower('SparkSql')") + checkSqlGeneration("SELECT locate('is', 'This is a test', 3)") + checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')") + checkSqlGeneration("SELECT ltrim(' SparkSql ')") + // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators + // checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')") + checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')") + checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)") + checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')") + checkSqlGeneration("SELECT repeat('SparkSql', 3)") + checkSqlGeneration("SELECT reverse('SparkSql')") + checkSqlGeneration("SELECT rpad('SparkSql', 16, ' is Cool')") + checkSqlGeneration("SELECT rtrim(' SparkSql ')") + checkSqlGeneration("SELECT soundex('SparkSql')") + checkSqlGeneration("SELECT space(2)") + checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')") + checkSqlGeneration("SELECT space(2)") + checkSqlGeneration("SELECT substr('This is a test', 'is')") + checkSqlGeneration("SELECT substring('This is a test', 'is')") + checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)") + checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')") + checkSqlGeneration("SELECT trim(' SparkSql ')") + checkSqlGeneration("SELECT ucase('SparkSql')") + checkSqlGeneration("SELECT unbase64('SparkSql')") + checkSqlGeneration("SELECT unhex(41)") + checkSqlGeneration("SELECT upper('SparkSql')") + } + + test("datetime functions") { + checkSqlGeneration("SELECT add_months('2001-03-31', 1)") + checkSqlGeneration("SELECT count(current_date())") + checkSqlGeneration("SELECT count(current_timestamp())") + checkSqlGeneration("SELECT datediff('2001-01-02', '2001-01-01')") + checkSqlGeneration("SELECT date_add('2001-01-02', 1)") + checkSqlGeneration("SELECT date_format('2001-05-02', 'yyyy-dd')") + checkSqlGeneration("SELECT date_sub('2001-01-02', 1)") + checkSqlGeneration("SELECT day('2001-05-02')") + checkSqlGeneration("SELECT dayofyear('2001-05-02')") + checkSqlGeneration("SELECT dayofmonth('2001-05-02')") + checkSqlGeneration("SELECT from_unixtime(1000, 'yyyy-MM-dd HH:mm:ss')") + checkSqlGeneration("SELECT from_utc_timestamp('2015-07-24 00:00:00', 'PST')") + checkSqlGeneration("SELECT hour('11:35:55')") + checkSqlGeneration("SELECT last_day('2001-01-01')") + checkSqlGeneration("SELECT minute('11:35:55')") + checkSqlGeneration("SELECT month('2001-05-02')") + checkSqlGeneration("SELECT months_between('2001-10-30 10:30:00', '1996-10-30')") + checkSqlGeneration("SELECT next_day('2001-05-02', 'TU')") + checkSqlGeneration("SELECT count(now())") + checkSqlGeneration("SELECT quarter('2001-05-02')") + checkSqlGeneration("SELECT second('11:35:55')") + checkSqlGeneration("SELECT to_date('2001-10-30 10:30:00')") + checkSqlGeneration("SELECT to_unix_timestamp('2015-07-24 00:00:00', 'yyyy-MM-dd HH:mm:ss')") + checkSqlGeneration("SELECT to_utc_timestamp('2015-07-24 00:00:00', 'PST')") + checkSqlGeneration("SELECT trunc('2001-10-30 10:30:00', 'YEAR')") + checkSqlGeneration("SELECT unix_timestamp('2001-10-30 10:30:00')") + checkSqlGeneration("SELECT weekofyear('2001-05-02')") + checkSqlGeneration("SELECT year('2001-05-02')") + + checkSqlGeneration("SELECT interval 3 years - 3 month 7 week 123 microseconds as i") + } + + test("collection functions") { + checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)") + checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))") + checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))") + } + + test("misc functions") { + checkSqlGeneration("SELECT crc32('Spark')") + checkSqlGeneration("SELECT md5('Spark')") + checkSqlGeneration("SELECT hash('Spark')") + checkSqlGeneration("SELECT sha('Spark')") + checkSqlGeneration("SELECT sha1('Spark')") + checkSqlGeneration("SELECT sha2('Spark', 0)") + checkSqlGeneration("SELECT spark_partition_id()") + checkSqlGeneration("SELECT input_file_name()") + checkSqlGeneration("SELECT monotonically_increasing_id()") + } +} diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala index 5255b150aa..b162adf215 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala @@ -86,6 +86,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils { checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)") } + test("not in") { + checkHiveQl("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)") + } + + test("not like") { + checkHiveQl("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'") + } + test("aggregate function in having clause") { checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0") } |