aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-12-16 21:21:11 -0800
committerMichael Armbrust <michael@databricks.com>2014-12-16 21:21:11 -0800
commit770d8153a5fe400147cc597c8b4b703f0aa00c22 (patch)
treed7ebb94a6f975fcefe76edcb791a7ced94734c42 /sql
parentddc7ba31cb1062acb182293b2698b1b20ea56a46 (diff)
downloadspark-770d8153a5fe400147cc597c8b4b703f0aa00c22.tar.gz
spark-770d8153a5fe400147cc597c8b4b703f0aa00c22.tar.bz2
spark-770d8153a5fe400147cc597c8b4b703f0aa00c22.zip
[SPARK-4375] [SQL] Add 0 argument support for udf
Author: Cheng Hao <hao.cheng@intel.com> Closes #3595 from chenghao-intel/udf0 and squashes the following commits: a858973 [Cheng Hao] Add 0 arguments support for udf
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala5
2 files changed, 15 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 00d6b43a57..5fb472686c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -72,14 +72,13 @@ private[sql] trait UDFRegistration {
functionRegistry.registerFunction(name, builder)
}
- /** registerFunction 1-22 were generated by this script
+ /** registerFunction 0-22 were generated by this script
- (1 to 22).map { x =>
- val types = (1 to x).map(x => "_").reduce(_ + ", " + _)
+ (0 to 22).map { x =>
+ val types = (1 to x).foldRight("T")((_, s) => {s"_, $s"})
s"""
- def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = {
- def builder(e: Seq[Expression]) =
- ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
+ def registerFunction[T: TypeTag](name: String, func: Function$x[$types]): Unit = {
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
"""
@@ -87,6 +86,11 @@ private[sql] trait UDFRegistration {
*/
// scalastyle:off
+ def registerFunction[T: TypeTag](name: String, func: Function0[T]): Unit = {
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
+ functionRegistry.registerFunction(name, builder)
+ }
+
def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index ef9b76b1e2..720953ae37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -31,6 +31,11 @@ class UDFSuite extends QueryTest {
assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
}
+ test("ZeroArgument UDF") {
+ registerFunction("random0", () => { Math.random()})
+ assert(sql("SELECT random0()").first().getDouble(0) >= 0.0)
+ }
+
test("TwoArgument UDF") {
registerFunction("strLenScala", (_: String).length + (_:Int))
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)