aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-09-17 11:14:52 -0700
committerMichael Armbrust <michael@databricks.com>2015-09-17 11:14:52 -0700
commitaad644fbe29151aec9004817d42e4928bdb326f3 (patch)
tree77bfa902698f82e6e2547e9bc70dfec46bd0970f /sql/hive
parente0dc2bc232206d2f4da4278502c1f88babc8b55a (diff)
downloadspark-aad644fbe29151aec9004817d42e4928bdb326f3.tar.gz
spark-aad644fbe29151aec9004817d42e4928bdb326f3.tar.bz2
spark-aad644fbe29151aec9004817d42e4928bdb326f3.zip
[SPARK-10639] [SQL] Need to convert UDAF's result from scala to sql type
https://issues.apache.org/jira/browse/SPARK-10639 Author: Yin Huai <yhuai@databricks.com> Closes #8788 from yhuai/udafConversion.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala108
1 files changed, 107 insertions, 1 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index a73b1bd52c..24b1846923 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,13 +17,55 @@
package org.apache.spark.sql.hive.execution
+import scala.collection.JavaConverters._
+
import org.apache.spark.sql._
+import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.aggregate
+import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
+import org.apache.spark.sql.functions._
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
+import org.apache.spark.sql.types._
import org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
import org.apache.spark.sql.hive.test.TestHiveSingleton
+class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = schema
+
+ def bufferSchema: StructType = schema
+
+ def dataType: DataType = schema
+
+ def deterministic: Boolean = true
+
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ (0 until schema.length).foreach { i =>
+ buffer.update(i, null)
+ }
+ }
+
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ if (!input.isNullAt(0) && input.getInt(0) == 50) {
+ (0 until schema.length).foreach { i =>
+ buffer.update(i, input.get(i))
+ }
+ }
+ }
+
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ if (!buffer2.isNullAt(0) && buffer2.getInt(0) == 50) {
+ (0 until schema.length).foreach { i =>
+ buffer1.update(i, buffer2.get(i))
+ }
+ }
+ }
+
+ def evaluate(buffer: Row): Any = {
+ Row.fromSeq(buffer.toSeq)
+ }
+}
+
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
@@ -508,6 +550,70 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
assert(errorMessage.contains("implemented based on the new Aggregate Function interface"))
}
}
+
+ test("udaf with all data types") {
+ val struct =
+ StructType(
+ StructField("f1", FloatType, true) ::
+ StructField("f2", ArrayType(BooleanType), true) :: Nil)
+ val dataTypes = Seq(StringType, BinaryType, NullType, BooleanType,
+ ByteType, ShortType, IntegerType, LongType,
+ FloatType, DoubleType, DecimalType(25, 5), DecimalType(6, 5),
+ DateType, TimestampType,
+ ArrayType(IntegerType), MapType(StringType, LongType), struct,
+ new MyDenseVectorUDT())
+ // Right now, we will use SortBasedAggregate to handle UDAFs.
+ // UnsafeRow.mutableFieldTypes.asScala.toSeq will trigger SortBasedAggregate to use
+ // UnsafeRow as the aggregation buffer. While, dataTypes will trigger
+ // SortBasedAggregate to use a safe row as the aggregation buffer.
+ Seq(dataTypes, UnsafeRow.mutableFieldTypes.asScala.toSeq).foreach { dataTypes =>
+ val fields = dataTypes.zipWithIndex.map { case (dataType, index) =>
+ StructField(s"col$index", dataType, nullable = true)
+ }
+ // The schema used for data generator.
+ val schemaForGenerator = StructType(fields)
+ // The schema used for the DataFrame df.
+ val schema = StructType(StructField("id", IntegerType) +: fields)
+
+ logInfo(s"Testing schema: ${schema.treeString}")
+
+ val udaf = new ScalaAggregateFunction(schema)
+ // Generate data at the driver side. We need to materialize the data first and then
+ // create RDD.
+ val maybeDataGenerator =
+ RandomDataGenerator.forType(
+ dataType = schemaForGenerator,
+ nullable = true,
+ seed = Some(System.nanoTime()))
+ val dataGenerator =
+ maybeDataGenerator
+ .getOrElse(fail(s"Failed to create data generator for schema $schemaForGenerator"))
+ val data = (1 to 50).map { i =>
+ dataGenerator.apply() match {
+ case row: Row => Row.fromSeq(i +: row.toSeq)
+ case null => Row.fromSeq(i +: Seq.fill(schemaForGenerator.length)(null))
+ case other =>
+ fail(s"Row or null is expected to be generated, " +
+ s"but a ${other.getClass.getCanonicalName} is generated.")
+ }
+ }
+
+ // Create a DF for the schema with random data.
+ val rdd = sqlContext.sparkContext.parallelize(data, 1)
+ val df = sqlContext.createDataFrame(rdd, schema)
+
+ val allColumns = df.schema.fields.map(f => col(f.name))
+ val expectedAnaswer =
+ data
+ .find(r => r.getInt(0) == 50)
+ .getOrElse(fail("A row with id 50 should be the expected answer."))
+ checkAnswer(
+ df.groupBy().agg(udaf(allColumns: _*)),
+ // udaf returns a Row as the output value.
+ Row(expectedAnaswer)
+ )
+ }
+ }
}
class SortBasedAggregationQuerySuite extends AggregationQuerySuite {