aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala5
2 files changed, 18 insertions, 17 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index dbc0c2965a..15560a2a93 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -105,17 +105,18 @@ case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- var currentMin: Any = _
+ val currentMin: MutableLiteral = MutableLiteral(null, expr.dataType)
+ val cmp = GreaterThan(currentMin, expr)
override def update(input: Row): Unit = {
- if (currentMin == null) {
- currentMin = expr.eval(input)
- } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
- currentMin = expr.eval(input)
+ if (currentMin.value == null) {
+ currentMin.value = expr.eval(input)
+ } else if(cmp.eval(input) == true) {
+ currentMin.value = expr.eval(input)
}
}
- override def eval(input: Row): Any = currentMin
+ override def eval(input: Row): Any = currentMin.value
}
case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -135,17 +136,18 @@ case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[
case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- var currentMax: Any = _
+ val currentMax: MutableLiteral = MutableLiteral(null, expr.dataType)
+ val cmp = LessThan(currentMax, expr)
override def update(input: Row): Unit = {
- if (currentMax == null) {
- currentMax = expr.eval(input)
- } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
- currentMax = expr.eval(input)
+ if (currentMax.value == null) {
+ currentMax.value = expr.eval(input)
+ } else if(cmp.eval(input) == true) {
+ currentMax.value = expr.eval(input)
}
}
- override def eval(input: Row): Any = currentMax
+ override def eval(input: Row): Any = currentMax.value
}
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -350,7 +352,7 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
private val zero = Cast(Literal(0), expr.dataType)
private var count: Long = _
- private val sum = MutableLiteral(zero.eval(EmptyRow))
+ private val sum = MutableLiteral(zero.eval(null), expr.dataType)
private val sumAsDouble = Cast(sum, DoubleType)
private def addFunction(value: Any) = Add(sum, Literal(value))
@@ -423,7 +425,7 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
private val zero = Cast(Literal(0), expr.dataType)
- private val sum = MutableLiteral(zero.eval(null))
+ private val sum = MutableLiteral(zero.eval(null), expr.dataType)
private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
index a8c2396d62..78a0c55e4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala
@@ -61,11 +61,10 @@ case class Literal(value: Any, dataType: DataType) extends LeafExpression {
}
// TODO: Specialize
-case class MutableLiteral(var value: Any, nullable: Boolean = true) extends LeafExpression {
+case class MutableLiteral(var value: Any, dataType: DataType, nullable: Boolean = true)
+ extends LeafExpression {
type EvaluatedType = Any
- val dataType = Literal(value).dataType
-
def update(expression: Expression, input: Row) = {
value = expression.eval(input)
}