diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala | 5 | ||||
-rw-r--r-- | sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala | 64 |
2 files changed, 64 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 20359c1e54..c0d00104e8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF( override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate = copy(inputAggBufferOffset = newInputAggBufferOffset) - require( - children.length == udaf.inputSchema.length, - s"$udaf only accepts ${udaf.inputSchema.length} arguments, " + - s"but ${children.length} are provided.") - override def nullable: Boolean = true override def dataType: DataType = udaf.dataType diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala index 39c0a2a0de..064c0004b8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala @@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun } } +class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction { + + def inputSchema: StructType = StructType(Nil) + + def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil) + + def dataType: DataType = LongType + + def deterministic: Boolean = true + + def initialize(buffer: MutableAggregationBuffer): Unit = { + buffer.update(0, 0L) + } + + def update(buffer: MutableAggregationBuffer, input: Row): Unit = { + buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0)) + } + + def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { + buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0)) + } + + def evaluate(buffer: Row): Any = { + buffer.getLong(0) + } +} + class LongProductSum extends UserDefinedAggregateFunction { def inputSchema: StructType = new StructType() .add("a", LongType) @@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te ) } } + + test("udaf without specifying inputSchema") { + withTempTable("noInputSchemaUDAF") { + sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema) + + val data = + Row(1, Seq(Row(1), Row(2), Row(3))) :: + Row(1, Seq(Row(4), Row(5), Row(6))) :: + Row(2, Seq(Row(-10))) :: Nil + val schema = + StructType( + StructField("key", IntegerType) :: + StructField("myArray", + ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil) + sqlContext.createDataFrame( + sparkContext.parallelize(data, 2), + schema) + .registerTempTable("noInputSchemaUDAF") + + checkAnswer( + sqlContext.sql( + """ + |SELECT key, noInputSchema(myArray) + |FROM noInputSchemaUDAF + |GROUP BY key + """.stripMargin), + Row(1, 21) :: Row(2, -10) :: Nil) + + checkAnswer( + sqlContext.sql( + """ + |SELECT noInputSchema(myArray) + |FROM noInputSchemaUDAF + """.stripMargin), + Row(11) :: Nil) + } + } } |