aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala40
1 files changed, 29 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index b26ceba228..54bf4a5293 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -1343,17 +1343,35 @@ object DecimalAggregates extends Rule[LogicalPlan] {
/** Maximum number of decimal digits representable precisely in a Double */
private val MAX_DOUBLE_DIGITS = 15
- def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
- case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _)
- if prec + 10 <= MAX_LONG_DIGITS =>
- MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
-
- case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _)
- if prec + 4 <= MAX_DOUBLE_DIGITS =>
- val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
- Cast(
- Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
- DecimalType(prec + 4, scale + 4))
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case q: LogicalPlan => q transformExpressionsDown {
+ case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match {
+ case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
+ MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))),
+ prec + 10, scale)
+
+ case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ val newAggExpr =
+ we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e))))
+ Cast(
+ Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
+ DecimalType(prec + 4, scale + 4))
+
+ case _ => we
+ }
+ case ae @ AggregateExpression(af, _, _, _) => af match {
+ case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
+ MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale)
+
+ case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS =>
+ val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e)))
+ Cast(
+ Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)),
+ DecimalType(prec + 4, scale + 4))
+
+ case _ => ae
+ }
+ }
}
}