diff options
author | Yin Huai <yhuai@databricks.com> | 2015-12-10 12:03:29 -0800 |
---|---|---|
committer | Yin Huai <yhuai@databricks.com> | 2015-12-10 12:03:29 -0800 |
commit | bc5f56aa60a430244ffa0cacd81c0b1ecbf8d68f (patch) | |
tree | 71dfe00487afad94ad2f83939af1676f55c1cce2 /sql/hive/src | |
parent | d9d354ed40eec56b3f03d32f4e2629d367b1bf02 (diff) | |
download | spark-bc5f56aa60a430244ffa0cacd81c0b1ecbf8d68f.tar.gz spark-bc5f56aa60a430244ffa0cacd81c0b1ecbf8d68f.tar.bz2 spark-bc5f56aa60a430244ffa0cacd81c0b1ecbf8d68f.zip |
[SPARK-12250][SQL] Allow users to define a UDAF without providing details of its inputSchema
https://issues.apache.org/jira/browse/SPARK-12250
Author: Yin Huai <yhuai@databricks.com>
Closes #10236 from yhuai/SPARK-12250.
Diffstat (limited to 'sql/hive/src')
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 64 |
1 files changed, 64 insertions, 0 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 39c0a2a0de..064c0004b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun } } +class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction { + + def inputSchema: StructType = StructType(Nil) + + def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer.update(0, 0L) + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0)) + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) + } + + def evaluate(buffer: Row): Any = { + buffer.getLong(0) + } +} + class LongProductSum extends UserDefinedAggregateFunction { def inputSchema: StructType = new StructType() .add("a", LongType) @@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) } } + + test("udaf without specifying inputSchema") { + withTempTable("noInputSchemaUDAF") { + sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) + + val data = + Row(1, Seq(Row(1), Row(2), Row(3))) :: + Row(1, Seq(Row(4), Row(5), Row(6))) :: + Row(2, Seq(Row(-10))) :: Nil + val schema = + StructType( + StructField("key", IntegerType) :: + StructField("myArray", + ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil) + sqlContext.createDataFrame( + sparkContext.parallelize(data, 2), + schema) + .registerTempTable("noInputSchemaUDAF") + + checkAnswer( + sqlContext.sql( + """ + |SELECT key, noInputSchema(myArray) + |FROM noInputSchemaUDAF + |GROUP BY key + """.stripMargin), + Row(1, 21) :: Row(2, -10) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT noInputSchema(myArray) + |FROM noInputSchemaUDAF + """.stripMargin), + Row(11) :: Nil) + } + } } |