aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala51
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala26
3 files changed, 73 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 65168998c8..c5f91c1590 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.xml._
import org.apache.spark.sql.catalyst.util.StringKeyHashMap
+import org.apache.spark.sql.types._
/**
@@ -408,8 +409,21 @@ object FunctionRegistry {
expression[BitwiseAnd]("&"),
expression[BitwiseNot]("~"),
expression[BitwiseOr]("|"),
- expression[BitwiseXor]("^")
-
+ expression[BitwiseXor]("^"),
+
+ // Cast aliases (SPARK-16730)
+ castAlias("boolean", BooleanType),
+ castAlias("tinyint", ByteType),
+ castAlias("smallint", ShortType),
+ castAlias("int", IntegerType),
+ castAlias("bigint", LongType),
+ castAlias("float", FloatType),
+ castAlias("double", DoubleType),
+ castAlias("decimal", DecimalType.USER_DEFAULT),
+ castAlias("date", DateType),
+ castAlias("timestamp", TimestampType),
+ castAlias("binary", BinaryType),
+ castAlias("string", StringType)
)
val builtin: SimpleFunctionRegistry = {
@@ -452,14 +466,37 @@ object FunctionRegistry {
}
}
- val clazz = tag.runtimeClass
+ (name, (expressionInfo[T](name), builder))
+ }
+
+ /**
+ * Creates a function registry lookup entry for cast aliases (SPARK-16730).
+ * For example, if name is "int", and dataType is IntegerType, this means int(x) would become
+ * an alias for cast(x as IntegerType).
+ * See usage above.
+ */
+ private def castAlias(
+ name: String,
+ dataType: DataType): (String, (ExpressionInfo, FunctionBuilder)) = {
+ val builder = (args: Seq[Expression]) => {
+ if (args.size != 1) {
+ throw new AnalysisException(s"Function $name accepts only one argument")
+ }
+ Cast(args.head, dataType)
+ }
+ (name, (expressionInfo[Cast](name), builder))
+ }
+
+ /**
+ * Creates an [[ExpressionInfo]] for the function as defined by expression T using the given name.
+ */
+ private def expressionInfo[T <: Expression : ClassTag](name: String): ExpressionInfo = {
+ val clazz = scala.reflect.classTag[T].runtimeClass
val df = clazz.getAnnotation(classOf[ExpressionDescription])
if (df != null) {
- (name,
- (new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended()),
- builder))
+ new ExpressionInfo(clazz.getCanonicalName, name, df.usage(), df.extended())
} else {
- (name, (new ExpressionInfo(clazz.getCanonicalName, name), builder))
+ new ExpressionInfo(clazz.getCanonicalName, name)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
index a12fba047b..c452765af2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala
@@ -113,6 +113,9 @@ object Cast {
}
/** Cast the child expression to the target data type. */
+@ExpressionDescription(
+ usage = " - Cast value v to the target data type.",
+ extended = "> SELECT _FUNC_('10' as int);\n 10")
case class Cast(child: Expression, dataType: DataType) extends UnaryExpression with NullIntolerant {
override def toString: String = s"cast($child as ${dataType.simpleString})"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
index 1e3239550f..27b60e0d9d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLCompatibilityFunctionSuite.scala
@@ -17,10 +17,14 @@
package org.apache.spark.sql
+import java.math.BigDecimal
+import java.sql.Timestamp
+
import org.apache.spark.sql.test.SharedSQLContext
/**
* A test suite for functions added for compatibility with other databases such as Oracle, MSSQL.
+ *
* These functions are typically implemented using the trait
* [[org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable]].
*/
@@ -69,4 +73,26 @@ class SQLCompatibilityFunctionSuite extends QueryTest with SharedSQLContext {
sql("SELECT nvl2(null, 1, 2.1d), nvl2('n', 1, 2.1d)"),
Row(2.1, 1.0))
}
+
+ test("SPARK-16730 cast alias functions for Hive compatibility") {
+ checkAnswer(
+ sql("SELECT boolean(1), tinyint(1), smallint(1), int(1), bigint(1)"),
+ Row(true, 1.toByte, 1.toShort, 1, 1L))
+
+ checkAnswer(
+ sql("SELECT float(1), double(1), decimal(1)"),
+ Row(1.toFloat, 1.0, new BigDecimal(1)))
+
+ checkAnswer(
+ sql("SELECT date(\"2014-04-04\"), timestamp(date(\"2014-04-04\"))"),
+ Row(new java.util.Date(114, 3, 4), new Timestamp(114, 3, 4, 0, 0, 0, 0)))
+
+ checkAnswer(
+ sql("SELECT string(1)"),
+ Row("1"))
+
+ // Error handling: only one argument
+ val errorMsg = intercept[AnalysisException](sql("SELECT string(1, 2)")).getMessage
+ assert(errorMsg.contains("Function string accepts only one argument"))
+ }
}