aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala5
2 files changed, 15 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
index 00d6b43a57..5fb472686c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/UdfRegistration.scala
@@ -72,14 +72,13 @@ private[sql] trait UDFRegistration {
functionRegistry.registerFunction(name, builder)
}
- /** registerFunction 1-22 were generated by this script
+ /** registerFunction 0-22 were generated by this script
- (1 to 22).map { x =>
- val types = (1 to x).map(x => "_").reduce(_ + ", " + _)
+ (0 to 22).map { x =>
+ val types = (1 to x).foldRight("T")((_, s) => {s"_, $s"})
s"""
- def registerFunction[T: TypeTag](name: String, func: Function$x[$types, T]): Unit = {
- def builder(e: Seq[Expression]) =
- ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
+ def registerFunction[T: TypeTag](name: String, func: Function$x[$types]): Unit = {
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
}
"""
@@ -87,6 +86,11 @@ private[sql] trait UDFRegistration {
*/
// scalastyle:off
+ def registerFunction[T: TypeTag](name: String, func: Function0[T]): Unit = {
+ def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
+ functionRegistry.registerFunction(name, builder)
+ }
+
def registerFunction[T: TypeTag](name: String, func: Function1[_, T]): Unit = {
def builder(e: Seq[Expression]) = ScalaUdf(func, ScalaReflection.schemaFor[T].dataType, e)
functionRegistry.registerFunction(name, builder)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index ef9b76b1e2..720953ae37 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -31,6 +31,11 @@ class UDFSuite extends QueryTest {
assert(sql("SELECT strLenScala('test')").first().getInt(0) === 4)
}
+ test("ZeroArgument UDF") {
+ registerFunction("random0", () => { Math.random()})
+ assert(sql("SELECT random0()").first().getDouble(0) >= 0.0)
+ }
+
test("TwoArgument UDF") {
registerFunction("strLenScala", (_: String).length + (_:Int))
assert(sql("SELECT strLenScala('test', 1)").first().getInt(0) === 5)