diff options
author | Cheng Hao <hao.cheng@intel.com> | 2014-12-16 21:21:11 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-12-16 21:21:11 -0800 |
commit | 770d8153a5fe400147cc597c8b4b703f0aa00c22 (patch) | |
tree | d7ebb94a6f975fcefe76edcb791a7ced94734c42 | |
parent | ddc7ba31cb1062acb182293b2698b1b20ea56a46 (diff) | |
download | spark-770d8153a5fe400147cc597c8b4b703f0aa00c22.tar.gz spark-770d8153a5fe400147cc597c8b4b703f0aa00c22.tar.bz2 spark-770d8153a5fe400147cc597c8b4b703f0aa00c22.zip |
[SPARK-4375] [SQL] Add 0 argument support for udf
Author: Cheng Hao <hao.cheng@intel.com>
Closes #3595 from chenghao-intel/udf0 and squashes the following commits:
a858973 [Cheng Hao] Add 0 arguments support for udf
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala | 16 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 5 |
2 files changed, 15 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 00d6b43a57..5fb472686c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -72,14 +72,13 @@ private[sql] trait UDFRegistration { functionRegistry.registerFunction(name, builder) } - /** registerFunction 1-22 were generated by this script + /** registerFunction 0-22 were generated by this script - (1 to 22).map { x => - val types = (1 to x).map(x => "_").reduce(_ + ", " + _) + (0 to 22).map { x => + val types = (1 to x).foldRight("T")((_, s) => {s"_, $s"}) s""" - def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { - def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + def registerFunction[T: TypeTag](name: String, func: Function$x[$types]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,6 +86,11 @@ private[sql] trait UDFRegistration { */ // scalastyle:off + def registerFunction[T: TypeTag](name: String, func: Function0[T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + functionRegistry.registerFunction(name, builder) + } + def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ef9b76b1e2..720953ae37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -31,6 +31,11 @@ class UDFSuite extends QueryTest { assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) } + test("ZeroArgument UDF") { + registerFunction("random0", () => { Math.random()}) + assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + } + test("TwoArgument UDF") { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) |