diff options
Diffstat (limited to 'sql')
4 files changed, 195 insertions, 52 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index 3ceb5ecaf6..0cd90866e1 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -158,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod override def asPartial: SplitEvaluation = { val partialCount = Alias(Count(child), "PartialCount")() - SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil) + SplitEvaluation(Coalesce(Seq(Sum(partialCount.toAttribute), Literal(0L))), partialCount :: Nil) } override def newInstance() = new CountFunction(child, this) @@ -285,7 +285,7 @@ case class ApproxCountDistinct(child: Expression, relativeSD: Double = 0.05) case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false + override def nullable = true override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => @@ -299,12 +299,12 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN override def toString = s"AVG($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(child), "PartialSum")() - val partialCount = Alias(Count(child), "PartialCount")() - child.dataType match { case DecimalType.Fixed(_, _) => - // Turn the results to unlimited decimals for the division, before going back to fixed + // Turn the child to unlimited decimals for calculation, before going back to fixed + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), DecimalType.Unlimited) val castedCount = Cast(Sum(partialCount.toAttribute), DecimalType.Unlimited) SplitEvaluation( @@ -312,6 +312,9 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN partialCount :: partialSum :: Nil) case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + val partialCount = Alias(Count(child), "PartialCount")() + val castedSum = Cast(Sum(partialSum.toAttribute), dataType) val castedCount = Cast(Sum(partialCount.toAttribute), dataType) SplitEvaluation( @@ -325,7 +328,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] { - override def nullable = false + override def nullable = true override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => @@ -339,10 +342,19 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ override def toString = s"SUM($child)" override def asPartial: SplitEvaluation = { - val partialSum = Alias(Sum(child), "PartialSum")() - SplitEvaluation( - Sum(partialSum.toAttribute), - partialSum :: Nil) + child.dataType match { + case DecimalType.Fixed(_, _) => + val partialSum = Alias(Sum(Cast(child, DecimalType.Unlimited)), "PartialSum")() + SplitEvaluation( + Cast(Sum(partialSum.toAttribute), dataType), + partialSum :: Nil) + + case _ => + val partialSum = Alias(Sum(child), "PartialSum")() + SplitEvaluation( + Sum(partialSum.toAttribute), + partialSum :: Nil) + } } override def newInstance() = new SumFunction(child, this) @@ -351,7 +363,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[ case class SumDistinct(child: Expression) extends AggregateExpression with trees.UnaryNode[Expression] { - override def nullable = false + override def nullable = true override def dataType = child.dataType match { case DecimalType.Fixed(precision, scale) => @@ -401,16 +413,37 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) def this() = this(null, null) // Required for serialization. - private val zero = Cast(Literal(0), expr.dataType) + private val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + private val zero = Cast(Literal(0), calcType) private var count: Long = _ - private val sum = MutableLiteral(zero.eval(null), expr.dataType) - private val sumAsDouble = Cast(sum, DoubleType) + private val sum = MutableLiteral(zero.eval(null), calcType) - private def addFunction(value: Any) = Add(sum, Literal(value)) + private def addFunction(value: Any) = Add(sum, Cast(Literal(value, expr.dataType), calcType)) - override def eval(input: Row): Any = - sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble + override def eval(input: Row): Any = { + if (count == 0L) { + null + } else { + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(Divide( + Cast(sum, DecimalType.Unlimited), + Cast(Literal(count), DecimalType.Unlimited)), dataType).eval(null) + case _ => + Divide( + Cast(sum, dataType), + Cast(Literal(count), dataType)).eval(null) + } + } + } override def update(input: Row): Unit = { val evaluatedExpr = expr.eval(input) @@ -475,17 +508,31 @@ case class ApproxCountDistinctMergeFunction( case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { def this() = this(null, null) // Required for serialization. - private val zero = Cast(Literal(0), expr.dataType) + private val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + + private val zero = Cast(Literal(0), calcType) - private val sum = MutableLiteral(zero.eval(null), expr.dataType) + private val sum = MutableLiteral(null, calcType) - private val addFunction = Add(sum, Coalesce(Seq(expr, zero))) + private val addFunction = Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(expr, calcType)), sum)) override def update(input: Row): Unit = { sum.update(addFunction, input) } - override def eval(input: Row): Any = sum.eval(null) + override def eval(input: Row): Any = { + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(sum, dataType).eval(null) + case _ => sum.eval(null) + } + } } case class SumDistinctFunction(expr: Expression, base: AggregateExpression) @@ -502,8 +549,16 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } - override def eval(input: Row): Any = - seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) + override def eval(input: Row): Any = { + if (seen.size == 0) { + null + } else { + Cast(Literal( + seen.reduceLeft( + dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)), + dataType).eval(null) + } + } } case class CountDistinctFunction( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala index 087b0ecbb2..18afc5d741 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala @@ -83,29 +83,45 @@ case class GeneratedAggregate( AggregateEvaluation(currentCount :: Nil, initialValue :: Nil, updateFunction :: Nil, result) - case Sum(expr) => - val resultType = expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 10, scale) - case _ => - expr.dataType - } + case s @ Sum(expr) => + val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } - val currentSum = AttributeReference("currentSum", resultType, nullable = false)() - val initialValue = Cast(Literal(0L), resultType) + val currentSum = AttributeReference("currentSum", calcType, nullable = true)() + val initialValue = Literal(null, calcType) // Coalasce avoids double calculation... // but really, common sub expression elimination would be better.... - val updateFunction = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) - val result = currentSum + val zero = Cast(Literal(0), calcType) + val updateFunction = Coalesce( + Add(Coalesce(currentSum :: zero :: Nil), Cast(expr, calcType)) :: currentSum :: Nil) + val result = + expr.dataType match { + case DecimalType.Fixed(_, _) => + Cast(currentSum, s.dataType) + case _ => currentSum + } AggregateEvaluation(currentSum :: Nil, initialValue :: Nil, updateFunction :: Nil, result) case a @ Average(expr) => + val calcType = + expr.dataType match { + case DecimalType.Fixed(_, _) => + DecimalType.Unlimited + case _ => + expr.dataType + } + val currentCount = AttributeReference("currentCount", LongType, nullable = false)() - val currentSum = AttributeReference("currentSum", expr.dataType, nullable = false)() + val currentSum = AttributeReference("currentSum", calcType, nullable = false)() val initialCount = Literal(0L) - val initialSum = Cast(Literal(0L), expr.dataType) + val initialSum = Cast(Literal(0L), calcType) // If we're evaluating UnscaledValue(x), we can do Count on x directly, since its // UnscaledValue will be null if and only if x is null; helps with Average on decimals @@ -115,17 +131,21 @@ case class GeneratedAggregate( } val updateCount = If(IsNotNull(toCount), Add(currentCount, Literal(1L)), currentCount) - val updateSum = Coalesce(Add(expr, currentSum) :: currentSum :: Nil) - - val resultType = expr.dataType match { - case DecimalType.Fixed(precision, scale) => - DecimalType(precision + 4, scale + 4) - case DecimalType.Unlimited => - DecimalType.Unlimited - case _ => - DoubleType - } - val result = Divide(Cast(currentSum, resultType), Cast(currentCount, resultType)) + val updateSum = Coalesce(Add(Cast(expr, calcType), currentSum) :: currentSum :: Nil) + + val result = + expr.dataType match { + case DecimalType.Fixed(_, _) => + If(EqualTo(currentCount, Literal(0L)), + Literal(null, a.dataType), + Cast(Divide( + Cast(currentSum, DecimalType.Unlimited), + Cast(currentCount, DecimalType.Unlimited)), a.dataType)) + case _ => + If(EqualTo(currentCount, Literal(0L)), + Literal(null, a.dataType), + Divide(Cast(currentSum, a.dataType), Cast(currentCount, a.dataType))) + } AggregateEvaluation( currentCount :: currentSum :: Nil, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala index e70ad891ee..94bd97758f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala @@ -156,22 +156,58 @@ class DslQuerySuite extends QueryTest { test("average") { checkAnswer( - testData2.groupBy()(avg('a)), + testData2.aggregate(avg('a)), 2.0) + + checkAnswer( + testData2.aggregate(avg('a), sumDistinct('a)), // non-partial + (2.0, 6.0) :: Nil) + + checkAnswer( + decimalData.aggregate(avg('a)), + BigDecimal(2.0)) + checkAnswer( + decimalData.aggregate(avg('a), sumDistinct('a)), // non-partial + (BigDecimal(2.0), BigDecimal(6)) :: Nil) + + checkAnswer( + decimalData.aggregate(avg('a cast DecimalType(10, 2))), + BigDecimal(2.0)) + checkAnswer( + decimalData.aggregate(avg('a cast DecimalType(10, 2)), sumDistinct('a cast DecimalType(10, 2))), // non-partial + (BigDecimal(2.0), BigDecimal(6)) :: Nil) } test("null average") { checkAnswer( - testData3.groupBy()(avg('b)), + testData3.aggregate(avg('b)), 2.0) checkAnswer( - testData3.groupBy()(avg('b), countDistinct('b)), + testData3.aggregate(avg('b), countDistinct('b)), (2.0, 1) :: Nil) + + checkAnswer( + testData3.aggregate(avg('b), sumDistinct('b)), // non-partial + (2.0, 2.0) :: Nil) + } + + test("zero average") { + checkAnswer( + emptyTableData.aggregate(avg('a)), + null) + + checkAnswer( + emptyTableData.aggregate(avg('a), sumDistinct('b)), // non-partial + (null, null) :: Nil) } test("count") { assert(testData2.count() === testData2.map(_ => 1).count()) + + checkAnswer( + testData2.aggregate(count('a), sumDistinct('a)), // non-partial + (6, 6.0) :: Nil) } test("null count") { @@ -186,13 +222,34 @@ class DslQuerySuite extends QueryTest { ) checkAnswer( - testData3.groupBy()(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), + testData3.aggregate(count('a), count('b), count(1), countDistinct('a), countDistinct('b)), (2, 1, 2, 2, 1) :: Nil ) + + checkAnswer( + testData3.aggregate(count('b), countDistinct('b), sumDistinct('b)), // non-partial + (1, 1, 2) :: Nil + ) } test("zero count") { assert(emptyTableData.count() === 0) + + checkAnswer( + emptyTableData.aggregate(count('a), sumDistinct('a)), // non-partial + (0, null) :: Nil) + } + + test("zero sum") { + checkAnswer( + emptyTableData.aggregate(sum('a)), + null) + } + + test("zero sum distinct") { + checkAnswer( + emptyTableData.aggregate(sumDistinct('a)), + null) } test("except") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala index 92b49e8155..933e027436 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala @@ -54,6 +54,17 @@ object TestData { TestData2(3, 2) :: Nil).toSchemaRDD testData2.registerTempTable("testData2") + case class DecimalData(a: BigDecimal, b: BigDecimal) + val decimalData = + TestSQLContext.sparkContext.parallelize( + DecimalData(1, 1) :: + DecimalData(1, 2) :: + DecimalData(2, 1) :: + DecimalData(2, 2) :: + DecimalData(3, 1) :: + DecimalData(3, 2) :: Nil).toSchemaRDD + decimalData.registerTempTable("decimalData") + case class BinaryData(a: Array[Byte], b: Int) val binaryData = TestSQLContext.sparkContext.parallelize( |