From 1e43851d6455f65b850ea0327d0e92f65395d23f Mon Sep 17 00:00:00 2001 From: Liang-Chi Hsieh Date: Thu, 16 Apr 2015 17:50:20 -0700 Subject: [SPARK-6899][SQL] Fix type mismatch when using codegen with Average on DecimalType JIRA https://issues.apache.org/jira/browse/SPARK-6899 Author: Liang-Chi Hsieh Closes #5517 from viirya/fix_codegen_average and squashes the following commits: 8ae5f65 [Liang-Chi Hsieh] Add the case of DecimalType.Unlimited to Average. --- .../org/apache/spark/sql/catalyst/expressions/aggregates.scala | 2 +- .../src/test/scala/org/apache/spark/sql/DataFrameSuite.scala | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 14a855054b..f3830c6d3b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -326,7 +326,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def asPartial: SplitEvaluation = { child.dataType match { - case DecimalType.Fixed(_, _) => + case DecimalType.Fixed(_, _) | DecimalType.Unlimited => // Turn the child to unlimited decimals for calculation, before going back to fixed val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() val partialCount = Alias(Count(child), "PartialCount")() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 34b2cb054a..44a7d1e7bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -537,4 +537,13 @@ class DataFrameSuite extends QueryTest { val df = TestSQLContext.createDataFrame(rowRDD, schema) df.rdd.collect() } + + test("SPARK-6899") { + val originalValue = TestSQLContext.conf.codegenEnabled + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + checkAnswer( + decimalData.agg(avg('a)), + Row(new java.math.BigDecimal(2.0))) + TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + } } -- cgit v1.2.3