diff options
author | Daoyuan Wang <daoyuan.wang@intel.com> | 2015-07-13 00:14:32 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-07-13 00:14:32 -0700 |
commit | 92540d22e45f9300f413f520a1770e9f36cfa833 (patch) | |
tree | 793bb6ebaeaa52381cf0966c47038a7a2d4f7b40 /sql | |
parent | 20b474335c68c644150fdc8443a2d0d2dad5e27d (diff) | |
download | spark-92540d22e45f9300f413f520a1770e9f36cfa833.tar.gz spark-92540d22e45f9300f413f520a1770e9f36cfa833.tar.bz2 spark-92540d22e45f9300f413f520a1770e9f36cfa833.zip |
[SPARK-8203] [SPARK-8204] [SQL] conditional function: least/greatest
chenghao-intel zhichao-li qiansl127
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Closes #6851 from adrian-wang/udflg and squashes the following commits:
0f1bff2 [Daoyuan Wang] address comments from davis
7a6bdbb [Daoyuan Wang] add '.' for hex()
c1f6824 [Daoyuan Wang] add codegen, test for all types
ec625b0 [Daoyuan Wang] conditional function: least/greatest
Diffstat (limited to 'sql')
5 files changed, 263 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index f62d79f8ce..ed69c42dcb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -76,9 +76,11 @@ object FunctionRegistry { expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), + expression[Greatest]("greatest"), expression[If]("if"), expression[IsNull]("isnull"), expression[IsNotNull]("isnotnull"), + expression[Least]("least"), expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala index 395e84f089..e6a705fb80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen._ -import org.apache.spark.sql.types.{BooleanType, DataType} +import org.apache.spark.sql.catalyst.util.TypeUtils +import org.apache.spark.sql.types.{NullType, BooleanType, DataType} case class If(predicate: Expression, trueValue: Expression, falseValue: Expression) @@ -312,3 +313,103 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW }.mkString } } + +case class Least(children: Expression*) extends Expression { + require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length) + + override def nullable: Boolean = children.forall(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + private lazy val ordering = TypeUtils.getOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got LEAST (${children.map(_.dataType)}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + } + } + + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { + children.foldLeft[Any](null)((r, c) => { + val evalc = c.eval(input) + if (evalc != null) { + if (r == null || ordering.lt(evalc, r)) evalc else r + } else { + r + } + }) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evalChildren = children.map(_.gen(ctx)) + def updateEval(i: Int): String = + s""" + if (!${evalChildren(i).isNull} && (${ev.isNull} || + ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) { + ${ev.isNull} = false; + ${ev.primitive} = ${evalChildren(i).primitive}; + } + """ + s""" + ${evalChildren.map(_.code).mkString("\n")} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${(0 until children.length).map(updateEval).mkString("\n")} + """ + } +} + +case class Greatest(children: Expression*) extends Expression { + require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length) + + override def nullable: Boolean = children.forall(_.nullable) + override def foldable: Boolean = children.forall(_.foldable) + + private lazy val ordering = TypeUtils.getOrdering(dataType) + + override def checkInputDataTypes(): TypeCheckResult = { + if (children.map(_.dataType).distinct.count(_ != NullType) > 1) { + TypeCheckResult.TypeCheckFailure( + s"The expressions should all have the same type," + + s" got GREATEST (${children.map(_.dataType)}).") + } else { + TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName) + } + } + + override def dataType: DataType = children.head.dataType + + override def eval(input: InternalRow): Any = { + children.foldLeft[Any](null)((r, c) => { + val evalc = c.eval(input) + if (evalc != null) { + if (r == null || ordering.gt(evalc, r)) evalc else r + } else { + r + } + }) + } + + override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val evalChildren = children.map(_.gen(ctx)) + def updateEval(i: Int): String = + s""" + if (!${evalChildren(i).isNull} && (${ev.isNull} || + ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) { + ${ev.isNull} = false; + ${ev.primitive} = ${evalChildren(i).primitive}; + } + """ + s""" + ${evalChildren.map(_.code).mkString("\n")} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + ${(0 until children.length).map(updateEval).mkString("\n")} + """ + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 372848ea9a..aaf40cc83e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -17,7 +17,10 @@ package org.apache.spark.sql.catalyst.expressions +import java.sql.{Timestamp, Date} + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.types._ @@ -134,4 +137,82 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row) } + test("function least") { + val row = create_row(1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.string.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + checkEvaluation(Least(c4, c3, c5), "a", row) + checkEvaluation(Least(c1, c2), 1, row) + checkEvaluation(Least(c1, c2, Literal(-1)), -1, row) + checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row) + + checkEvaluation(Least(Literal(null), Literal(null)), null, InternalRow.empty) + checkEvaluation(Least(Literal(-1.0), Literal(2.5)), -1.0, InternalRow.empty) + checkEvaluation(Least(Literal(-1), Literal(2)), -1, InternalRow.empty) + checkEvaluation( + Least(Literal((-1.0).toFloat), Literal(2.5.toFloat)), (-1.0).toFloat, InternalRow.empty) + checkEvaluation( + Least(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MinValue, InternalRow.empty) + checkEvaluation(Least(Literal(1.toByte), Literal(2.toByte)), 1.toByte, InternalRow.empty) + checkEvaluation( + Least(Literal(1.toShort), Literal(2.toByte.toShort)), 1.toShort, InternalRow.empty) + checkEvaluation(Least(Literal("abc"), Literal("aaaa")), "aaaa", InternalRow.empty) + checkEvaluation(Least(Literal(true), Literal(false)), false, InternalRow.empty) + checkEvaluation( + Least( + Literal(BigDecimal("1234567890987654321123456")), + Literal(BigDecimal("1234567890987654321123458"))), + BigDecimal("1234567890987654321123456"), InternalRow.empty) + checkEvaluation( + Least(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))), + Date.valueOf("2015-01-01"), InternalRow.empty) + checkEvaluation( + Least( + Literal(Timestamp.valueOf("2015-07-01 08:00:00")), + Literal(Timestamp.valueOf("2015-07-01 10:00:00"))), + Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty) + } + + test("function greatest") { + val row = create_row(1, 2, "a", "b", "c") + val c1 = 'a.int.at(0) + val c2 = 'a.int.at(1) + val c3 = 'a.string.at(2) + val c4 = 'a.string.at(3) + val c5 = 'a.string.at(4) + checkEvaluation(Greatest(c4, c5, c3), "c", row) + checkEvaluation(Greatest(c2, c1), 2, row) + checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row) + checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row) + + checkEvaluation(Greatest(Literal(null), Literal(null)), null, InternalRow.empty) + checkEvaluation(Greatest(Literal(-1.0), Literal(2.5)), 2.5, InternalRow.empty) + checkEvaluation(Greatest(Literal(-1), Literal(2)), 2, InternalRow.empty) + checkEvaluation( + Greatest(Literal((-1.0).toFloat), Literal(2.5.toFloat)), 2.5.toFloat, InternalRow.empty) + checkEvaluation( + Greatest(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MaxValue, InternalRow.empty) + checkEvaluation(Greatest(Literal(1.toByte), Literal(2.toByte)), 2.toByte, InternalRow.empty) + checkEvaluation( + Greatest(Literal(1.toShort), Literal(2.toByte.toShort)), 2.toShort, InternalRow.empty) + checkEvaluation(Greatest(Literal("abc"), Literal("aaaa")), "abc", InternalRow.empty) + checkEvaluation(Greatest(Literal(true), Literal(false)), true, InternalRow.empty) + checkEvaluation( + Greatest( + Literal(BigDecimal("1234567890987654321123456")), + Literal(BigDecimal("1234567890987654321123458"))), + BigDecimal("1234567890987654321123458"), InternalRow.empty) + checkEvaluation( + Greatest(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))), + Date.valueOf("2015-07-01"), InternalRow.empty) + checkEvaluation( + Greatest( + Literal(Timestamp.valueOf("2015-07-01 08:00:00")), + Literal(Timestamp.valueOf("2015-07-01 10:00:00"))), + Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty) + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 08bf37a5c2..ffa52f6258 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -599,7 +599,7 @@ object functions { /** * Creates a new row for each element in the given array or map column. */ - def explode(e: Column): Column = Explode(e.expr) + def explode(e: Column): Column = Explode(e.expr) /** * Converts a string exprsesion to lower case. @@ -1073,15 +1073,41 @@ object functions { def floor(columnName: String): Column = floor(Column(columnName)) /** - * Computes hex value of the given column + * Returns the greatest value of the list of values. * - * @group math_funcs + * @group normal_funcs * @since 1.5.0 */ + @scala.annotation.varargs + def greatest(exprs: Column*): Column = if (exprs.length < 2) { + sys.error("GREATEST takes at least 2 parameters") + } else { + Greatest(exprs.map(_.expr): _*) + } + + /** + * Returns the greatest value of the list of column names. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { + sys.error("GREATEST takes at least 2 parameters") + } else { + greatest((columnName +: columnNames).map(Column.apply): _*) + } + + /** + * Computes hex value of the given column. + * + * @group math_funcs + * @since 1.5.0 + */ def hex(column: Column): Column = Hex(column.expr) /** - * Computes hex value of the given input + * Computes hex value of the given input. * * @group math_funcs * @since 1.5.0 @@ -1172,6 +1198,32 @@ object functions { def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName)) /** + * Returns the least value of the list of values. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(exprs: Column*): Column = if (exprs.length < 2) { + sys.error("LEAST takes at least 2 parameters") + } else { + Least(exprs.map(_.expr): _*) + } + + /** + * Returns the least value of the list of column names. + * + * @group normal_funcs + * @since 1.5.0 + */ + @scala.annotation.varargs + def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) { + sys.error("LEAST takes at least 2 parameters") + } else { + least((columnName +: columnNames).map(Column.apply): _*) + } + + /** * Computes the natural logarithm of the given value. * * @group math_funcs diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 173280375c..6cebec95d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -381,4 +381,26 @@ class DataFrameFunctionsSuite extends QueryTest { df.selectExpr("split(a, '[1-9]+')"), Row(Seq("aa", "bb", "cc"))) } + + test("conditional function: least") { + checkAnswer( + testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1), + Row(-1) + ) + checkAnswer( + ctx.sql("SELECT least(a, 2) as l from testData2 order by l"), + Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2)) + ) + } + + test("conditional function: greatest") { + checkAnswer( + testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1), + Row(3) + ) + checkAnswer( + ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"), + Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3)) + ) + } } |