aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala70
1 files changed, 70 insertions, 0 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 703a34c47e..8e5da3ac14 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -82,6 +82,76 @@ class UDFSuite extends QueryTest {
assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
+ test("UDF in a WHERE") {
+ ctx.udf.register("oneArgFilter", (n: Int) => { n > 80 })
+
+ val df = ctx.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString))).toDF()
+ df.registerTempTable("integerData")
+
+ val result =
+ ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
+ assert(result.count() === 20)
+ }
+
+ test("UDF in a HAVING") {
+ ctx.udf.register("havingFilter", (n: Long) => { n > 5 })
+
+ val df = Seq(("red", 1), ("red", 2), ("blue", 10),
+ ("green", 100), ("green", 200)).toDF("g", "v")
+ df.registerTempTable("groupData")
+
+ val result =
+ ctx.sql(
+ """
+ | SELECT g, SUM(v) as s
+ | FROM groupData
+ | GROUP BY g
+ | HAVING havingFilter(s)
+ """.stripMargin)
+
+ assert(result.count() === 2)
+ }
+
+ test("UDF in a GROUP BY") {
+ ctx.udf.register("groupFunction", (n: Int) => { n > 10 })
+
+ val df = Seq(("red", 1), ("red", 2), ("blue", 10),
+ ("green", 100), ("green", 200)).toDF("g", "v")
+ df.registerTempTable("groupData")
+
+ val result =
+ ctx.sql(
+ """
+ | SELECT SUM(v)
+ | FROM groupData
+ | GROUP BY groupFunction(v)
+ """.stripMargin)
+ assert(result.count() === 2)
+ }
+
+ test("UDFs everywhere") {
+ ctx.udf.register("groupFunction", (n: Int) => { n > 10 })
+ ctx.udf.register("havingFilter", (n: Long) => { n > 2000 })
+ ctx.udf.register("whereFilter", (n: Int) => { n < 150 })
+ ctx.udf.register("timesHundred", (n: Long) => { n * 100 })
+
+ val df = Seq(("red", 1), ("red", 2), ("blue", 10),
+ ("green", 100), ("green", 200)).toDF("g", "v")
+ df.registerTempTable("groupData")
+
+ val result =
+ ctx.sql(
+ """
+ | SELECT timesHundred(SUM(v)) as v100
+ | FROM groupData
+ | WHERE whereFilter(v)
+ | GROUP BY groupFunction(v)
+ | HAVING havingFilter(v100)
+ """.stripMargin)
+ assert(result.count() === 1)
+ }
+
test("struct UDF") {
ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))