aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@databricks.com>2015-06-11 17:06:21 -0700
committerReynold Xin <rxin@databricks.com>2015-06-11 17:06:21 -0700
commit337c16d57e40cb4967bf85269baae14745f161db (patch)
treeab201ed2a080bf1cc60502e62d1b00db16ece1cc /sql
parent7914c720bf7447e8c9d96d564eafd6b687d2fc1a (diff)
downloadspark-337c16d57e40cb4967bf85269baae14745f161db.tar.gz
spark-337c16d57e40cb4967bf85269baae14745f161db.tar.bz2
spark-337c16d57e40cb4967bf85269baae14745f161db.zip
[SQL] Miscellaneous SQL/DF expression changes.
SPARK-8201 conditional function: if SPARK-8205 conditional function: nvl SPARK-8208 math function: ceiling SPARK-8210 math function: degrees SPARK-8211 math function: radians SPARK-8219 math function: negative SPARK-8216 math function: rename log -> ln SPARK-8222 math function: alias power / pow SPARK-8225 math function: alias sign / signum SPARK-8228 conditional function: isnull SPARK-8229 conditional function: isnotnull SPARK-8250 string function: alias lower/lcase SPARK-8251 string function: alias upper / ucase Author: Reynold Xin <rxin@databricks.com> Closes #6754 from rxin/expressions-misc and squashes the following commits: 35fce15 [Reynold Xin] Removed println. 2647067 [Reynold Xin] Promote to string type. 3c32bbc [Reynold Xin] Fixed if. de827ac [Reynold Xin] Fixed style b201cd4 [Reynold Xin] Removed if. 6b21a9b [Reynold Xin] [SQL] Miscellaneous SQL/DF expression changes.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala20
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala30
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala13
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala43
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala16
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala51
7 files changed, 175 insertions, 27 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index a7816e3275..45bcbf73fa 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -84,43 +84,51 @@ object FunctionRegistry {
type FunctionBuilder = Seq[Expression] => Expression
val expressions: Map[String, FunctionBuilder] = Map(
- // Non aggregate functions
+ // misc non-aggregate functions
expression[Abs]("abs"),
expression[CreateArray]("array"),
expression[Coalesce]("coalesce"),
expression[Explode]("explode"),
+ expression[If]("if"),
+ expression[IsNull]("isnull"),
+ expression[IsNotNull]("isnotnull"),
+ expression[Coalesce]("nvl"),
expression[Rand]("rand"),
expression[Randn]("randn"),
expression[CreateStruct]("struct"),
expression[Sqrt]("sqrt"),
- // Math functions
+ // math functions
expression[Acos]("acos"),
expression[Asin]("asin"),
expression[Atan]("atan"),
expression[Atan2]("atan2"),
expression[Cbrt]("cbrt"),
expression[Ceil]("ceil"),
+ expression[Ceil]("ceiling"),
expression[Cos]("cos"),
expression[EulerNumber]("e"),
expression[Exp]("exp"),
expression[Expm1]("expm1"),
expression[Floor]("floor"),
expression[Hypot]("hypot"),
- expression[Log]("log"),
+ expression[Log]("ln"),
expression[Log10]("log10"),
expression[Log1p]("log1p"),
+ expression[UnaryMinus]("negative"),
expression[Pi]("pi"),
expression[Log2]("log2"),
expression[Pow]("pow"),
+ expression[Pow]("power"),
expression[Rint]("rint"),
+ expression[Signum]("sign"),
expression[Signum]("signum"),
expression[Sin]("sin"),
expression[Sinh]("sinh"),
expression[Tan]("tan"),
expression[Tanh]("tanh"),
- expression[ToDegrees]("todegrees"),
- expression[ToRadians]("toradians"),
+ expression[ToDegrees]("degrees"),
+ expression[ToRadians]("radians"),
// aggregate functions
expression[Average]("avg"),
@@ -132,10 +140,12 @@ object FunctionRegistry {
expression[Sum]("sum"),
// string functions
+ expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
expression[Substring]("substr"),
expression[Substring]("substring"),
+ expression[Upper]("ucase"),
expression[Upper]("upper")
)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 737905c358..6ed192360d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -58,6 +58,15 @@ object HiveTypeCoercion {
case _ => None
}
+ /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */
+ private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = {
+ findTightestCommonTypeOfTwo(left, right).orElse((left, right) match {
+ case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType)
+ case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType)
+ case _ => None
+ })
+ }
+
/**
* Find the tightest common type of a set of types by continuously applying
* `findTightestCommonTypeOfTwo` on these types.
@@ -91,6 +100,7 @@ trait HiveTypeCoercion {
StringToIntegralCasts ::
FunctionArgumentConversion ::
CaseWhenCoercion ::
+ IfCoercion ::
Division ::
PropagateTypes ::
ExpectedInputConversion ::
@@ -653,6 +663,26 @@ trait HiveTypeCoercion {
}
/**
+ * Coerces the type of different branches of If statement to a common type.
+ */
+ object IfCoercion extends Rule[LogicalPlan] {
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
+ // Find tightest common type for If, if the true value and false value have different types.
+ case i @ If(pred, left, right) if left.dataType != right.dataType =>
+ findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType =>
+ val newLeft = if (left.dataType == widestType) left else Cast(left, widestType)
+ val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
+ i.makeCopy(Array(pred, newLeft, newRight))
+ }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
+
+ // Convert If(null literal, _, _) into boolean type.
+ // In the optimizer, we should short-circuit this directly into false value.
+ case i @ If(pred, left, right) if pred.dataType == NullType =>
+ i.makeCopy(Array(Literal.create(null, BooleanType), left, right))
+ }
+ }
+
+ /**
* Casts types according to the expected input types for Expressions that have the trait
* `ExpectsInputTypes`.
*/
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 9977f7af00..f7b8e21bed 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -134,6 +134,19 @@ class HiveTypeCoercionSuite extends PlanTest {
:: Nil))
}
+ test("type coercion for If") {
+ val rule = new HiveTypeCoercion { }.IfCoercion
+ ruleTest(rule,
+ If(Literal(true), Literal(1), Literal(1L)),
+ If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
+ )
+
+ ruleTest(rule,
+ If(Literal.create(null, NullType), Literal(1), Literal(1)),
+ If(Literal.create(null, BooleanType), Literal(1), Literal(1))
+ )
+ }
+
test("type coercion for CaseKeyWhen") {
val cwc = new HiveTypeCoercion {}.CaseWhenCoercion
ruleTest(cwc,
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
index 152c4e4111..372848ea9a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala
@@ -19,11 +19,52 @@ package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.types.{IntegerType, BooleanType}
+import org.apache.spark.sql.types._
class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper {
+ test("if") {
+ val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)](
+ (true, 1, 2, 1),
+ (false, 1, 2, 2),
+ (null, 1, 2, 2),
+ (true, null, 2, null),
+ (false, 1, null, null),
+ (null, null, 2, 2),
+ (null, 1, null, null)
+ )
+
+ // dataType must match T.
+ def testIf(convert: (Integer => Any), dataType: DataType): Unit = {
+ for ((predicate, trueValue, falseValue, expected) <- testcases) {
+ val trueValueConverted = if (trueValue == null) null else convert(trueValue)
+ val falseValueConverted = if (falseValue == null) null else convert(falseValue)
+ val expectedConverted = if (expected == null) null else convert(expected)
+
+ checkEvaluation(
+ If(Literal.create(predicate, BooleanType),
+ Literal.create(trueValueConverted, dataType),
+ Literal.create(falseValueConverted, dataType)),
+ expectedConverted)
+ }
+ }
+
+ testIf(_ == 1, BooleanType)
+ testIf(_.toShort, ShortType)
+ testIf(identity, IntegerType)
+ testIf(_.toLong, LongType)
+
+ testIf(_.toFloat, FloatType)
+ testIf(_.toDouble, DoubleType)
+ testIf(Decimal(_), DecimalType.Unlimited)
+
+ testIf(identity, DateType)
+ testIf(_.toLong, TimestampType)
+
+ testIf(_.toString, StringType)
+ }
+
test("case when") {
val row = create_row(null, false, true, "a", "b", "c")
val c1 = 'a.boolean.at(0)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 4f5484f136..efcdae5bce 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -185,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest {
checkAnswer(
nullStrings.toDF.where($"s".isNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
+
+ checkAnswer(
+ ctx.sql("select isnull(null), isnull(1)"),
+ Row(true, false))
}
test("isNotNull") {
checkAnswer(
nullStrings.toDF.where($"s".isNotNull),
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
+
+ checkAnswer(
+ ctx.sql("select isnotnull(null), isnotnull('a')"),
+ Row(false, true))
}
test("===") {
@@ -393,6 +401,10 @@ class ColumnExpressionSuite extends QueryTest {
testData.select(upper(lit(null))),
(1 to 100).map(n => Row(null))
)
+
+ checkAnswer(
+ ctx.sql("SELECT upper('aB'), ucase('cDe')"),
+ Row("AB", "CDE"))
}
test("lower") {
@@ -410,6 +422,10 @@ class ColumnExpressionSuite extends QueryTest {
testData.select(lower(lit(null))),
(1 to 100).map(n => Row(null))
)
+
+ checkAnswer(
+ ctx.sql("SELECT lower('aB'), lcase('cDe')"),
+ Row("ab", "cde"))
}
test("monotonicallyIncreasingId") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 659b64c185..cfd23867a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -110,7 +110,20 @@ class DataFrameFunctionsSuite extends QueryTest {
testData2.collect().toSeq.map(r => Row(~r.getInt(0))))
}
- test("length") {
+ test("if function") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ checkAnswer(
+ df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"),
+ Row("one", "not_one"))
+ }
+
+ test("nvl function") {
+ checkAnswer(
+ ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+ Row("x", "y", null))
+ }
+
+ test("string length function") {
checkAnswer(
nullStrings.select(strlen($"s"), strlen("s")),
nullStrings.collect().toSeq.map { r =>
@@ -127,18 +140,4 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(l)
})
}
-
- test("log2 functions test") {
- val df = Seq((1, 2)).toDF("a", "b")
- checkAnswer(
- df.select(log2("b") + log2("a")),
- Row(1))
-
- checkAnswer(
- ctx.sql("SELECT LOG2(8)"),
- Row(3))
- checkAnswer(
- ctx.sql("SELECT LOG2(null)"),
- Row(null))
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 0a38af2b4c..6561c3b232 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.functions.{log => logarithm}
private object MathExpressionsTestData {
@@ -151,20 +152,31 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(tanh, math.tanh)
}
- test("toDeg") {
+ test("toDegrees") {
testOneToOneMathFunction(toDegrees, math.toDegrees)
+ checkAnswer(
+ ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
+ Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
+ )
}
- test("toRad") {
+ test("toRadians") {
testOneToOneMathFunction(toRadians, math.toRadians)
+ checkAnswer(
+ ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
+ Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
+ )
}
test("cbrt") {
testOneToOneMathFunction(cbrt, math.cbrt)
}
- test("ceil") {
+ test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil)
+ checkAnswer(
+ ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
+ Row(0.0, 1.0, 2.0))
}
test("floor") {
@@ -183,12 +195,21 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction(expm1, math.expm1)
}
- test("signum") {
+ test("signum / sign") {
testOneToOneMathFunction[Double](signum, math.signum)
+
+ checkAnswer(
+ ctx.sql("SELECT sign(10), signum(-11)"),
+ Row(1, -1))
}
- test("pow") {
+ test("pow / power") {
testTwoToOneMathFunction(pow, pow, math.pow)
+
+ checkAnswer(
+ ctx.sql("SELECT pow(1, 2), power(2, 1)"),
+ Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
+ )
}
test("hypot") {
@@ -199,8 +220,12 @@ class MathExpressionsSuite extends QueryTest {
testTwoToOneMathFunction(atan2, atan2, math.atan2)
}
- test("log") {
+ test("log / ln") {
testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
+ checkAnswer(
+ ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
+ Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
+ )
}
test("log10") {
@@ -211,4 +236,18 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneNonNegativeMathFunction(log1p, math.log1p)
}
+ test("log2") {
+ val df = Seq((1, 2)).toDF("a", "b")
+ checkAnswer(
+ df.select(log2("b") + log2("a")),
+ Row(1))
+
+ checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
+ }
+
+ test("negative") {
+ checkAnswer(
+ ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
+ Row(-1, 0, 1))
+ }
}