aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2015-07-13 00:14:32 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-13 00:14:32 -0700
commit92540d22e45f9300f413f520a1770e9f36cfa833 (patch)
tree793bb6ebaeaa52381cf0966c47038a7a2d4f7b40 /sql
parent20b474335c68c644150fdc8443a2d0d2dad5e27d (diff)
downloadspark-92540d22e45f9300f413f520a1770e9f36cfa833.tar.gz
spark-92540d22e45f9300f413f520a1770e9f36cfa833.tar.bz2
spark-92540d22e45f9300f413f520a1770e9f36cfa833.zip
[SPARK-8203] [SPARK-8204] [SQL] conditional function: least/greatest
chenghao-intel zhichao-li qiansl127 Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #6851 from adrian-wang/udflg and squashes the following commits: 0f1bff2 [Daoyuan Wang] address comments from davis 7a6bdbb [Daoyuan Wang] add '.' for hex() c1f6824 [Daoyuan Wang] add codegen, test for all types ec625b0 [Daoyuan Wang] conditional function: least/greatest
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala103
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala81
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala60
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala22
5 files changed, 263 insertions, 5 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index f62d79f8ce..ed69c42dcb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -76,9 +76,11 @@ object FunctionRegistry {
expression[CreateArray]("array"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
+ expression[Greatest]("greatest"),
expression[If]("if"),
expression[IsNull]("isnull"),
expression[IsNotNull]("isnotnull"),
+ expression[Least]("least"),
expression[Coalesce]("nvl"),
expression[Rand]("rand"),
expression[Randn]("randn"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
index 395e84f089..e6a705fb80 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionals.scala
@@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.types.{BooleanType, DataType}
+import org.apache.spark.sql.catalyst.util.TypeUtils
+import org.apache.spark.sql.types.{NullType, BooleanType, DataType}
case class If(predicate: Expression, trueValue: Expression, falseValue: Expression)
@@ -312,3 +313,103 @@ case class CaseKeyWhen(key: Expression, branches: Seq[Expression]) extends CaseW
}.mkString
}
}
+
+case class Least(children: Expression*) extends Expression {
+ require(children.length > 1, "LEAST requires at least 2 arguments, got " + children.length)
+
+ override def nullable: Boolean = children.forall(_.nullable)
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ private lazy val ordering = TypeUtils.getOrdering(dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"The expressions should all have the same type," +
+ s" got LEAST (${children.map(_.dataType)}).")
+ } else {
+ TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+ }
+ }
+
+ override def dataType: DataType = children.head.dataType
+
+ override def eval(input: InternalRow): Any = {
+ children.foldLeft[Any](null)((r, c) => {
+ val evalc = c.eval(input)
+ if (evalc != null) {
+ if (r == null || ordering.lt(evalc, r)) evalc else r
+ } else {
+ r
+ }
+ })
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val evalChildren = children.map(_.gen(ctx))
+ def updateEval(i: Int): String =
+ s"""
+ if (!${evalChildren(i).isNull} && (${ev.isNull} ||
+ ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} < 0)) {
+ ${ev.isNull} = false;
+ ${ev.primitive} = ${evalChildren(i).primitive};
+ }
+ """
+ s"""
+ ${evalChildren.map(_.code).mkString("\n")}
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ ${(0 until children.length).map(updateEval).mkString("\n")}
+ """
+ }
+}
+
+case class Greatest(children: Expression*) extends Expression {
+ require(children.length > 1, "GREATEST requires at least 2 arguments, got " + children.length)
+
+ override def nullable: Boolean = children.forall(_.nullable)
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ private lazy val ordering = TypeUtils.getOrdering(dataType)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (children.map(_.dataType).distinct.count(_ != NullType) > 1) {
+ TypeCheckResult.TypeCheckFailure(
+ s"The expressions should all have the same type," +
+ s" got GREATEST (${children.map(_.dataType)}).")
+ } else {
+ TypeUtils.checkForOrderingExpr(dataType, "function " + prettyName)
+ }
+ }
+
+ override def dataType: DataType = children.head.dataType
+
+ override def eval(input: InternalRow): Any = {
+ children.foldLeft[Any](null)((r, c) => {
+ val evalc = c.eval(input)
+ if (evalc != null) {
+ if (r == null || ordering.gt(evalc, r)) evalc else r
+ } else {
+ r
+ }
+ })
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val evalChildren = children.map(_.gen(ctx))
+ def updateEval(i: Int): String =
+ s"""
+ if (!${evalChildren(i).isNull} && (${ev.isNull} ||
+ ${ctx.genComp(dataType, evalChildren(i).primitive, ev.primitive)} > 0)) {
+ ${ev.isNull} = false;
+ ${ev.primitive} = ${evalChildren(i).primitive};
+ }
+ """
+ s"""
+ ${evalChildren.map(_.code).mkString("\n")}
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ ${(0 until children.length).map(updateEval).mkString("\n")}
+ """
+ }
+}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 372848ea9a..aaf40cc83e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -17,7 +17,10 @@
package org.apache.spark.sql.catalyst.expressions
+import java.sql.{Timestamp, Date}
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.types._
@@ -134,4 +137,82 @@ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(CaseKeyWhen(literalNull, Seq(c2, c5, c1, c6)), "c", row)
}
+ test("function least") {
+ val row = create_row(1, 2, "a", "b", "c")
+ val c1 = 'a.int.at(0)
+ val c2 = 'a.int.at(1)
+ val c3 = 'a.string.at(2)
+ val c4 = 'a.string.at(3)
+ val c5 = 'a.string.at(4)
+ checkEvaluation(Least(c4, c3, c5), "a", row)
+ checkEvaluation(Least(c1, c2), 1, row)
+ checkEvaluation(Least(c1, c2, Literal(-1)), -1, row)
+ checkEvaluation(Least(c4, c5, c3, c3, Literal("a")), "a", row)
+
+ checkEvaluation(Least(Literal(null), Literal(null)), null, InternalRow.empty)
+ checkEvaluation(Least(Literal(-1.0), Literal(2.5)), -1.0, InternalRow.empty)
+ checkEvaluation(Least(Literal(-1), Literal(2)), -1, InternalRow.empty)
+ checkEvaluation(
+ Least(Literal((-1.0).toFloat), Literal(2.5.toFloat)), (-1.0).toFloat, InternalRow.empty)
+ checkEvaluation(
+ Least(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MinValue, InternalRow.empty)
+ checkEvaluation(Least(Literal(1.toByte), Literal(2.toByte)), 1.toByte, InternalRow.empty)
+ checkEvaluation(
+ Least(Literal(1.toShort), Literal(2.toByte.toShort)), 1.toShort, InternalRow.empty)
+ checkEvaluation(Least(Literal("abc"), Literal("aaaa")), "aaaa", InternalRow.empty)
+ checkEvaluation(Least(Literal(true), Literal(false)), false, InternalRow.empty)
+ checkEvaluation(
+ Least(
+ Literal(BigDecimal("1234567890987654321123456")),
+ Literal(BigDecimal("1234567890987654321123458"))),
+ BigDecimal("1234567890987654321123456"), InternalRow.empty)
+ checkEvaluation(
+ Least(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))),
+ Date.valueOf("2015-01-01"), InternalRow.empty)
+ checkEvaluation(
+ Least(
+ Literal(Timestamp.valueOf("2015-07-01 08:00:00")),
+ Literal(Timestamp.valueOf("2015-07-01 10:00:00"))),
+ Timestamp.valueOf("2015-07-01 08:00:00"), InternalRow.empty)
+ }
+
+ test("function greatest") {
+ val row = create_row(1, 2, "a", "b", "c")
+ val c1 = 'a.int.at(0)
+ val c2 = 'a.int.at(1)
+ val c3 = 'a.string.at(2)
+ val c4 = 'a.string.at(3)
+ val c5 = 'a.string.at(4)
+ checkEvaluation(Greatest(c4, c5, c3), "c", row)
+ checkEvaluation(Greatest(c2, c1), 2, row)
+ checkEvaluation(Greatest(c1, c2, Literal(2)), 2, row)
+ checkEvaluation(Greatest(c4, c5, c3, Literal("ccc")), "ccc", row)
+
+ checkEvaluation(Greatest(Literal(null), Literal(null)), null, InternalRow.empty)
+ checkEvaluation(Greatest(Literal(-1.0), Literal(2.5)), 2.5, InternalRow.empty)
+ checkEvaluation(Greatest(Literal(-1), Literal(2)), 2, InternalRow.empty)
+ checkEvaluation(
+ Greatest(Literal((-1.0).toFloat), Literal(2.5.toFloat)), 2.5.toFloat, InternalRow.empty)
+ checkEvaluation(
+ Greatest(Literal(Long.MaxValue), Literal(Long.MinValue)), Long.MaxValue, InternalRow.empty)
+ checkEvaluation(Greatest(Literal(1.toByte), Literal(2.toByte)), 2.toByte, InternalRow.empty)
+ checkEvaluation(
+ Greatest(Literal(1.toShort), Literal(2.toByte.toShort)), 2.toShort, InternalRow.empty)
+ checkEvaluation(Greatest(Literal("abc"), Literal("aaaa")), "abc", InternalRow.empty)
+ checkEvaluation(Greatest(Literal(true), Literal(false)), true, InternalRow.empty)
+ checkEvaluation(
+ Greatest(
+ Literal(BigDecimal("1234567890987654321123456")),
+ Literal(BigDecimal("1234567890987654321123458"))),
+ BigDecimal("1234567890987654321123458"), InternalRow.empty)
+ checkEvaluation(
+ Greatest(Literal(Date.valueOf("2015-01-01")), Literal(Date.valueOf("2015-07-01"))),
+ Date.valueOf("2015-07-01"), InternalRow.empty)
+ checkEvaluation(
+ Greatest(
+ Literal(Timestamp.valueOf("2015-07-01 08:00:00")),
+ Literal(Timestamp.valueOf("2015-07-01 10:00:00"))),
+ Timestamp.valueOf("2015-07-01 10:00:00"), InternalRow.empty)
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 08bf37a5c2..ffa52f6258 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -599,7 +599,7 @@ object functions {
/**
* Creates a new row for each element in the given array or map column.
*/
- def explode(e: Column): Column = Explode(e.expr)
+ def explode(e: Column): Column = Explode(e.expr)
/**
* Converts a string exprsesion to lower case.
@@ -1073,15 +1073,41 @@ object functions {
def floor(columnName: String): Column = floor(Column(columnName))
/**
- * Computes hex value of the given column
+ * Returns the greatest value of the list of values.
*
- * @group math_funcs
+ * @group normal_funcs
* @since 1.5.0
*/
+ @scala.annotation.varargs
+ def greatest(exprs: Column*): Column = if (exprs.length < 2) {
+ sys.error("GREATEST takes at least 2 parameters")
+ } else {
+ Greatest(exprs.map(_.expr): _*)
+ }
+
+ /**
+ * Returns the greatest value of the list of column names.
+ *
+ * @group normal_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def greatest(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
+ sys.error("GREATEST takes at least 2 parameters")
+ } else {
+ greatest((columnName +: columnNames).map(Column.apply): _*)
+ }
+
+ /**
+ * Computes hex value of the given column.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
def hex(column: Column): Column = Hex(column.expr)
/**
- * Computes hex value of the given input
+ * Computes hex value of the given input.
*
* @group math_funcs
* @since 1.5.0
@@ -1172,6 +1198,32 @@ object functions {
def hypot(l: Double, rightName: String): Column = hypot(l, Column(rightName))
/**
+ * Returns the least value of the list of values.
+ *
+ * @group normal_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def least(exprs: Column*): Column = if (exprs.length < 2) {
+ sys.error("LEAST takes at least 2 parameters")
+ } else {
+ Least(exprs.map(_.expr): _*)
+ }
+
+ /**
+ * Returns the least value of the list of column names.
+ *
+ * @group normal_funcs
+ * @since 1.5.0
+ */
+ @scala.annotation.varargs
+ def least(columnName: String, columnNames: String*): Column = if (columnNames.isEmpty) {
+ sys.error("LEAST takes at least 2 parameters")
+ } else {
+ least((columnName +: columnNames).map(Column.apply): _*)
+ }
+
+ /**
* Computes the natural logarithm of the given value.
*
* @group math_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 173280375c..6cebec95d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -381,4 +381,26 @@ class DataFrameFunctionsSuite extends QueryTest {
df.selectExpr("split(a, '[1-9]+')"),
Row(Seq("aa", "bb", "cc")))
}
+
+ test("conditional function: least") {
+ checkAnswer(
+ testData2.select(least(lit(-1), lit(0), col("a"), col("b"))).limit(1),
+ Row(-1)
+ )
+ checkAnswer(
+ ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
+ Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
+ )
+ }
+
+ test("conditional function: greatest") {
+ checkAnswer(
+ testData2.select(greatest(lit(2), lit(3), col("a"), col("b"))).limit(1),
+ Row(3)
+ )
+ checkAnswer(
+ ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
+ Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
+ )
+ }
}