diff options
author | Yin Huai <yhuai@databricks.com> | 2015-11-02 21:18:38 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-11-02 21:18:38 -0800 |
commit | 9cf56c96b7d02a14175d40b336da14c2e1c88339 (patch) | |
tree | e179b2b274e7f7e61683164e972de555d93bb97f /sql/core/src/test | |
parent | efaa4721b511a1d29229facde6457a6dcda18966 (diff) | |
download | spark-9cf56c96b7d02a14175d40b336da14c2e1c88339.tar.gz spark-9cf56c96b7d02a14175d40b336da14c2e1c88339.tar.bz2 spark-9cf56c96b7d02a14175d40b336da14c2e1c88339.zip |
[SPARK-11469][SQL] Allow users to define nondeterministic udfs.
This is the first task (https://issues.apache.org/jira/browse/SPARK-11469) of https://issues.apache.org/jira/browse/SPARK-11438
Author: Yin Huai <yhuai@databricks.com>
Closes #9393 from yhuai/udfNondeterministic.
Diffstat (limited to 'sql/core/src/test')
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala | 105 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala | 4 |
2 files changed, 107 insertions, 2 deletions
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index e0435a0dba..6e510f0b8a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql +import org.apache.spark.sql.catalyst.expressions.ScalaUDF +import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ @@ -191,4 +193,107 @@ class UDFSuite extends QueryTest with SharedSQLContext { // pass a decimal to intExpected. assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) } + + private def checkNumUDFs(df: DataFrame, expectedNumUDFs: Int): Unit = { + val udfs = df.queryExecution.optimizedPlan.collect { + case p: logical.Project => p.projectList.flatMap { + case e => e.collect { + case udf: ScalaUDF => udf + } + } + }.flatten + assert(udfs.length === expectedNumUDFs) + } + + test("foldable udf") { + import org.apache.spark.sql.functions._ + + val myUDF = udf((x: Int) => x + 1) + + { + val df = sql("SELECT 1 as a") + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 0) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("nondeterministic udf: using UDFRegistration") { + import org.apache.spark.sql.functions._ + + val myUDF = sqlContext.udf.register("plusOne1", (x: Int) => x + 1) + sqlContext.udf.register("plusOne2", myUDF.nondeterministic) + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), callUDF("plusOne1", col("a")).as("b")) + .select(col("a"), col("b"), callUDF("plusOne1", col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), callUDF("plusOne2", col("a")).as("b")) + .select(col("a"), col("b"), callUDF("plusOne2", col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("nondeterministic udf: using udf function") { + import org.apache.spark.sql.functions._ + + val myUDF = udf((x: Int) => x + 1) + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF(col("a")).as("b")) + .select(col("a"), col("b"), myUDF(col("b")).as("c")) + checkNumUDFs(df, 3) + checkAnswer(df, Row(1, 2, 3)) + } + + { + val df = sqlContext.range(1, 2).select(col("id").as("a")) + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + + { + // nondeterministicUDF will not be foldable. + val df = sql("SELECT 1 as a") + .select(col("a"), myUDF.nondeterministic(col("a")).as("b")) + .select(col("a"), col("b"), myUDF.nondeterministic(col("b")).as("c")) + checkNumUDFs(df, 2) + checkAnswer(df, Row(1, 2, 3)) + } + } + + test("override a registered udf") { + sqlContext.udf.register("intExpected", (x: Int) => x) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1) + + sqlContext.udf.register("intExpected", (x: Int) => x + 1) + assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 2) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index 7274479989..f14b2886a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -381,7 +381,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) @@ -405,7 +405,7 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext { sqlContext.udf.register("div0", (x: Int) => x / 0) withTempPath { dir => intercept[org.apache.spark.SparkException] { - sqlContext.sql("select div0(1)").write.parquet(dir.getCanonicalPath) + sqlContext.range(1, 2).selectExpr("div0(id) as a").write.parquet(dir.getCanonicalPath) } val path = new Path(dir.getCanonicalPath, "_temporary") val fs = path.getFileSystem(hadoopConfiguration) |