aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala45
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala24
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala14
6 files changed, 104 insertions, 4 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
index 1f6526ef66..566b34f7c3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala
@@ -369,6 +369,51 @@ case class MaxOf(left: Expression, right: Expression) extends Expression {
override def toString: String = s"MaxOf($left, $right)"
}
+case class MinOf(left: Expression, right: Expression) extends Expression {
+ type EvaluatedType = Any
+
+ override def foldable: Boolean = left.foldable && right.foldable
+
+ override def nullable: Boolean = left.nullable && right.nullable
+
+ override def children: Seq[Expression] = left :: right :: Nil
+
+ override lazy val resolved =
+ left.resolved && right.resolved &&
+ left.dataType == right.dataType
+
+ override def dataType: DataType = {
+ if (!resolved) {
+ throw new UnresolvedException(this,
+ s"datatype. Can not resolve due to differing types ${left.dataType}, ${right.dataType}")
+ }
+ left.dataType
+ }
+
+ lazy val ordering = left.dataType match {
+ case i: NativeType => i.ordering.asInstanceOf[Ordering[Any]]
+ case other => sys.error(s"Type $other does not support ordered operations")
+ }
+
+ override def eval(input: Row): Any = {
+ val evalE1 = left.eval(input)
+ val evalE2 = right.eval(input)
+ if (evalE1 == null) {
+ evalE2
+ } else if (evalE2 == null) {
+ evalE1
+ } else {
+ if (ordering.compare(evalE1, evalE2) < 0) {
+ evalE1
+ } else {
+ evalE2
+ }
+ }
+ }
+
+ override def toString: String = s"MinOf($left, $right)"
+}
+
/**
* A function that get the absolute value of the numeric value.
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index aac56e1568..d141354a0f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -524,6 +524,30 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
}
""".children
+ case MinOf(e1, e2) =>
+ val eval1 = expressionEvaluator(e1)
+ val eval2 = expressionEvaluator(e2)
+
+ eval1.code ++ eval2.code ++
+ q"""
+ var $nullTerm = false
+ var $primitiveTerm: ${termForType(e1.dataType)} = ${defaultPrimitive(e1.dataType)}
+
+ if (${eval1.nullTerm}) {
+ $nullTerm = ${eval2.nullTerm}
+ $primitiveTerm = ${eval2.primitiveTerm}
+ } else if (${eval2.nullTerm}) {
+ $nullTerm = ${eval1.nullTerm}
+ $primitiveTerm = ${eval1.primitiveTerm}
+ } else {
+ if (${eval1.primitiveTerm} < ${eval2.primitiveTerm}) {
+ $primitiveTerm = ${eval1.primitiveTerm}
+ } else {
+ $primitiveTerm = ${eval2.primitiveTerm}
+ }
+ }
+ """.children
+
case UnscaledValue(child) =>
val childEval = expressionEvaluator(child)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
index d2b1090a0c..d4362a91d9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvaluationSuite.scala
@@ -233,6 +233,16 @@ class ExpressionEvaluationSuite extends ExpressionEvaluationBaseSuite {
checkEvaluation(MaxOf(2, Literal.create(null, IntegerType)), 2)
}
+ test("MinOf") {
+ checkEvaluation(MinOf(1, 2), 1)
+ checkEvaluation(MinOf(2, 1), 1)
+ checkEvaluation(MinOf(1L, 2L), 1L)
+ checkEvaluation(MinOf(2L, 1L), 1L)
+
+ checkEvaluation(MinOf(Literal.create(null, IntegerType), 1), 1)
+ checkEvaluation(MinOf(1, Literal.create(null, IntegerType)), 1)
+ }
+
test("LIKE literal Regular Expression") {
checkEvaluation(Literal.create(null, StringType).like("a"), null)
checkEvaluation(Literal.create("a", StringType).like(Literal.create(null, StringType)), null)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index b510cf033c..b1ef6556de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -164,6 +164,17 @@ case class GeneratedAggregate(
updateMax :: Nil,
currentMax)
+ case m @ Min(expr) =>
+ val currentMin = AttributeReference("currentMin", expr.dataType, nullable = true)()
+ val initialValue = Literal.create(null, expr.dataType)
+ val updateMin = MinOf(currentMin, expr)
+
+ AggregateEvaluation(
+ currentMin :: Nil,
+ initialValue :: Nil,
+ updateMin :: Nil,
+ currentMin)
+
case CollectHashSet(Seq(expr)) =>
val set =
AttributeReference("hashSet", new OpenHashSetUDT(expr.dataType), nullable = false)()
@@ -188,6 +199,8 @@ case class GeneratedAggregate(
initialValue :: Nil,
collectSets :: Nil,
CountSet(set))
+
+ case o => sys.error(s"$o can't be codegened.")
}
val computationSchema = computeFunctions.flatMap(_.schema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index f0d92ffffc..5b99e40c2f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -155,7 +155,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
def canBeCodeGened(aggs: Seq[AggregateExpression]): Boolean = !aggs.exists {
- case _: CombineSum | _: Sum | _: Count | _: Max | _: CombineSetsAndCount => false
+ case _: CombineSum | _: Sum | _: Count | _: Max | _: Min | _: CombineSetsAndCount => false
// The generated set implementation is pretty limited ATM.
case CollectHashSet(exprs) if exprs.size == 1 &&
Seq(IntegerType, LongType).contains(exprs.head.dataType) => false
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5e453e05e2..73fb791c3e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -172,6 +172,13 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
testCodeGen(
"SELECT max(key) FROM testData3x",
Row(100) :: Nil)
+ // MIN
+ testCodeGen(
+ "SELECT value, min(key) FROM testData3x GROUP BY value",
+ (1 to 100).map(i => Row(i.toString, i)))
+ testCodeGen(
+ "SELECT min(key) FROM testData3x",
+ Row(1) :: Nil)
// Some combinations.
testCodeGen(
"""
@@ -179,16 +186,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll {
| value,
| sum(key),
| max(key),
+ | min(key),
| avg(key),
| count(key),
| count(distinct key)
|FROM testData3x
|GROUP BY value
""".stripMargin,
- (1 to 100).map(i => Row(i.toString, i*3, i, i, 3, 1)))
+ (1 to 100).map(i => Row(i.toString, i*3, i, i, i, 3, 1)))
testCodeGen(
- "SELECT max(key), avg(key), count(key), count(distinct key) FROM testData3x",
- Row(100, 50.5, 300, 100) :: Nil)
+ "SELECT max(key), min(key), avg(key), count(key), count(distinct key) FROM testData3x",
+ Row(100, 1, 50.5, 300, 100) :: Nil)
// Aggregate with Code generation handling all null values
testCodeGen(
"SELECT sum('a'), avg('a'), count(null) FROM testData",