aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-11-03 11:42:08 +0100
committerMichael Armbrust <michael@databricks.com>2015-11-03 11:42:08 +0100
commit67e23b39ac3cdee06668fa9131951278b9731e29 (patch)
treeeb03786e2392d69c0e6964df7d24739825b35c1c
parentd728d5c98658c44ed2949b55d36edeaa46f8c980 (diff)
downloadspark-67e23b39ac3cdee06668fa9131951278b9731e29.tar.gz
spark-67e23b39ac3cdee06668fa9131951278b9731e29.tar.bz2
spark-67e23b39ac3cdee06668fa9131951278b9731e29.zip
[SPARK-10429] [SQL] make mutableProjection atomic
Right now, SQL's mutable projection updates every value of the mutable project after it evaluates the corresponding expression. This makes the behavior of MutableProjection confusing and complicate the implementation of common aggregate functions like stddev because developers need to be aware that when evaluating {{i+1}}th expression of a mutable projection, {{i}}th slot of the mutable row has already been updated. This PR make the MutableProjection atomic, by generating all the results of expressions first, then copy them into mutableRow. Had run a mircro-benchmark, there is no notable performance difference between using class members and local variables. cc yhuai Author: Davies Liu <davies@databricks.com> Closes #9422 from davies/atomic_mutable and squashes the following commits: bbc1758 [Davies Liu] support wide table 8a0ae14 [Davies Liu] fix bug bec07da [Davies Liu] refactor 2891628 [Davies Liu] make mutableProjection atomic
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala154
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala28
3 files changed, 97 insertions, 98 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index afe52e6a66..a6fe730f6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
@@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+ private[this] val buffer = new Array[Any](expressions.size)
+
expressions.foreach(_.foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
@@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
override def apply(input: InternalRow): InternalRow = {
var i = 0
while (i < exprArray.length) {
- mutableRow(i) = exprArray(i).eval(input)
+ // Store the result into buffer first, to make the projection atomic (needed by aggregation)
+ buffer(i) = exprArray(i).eval(input)
+ i += 1
+ }
+ i = 0
+ while (i < exprArray.length) {
+ mutableRow(i) = buffer(i)
i += 1
}
mutableRow
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 5d2eb7b017..f2c3eca095 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate {
case _ => DoubleType
}
- private val currentSum = AttributeReference("currentSum", sumDataType)()
- private val currentCount = AttributeReference("currentCount", LongType)()
+ private val sum = AttributeReference("sum", sumDataType)()
+ private val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = currentSum :: currentCount :: Nil
+ override val aggBufferAttributes = sum :: count :: Nil
override val initialValues = Seq(
- /* currentSum = */ Cast(Literal(0), sumDataType),
- /* currentCount = */ Literal(0L)
+ /* sum = */ Cast(Literal(0), sumDataType),
+ /* count = */ Literal(0L)
)
override val updateExpressions = Seq(
- /* currentSum = */
+ /* sum = */
Add(
- currentSum,
+ sum,
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
- /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ /* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
- /* currentSum = */ currentSum.left + currentSum.right,
- /* currentCount = */ currentCount.left + currentCount.right
+ /* sum = */ sum.left + sum.right,
+ /* count = */ count.left + count.right
)
- // If all input are nulls, currentCount will be 0 and we will get null after the division.
+ // If all input are nulls, count will be 0 and we will get null after the division.
override val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
- Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
+ Cast(Cast(sum, dt) / Cast(count, dt), resultType)
case _ =>
- Cast(currentSum, resultType) / Cast(currentCount, resultType)
+ Cast(sum, resultType) / Cast(count, resultType)
}
}
@@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val currentCount = AttributeReference("currentCount", LongType)()
+ private val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = currentCount :: Nil
+ override val aggBufferAttributes = count :: Nil
override val initialValues = Seq(
- /* currentCount = */ Literal(0L)
+ /* count = */ Literal(0L)
)
override val updateExpressions = Seq(
- /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ /* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
- /* currentCount = */ currentCount.left + currentCount.right
+ /* count = */ count.left + count.right
)
- override val evaluateExpression = Cast(currentCount, LongType)
+ override val evaluateExpression = Cast(count, LongType)
}
/**
@@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
private val resultType = DoubleType
- private val preCount = AttributeReference("preCount", resultType)()
- private val currentCount = AttributeReference("currentCount", resultType)()
- private val preAvg = AttributeReference("preAvg", resultType)()
- private val currentAvg = AttributeReference("currentAvg", resultType)()
- private val currentMk = AttributeReference("currentMk", resultType)()
+ private val count = AttributeReference("count", resultType)()
+ private val avg = AttributeReference("avg", resultType)()
+ private val mk = AttributeReference("mk", resultType)()
- override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
- currentAvg :: currentMk :: Nil
+ override val aggBufferAttributes = count :: avg :: mk :: Nil
override val initialValues = Seq(
- /* preCount = */ Cast(Literal(0), resultType),
- /* currentCount = */ Cast(Literal(0), resultType),
- /* preAvg = */ Cast(Literal(0), resultType),
- /* currentAvg = */ Cast(Literal(0), resultType),
- /* currentMk = */ Cast(Literal(0), resultType)
+ /* count = */ Cast(Literal(0), resultType),
+ /* avg = */ Cast(Literal(0), resultType),
+ /* mk = */ Cast(Literal(0), resultType)
)
override val updateExpressions = {
+ val value = Cast(child, resultType)
+ val newCount = count + Cast(Literal(1), resultType)
// update average
// avg = avg + (value - avg)/count
- def avgAdd: Expression = {
- currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
- }
+ val newAvg = avg + (value - avg) / newCount
// update sum of square of difference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
- def mkAdd: Expression = {
- val delta1 = Cast(child, resultType) - preAvg
- val delta2 = Cast(child, resultType) - currentAvg
- currentMk + (delta1 * delta2)
- }
+ val newMk = mk + (value - avg) * (value - newAvg)
Seq(
- /* preCount = */ If(IsNull(child), preCount, currentCount),
- /* currentCount = */ If(IsNull(child), currentCount,
- Add(currentCount, Cast(Literal(1), resultType))),
- /* preAvg = */ If(IsNull(child), preAvg, currentAvg),
- /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
- /* currentMk = */ If(IsNull(child), currentMk, mkAdd)
+ /* count = */ If(IsNull(child), count, newCount),
+ /* avg = */ If(IsNull(child), avg, newAvg),
+ /* mk = */ If(IsNull(child), mk, newMk)
)
}
override val mergeExpressions = {
// count merge
- def countMerge: Expression = {
- currentCount.left + currentCount.right
- }
+ val newCount = count.left + count.right
// average merge
- def avgMerge: Expression = {
- ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
- (preCount + currentCount.right)
- }
+ val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount
// update sum of square differences
- def mkMerge: Expression = {
- val avgDelta = currentAvg.right - preAvg
- val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
- (preCount + currentCount.right)
-
- currentMk.left + currentMk.right + mkDelta
+ val newMk = {
+ val avgDelta = avg.right - avg.left
+ val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
+ mk.left + mk.right + mkDelta
}
Seq(
- /* preCount = */ If(IsNull(currentCount.left),
- Cast(Literal(0), resultType), currentCount.left),
- /* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
- If(IsNull(currentCount.right), currentCount.left, countMerge)),
- /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
- /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
- If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
- /* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
- If(IsNull(currentMk.right), currentMk.left, mkMerge))
+ /* count = */ If(IsNull(count.left), count.right,
+ If(IsNull(count.right), count.left, newCount)),
+ /* avg = */ If(IsNull(avg.left), avg.right,
+ If(IsNull(avg.right), avg.left, newAvg)),
+ /* mk = */ If(IsNull(mk.left), mk.right,
+ If(IsNull(mk.right), mk.left, newMk))
)
}
override val evaluateExpression = {
- // when currentCount == 0, return null
- // when currentCount == 1, return 0
- // when currentCount >1
- // stddev_samp = sqrt (currentMk/(currentCount -1))
- // stddev_pop = sqrt (currentMk/currentCount)
- val varCol = {
+ // when count == 0, return null
+ // when count == 1, return 0
+ // when count >1
+ // stddev_samp = sqrt (mk/(count -1))
+ // stddev_pop = sqrt (mk/count)
+ val varCol =
if (isSample) {
- currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType)
- }
- else {
- currentMk / currentCount
+ mk / Cast((count - Cast(Literal(1), resultType)), resultType)
+ } else {
+ mk / count
}
- }
- If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
- If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
+ If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
+ If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
}
}
@@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
private val sumDataType = resultType
- private val currentSum = AttributeReference("currentSum", sumDataType)()
+ private val sum = AttributeReference("sum", sumDataType)()
private val zero = Cast(Literal(0), sumDataType)
- override val aggBufferAttributes = currentSum :: Nil
+ override val aggBufferAttributes = sum :: Nil
override val initialValues = Seq(
- /* currentSum = */ Literal.create(null, sumDataType)
+ /* sum = */ Literal.create(null, sumDataType)
)
override val updateExpressions = Seq(
- /* currentSum = */
- Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
+ /* sum = */
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
override val mergeExpressions = {
- val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
+ val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
- /* currentSum = */
- Coalesce(Seq(add, currentSum.left))
+ /* sum = */
+ Coalesce(Seq(add, sum.left))
)
}
- override val evaluateExpression = Cast(currentSum, resultType)
+ override val evaluateExpression = Cast(sum, resultType)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e8ee64756d..4b66069b5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
+ val isNull = s"isNull_$i"
+ val value = s"value_$i"
+ ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$isNull = ${evaluationCode.isNull};
+ this.$value = ${evaluationCode.value};
+ """
+ }
+ val updates = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
if (e.dataType.isInstanceOf[DecimalType]) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
- ${evaluationCode.code}
- if (${evaluationCode.isNull}) {
+ if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, null)};
} else {
- ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
} else {
s"""
- ${evaluationCode.code}
- if (${evaluationCode.isNull}) {
+ if (this.isNull_$i) {
mutableRow.setNullAt($i);
} else {
- ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
}
}
+
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
+ val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
val code = s"""
public Object generate($exprType[] expr) {
@@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
public Object apply(Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$allProjections
+ // copy all the results into MutableRow
+ $allUpdates
return mutableRow;
}
}