aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src/test
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-11-02 21:18:38 -0800
committerYin Huai <yhuai@databricks.com>2015-11-02 21:18:38 -0800
commit9cf56c96b7d02a14175d40b336da14c2e1c88339 (patch)
treee179b2b274e7f7e61683164e972de555d93bb97f /sql/core/src/test
parentefaa4721b511a1d29229facde6457a6dcda18966 (diff)
downloadspark-9cf56c96b7d02a14175d40b336da14c2e1c88339.tar.gz
spark-9cf56c96b7d02a14175d40b336da14c2e1c88339.tar.bz2
spark-9cf56c96b7d02a14175d40b336da14c2e1c88339.zip
[SPARK-11469][SQL] Allow users to define nondeterministic udfs.
This is the first task (https://issues.apache.org/jira/browse/SPARK-11469) of https://issues.apache.org/jira/browse/SPARK-11438 Author: Yin Huai <yhuai@databricks.com> Closes #9393 from yhuai/udfNondeterministic.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala105
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala4
2 files changed, 107 insertions, 2 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index e0435a0dba..6e510f0b8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,6 +17,8 @@
package org.apache.spark.sql
+import org.apache.spark.sql.catalyst.expressions.ScalaUDF
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.test.SQLTestData._
@@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext {
// pass a decimal to intExpected.
assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
+
+ private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = {
+ val udfs = df.queryExecution.optimizedPlan.collect {
+ case p: logical.Project => p.projectList.flatMap {
+ case e => e.collect {
+ case udf: ScalaUDF => udf
+ }
+ }
+ }.flatten
+ assert(udfs.length === expectedNumUDFs)
+ }
+
+ test("foldable udf") {
+ import org.apache.spark.sql.functions._
+
+ val myUDF = udf((x: Int) => x + 1)
+
+ {
+ val df = sql("SELECT 1 as a")
+ .select(col("a"), myUDF(col("a")).as("b"))
+ .select(col("a"), col("b"), myUDF(col("b")).as("c"))
+ checkNumUDFs(df, 0)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+ }
+
+ test("nondeterministic udf: using UDFRegistration") {
+ import org.apache.spark.sql.functions._
+
+ val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1)
+ sqlContext.udf.register("plusOne2", myUDF.nondeterministic)
+
+ {
+ val df = sqlContext.range(1, 2).select(col("id").as("a"))
+ .select(col("a"), myUDF(col("a")).as("b"))
+ .select(col("a"), col("b"), myUDF(col("b")).as("c"))
+ checkNumUDFs(df, 3)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+
+ {
+ val df = sqlContext.range(1, 2).select(col("id").as("a"))
+ .select(col("a"), callUDF("plusOne1", col("a")).as("b"))
+ .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c"))
+ checkNumUDFs(df, 3)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+
+ {
+ val df = sqlContext.range(1, 2).select(col("id").as("a"))
+ .select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
+ .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
+ checkNumUDFs(df, 2)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+
+ {
+ val df = sqlContext.range(1, 2).select(col("id").as("a"))
+ .select(col("a"), callUDF("plusOne2", col("a")).as("b"))
+ .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c"))
+ checkNumUDFs(df, 2)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+ }
+
+ test("nondeterministic udf: using udf function") {
+ import org.apache.spark.sql.functions._
+
+ val myUDF = udf((x: Int) => x + 1)
+
+ {
+ val df = sqlContext.range(1, 2).select(col("id").as("a"))
+ .select(col("a"), myUDF(col("a")).as("b"))
+ .select(col("a"), col("b"), myUDF(col("b")).as("c"))
+ checkNumUDFs(df, 3)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+
+ {
+ val df = sqlContext.range(1, 2).select(col("id").as("a"))
+ .select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
+ .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
+ checkNumUDFs(df, 2)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+
+ {
+ // nondeterministicUDF will not be foldable.
+ val df = sql("SELECT 1 as a")
+ .select(col("a"), myUDF.nondeterministic(col("a")).as("b"))
+ .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c"))
+ checkNumUDFs(df, 2)
+ checkAnswer(df, Row(1, 2, 3))
+ }
+ }
+
+ test("override a registered udf") {
+ sqlContext.udf.register("intExpected", (x: Int) => x)
+ assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
+
+ sqlContext.udf.register("intExpected", (x: Int) => x + 1)
+ assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2)
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index 7274479989..f14b2886a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
- sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
+ sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)
@@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
sqlContext.udf.register("div0", (x: Int) => x / 0)
withTempPath { dir =>
intercept[org.apache.spark.SparkException] {
- sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath)
+ sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath)
}
val path = new Path(dir.getCanonicalPath, "_temporary")
val fs = path.getFileSystem(hadoopConfiguration)