diff options
author | Zhan Zhang <zhanzhang@fb.com> | 2016-12-09 16:35:06 +0800 |
---|---|---|
committer | Wenchen Fan <wenchen@databricks.com> | 2016-12-09 16:35:06 +0800 |
commit | 67587d961d5f94a8639c20cb80127c86bf79d5a8 (patch) | |
tree | c9cd42fac2a18ec79e073eeb1c2e587c2c2d2a52 /sql | |
parent | c074c96dc57bf18b28fafdcac0c768d75c642cba (diff) | |
download | spark-67587d961d5f94a8639c20cb80127c86bf79d5a8.tar.gz spark-67587d961d5f94a8639c20cb80127c86bf79d5a8.tar.bz2 spark-67587d961d5f94a8639c20cb80127c86bf79d5a8.zip |
[SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic
## What changes were proposed in this pull request?
Make stateful udf as nondeterministic
## How was this patch tested?
Add new test cases with both Stateful and Stateless UDF.
Without the patch, the test cases will throw exception:
1 did not equal 10
ScalaTestFailureLocation: org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at (HiveUDFSuite.scala:501)
org.scalatest.exceptions.TestFailedException: 1 did not equal 10
at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500)
at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555)
...
Author: Zhan Zhang <zhanzhang@fb.com>
Closes #16068 from zhzhan/state.
Diffstat (limited to 'sql')
-rw-r--r-- | sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala | 4 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala | 45 |
2 files changed, 45 insertions, 4 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 349faae40b..26dc372d7c 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -61,7 +61,7 @@ private[hive] case class HiveSimpleUDF( @transient private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() + udfType != null && udfType.deterministic() && !udfType.stateful() } override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) @@ -144,7 +144,7 @@ private[hive] case class HiveGenericUDF( @transient private lazy val isUDFDeterministic = { val udfType = function.getClass.getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() + udfType != null && udfType.deterministic() && !udfType.stateful() } @transient diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala index 48adc833f4..4098bb597b 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter} import java.util.{ArrayList, Arrays, Properties} import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.UDAFPercentile +import org.apache.hadoop.hive.ql.exec.UDF +import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType} import org.apache.hadoop.hive.ql.udf.generic._ import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.io.Writable +import org.apache.hadoop.io.{LongWritable, Writable} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.functions.max import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.util.Utils @@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils { assert(count4 == 1) sql("DROP TABLE parquet_tmp") } + + test("Hive Stateful UDF") { + withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) { + sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'") + sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'") + val testData = spark.range(10).repartition(1) + + // Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1. + checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10)) + + // Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1, + // and the data is evenly distributed into 2 partitions. + checkAnswer(testData.repartition(2) + .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5)) + + // Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced + // by constant 1 by ConstantFolding optimizer. + checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1)) + } + } } class TestPair(x: Int, y: Int) extends Writable with Serializable { @@ -551,3 +573,22 @@ class PairUDF extends GenericUDF { override def getDisplayString(p1: Array[String]): String = "" } + +@UDFType(stateful = true) +class StatefulUDF extends UDF { + private val result = new LongWritable(0) + + def evaluate(): LongWritable = { + result.set(result.get() + 1) + result + } +} + +class StatelessUDF extends UDF { + private val result = new LongWritable(0) + + def evaluate(): LongWritable = { + result.set(result.get() + 1) + result + } +} |