From af92299fdb8d358bbfeb2bf0d88ef7da8e2144b9 Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 27 Apr 2016 21:36:19 +0200 Subject: [SPARK-14664][SQL] Implement DecimalAggregates optimization for Window queries ## What changes were proposed in this pull request? This PR aims to implement decimal aggregation optimization for window queries by improving existing `DecimalAggregates`. Historically, `DecimalAggregates` optimizer is designed to transform general `sum/avg(decimal)`, but it breaks recently added windows queries like the followings. The following queries work well without the current `DecimalAggregates` optimizer. **Sum** ```scala scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").head java.lang.RuntimeException: Unsupported window function: MakeDecimal((sum(UnscaledValue(a#31)),mode=Complete,isDistinct=false),12,1) scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23] : +- INPUT +- Window [MakeDecimal((sum(UnscaledValue(a#21)),mode=Complete,isDistinct=false),12,1) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#23] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#21] +- Scan OneRowRelation[] ``` **Average** ```scala scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").head java.lang.RuntimeException: Unsupported window function: cast(((avg(UnscaledValue(a#40)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5)) scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44] : +- INPUT +- Window [cast(((avg(UnscaledValue(a#42)),mode=Complete,isDistinct=false) / 10.0) as decimal(6,5)) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) AS avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#44] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#42] +- Scan OneRowRelation[] ``` After this PR, those queries work fine and new optimized physical plans look like the followings. **Sum** ```scala scala> sql("select sum(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35] : +- INPUT +- Window [MakeDecimal((sum(UnscaledValue(a#33)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING),12,1) AS sum(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#35] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#33] +- Scan OneRowRelation[] ``` **Average** ```scala scala> sql("select avg(a) over () from (select explode(array(1.0,2.0)) a) t").explain() == Physical Plan == WholeStageCodegen : +- Project [avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47] : +- INPUT +- Window [cast(((avg(UnscaledValue(a#45)),mode=Complete,isDistinct=false) windowspecdefinition(ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING) / 10.0) as decimal(6,5)) AS avg(a) OVER ( ROWS BETWEEN UNBOUNDED PRECEDING AND UNBOUNDED FOLLOWING)#47] +- Exchange SinglePartition, None +- Generate explode([1.0,2.0]), false, false, [a#45] +- Scan OneRowRelation[] ``` In this PR, *SUM over window* pattern matching is based on the code of hvanhovell ; he should be credited for the work he did. ## How was this patch tested? Pass the Jenkins tests (with newly added testcases) Author: Dongjoon Hyun Closes #12421 from dongjoon-hyun/SPARK-14664. --- .../spark/sql/catalyst/optimizer/Optimizer.scala | 40 +++++-- .../optimizer/DecimalAggregatesSuite.scala | 122 +++++++++++++++++++++ .../apache/spark/sql/DataFrameAggregateSuite.scala | 11 +- 3 files changed, 161 insertions(+), 12 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index b26ceba228..54bf4a5293 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1343,17 +1343,35 @@ object DecimalAggregates extends Rule[LogicalPlan] { /** Maximum number of decimal digits representable precisely in a Double */ private val MAX_DOUBLE_DIGITS = 15 - def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { - case ae @ AggregateExpression(Sum(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 10 <= MAX_LONG_DIGITS => - MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) - - case ae @ AggregateExpression(Average(e @ DecimalType.Expression(prec, scale)), _, _, _) - if prec + 4 <= MAX_DOUBLE_DIGITS => - val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) - Cast( - Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), - DecimalType(prec + 4, scale + 4)) + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case q: LogicalPlan => q transformExpressionsDown { + case we @ WindowExpression(ae @ AggregateExpression(af, _, _, _), _) => af match { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(we.copy(windowFunction = ae.copy(aggregateFunction = Sum(UnscaledValue(e)))), + prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = + we.copy(windowFunction = ae.copy(aggregateFunction = Average(UnscaledValue(e)))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case _ => we + } + case ae @ AggregateExpression(af, _, _, _) => af match { + case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS => + MakeDecimal(ae.copy(aggregateFunction = Sum(UnscaledValue(e))), prec + 10, scale) + + case Average(e @ DecimalType.Expression(prec, scale)) if prec + 4 <= MAX_DOUBLE_DIGITS => + val newAggExpr = ae.copy(aggregateFunction = Average(UnscaledValue(e))) + Cast( + Divide(newAggExpr, Literal.create(math.pow(10.0, scale), DoubleType)), + DecimalType(prec + 4, scale + 4)) + + case _ => ae + } + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala new file mode 100644 index 0000000000..711294ed61 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecimalAggregatesSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.DecimalType + +class DecimalAggregatesSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("Decimal Optimizations", FixedPoint(100), + DecimalAggregates) :: Nil + } + + val testRelation = LocalRelation('a.decimal(2, 1), 'b.decimal(12, 1)) + + test("Decimal Sum Aggregation: Optimized") { + val originalQuery = testRelation.select(sum('a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select(MakeDecimal(sum(UnscaledValue('a)), 12, 1).as("sum(a)")).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Sum Aggregation: Not Optimized") { + val originalQuery = testRelation.select(sum('b)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation: Optimized") { + val originalQuery = testRelation.select(avg('a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select((avg(UnscaledValue('a)) / 10.0).cast(DecimalType(6, 5)).as("avg(a)")).analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation: Not Optimized") { + val originalQuery = testRelation.select(avg('b)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Sum Aggregation over Window: Optimized") { + val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(sum('a), spec).as('sum_a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select('a) + .window( + Seq(MakeDecimal(windowExpr(sum(UnscaledValue('a)), spec), 12, 1).as('sum_a)), + Seq('a), + Nil) + .select('a, 'sum_a, 'sum_a) + .select('sum_a) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Sum Aggregation over Window: Not Optimized") { + val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(sum('b), spec)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation over Window: Optimized") { + val spec = windowSpec(Seq('a), Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(avg('a), spec).as('avg_a)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = testRelation + .select('a) + .window( + Seq((windowExpr(avg(UnscaledValue('a)), spec) / 10.0).cast(DecimalType(6, 5)).as('avg_a)), + Seq('a), + Nil) + .select('a, 'avg_a, 'avg_a) + .select('avg_a) + .analyze + + comparePlans(optimized, correctAnswer) + } + + test("Decimal Average Aggregation over Window: Not Optimized") { + val spec = windowSpec('b :: Nil, Nil, UnspecifiedFrame) + val originalQuery = testRelation.select(windowExpr(avg('b), spec)) + val optimized = Optimize.execute(originalQuery.analyze) + val correctAnswer = originalQuery.analyze + + comparePlans(optimized, correctAnswer) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 2f685c5f9c..63f4b759a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData.DecimalData -import org.apache.spark.sql.types.DecimalType +import org.apache.spark.sql.types.{Decimal, DecimalType} case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Double) @@ -430,4 +430,13 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { expr("kurtosis(a)")), Row(null, null, null, null, null)) } + + test("SPARK-14664: Decimal sum/avg over window should work.") { + checkAnswer( + sqlContext.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), + Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) + checkAnswer( + sqlContext.sql("select avg(a) over () from values 1.0, 2.0, 3.0 T(a)"), + Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) + } } -- cgit v1.2.3