aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorpetermaxlee <petermaxlee@gmail.com>2016-07-28 13:13:17 +0800
committerWenchen Fan <wenchen@databricks.com>2016-07-28 13:13:17 +0800
commit11d427c924d303e20af90c0179a105f6ff4d89e2 (patch)
treea82e104985a113ad95602c204f5076f478790948
parentb14d7b5cf4f173a1e45a4b1ae2a5e4e7ac5e9bb1 (diff)
downloadspark-11d427c924d303e20af90c0179a105f6ff4d89e2.tar.gz
spark-11d427c924d303e20af90c0179a105f6ff4d89e2.tar.bz2
spark-11d427c924d303e20af90c0179a105f6ff4d89e2.zip
[SPARK-16730][SQL] Implement function aliases for type casts
## What changes were proposed in this pull request? Spark 1.x supports using the Hive type name as function names for doing casts, e.g. ```sql SELECT int(1.0); SELECT string(2.0); ``` The above query would work in Spark 1.x because Spark 1.x fail back to Hive for unimplemented functions, and break in Spark 2.0 because the fall back was removed. This patch implements function aliases using an analyzer rule for the following cast functions: - boolean - tinyint - smallint - int - bigint - float - double - decimal - date - timestamp - binary - string ## How was this patch tested? Added end-to-end tests in SQLCompatibilityFunctionSuite. Author: petermaxlee <petermaxlee@gmail.com> Closes #14364 from petermaxlee/SPARK-16730-2.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala26
3 files changed, 73 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 65168998c8..c5f91c1590 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
+import org.apache.spark.sql.types._
/**
@@ -408,8 +409,21 @@ object FunctionRegistry {
expression[BitwiseAnd]("&"),
expression[BitwiseNot]("~"),
expression[BitwiseOr]("|"),
- expression[BitwiseXor]("^")
-
+ expression[BitwiseXor]("^"),
+
+ // Cast aliases (SPARK-16730)
+ castAlias("boolean", BooleanType),
+ castAlias("tinyint", ByteType),
+ castAlias("smallint", ShortType),
+ castAlias("int", IntegerType),
+ castAlias("bigint", LongType),
+ castAlias("float", FloatType),
+ castAlias("double", DoubleType),
+ castAlias("decimal", DecimalType.USER_DEFAULT),
+ castAlias("date", DateType),
+ castAlias("timestamp", TimestampType),
+ castAlias("binary", BinaryType),
+ castAlias("string", StringType)
)
val builtin: SimpleFunctionRegistry = {
@@ -452,14 +466,37 @@ object FunctionRegistry {
}
}
- val clazz = tag.runtimeClass
+ (name, (expressionInfo[T](name), builder))
+ }
+
+ /**
+ * Creates a function registry lookup entry for cast aliases (SPARK-16730).
+ * For example, if name is "int", and dataType is IntegerType, this means int(x) would become
+ * an alias for cast(x as IntegerType).
+ * See usage above.
+ */
+ private def castAlias(
+ name: String,
+ dataType: DataType): (String, (ExpressionInfo, FunctionBuilder)) = {
+ val builder = (args: Seq[Expression]) => {
+ if (args.size != 1) {
+ throw new AnalysisException(s"Function $name accepts only one argument")
+ }
+ Cast(args.head, dataType)
+ }
+ (name, (expressionInfo[Cast](name), builder))
+ }
+
+ /**
+ * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name.
+ */
+ private def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = {
+ val clazz = scala.reflect.classTag[T].runtimeClass
val df = clazz.getAnnotation(classOf[ExpressionDescription])
if (df != null) {
- (name,
- (new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()),
- builder))
+ new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended())
} else {
- (name, (new ExpressionInfo(clazz.getCanonicalName, name), builder))
+ new ExpressionInfo(clazz.getCanonicalName, name)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a12fba047b..c452765af2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -113,6 +113,9 @@ object Cast {
}
/** Cast the child expression to the target data type. */
+@ExpressionDescription(
+ usage = " - Cast value v to the target data type.",
+ extended = "> SELECT _FUNC_('10' as int);\n 10")
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
override def toString: String = s"cast($child as ${dataType.simpleString})"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
index 1e3239550f..27b60e0d9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql
+import java.math.BigDecimal
+import java.sql.Timestamp
+
import org.apache.spark.sql.test.SharedSQLContext
/**
* A test suite for functions added for compatibility with other databases such as Oracle, MSSQL.
+ *
* These functions are typically implemented using the trait
* [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]].
*/
@@ -69,4 +73,26 @@ class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {
sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"),
Row(2.1, 1.0))
}
+
+ test("SPARK-16730 cast alias functions for Hive compatibility") {
+ checkAnswer(
+ sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"),
+ Row(true, 1.toByte, 1.toShort, 1, 1L))
+
+ checkAnswer(
+ sql("SELECT float(1), double(1), decimal(1)"),
+ Row(1.toFloat, 1.0, new BigDecimal(1)))
+
+ checkAnswer(
+ sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"),
+ Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0)))
+
+ checkAnswer(
+ sql("SELECT string(1)"),
+ Row("1"))
+
+ // Error handling: only one argument
+ val errorMsg = intercept[AnalysisException](sql("SELECT string(1, 2)")).getMessage
+ assert(errorMsg.contains("Function string accepts only one argument"))
+ }
}