aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src
diff options
context:
space:
mode:
authorCheolsoo Park <cheolsoop@netflix.com>2015-07-03 22:14:21 -0700
committerReynold Xin <rxin@databricks.com>2015-07-03 22:14:21 -0700
commit4a22bce8fce30f86f364467a8ba51d2e744ff379 (patch)
treea4f8c1a74b2a73a05891a55b99de45b82de79d50 /sql/catalyst/src
parente92c24d37cae54634e7af20cbfe313d023786f87 (diff)
downloadspark-4a22bce8fce30f86f364467a8ba51d2e744ff379.tar.gz
spark-4a22bce8fce30f86f364467a8ba51d2e744ff379.tar.bz2
spark-4a22bce8fce30f86f364467a8ba51d2e744ff379.zip
[SPARK-8572] [SQL] Type coercion for ScalaUDFs
Implemented type coercion for udf arguments in Scala. The changes include- * Add `with ExpectsInputTypes ` to `ScalaUDF` class. * Pass down argument types info from `UDFRegistration` and `functions`. With this patch, the example query in [SPARK-8572](https://issues.apache.org/jira/browse/SPARK-8572) no longer throws a type cast error at runtime. Also added a unit test to `UDFSuite` in which a decimal type is passed to a udf that expects an int. Author: Cheolsoo Park <cheolsoop@netflix.com> Closes #7203 from piaozhexiu/SPARK-8572 and squashes the following commits: 2d0ed15 [Cheolsoo Park] Incorporate comments dce1efd [Cheolsoo Park] Fix unit tests and update the codegen script 066deed [Cheolsoo Park] Type coercion for udf inputs
Diffstat (limited to 'sql/catalyst/src')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala7
2 files changed, 6 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 38eb8322c8..84acc0e7e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -680,7 +680,7 @@ object HiveTypeCoercion {
// Skip nodes who's children have not been resolved yet.
case e if !e.childrenResolved => e
- case e: ExpectsInputTypes =>
+ case e: ExpectsInputTypes if (e.inputTypes.nonEmpty) =>
val children: Seq[Expression] = e.children.zip(e.inputTypes).map { case (in, expected) =>
// If we cannot do the implicit cast, just use the original input.
implicitCast(in, expected).getOrElse(in)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
index caf021b016..fc055c97a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala
@@ -24,8 +24,11 @@ import org.apache.spark.sql.types.DataType
* User-defined function.
* @param dataType Return type of function.
*/
-case class ScalaUDF(function: AnyRef, dataType: DataType, children: Seq[Expression])
- extends Expression {
+case class ScalaUDF(
+ function: AnyRef,
+ dataType: DataType,
+ children: Seq[Expression],
+ inputTypes: Seq[DataType] = Nil) extends Expression with ExpectsInputTypes {
override def nullable: Boolean = true