aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorZhan Zhang <zhanzhang@fb.com>2016-12-09 16:35:06 +0800
committerWenchen Fan <wenchen@databricks.com>2016-12-09 16:35:06 +0800
commit67587d961d5f94a8639c20cb80127c86bf79d5a8 (patch)
treec9cd42fac2a18ec79e073eeb1c2e587c2c2d2a52
parentc074c96dc57bf18b28fafdcac0c768d75c642cba (diff)
downloadspark-67587d961d5f94a8639c20cb80127c86bf79d5a8.tar.gz
spark-67587d961d5f94a8639c20cb80127c86bf79d5a8.tar.bz2
spark-67587d961d5f94a8639c20cb80127c86bf79d5a8.zip
[SPARK-18637][SQL] Stateful UDF should be considered as nondeterministic
## What changes were proposed in this pull request? Make stateful udf as nondeterministic ## How was this patch tested? Add new test cases with both Stateful and Stateless UDF. Without the patch, the test cases will throw exception: 1 did not equal 10 ScalaTestFailureLocation: org.apache.spark.sql.hive.execution.HiveUDFSuite$$anonfun$21 at (HiveUDFSuite.scala:501) org.scalatest.exceptions.TestFailedException: 1 did not equal 10 at org.scalatest.Assertions$class.newAssertionFailedException(Assertions.scala:500) at org.scalatest.FunSuite.newAssertionFailedException(FunSuite.scala:1555) ... Author: Zhan Zhang <zhanzhang@fb.com> Closes #16068 from zhzhan/state.
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala45
2 files changed, 45 insertions, 4 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
index 349faae40b..26dc372d7c 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala
@@ -61,7 +61,7 @@ private[hive] case class HiveSimpleUDF(
@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
- udfType != null && udfType.deterministic()
+ udfType != null && udfType.deterministic() && !udfType.stateful()
}
override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable)
@@ -144,7 +144,7 @@ private[hive] case class HiveGenericUDF(
@transient
private lazy val isUDFDeterministic = {
val udfType = function.getClass.getAnnotation(classOf[HiveUDFType])
- udfType != null && udfType.deterministic()
+ udfType != null && udfType.deterministic() && !udfType.stateful()
}
@transient
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
index 48adc833f4..4098bb597b 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala
@@ -21,15 +21,17 @@ import java.io.{DataInput, DataOutput, File, PrintWriter}
import java.util.{ArrayList, Arrays, Properties}
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.UDAFPercentile
+import org.apache.hadoop.hive.ql.exec.UDF
+import org.apache.hadoop.hive.ql.udf.{UDAFPercentile, UDFType}
import org.apache.hadoop.hive.ql.udf.generic._
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats}
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
-import org.apache.hadoop.io.Writable
+import org.apache.hadoop.io.{LongWritable, Writable}
import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
+import org.apache.spark.sql.functions.max
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
@@ -487,6 +489,26 @@ class HiveUDFSuite extends QueryTest with TestHiveSingleton with SQLTestUtils {
assert(count4 == 1)
sql("DROP TABLE parquet_tmp")
}
+
+ test("Hive Stateful UDF") {
+ withUserDefinedFunction("statefulUDF" -> true, "statelessUDF" -> true) {
+ sql(s"CREATE TEMPORARY FUNCTION statefulUDF AS '${classOf[StatefulUDF].getName}'")
+ sql(s"CREATE TEMPORARY FUNCTION statelessUDF AS '${classOf[StatelessUDF].getName}'")
+ val testData = spark.range(10).repartition(1)
+
+ // Expected Max(s) is 10 as statefulUDF returns the sequence number starting from 1.
+ checkAnswer(testData.selectExpr("statefulUDF() as s").agg(max($"s")), Row(10))
+
+ // Expected Max(s) is 5 as statefulUDF returns the sequence number starting from 1,
+ // and the data is evenly distributed into 2 partitions.
+ checkAnswer(testData.repartition(2)
+ .selectExpr("statefulUDF() as s").agg(max($"s")), Row(5))
+
+ // Expected Max(s) is 1, as stateless UDF is deterministic and foldable and replaced
+ // by constant 1 by ConstantFolding optimizer.
+ checkAnswer(testData.selectExpr("statelessUDF() as s").agg(max($"s")), Row(1))
+ }
+ }
}
class TestPair(x: Int, y: Int) extends Writable with Serializable {
@@ -551,3 +573,22 @@ class PairUDF extends GenericUDF {
override def getDisplayString(p1: Array[String]): String = ""
}
+
+@UDFType(stateful = true)
+class StatefulUDF extends UDF {
+ private val result = new LongWritable(0)
+
+ def evaluate(): LongWritable = {
+ result.set(result.get() + 1)
+ result
+ }
+}
+
+class StatelessUDF extends UDF {
+ private val result = new LongWritable(0)
+
+ def evaluate(): LongWritable = {
+ result.set(result.get() + 1)
+ result
+ }
+}