aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-04-16 17:50:20 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-16 17:50:20 -0700
commit1e43851d6455f65b850ea0327d0e92f65395d23f (patch)
tree02aa667acc861cf357bae14bddd9851a8718ef27 /sql
parentd96608674f6c2ff3abb13c65d80c1a3872206710 (diff)
downloadspark-1e43851d6455f65b850ea0327d0e92f65395d23f.tar.gz
spark-1e43851d6455f65b850ea0327d0e92f65395d23f.tar.bz2
spark-1e43851d6455f65b850ea0327d0e92f65395d23f.zip
[SPARK-6899][SQL] Fix type mismatch when using codegen with Average on DecimalType
JIRA https://issues.apache.org/jira/browse/SPARK-6899 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #5517 from viirya/fix_codegen_average and squashes the following commits: 8ae5f65 [Liang-Chi Hsieh] Add the case of DecimalType.Unlimited to Average.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala9
2 files changed, 10 insertions, 1 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index 14a855054b..f3830c6d3b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -326,7 +326,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
override def asPartial: SplitEvaluation = {
child.dataType match {
- case DecimalType.Fixed(_, _) =>
+ case DecimalType.Fixed(_, _) | DecimalType.Unlimited =>
// Turn the child to unlimited decimals for calculation, before going back to fixed
val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")()
val partialCount = Alias(Count(child), "PartialCount")()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 34b2cb054a..44a7d1e7bb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -537,4 +537,13 @@ class DataFrameSuite extends QueryTest {
val df = TestSQLContext.createDataFrame(rowRDD, schema)
df.rdd.collect()
}
+
+ test("SPARK-6899") {
+ val originalValue = TestSQLContext.conf.codegenEnabled
+ TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true")
+ checkAnswer(
+ decimalData.agg(avg('a)),
+ Row(new java.math.BigDecimal(2.0)))
+ TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString)
+ }
}