diff options
Diffstat (limited to 'sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala')
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala | 40 |
1 files changed, 29 insertions, 11 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b26ceba228..54bf4a5293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1343,17 +1343,35 @@ object DecimalAggregates extends Rule[LogicalPlan] { /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - - case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) - Cast( - Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), + prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = + we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case _ => we + } + case ae @ AggregateExpression(af, _, _, _) => af match { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case _ => ae + } + } } } |