aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala11
2 files changed, 11 insertions, 3 deletions
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index bfe43373d9..47305571e5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -375,9 +375,8 @@ private[hive] case class HiveUdafFunction(
private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors)
- // Cast required to avoid type inference selecting a deprecated Hive API.
private val buffer =
- function.getNewAggregationBuffer.asInstanceOf[GenericUDAFEvaluator.AbstractAggregationBuffer]
+ function.getNewAggregationBuffer
override def eval(input: Row): Any = unwrap(function.evaluate(buffer), returnInspector)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
index cb405f56bf..d7c5d1a25a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala
@@ -22,7 +22,7 @@ import java.util
import java.util.Properties
import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.hive.ql.udf.generic.GenericUDF
+import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF}
import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory
import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory}
@@ -93,6 +93,15 @@ class HiveUdfSuite extends QueryTest {
sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf")
}
+ test("SPARK-6409 UDAFAverage test") {
+ sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'")
+ checkAnswer(
+ sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"),
+ Seq(Row(1.0, 260.182)))
+ sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg")
+ TestHive.reset()
+ }
+
test("SPARK-2693 udaf aggregates test") {
checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"),
sql("SELECT max(key) FROM src").collect().toSeq)