diff options
author | Davies Liu <davies@databricks.com> | 2015-11-03 11:42:08 +0100 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-03 11:42:08 +0100 |
commit | 67e23b39ac3cdee06668fa9131951278b9731e29 (patch) | |
tree | eb03786e2392d69c0e6964df7d24739825b35c1c | |
parent | d728d5c98658c44ed2949b55d36edeaa46f8c980 (diff) | |
download | spark-67e23b39ac3cdee06668fa9131951278b9731e29.tar.gz spark-67e23b39ac3cdee06668fa9131951278b9731e29.tar.bz2 spark-67e23b39ac3cdee06668fa9131951278b9731e29.zip |
[SPARK-10429] [SQL] make mutableProjection atomic
Right now, SQL's mutable projection updates every value of the mutable project after it evaluates the corresponding expression. This makes the behavior of MutableProjection confusing and complicate the implementation of common aggregate functions like stddev because developers need to be aware that when evaluating {{i+1}}th expression of a mutable projection, {{i}}th slot of the mutable row has already been updated.
This PR make the MutableProjection atomic, by generating all the results of expressions first, then copy them into mutableRow.
Had run a mircro-benchmark, there is no notable performance difference between using class members and local variables.
cc yhuai
Author: Davies Liu <davies@databricks.com>
Closes #9422 from davies/atomic_mutable and squashes the following commits:
bbc1758 [Davies Liu] support wide table
8a0ae14 [Davies Liu] fix bug
bec07da [Davies Liu] refactor
2891628 [Davies Liu] make mutableProjection atomic
3 files changed, 97 insertions, 98 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala index afe52e6a66..a6fe730f6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala @@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection} -import org.apache.spark.sql.types.{DataType, Decimal, StructType, _} -import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} +import org.apache.spark.sql.types.{DataType, StructType} /** * A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions. @@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) = this(expressions.map(BindReferences.bindReference(_, inputSchema))) + private[this] val buffer = new Array[Any](expressions.size) + expressions.foreach(_.foreach { case n: Nondeterministic => n.setInitialValues() case _ => @@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu override def apply(input: InternalRow): InternalRow = { var i = 0 while (i < exprArray.length) { - mutableRow(i) = exprArray(i).eval(input) + // Store the result into buffer first, to make the projection atomic (needed by aggregation) + buffer(i) = exprArray(i).eval(input) + i += 1 + } + i = 0 + while (i < exprArray.length) { + mutableRow(i) = buffer(i) i += 1 } mutableRow diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 5d2eb7b017..f2c3eca095 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate { case _ => DoubleType } - private val currentSum = AttributeReference("currentSum", sumDataType)() - private val currentCount = AttributeReference("currentCount", LongType)() + private val sum = AttributeReference("sum", sumDataType)() + private val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = currentSum :: currentCount :: Nil + override val aggBufferAttributes = sum :: count :: Nil override val initialValues = Seq( - /* currentSum = */ Cast(Literal(0), sumDataType), - /* currentCount = */ Literal(0L) + /* sum = */ Cast(Literal(0), sumDataType), + /* count = */ Literal(0L) ) override val updateExpressions = Seq( - /* currentSum = */ + /* sum = */ Add( - currentSum, + sum, Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)), - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + /* count = */ If(IsNull(child), count, count + 1L) ) override val mergeExpressions = Seq( - /* currentSum = */ currentSum.left + currentSum.right, - /* currentCount = */ currentCount.left + currentCount.right + /* sum = */ sum.left + sum.right, + /* count = */ count.left + count.right ) - // If all input are nulls, currentCount will be 0 and we will get null after the division. + // If all input are nulls, count will be 0 and we will get null after the division. override val evaluateExpression = child.dataType match { case DecimalType.Fixed(p, s) => // increase the precision and scale to prevent precision loss val dt = DecimalType.bounded(p + 14, s + 4) - Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType) + Cast(Cast(sum, dt) / Cast(count, dt), resultType) case _ => - Cast(currentSum, resultType) / Cast(currentCount, resultType) + Cast(sum, resultType) / Cast(count, resultType) } } @@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate { // Expected input data type. override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) - private val currentCount = AttributeReference("currentCount", LongType)() + private val count = AttributeReference("count", LongType)() - override val aggBufferAttributes = currentCount :: Nil + override val aggBufferAttributes = count :: Nil override val initialValues = Seq( - /* currentCount = */ Literal(0L) + /* count = */ Literal(0L) ) override val updateExpressions = Seq( - /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L) + /* count = */ If(IsNull(child), count, count + 1L) ) override val mergeExpressions = Seq( - /* currentCount = */ currentCount.left + currentCount.right + /* count = */ count.left + count.right ) - override val evaluateExpression = Cast(currentCount, LongType) + override val evaluateExpression = Cast(count, LongType) } /** @@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { private val resultType = DoubleType - private val preCount = AttributeReference("preCount", resultType)() - private val currentCount = AttributeReference("currentCount", resultType)() - private val preAvg = AttributeReference("preAvg", resultType)() - private val currentAvg = AttributeReference("currentAvg", resultType)() - private val currentMk = AttributeReference("currentMk", resultType)() + private val count = AttributeReference("count", resultType)() + private val avg = AttributeReference("avg", resultType)() + private val mk = AttributeReference("mk", resultType)() - override val aggBufferAttributes = preCount :: currentCount :: preAvg :: - currentAvg :: currentMk :: Nil + override val aggBufferAttributes = count :: avg :: mk :: Nil override val initialValues = Seq( - /* preCount = */ Cast(Literal(0), resultType), - /* currentCount = */ Cast(Literal(0), resultType), - /* preAvg = */ Cast(Literal(0), resultType), - /* currentAvg = */ Cast(Literal(0), resultType), - /* currentMk = */ Cast(Literal(0), resultType) + /* count = */ Cast(Literal(0), resultType), + /* avg = */ Cast(Literal(0), resultType), + /* mk = */ Cast(Literal(0), resultType) ) override val updateExpressions = { + val value = Cast(child, resultType) + val newCount = count + Cast(Literal(1), resultType) // update average // avg = avg + (value - avg)/count - def avgAdd: Expression = { - currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount) - } + val newAvg = avg + (value - avg) / newCount // update sum of square of difference from mean // Mk = Mk + (value - preAvg) * (value - updatedAvg) - def mkAdd: Expression = { - val delta1 = Cast(child, resultType) - preAvg - val delta2 = Cast(child, resultType) - currentAvg - currentMk + (delta1 * delta2) - } + val newMk = mk + (value - avg) * (value - newAvg) Seq( - /* preCount = */ If(IsNull(child), preCount, currentCount), - /* currentCount = */ If(IsNull(child), currentCount, - Add(currentCount, Cast(Literal(1), resultType))), - /* preAvg = */ If(IsNull(child), preAvg, currentAvg), - /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd), - /* currentMk = */ If(IsNull(child), currentMk, mkAdd) + /* count = */ If(IsNull(child), count, newCount), + /* avg = */ If(IsNull(child), avg, newAvg), + /* mk = */ If(IsNull(child), mk, newMk) ) } override val mergeExpressions = { // count merge - def countMerge: Expression = { - currentCount.left + currentCount.right - } + val newCount = count.left + count.right // average merge - def avgMerge: Expression = { - ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) / - (preCount + currentCount.right) - } + val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount // update sum of square differences - def mkMerge: Expression = { - val avgDelta = currentAvg.right - preAvg - val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) / - (preCount + currentCount.right) - - currentMk.left + currentMk.right + mkDelta + val newMk = { + val avgDelta = avg.right - avg.left + val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount + mk.left + mk.right + mkDelta } Seq( - /* preCount = */ If(IsNull(currentCount.left), - Cast(Literal(0), resultType), currentCount.left), - /* currentCount = */ If(IsNull(currentCount.left), currentCount.right, - If(IsNull(currentCount.right), currentCount.left, countMerge)), - /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left), - /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right, - If(IsNull(currentAvg.right), currentAvg.left, avgMerge)), - /* currentMk = */ If(IsNull(currentMk.left), currentMk.right, - If(IsNull(currentMk.right), currentMk.left, mkMerge)) + /* count = */ If(IsNull(count.left), count.right, + If(IsNull(count.right), count.left, newCount)), + /* avg = */ If(IsNull(avg.left), avg.right, + If(IsNull(avg.right), avg.left, newAvg)), + /* mk = */ If(IsNull(mk.left), mk.right, + If(IsNull(mk.right), mk.left, newMk)) ) } override val evaluateExpression = { - // when currentCount == 0, return null - // when currentCount == 1, return 0 - // when currentCount >1 - // stddev_samp = sqrt (currentMk/(currentCount -1)) - // stddev_pop = sqrt (currentMk/currentCount) - val varCol = { + // when count == 0, return null + // when count == 1, return 0 + // when count >1 + // stddev_samp = sqrt (mk/(count -1)) + // stddev_pop = sqrt (mk/count) + val varCol = if (isSample) { - currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType) - } - else { - currentMk / currentCount + mk / Cast((count - Cast(Literal(1), resultType)), resultType) + } else { + mk / count } - } - If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), - If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), + If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType), + If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType), Cast(Sqrt(varCol), resultType))) } } @@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate { private val sumDataType = resultType - private val currentSum = AttributeReference("currentSum", sumDataType)() + private val sum = AttributeReference("sum", sumDataType)() private val zero = Cast(Literal(0), sumDataType) - override val aggBufferAttributes = currentSum :: Nil + override val aggBufferAttributes = sum :: Nil override val initialValues = Seq( - /* currentSum = */ Literal.create(null, sumDataType) + /* sum = */ Literal.create(null, sumDataType) ) override val updateExpressions = Seq( - /* currentSum = */ - Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum)) + /* sum = */ + Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum)) ) override val mergeExpressions = { - val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType)) + val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType)) Seq( - /* currentSum = */ - Coalesce(Seq(add, currentSum.left)) + /* sum = */ + Coalesce(Seq(add, sum.left)) ) } - override val evaluateExpression = Cast(currentSum, resultType) + override val evaluateExpression = Cast(sum, resultType) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index e8ee64756d..4b66069b5f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu case (NoOp, _) => "" case (e, i) => val evaluationCode = e.gen(ctx) + val isNull = s"isNull_$i" + val value = s"value_$i" + ctx.addMutableState("boolean", isNull, s"this.$isNull = true;") + ctx.addMutableState(ctx.javaType(e.dataType), value, + s"this.$value = ${ctx.defaultValue(e.dataType)};") + s""" + ${evaluationCode.code} + this.$isNull = ${evaluationCode.isNull}; + this.$value = ${evaluationCode.value}; + """ + } + val updates = expressions.zipWithIndex.map { + case (NoOp, _) => "" + case (e, i) => if (e.dataType.isInstanceOf[DecimalType]) { // Can't call setNullAt on DecimalType, because we need to keep the offset s""" - ${evaluationCode.code} - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { ${ctx.setColumn("mutableRow", e.dataType, i, null)}; } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ } else { s""" - ${evaluationCode.code} - if (${evaluationCode.isNull}) { + if (this.isNull_$i) { mutableRow.setNullAt($i); } else { - ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)}; + ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")}; } """ } } + val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes) + val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates) val code = s""" public Object generate($exprType[] expr) { @@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public Object apply(Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; $allProjections + // copy all the results into MutableRow + $allUpdates return mutableRow; } } |