From 270a659584b6c1c304a9f9a331c56287672e00b0 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 29 Dec 2015 16:58:23 -0800 Subject: [SPARK-12549][SQL] Take Option[Seq[DataType]] in UDF input type specification. In Spark we allow UDFs to declare its expected input types in order to apply type coercion. The expected input type parameter takes a Seq[DataType] and uses Nil when no type coercion is applied. It makes more sense to take Option[Seq[DataType]] instead, so we can differentiate a no-arg function vs function with no expected input type specified. Author: Reynold Xin Closes #10504 from rxin/SPARK-12549. --- .../org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 85faa19bbf..64d397bf84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -30,7 +30,10 @@ import org.apache.spark.sql.types.DataType * null. Use boxed type or [[Option]] if you wanna do the null-handling yourself. * @param dataType Return type of function. * @param children The input expressions of this UDF. - * @param inputTypes The expected input types of this UDF. + * @param inputTypes The expected input types of this UDF, used to perform type coercion. If we do + * not want to perform coercion, simply use "Nil". Note that it would've been + * better to use Option of Seq[DataType] so we can use "None" as the case for no + * type coercion. However, that would require more refactoring of the codebase. */ case class ScalaUDF( function: AnyRef, @@ -43,7 +46,7 @@ case class ScalaUDF( override def toString: String = s"UDF(${children.mkString(",")})" - // scalastyle:off + // scalastyle:off line.size.limit /** This method has been generated by this script @@ -969,7 +972,7 @@ case class ScalaUDF( } } - // scalastyle:on + // scalastyle:on line.size.limit // Generate codes used to convert the arguments to Scala type for user-defined funtions private[this] def genCodeForConverter(ctx: CodeGenContext, index: Int): String = { @@ -1010,7 +1013,7 @@ case class ScalaUDF( // This must be called before children expressions' codegen // because ctx.references is used in genCodeForConverter - val converterTerms = (0 until children.size).map(genCodeForConverter(ctx, _)) + val converterTerms = children.indices.map(genCodeForConverter(ctx, _)) // Initialize user-defined function val funcClassName = s"scala.Function${children.size}" @@ -1054,5 +1057,6 @@ case class ScalaUDF( } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) + override def eval(input: InternalRow): Any = converter(f(input)) } -- cgit v1.2.3