aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2014-05-08 01:08:43 -0400
committerReynold Xin <rxin@apache.org>2014-05-08 01:08:43 -0400
commit19c8fb02bc2c2f76c3c45bfff4b8d093be9d7c66 (patch)
tree9243c147d1d2cb2e4b380ed7801187ea57ee066d /sql/catalyst
parent6ed7e2cd01955adfbb3960e2986b6d19eaee8717 (diff)
downloadspark-19c8fb02bc2c2f76c3c45bfff4b8d093be9d7c66.tar.gz
spark-19c8fb02bc2c2f76c3c45bfff4b8d093be9d7c66.tar.bz2
spark-19c8fb02bc2c2f76c3c45bfff4b8d093be9d7c66.zip
[SQL] Improve SparkSQL Aggregates
* Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum. Author: Michael Armbrust <michael@databricks.com> Closes #683 from marmbrus/aggFixes and squashes the following commits: 64fe30b [Michael Armbrust] Improve SparkSQL Aggregates * Add native min/max (was using hive before). * Handle nulls correctly in Avg and Sum.
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala85
2 files changed, 79 insertions, 10 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
index 8c76a3aa96..b3a3a1ef1b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/SqlParser.scala
@@ -114,6 +114,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
protected val JOIN = Keyword("JOIN")
protected val LEFT = Keyword("LEFT")
protected val LIMIT = Keyword("LIMIT")
+ protected val MAX = Keyword("MAX")
+ protected val MIN = Keyword("MIN")
protected val NOT = Keyword("NOT")
protected val NULL = Keyword("NULL")
protected val ON = Keyword("ON")
@@ -318,6 +320,8 @@ class SqlParser extends StandardTokenParsers with PackratParsers {
COUNT ~> "(" ~> DISTINCT ~> expression <~ ")" ^^ { case exp => CountDistinct(exp :: Nil) } |
FIRST ~> "(" ~> expression <~ ")" ^^ { case exp => First(exp) } |
AVG ~> "(" ~> expression <~ ")" ^^ { case exp => Average(exp) } |
+ MIN ~> "(" ~> expression <~ ")" ^^ { case exp => Min(exp) } |
+ MAX ~> "(" ~> expression <~ ")" ^^ { case exp => Max(exp) } |
IF ~> "(" ~> expression ~ "," ~ expression ~ "," ~ expression <~ ")" ^^ {
case c ~ "," ~ t ~ "," ~ f => If(c,t,f)
} |
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
index b152f95f96..7777d37290 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala
@@ -86,6 +86,67 @@ abstract class AggregateFunction
override def newInstance() = makeCopy(productIterator.map { case a: AnyRef => a }.toArray)
}
+case class Min(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def references = child.references
+ override def nullable = child.nullable
+ override def dataType = child.dataType
+ override def toString = s"MIN($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialMin = Alias(Min(child), "PartialMin")()
+ SplitEvaluation(Min(partialMin.toAttribute), partialMin :: Nil)
+ }
+
+ override def newInstance() = new MinFunction(child, this)
+}
+
+case class MinFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var currentMin: Any = _
+
+ override def update(input: Row): Unit = {
+ if (currentMin == null) {
+ currentMin = expr.eval(input)
+ } else if(GreaterThan(Literal(currentMin, expr.dataType), expr).eval(input) == true) {
+ currentMin = expr.eval(input)
+ }
+ }
+
+ override def eval(input: Row): Any = currentMin
+}
+
+case class Max(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
+ override def references = child.references
+ override def nullable = child.nullable
+ override def dataType = child.dataType
+ override def toString = s"MAX($child)"
+
+ override def asPartial: SplitEvaluation = {
+ val partialMax = Alias(Max(child), "PartialMax")()
+ SplitEvaluation(Max(partialMax.toAttribute), partialMax :: Nil)
+ }
+
+ override def newInstance() = new MaxFunction(child, this)
+}
+
+case class MaxFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
+ def this() = this(null, null) // Required for serialization.
+
+ var currentMax: Any = _
+
+ override def update(input: Row): Unit = {
+ if (currentMax == null) {
+ currentMax = expr.eval(input)
+ } else if(LessThan(Literal(currentMax, expr.dataType), expr).eval(input) == true) {
+ currentMax = expr.eval(input)
+ }
+ }
+
+ override def eval(input: Row): Any = currentMax
+}
+
+
case class Count(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
override def references = child.references
override def nullable = false
@@ -97,7 +158,7 @@ case class Count(child: Expression) extends PartialAggregate with trees.UnaryNod
SplitEvaluation(Sum(partialCount.toAttribute), partialCount :: Nil)
}
- override def newInstance()= new CountFunction(child, this)
+ override def newInstance() = new CountFunction(child, this)
}
case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpression {
@@ -106,7 +167,7 @@ case class CountDistinct(expressions: Seq[Expression]) extends AggregateExpressi
override def nullable = false
override def dataType = IntegerType
override def toString = s"COUNT(DISTINCT ${expressions.mkString(",")}})"
- override def newInstance()= new CountDistinctFunction(expressions, this)
+ override def newInstance() = new CountDistinctFunction(expressions, this)
}
case class Average(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -126,7 +187,7 @@ case class Average(child: Expression) extends PartialAggregate with trees.UnaryN
partialCount :: partialSum :: Nil)
}
- override def newInstance()= new AverageFunction(child, this)
+ override def newInstance() = new AverageFunction(child, this)
}
case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -142,7 +203,7 @@ case class Sum(child: Expression) extends PartialAggregate with trees.UnaryNode[
partialSum :: Nil)
}
- override def newInstance()= new SumFunction(child, this)
+ override def newInstance() = new SumFunction(child, this)
}
case class SumDistinct(child: Expression)
@@ -153,7 +214,7 @@ case class SumDistinct(child: Expression)
override def dataType = child.dataType
override def toString = s"SUM(DISTINCT $child)"
- override def newInstance()= new SumDistinctFunction(child, this)
+ override def newInstance() = new SumDistinctFunction(child, this)
}
case class First(child: Expression) extends PartialAggregate with trees.UnaryNode[Expression] {
@@ -168,7 +229,7 @@ case class First(child: Expression) extends PartialAggregate with trees.UnaryNod
First(partialFirst.toAttribute),
partialFirst :: Nil)
}
- override def newInstance()= new FirstFunction(child, this)
+ override def newInstance() = new FirstFunction(child, this)
}
case class AverageFunction(expr: Expression, base: AggregateExpression)
@@ -176,11 +237,13 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
def this() = this(null, null) // Required for serialization.
+ private val zero = Cast(Literal(0), expr.dataType)
+
private var count: Long = _
- private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(EmptyRow))
+ private val sum = MutableLiteral(zero.eval(EmptyRow))
private val sumAsDouble = Cast(sum, DoubleType)
- private val addFunction = Add(sum, expr)
+ private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
override def eval(input: Row): Any =
sumAsDouble.eval(EmptyRow).asInstanceOf[Double] / count.toDouble
@@ -209,9 +272,11 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
case class SumFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
def this() = this(null, null) // Required for serialization.
- private val sum = MutableLiteral(Cast(Literal(0), expr.dataType).eval(null))
+ private val zero = Cast(Literal(0), expr.dataType)
+
+ private val sum = MutableLiteral(zero.eval(null))
- private val addFunction = Add(sum, expr)
+ private val addFunction = Add(sum, Coalesce(Seq(expr, zero)))
override def update(input: Row): Unit = {
sum.update(addFunction, input)