aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala64
2 files changed, 64 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 20359c1e54..c0d00104e8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -332,11 +332,6 @@ private[sql] case class ScalaUDAF(
override def withNewInputAggBufferOffset(newInputAggBufferOffset: Int): ImperativeAggregate =
copy(inputAggBufferOffset = newInputAggBufferOffset)
- require(
- children.length == udaf.inputSchema.length,
- s"$udaf only accepts ${udaf.inputSchema.length} arguments, " +
- s"but ${children.length} are provided.")
-
override def nullable: Boolean = true
override def dataType: DataType = udaf.dataType
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 39c0a2a0de..064c0004b8 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -66,6 +66,33 @@ class ScalaAggregateFunction(schema: StructType) extends UserDefinedAggregateFun
}
}
+class ScalaAggregateFunctionWithoutInputSchema extends UserDefinedAggregateFunction {
+
+ def inputSchema: StructType = StructType(Nil)
+
+ def bufferSchema: StructType = StructType(StructField("value", LongType) :: Nil)
+
+ def dataType: DataType = LongType
+
+ def deterministic: Boolean = true
+
+ def initialize(buffer: MutableAggregationBuffer): Unit = {
+ buffer.update(0, 0L)
+ }
+
+ def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
+ buffer.update(0, input.getAs[Seq[Row]](0).map(_.getAs[Int]("v")).sum + buffer.getLong(0))
+ }
+
+ def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
+ buffer1.update(0, buffer1.getLong(0) + buffer2.getLong(0))
+ }
+
+ def evaluate(buffer: Row): Any = {
+ buffer.getLong(0)
+ }
+}
+
class LongProductSum extends UserDefinedAggregateFunction {
def inputSchema: StructType = new StructType()
.add("a", LongType)
@@ -858,6 +885,43 @@ abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with Te
)
}
}
+
+ test("udaf without specifying inputSchema") {
+ withTempTable("noInputSchemaUDAF") {
+ sqlContext.udf.register("noInputSchema", new ScalaAggregateFunctionWithoutInputSchema)
+
+ val data =
+ Row(1, Seq(Row(1), Row(2), Row(3))) ::
+ Row(1, Seq(Row(4), Row(5), Row(6))) ::
+ Row(2, Seq(Row(-10))) :: Nil
+ val schema =
+ StructType(
+ StructField("key", IntegerType) ::
+ StructField("myArray",
+ ArrayType(StructType(StructField("v", IntegerType) :: Nil))) :: Nil)
+ sqlContext.createDataFrame(
+ sparkContext.parallelize(data, 2),
+ schema)
+ .registerTempTable("noInputSchemaUDAF")
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT key, noInputSchema(myArray)
+ |FROM noInputSchemaUDAF
+ |GROUP BY key
+ """.stripMargin),
+ Row(1, 21) :: Row(2, -10) :: Nil)
+
+ checkAnswer(
+ sqlContext.sql(
+ """
+ |SELECT noInputSchema(myArray)
+ |FROM noInputSchemaUDAF
+ """.stripMargin),
+ Row(11) :: Nil)
+ }
+ }
}