aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala13
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala154
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala28
3 files changed, 97 insertions, 98 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
index afe52e6a66..a6fe730f6d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Projection.scala
@@ -19,8 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection, GenerateUnsafeProjection}
-import org.apache.spark.sql.types.{DataType, Decimal, StructType, _}
-import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
+import org.apache.spark.sql.types.{DataType, StructType}
/**
* A [[Projection]] that is calculated by calling the `eval` of each of the specified expressions.
@@ -62,6 +61,8 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
def this(expressions: Seq[Expression], inputSchema: Seq[Attribute]) =
this(expressions.map(BindReferences.bindReference(_, inputSchema)))
+ private[this] val buffer = new Array[Any](expressions.size)
+
expressions.foreach(_.foreach {
case n: Nondeterministic => n.setInitialValues()
case _ =>
@@ -79,7 +80,13 @@ case class InterpretedMutableProjection(expressions: Seq[Expression]) extends Mu
override def apply(input: InternalRow): InternalRow = {
var i = 0
while (i < exprArray.length) {
- mutableRow(i) = exprArray(i).eval(input)
+ // Store the result into buffer first, to make the projection atomic (needed by aggregation)
+ buffer(i) = exprArray(i).eval(input)
+ i += 1
+ }
+ i = 0
+ while (i < exprArray.length) {
+ mutableRow(i) = buffer(i)
i += 1
}
mutableRow
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
index 5d2eb7b017..f2c3eca095 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala
@@ -57,37 +57,37 @@ case class Average(child: Expression) extends DeclarativeAggregate {
case _ => DoubleType
}
- private val currentSum = AttributeReference("currentSum", sumDataType)()
- private val currentCount = AttributeReference("currentCount", LongType)()
+ private val sum = AttributeReference("sum", sumDataType)()
+ private val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = currentSum :: currentCount :: Nil
+ override val aggBufferAttributes = sum :: count :: Nil
override val initialValues = Seq(
- /* currentSum = */ Cast(Literal(0), sumDataType),
- /* currentCount = */ Literal(0L)
+ /* sum = */ Cast(Literal(0), sumDataType),
+ /* count = */ Literal(0L)
)
override val updateExpressions = Seq(
- /* currentSum = */
+ /* sum = */
Add(
- currentSum,
+ sum,
Coalesce(Cast(child, sumDataType) :: Cast(Literal(0), sumDataType) :: Nil)),
- /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ /* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
- /* currentSum = */ currentSum.left + currentSum.right,
- /* currentCount = */ currentCount.left + currentCount.right
+ /* sum = */ sum.left + sum.right,
+ /* count = */ count.left + count.right
)
- // If all input are nulls, currentCount will be 0 and we will get null after the division.
+ // If all input are nulls, count will be 0 and we will get null after the division.
override val evaluateExpression = child.dataType match {
case DecimalType.Fixed(p, s) =>
// increase the precision and scale to prevent precision loss
val dt = DecimalType.bounded(p + 14, s + 4)
- Cast(Cast(currentSum, dt) / Cast(currentCount, dt), resultType)
+ Cast(Cast(sum, dt) / Cast(count, dt), resultType)
case _ =>
- Cast(currentSum, resultType) / Cast(currentCount, resultType)
+ Cast(sum, resultType) / Cast(count, resultType)
}
}
@@ -102,23 +102,23 @@ case class Count(child: Expression) extends DeclarativeAggregate {
// Expected input data type.
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)
- private val currentCount = AttributeReference("currentCount", LongType)()
+ private val count = AttributeReference("count", LongType)()
- override val aggBufferAttributes = currentCount :: Nil
+ override val aggBufferAttributes = count :: Nil
override val initialValues = Seq(
- /* currentCount = */ Literal(0L)
+ /* count = */ Literal(0L)
)
override val updateExpressions = Seq(
- /* currentCount = */ If(IsNull(child), currentCount, currentCount + 1L)
+ /* count = */ If(IsNull(child), count, count + 1L)
)
override val mergeExpressions = Seq(
- /* currentCount = */ currentCount.left + currentCount.right
+ /* count = */ count.left + count.right
)
- override val evaluateExpression = Cast(currentCount, LongType)
+ override val evaluateExpression = Cast(count, LongType)
}
/**
@@ -372,101 +372,77 @@ abstract class StddevAgg(child: Expression) extends DeclarativeAggregate {
private val resultType = DoubleType
- private val preCount = AttributeReference("preCount", resultType)()
- private val currentCount = AttributeReference("currentCount", resultType)()
- private val preAvg = AttributeReference("preAvg", resultType)()
- private val currentAvg = AttributeReference("currentAvg", resultType)()
- private val currentMk = AttributeReference("currentMk", resultType)()
+ private val count = AttributeReference("count", resultType)()
+ private val avg = AttributeReference("avg", resultType)()
+ private val mk = AttributeReference("mk", resultType)()
- override val aggBufferAttributes = preCount :: currentCount :: preAvg ::
- currentAvg :: currentMk :: Nil
+ override val aggBufferAttributes = count :: avg :: mk :: Nil
override val initialValues = Seq(
- /* preCount = */ Cast(Literal(0), resultType),
- /* currentCount = */ Cast(Literal(0), resultType),
- /* preAvg = */ Cast(Literal(0), resultType),
- /* currentAvg = */ Cast(Literal(0), resultType),
- /* currentMk = */ Cast(Literal(0), resultType)
+ /* count = */ Cast(Literal(0), resultType),
+ /* avg = */ Cast(Literal(0), resultType),
+ /* mk = */ Cast(Literal(0), resultType)
)
override val updateExpressions = {
+ val value = Cast(child, resultType)
+ val newCount = count + Cast(Literal(1), resultType)
// update average
// avg = avg + (value - avg)/count
- def avgAdd: Expression = {
- currentAvg + ((Cast(child, resultType) - currentAvg) / currentCount)
- }
+ val newAvg = avg + (value - avg) / newCount
// update sum of square of difference from mean
// Mk = Mk + (value - preAvg) * (value - updatedAvg)
- def mkAdd: Expression = {
- val delta1 = Cast(child, resultType) - preAvg
- val delta2 = Cast(child, resultType) - currentAvg
- currentMk + (delta1 * delta2)
- }
+ val newMk = mk + (value - avg) * (value - newAvg)
Seq(
- /* preCount = */ If(IsNull(child), preCount, currentCount),
- /* currentCount = */ If(IsNull(child), currentCount,
- Add(currentCount, Cast(Literal(1), resultType))),
- /* preAvg = */ If(IsNull(child), preAvg, currentAvg),
- /* currentAvg = */ If(IsNull(child), currentAvg, avgAdd),
- /* currentMk = */ If(IsNull(child), currentMk, mkAdd)
+ /* count = */ If(IsNull(child), count, newCount),
+ /* avg = */ If(IsNull(child), avg, newAvg),
+ /* mk = */ If(IsNull(child), mk, newMk)
)
}
override val mergeExpressions = {
// count merge
- def countMerge: Expression = {
- currentCount.left + currentCount.right
- }
+ val newCount = count.left + count.right
// average merge
- def avgMerge: Expression = {
- ((currentAvg.left * preCount) + (currentAvg.right * currentCount.right)) /
- (preCount + currentCount.right)
- }
+ val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount
// update sum of square differences
- def mkMerge: Expression = {
- val avgDelta = currentAvg.right - preAvg
- val mkDelta = (avgDelta * avgDelta) * (preCount * currentCount.right) /
- (preCount + currentCount.right)
-
- currentMk.left + currentMk.right + mkDelta
+ val newMk = {
+ val avgDelta = avg.right - avg.left
+ val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount
+ mk.left + mk.right + mkDelta
}
Seq(
- /* preCount = */ If(IsNull(currentCount.left),
- Cast(Literal(0), resultType), currentCount.left),
- /* currentCount = */ If(IsNull(currentCount.left), currentCount.right,
- If(IsNull(currentCount.right), currentCount.left, countMerge)),
- /* preAvg = */ If(IsNull(currentAvg.left), Cast(Literal(0), resultType), currentAvg.left),
- /* currentAvg = */ If(IsNull(currentAvg.left), currentAvg.right,
- If(IsNull(currentAvg.right), currentAvg.left, avgMerge)),
- /* currentMk = */ If(IsNull(currentMk.left), currentMk.right,
- If(IsNull(currentMk.right), currentMk.left, mkMerge))
+ /* count = */ If(IsNull(count.left), count.right,
+ If(IsNull(count.right), count.left, newCount)),
+ /* avg = */ If(IsNull(avg.left), avg.right,
+ If(IsNull(avg.right), avg.left, newAvg)),
+ /* mk = */ If(IsNull(mk.left), mk.right,
+ If(IsNull(mk.right), mk.left, newMk))
)
}
override val evaluateExpression = {
- // when currentCount == 0, return null
- // when currentCount == 1, return 0
- // when currentCount >1
- // stddev_samp = sqrt (currentMk/(currentCount -1))
- // stddev_pop = sqrt (currentMk/currentCount)
- val varCol = {
+ // when count == 0, return null
+ // when count == 1, return 0
+ // when count >1
+ // stddev_samp = sqrt (mk/(count -1))
+ // stddev_pop = sqrt (mk/count)
+ val varCol =
if (isSample) {
- currentMk / Cast((currentCount - Cast(Literal(1), resultType)), resultType)
- }
- else {
- currentMk / currentCount
+ mk / Cast((count - Cast(Literal(1), resultType)), resultType)
+ } else {
+ mk / count
}
- }
- If(EqualTo(currentCount, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
- If(EqualTo(currentCount, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
+ If(EqualTo(count, Cast(Literal(0), resultType)), Cast(Literal(null), resultType),
+ If(EqualTo(count, Cast(Literal(1), resultType)), Cast(Literal(0), resultType),
Cast(Sqrt(varCol), resultType)))
}
}
@@ -499,30 +475,30 @@ case class Sum(child: Expression) extends DeclarativeAggregate {
private val sumDataType = resultType
- private val currentSum = AttributeReference("currentSum", sumDataType)()
+ private val sum = AttributeReference("sum", sumDataType)()
private val zero = Cast(Literal(0), sumDataType)
- override val aggBufferAttributes = currentSum :: Nil
+ override val aggBufferAttributes = sum :: Nil
override val initialValues = Seq(
- /* currentSum = */ Literal.create(null, sumDataType)
+ /* sum = */ Literal.create(null, sumDataType)
)
override val updateExpressions = Seq(
- /* currentSum = */
- Coalesce(Seq(Add(Coalesce(Seq(currentSum, zero)), Cast(child, sumDataType)), currentSum))
+ /* sum = */
+ Coalesce(Seq(Add(Coalesce(Seq(sum, zero)), Cast(child, sumDataType)), sum))
)
override val mergeExpressions = {
- val add = Add(Coalesce(Seq(currentSum.left, zero)), Cast(currentSum.right, sumDataType))
+ val add = Add(Coalesce(Seq(sum.left, zero)), Cast(sum.right, sumDataType))
Seq(
- /* currentSum = */
- Coalesce(Seq(add, currentSum.left))
+ /* sum = */
+ Coalesce(Seq(add, sum.left))
)
}
- override val evaluateExpression = Cast(currentSum, resultType)
+ override val evaluateExpression = Cast(sum, resultType)
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index e8ee64756d..4b66069b5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -44,28 +44,42 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
case (NoOp, _) => ""
case (e, i) =>
val evaluationCode = e.gen(ctx)
+ val isNull = s"isNull_$i"
+ val value = s"value_$i"
+ ctx.addMutableState("boolean", isNull, s"this.$isNull = true;")
+ ctx.addMutableState(ctx.javaType(e.dataType), value,
+ s"this.$value = ${ctx.defaultValue(e.dataType)};")
+ s"""
+ ${evaluationCode.code}
+ this.$isNull = ${evaluationCode.isNull};
+ this.$value = ${evaluationCode.value};
+ """
+ }
+ val updates = expressions.zipWithIndex.map {
+ case (NoOp, _) => ""
+ case (e, i) =>
if (e.dataType.isInstanceOf[DecimalType]) {
// Can't call setNullAt on DecimalType, because we need to keep the offset
s"""
- ${evaluationCode.code}
- if (${evaluationCode.isNull}) {
+ if (this.isNull_$i) {
${ctx.setColumn("mutableRow", e.dataType, i, null)};
} else {
- ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
} else {
s"""
- ${evaluationCode.code}
- if (${evaluationCode.isNull}) {
+ if (this.isNull_$i) {
mutableRow.setNullAt($i);
} else {
- ${ctx.setColumn("mutableRow", e.dataType, i, evaluationCode.value)};
+ ${ctx.setColumn("mutableRow", e.dataType, i, s"this.value_$i")};
}
"""
}
}
+
val allProjections = ctx.splitExpressions(ctx.INPUT_ROW, projectionCodes)
+ val allUpdates = ctx.splitExpressions(ctx.INPUT_ROW, updates)
val code = s"""
public Object generate($exprType[] expr) {
@@ -98,6 +112,8 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu
public Object apply(Object _i) {
InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i;
$allProjections
+ // copy all the results into MutableRow
+ $allUpdates
return mutableRow;
}
}