From 770d8153a5fe400147cc597c8b4b703f0aa00c22 Mon Sep 17 00:00:00 2001 From: Cheng Hao Date: Tue, 16 Dec 2014 21:21:11 -0800 Subject: [SPARK-4375] [SQL] Add 0 argument support for udf Author: Cheng Hao Closes #3595 from chenghao-intel/udf0 and squashes the following commits: a858973 [Cheng Hao] Add 0 arguments support for udf --- .../scala/org/apache/spark/sql/UdfRegistration.scala | 16 ++++++++++------ .../src/test/scala/org/apache/spark/sql/UDFSuite.scala | 5 +++++ 2 files changed, 15 insertions(+), 6 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala index 00d6b43a57..5fb472686c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala @@ -72,14 +72,13 @@ private[sql] trait UDFRegistration { functionRegistry.registerFunction(name, builder) } - /** registerFunction 1-22 were generated by this script + /** registerFunction 0-22 were generated by this script - (1 to 22).map { x => - val types = (1 to x).map(x => "_").reduce(_ + ", " + _) + (0 to 22).map { x => + val types = (1 to x).foldRight("T")((_, s) => {s"_, $s"}) s""" - def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = { - def builder(e: Seq[Expression]) = - ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + def registerFunction[T: TypeTag](name: String, func: Function$x[$types]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) } """ @@ -87,6 +86,11 @@ private[sql] trait UDFRegistration { */ // scalastyle:off + def registerFunction[T: TypeTag](name: String, func: Function0[T]): Unit = { + def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) + functionRegistry.registerFunction(name, builder) + } + def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = { def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e) functionRegistry.registerFunction(name, builder) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index ef9b76b1e2..720953ae37 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -31,6 +31,11 @@ class UDFSuite extends QueryTest { assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4) } + test("ZeroArgument UDF") { + registerFunction("random0", () => { Math.random()}) + assert(sql("SELECT random0()").first().getDouble(0) >= 0.0) + } + test("TwoArgument UDF") { registerFunction("strLenScala", (_: String).length + (_:Int)) assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5) -- cgit v1.2.3