aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2016-02-22 22:17:56 -0800
committerReynold Xin <rxin@databricks.com>2016-02-22 22:17:56 -0800
commit9dd5399d78d74a8ba2326db25704ba7cb7aa353d (patch)
tree5aa9be89454d316b9142f497b792e325300e0a70 /sql
parent5d80fac58f837933b5359a8057676f45539e53af (diff)
downloadspark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.tar.gz
spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.tar.bz2
spark-9dd5399d78d74a8ba2326db25704ba7cb7aa353d.zip
[SPARK-12723][SQL] Comprehensive Verification and Fixing of SQL Generation Support for Expressions
#### What changes were proposed in this pull request? Ensure that all built-in expressions can be mapped to its SQL representation if there is one (e.g. ScalaUDF doesn't have a SQL representation). The function lists are from the expression list in `FunctionRegistry`. window functions, grouping sets functions (`cube`, `rollup`, `grouping`, `grouping_id`), generator functions (`explode` and `json_tuple`) are covered by separate JIRA and PRs. Thus, this PR does not cover them. Except these functions, all the built-in expressions are covered. For details, see the list in `ExpressionToSQLSuite`. Fixed a few issues. For example, the `prettyName` of `approx_count_distinct` is not right. The `sql` of `hash` function is not right, since the `hash` function does not accept `seed`. Additionally, also correct the order of expressions in `FunctionRegistry` so that people are easier to find which functions are missing. cc liancheng #### How was the this patch tested? Added two test cases in LogicalPlanToSQLSuite for covering `not like` and `not in`. Added a new test suite `ExpressionToSQLSuite` to cover the functions: 1. misc non-aggregate functions + complex type creators + null expressions 2. math functions 3. aggregate functions 4. string functions 5. date time functions + calendar interval 6. collection functions 7. misc functions Author: gatorsmile <gatorsmile@gmail.com> Closes #11314 from gatorsmile/expressionToSQL.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala45
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala271
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala8
8 files changed, 306 insertions, 30 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 1be97c7b81..26bb96eb08 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -113,6 +113,7 @@ object FunctionRegistry {
type FunctionBuilder = Seq[Expression] => Expression
+ // Note: Whenever we add a new entry here, make sure we also update ExpressionToSQLSuite
val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
// misc non-aggregate functions
expression[Abs]("abs"),
@@ -125,13 +126,12 @@ object FunctionRegistry {
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
expression[Least]("least"),
+ expression[CreateNamedStruct]("named_struct"),
+ expression[NaNvl]("nanvl"),
expression[Coalesce]("nvl"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
- expression[CreateNamedStruct]("named_struct"),
- expression[Sqrt]("sqrt"),
- expression[NaNvl]("nanvl"),
// math functions
expression[Acos]("acos"),
@@ -145,24 +145,26 @@ object FunctionRegistry {
expression[Cos]("cos"),
expression[Cosh]("cosh"),
expression[Conv]("conv"),
+ expression[ToDegrees]("degrees"),
expression[EulerNumber]("e"),
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expression[Factorial]("factorial"),
- expression[Hypot]("hypot"),
expression[Hex]("hex"),
+ expression[Hypot]("hypot"),
expression[Logarithm]("log"),
- expression[Log]("ln"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
expression[Log2]("log2"),
+ expression[Log]("ln"),
expression[UnaryMinus]("negative"),
expression[Pi]("pi"),
- expression[Pow]("pow"),
- expression[Pow]("power"),
expression[Pmod]("pmod"),
expression[UnaryPositive]("positive"),
+ expression[Pow]("pow"),
+ expression[Pow]("power"),
+ expression[ToRadians]("radians"),
expression[Rint]("rint"),
expression[Round]("round"),
expression[ShiftLeft]("shiftleft"),
@@ -172,10 +174,9 @@ object FunctionRegistry {
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Sinh]("sinh"),
+ expression[Sqrt]("sqrt"),
expression[Tan]("tan"),
expression[Tanh]("tanh"),
- expression[ToDegrees]("degrees"),
- expression[ToRadians]("radians"),
// aggregate functions
expression[HyperLogLogPlusPlus]("approx_count_distinct"),
@@ -186,11 +187,13 @@ object FunctionRegistry {
expression[CovSample]("covar_samp"),
expression[First]("first"),
expression[First]("first_value"),
+ expression[Kurtosis]("kurtosis"),
expression[Last]("last"),
expression[Last]("last_value"),
expression[Max]("max"),
expression[Average]("mean"),
expression[Min]("min"),
+ expression[Skewness]("skewness"),
expression[StddevSamp]("stddev"),
expression[StddevPop]("stddev_pop"),
expression[StddevSamp]("stddev_samp"),
@@ -198,36 +201,34 @@ object FunctionRegistry {
expression[VarianceSamp]("variance"),
expression[VariancePop]("var_pop"),
expression[VarianceSamp]("var_samp"),
- expression[Skewness]("skewness"),
- expression[Kurtosis]("kurtosis"),
// string functions
expression[Ascii]("ascii"),
expression[Base64]("base64"),
expression[Concat]("concat"),
expression[ConcatWs]("concat_ws"),
- expression[Encode]("encode"),
expression[Decode]("decode"),
+ expression[Encode]("encode"),
expression[FindInSet]("find_in_set"),
expression[FormatNumber]("format_number"),
+ expression[FormatString]("format_string"),
expression[GetJsonObject]("get_json_object"),
expression[InitCap]("initcap"),
- expression[JsonTuple]("json_tuple"),
+ expression[StringInstr]("instr"),
expression[Lower]("lcase"),
- expression[Lower]("lower"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
- expression[RegExpExtract]("regexp_extract"),
- expression[RegExpReplace]("regexp_replace"),
- expression[StringInstr]("instr"),
+ expression[Lower]("lower"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
expression[StringTrimLeft]("ltrim"),
- expression[FormatString]("format_string"),
+ expression[JsonTuple]("json_tuple"),
expression[FormatString]("printf"),
- expression[StringRPad]("rpad"),
+ expression[RegExpExtract]("regexp_extract"),
+ expression[RegExpReplace]("regexp_replace"),
expression[StringRepeat]("repeat"),
expression[StringReverse]("reverse"),
+ expression[StringRPad]("rpad"),
expression[StringTrimRight]("rtrim"),
expression[SoundEx]("soundex"),
expression[StringSpace]("space"),
@@ -237,8 +238,8 @@ object FunctionRegistry {
expression[SubstringIndex]("substring_index"),
expression[StringTranslate]("translate"),
expression[StringTrim]("trim"),
- expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
+ expression[UnBase64]("unbase64"),
expression[Unhex]("unhex"),
expression[Upper]("upper"),
@@ -246,7 +247,6 @@ object FunctionRegistry {
expression[AddMonths]("add_months"),
expression[CurrentDate]("current_date"),
expression[CurrentTimestamp]("current_timestamp"),
- expression[CurrentTimestamp]("now"),
expression[DateDiff]("datediff"),
expression[DateAdd]("date_add"),
expression[DateFormatClass]("date_format"),
@@ -262,6 +262,7 @@ object FunctionRegistry {
expression[Month]("month"),
expression[MonthsBetween]("months_between"),
expression[NextDay]("next_day"),
+ expression[CurrentTimestamp]("now"),
expression[Quarter]("quarter"),
expression[Second]("second"),
expression[ToDate]("to_date"),
@@ -273,9 +274,9 @@ object FunctionRegistry {
expression[Year]("year"),
// collection functions
+ expression[ArrayContains]("array_contains"),
expression[Size]("size"),
expression[SortArray]("sort_array"),
- expression[ArrayContains]("array_contains"),
// misc functions
expression[Crc32]("crc32"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
index c49c601c30..dbd0acf06c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala
@@ -35,7 +35,7 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
override def dataType: DataType = StringType
- override val prettyName = "INPUT_FILE_NAME"
+ override def prettyName: String = "input_file_name"
override protected def initInternal(): Unit = {}
@@ -48,6 +48,4 @@ case class InputFileName() extends LeafExpression with Nondeterministic {
s"final ${ctx.javaType(dataType)} ${ev.value} = " +
"org.apache.spark.rdd.SqlNewHadoopRDDState.getInputFileName();"
}
-
- override def sql: String = prettyName
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
index a474017221..32bae13360 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/HyperLogLogPlusPlus.scala
@@ -68,6 +68,8 @@ case class HyperLogLogPlusPlus(
inputAggBufferOffset = 0)
}
+ override def prettyName: String = "approx_count_distinct"
+
override def withNewMutableAggBufferOffset(newMutableAggBufferOffset: Int): ImperativeAggregate =
copy(mutableAggBufferOffset = newMutableAggBufferOffset)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 5af234609d..ed812e0679 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -95,8 +95,6 @@ case class Abs(child: Expression) extends UnaryExpression with ExpectsInputTypes
}
protected override def nullSafeEval(input: Any): Any = numeric.abs(input)
-
- override def sql: String = s"$prettyName(${child.sql})"
}
abstract class BinaryArithmetic extends BinaryOperator {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index dcbb594afd..33bd3f2095 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -234,8 +234,6 @@ case class Murmur3Hash(children: Seq[Expression], seed: Int) extends Expression
override def prettyName: String = "hash"
- override def sql: String = s"$prettyName(${children.map(_.sql).mkString(", ")}, $seed)"
-
override def eval(input: InternalRow): Any = {
var hash = seed
var i = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 510894afac..b9873d38a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1972,7 +1972,7 @@ object functions extends LegacyFunctions {
def crc32(e: Column): Column = withExpr { Crc32(e.expr) }
/**
- * Calculates the hash code of given columns, and returns the result as a int column.
+ * Calculates the hash code of given columns, and returns the result as an int column.
*
* @group misc_funcs
* @since 2.0
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
new file mode 100644
index 0000000000..d68c602a88
--- /dev/null
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ExpressionToSQLSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.hive
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SQLTestUtils
+
+class ExpressionToSQLSuite extends SQLBuilderTest with SQLTestUtils {
+ import testImplicits._
+
+ protected override def beforeAll(): Unit = {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+
+ val bytes = Array[Byte](1, 2, 3, 4)
+ Seq((bytes, "AQIDBA==")).toDF("a", "b").write.saveAsTable("t0")
+
+ sqlContext
+ .range(10)
+ .select('id as 'key, concat(lit("val_"), 'id) as 'value)
+ .write
+ .saveAsTable("t1")
+
+ sqlContext.range(10).select('id as 'a, 'id as 'b, 'id as 'c, 'id as 'd).write.saveAsTable("t2")
+ }
+
+ override protected def afterAll(): Unit = {
+ sql("DROP TABLE IF EXISTS t0")
+ sql("DROP TABLE IF EXISTS t1")
+ sql("DROP TABLE IF EXISTS t2")
+ }
+
+ private def checkSqlGeneration(hiveQl: String): Unit = {
+ val df = sql(hiveQl)
+
+ val convertedSQL = try new SQLBuilder(df).toSQL catch {
+ case NonFatal(e) =>
+ fail(
+ s"""Cannot convert the following HiveQL query plan back to SQL query string:
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin)
+ }
+
+ try {
+ checkAnswer(sql(convertedSQL), df)
+ } catch { case cause: Throwable =>
+ fail(
+ s"""Failed to execute converted SQL string or got wrong answer:
+ |
+ |# Converted SQL query string:
+ |$convertedSQL
+ |
+ |# Original HiveQL query string:
+ |$hiveQl
+ |
+ |# Resolved query plan:
+ |${df.queryExecution.analyzed.treeString}
+ """.stripMargin,
+ cause)
+ }
+ }
+
+ test("misc non-aggregate functions") {
+ checkSqlGeneration("SELECT abs(15), abs(-15)")
+ checkSqlGeneration("SELECT array(1,2,3)")
+ checkSqlGeneration("SELECT coalesce(null, 1, 2)")
+ // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators
+ // checkSqlGeneration("SELECT explode(array(1,2,3))")
+ checkSqlGeneration("SELECT greatest(1,null,3)")
+ checkSqlGeneration("SELECT if(1==2, 'yes', 'no')")
+ checkSqlGeneration("SELECT isnan(15), isnan('invalid')")
+ checkSqlGeneration("SELECT isnull(null), isnull('a')")
+ checkSqlGeneration("SELECT isnotnull(null), isnotnull('a')")
+ checkSqlGeneration("SELECT least(1,null,3)")
+ checkSqlGeneration("SELECT named_struct('c1',1,'c2',2,'c3',3)")
+ checkSqlGeneration("SELECT nanvl(a, 5), nanvl(b, 10), nanvl(d, c) from t2")
+ checkSqlGeneration("SELECT nvl(null, 1, 2)")
+ checkSqlGeneration("SELECT rand(1)")
+ checkSqlGeneration("SELECT randn(3)")
+ checkSqlGeneration("SELECT struct(1,2,3)")
+ }
+
+ test("math functions") {
+ checkSqlGeneration("SELECT acos(-1)")
+ checkSqlGeneration("SELECT asin(-1)")
+ checkSqlGeneration("SELECT atan(1)")
+ checkSqlGeneration("SELECT atan2(1, 1)")
+ checkSqlGeneration("SELECT bin(10)")
+ checkSqlGeneration("SELECT cbrt(1000.0)")
+ checkSqlGeneration("SELECT ceil(2.333)")
+ checkSqlGeneration("SELECT ceiling(2.333)")
+ checkSqlGeneration("SELECT cos(1.0)")
+ checkSqlGeneration("SELECT cosh(1.0)")
+ checkSqlGeneration("SELECT conv(15, 10, 16)")
+ checkSqlGeneration("SELECT degrees(pi())")
+ checkSqlGeneration("SELECT e()")
+ checkSqlGeneration("SELECT exp(1.0)")
+ checkSqlGeneration("SELECT expm1(1.0)")
+ checkSqlGeneration("SELECT floor(-2.333)")
+ checkSqlGeneration("SELECT factorial(5)")
+ checkSqlGeneration("SELECT hex(10)")
+ checkSqlGeneration("SELECT hypot(3, 4)")
+ checkSqlGeneration("SELECT log(10.0)")
+ checkSqlGeneration("SELECT log10(1000.0)")
+ checkSqlGeneration("SELECT log1p(0.0)")
+ checkSqlGeneration("SELECT log2(8.0)")
+ checkSqlGeneration("SELECT ln(10.0)")
+ checkSqlGeneration("SELECT negative(-1)")
+ checkSqlGeneration("SELECT pi()")
+ checkSqlGeneration("SELECT pmod(3, 2)")
+ checkSqlGeneration("SELECT positive(3)")
+ checkSqlGeneration("SELECT pow(2, 3)")
+ checkSqlGeneration("SELECT power(2, 3)")
+ checkSqlGeneration("SELECT radians(180.0)")
+ checkSqlGeneration("SELECT rint(1.63)")
+ checkSqlGeneration("SELECT round(31.415, -1)")
+ checkSqlGeneration("SELECT shiftleft(2, 3)")
+ checkSqlGeneration("SELECT shiftright(16, 3)")
+ checkSqlGeneration("SELECT shiftrightunsigned(16, 3)")
+ checkSqlGeneration("SELECT sign(-2.63)")
+ checkSqlGeneration("SELECT signum(-2.63)")
+ checkSqlGeneration("SELECT sin(1.0)")
+ checkSqlGeneration("SELECT sinh(1.0)")
+ checkSqlGeneration("SELECT sqrt(100.0)")
+ checkSqlGeneration("SELECT tan(1.0)")
+ checkSqlGeneration("SELECT tanh(1.0)")
+ }
+
+ test("aggregate functions") {
+ checkSqlGeneration("SELECT approx_count_distinct(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT avg(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT corr(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT count(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT covar_pop(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT covar_samp(value, key) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT first(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT first_value(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT kurtosis(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT last(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT last_value(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT max(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT mean(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT min(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT skewness(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev_pop(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT stddev_samp(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT sum(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT variance(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT var_pop(value) FROM t1 GROUP BY key")
+ checkSqlGeneration("SELECT var_samp(value) FROM t1 GROUP BY key")
+ }
+
+ test("string functions") {
+ checkSqlGeneration("SELECT ascii('SparkSql')")
+ checkSqlGeneration("SELECT base64(a) FROM t0")
+ checkSqlGeneration("SELECT concat('This ', 'is ', 'a ', 'test')")
+ checkSqlGeneration("SELECT concat_ws(' ', 'This', 'is', 'a', 'test')")
+ checkSqlGeneration("SELECT decode(a, 'UTF-8') FROM t0")
+ checkSqlGeneration("SELECT encode('SparkSql', 'UTF-8')")
+ checkSqlGeneration("SELECT find_in_set('ab', 'abc,b,ab,c,def')")
+ checkSqlGeneration("SELECT format_number(1234567.890, 2)")
+ checkSqlGeneration("SELECT format_string('aa%d%s',123, 'cc')")
+ checkSqlGeneration("SELECT get_json_object('{\"a\":\"bc\"}','$.a')")
+ checkSqlGeneration("SELECT initcap('This is a test')")
+ checkSqlGeneration("SELECT instr('This is a test', 'is')")
+ checkSqlGeneration("SELECT lcase('SparkSql')")
+ checkSqlGeneration("SELECT length('This is a test')")
+ checkSqlGeneration("SELECT levenshtein('This is a test', 'Another test')")
+ checkSqlGeneration("SELECT lower('SparkSql')")
+ checkSqlGeneration("SELECT locate('is', 'This is a test', 3)")
+ checkSqlGeneration("SELECT lpad('SparkSql', 16, 'Learning')")
+ checkSqlGeneration("SELECT ltrim(' SparkSql ')")
+ // wait for resolution of JIRA SPARK-12719 SQL Generation for Generators
+ // checkSqlGeneration("SELECT json_tuple('{\"f1\": \"value1\", \"f2\": \"value2\"}','f1')")
+ checkSqlGeneration("SELECT printf('aa%d%s', 123, 'cc')")
+ checkSqlGeneration("SELECT regexp_extract('100-200', '(\\d+)-(\\d+)', 1)")
+ checkSqlGeneration("SELECT regexp_replace('100-200', '(\\d+)', 'num')")
+ checkSqlGeneration("SELECT repeat('SparkSql', 3)")
+ checkSqlGeneration("SELECT reverse('SparkSql')")
+ checkSqlGeneration("SELECT rpad('SparkSql', 16, ' is Cool')")
+ checkSqlGeneration("SELECT rtrim(' SparkSql ')")
+ checkSqlGeneration("SELECT soundex('SparkSql')")
+ checkSqlGeneration("SELECT space(2)")
+ checkSqlGeneration("SELECT split('aa2bb3cc', '[1-9]+')")
+ checkSqlGeneration("SELECT space(2)")
+ checkSqlGeneration("SELECT substr('This is a test', 'is')")
+ checkSqlGeneration("SELECT substring('This is a test', 'is')")
+ checkSqlGeneration("SELECT substring_index('www.apache.org','.',1)")
+ checkSqlGeneration("SELECT translate('translate', 'rnlt', '123')")
+ checkSqlGeneration("SELECT trim(' SparkSql ')")
+ checkSqlGeneration("SELECT ucase('SparkSql')")
+ checkSqlGeneration("SELECT unbase64('SparkSql')")
+ checkSqlGeneration("SELECT unhex(41)")
+ checkSqlGeneration("SELECT upper('SparkSql')")
+ }
+
+ test("datetime functions") {
+ checkSqlGeneration("SELECT add_months('2001-03-31', 1)")
+ checkSqlGeneration("SELECT count(current_date())")
+ checkSqlGeneration("SELECT count(current_timestamp())")
+ checkSqlGeneration("SELECT datediff('2001-01-02', '2001-01-01')")
+ checkSqlGeneration("SELECT date_add('2001-01-02', 1)")
+ checkSqlGeneration("SELECT date_format('2001-05-02', 'yyyy-dd')")
+ checkSqlGeneration("SELECT date_sub('2001-01-02', 1)")
+ checkSqlGeneration("SELECT day('2001-05-02')")
+ checkSqlGeneration("SELECT dayofyear('2001-05-02')")
+ checkSqlGeneration("SELECT dayofmonth('2001-05-02')")
+ checkSqlGeneration("SELECT from_unixtime(1000, 'yyyy-MM-dd HH:mm:ss')")
+ checkSqlGeneration("SELECT from_utc_timestamp('2015-07-24 00:00:00', 'PST')")
+ checkSqlGeneration("SELECT hour('11:35:55')")
+ checkSqlGeneration("SELECT last_day('2001-01-01')")
+ checkSqlGeneration("SELECT minute('11:35:55')")
+ checkSqlGeneration("SELECT month('2001-05-02')")
+ checkSqlGeneration("SELECT months_between('2001-10-30 10:30:00', '1996-10-30')")
+ checkSqlGeneration("SELECT next_day('2001-05-02', 'TU')")
+ checkSqlGeneration("SELECT count(now())")
+ checkSqlGeneration("SELECT quarter('2001-05-02')")
+ checkSqlGeneration("SELECT second('11:35:55')")
+ checkSqlGeneration("SELECT to_date('2001-10-30 10:30:00')")
+ checkSqlGeneration("SELECT to_unix_timestamp('2015-07-24 00:00:00', 'yyyy-MM-dd HH:mm:ss')")
+ checkSqlGeneration("SELECT to_utc_timestamp('2015-07-24 00:00:00', 'PST')")
+ checkSqlGeneration("SELECT trunc('2001-10-30 10:30:00', 'YEAR')")
+ checkSqlGeneration("SELECT unix_timestamp('2001-10-30 10:30:00')")
+ checkSqlGeneration("SELECT weekofyear('2001-05-02')")
+ checkSqlGeneration("SELECT year('2001-05-02')")
+
+ checkSqlGeneration("SELECT interval 3 years - 3 month 7 week 123 microseconds as i")
+ }
+
+ test("collection functions") {
+ checkSqlGeneration("SELECT array_contains(array(2, 9, 8), 9)")
+ checkSqlGeneration("SELECT size(array('b', 'd', 'c', 'a'))")
+ checkSqlGeneration("SELECT sort_array(array('b', 'd', 'c', 'a'))")
+ }
+
+ test("misc functions") {
+ checkSqlGeneration("SELECT crc32('Spark')")
+ checkSqlGeneration("SELECT md5('Spark')")
+ checkSqlGeneration("SELECT hash('Spark')")
+ checkSqlGeneration("SELECT sha('Spark')")
+ checkSqlGeneration("SELECT sha1('Spark')")
+ checkSqlGeneration("SELECT sha2('Spark', 0)")
+ checkSqlGeneration("SELECT spark_partition_id()")
+ checkSqlGeneration("SELECT input_file_name()")
+ checkSqlGeneration("SELECT monotonically_increasing_id()")
+ }
+}
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
index 5255b150aa..b162adf215 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/LogicalPlanToSQLSuite.scala
@@ -86,6 +86,14 @@ class LogicalPlanToSQLSuite extends SQLBuilderTest with SQLTestUtils {
checkHiveQl("SELECT id FROM t0 WHERE id IN (1, 2, 3)")
}
+ test("not in") {
+ checkHiveQl("SELECT id FROM t0 WHERE id NOT IN (1, 2, 3)")
+ }
+
+ test("not like") {
+ checkHiveQl("SELECT id FROM t0 WHERE id + 5 NOT LIKE '1%'")
+ }
+
test("aggregate function in having clause") {
checkHiveQl("SELECT COUNT(value) FROM t1 GROUP BY key HAVING MAX(key) > 0")
}