From 337c16d57e40cb4967bf85269baae14745f161db Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 11 Jun 2015 17:06:21 -0700 Subject: [SQL] Miscellaneous SQL/DF expression changes. SPARK-8201 conditional function: if SPARK-8205 conditional function: nvl SPARK-8208 math function: ceiling SPARK-8210 math function: degrees SPARK-8211 math function: radians SPARK-8219 math function: negative SPARK-8216 math function: rename log -> ln SPARK-8222 math function: alias power / pow SPARK-8225 math function: alias sign / signum SPARK-8228 conditional function: isnull SPARK-8229 conditional function: isnotnull SPARK-8250 string function: alias lower/lcase SPARK-8251 string function: alias upper / ucase Author: Reynold Xin Closes #6754 from rxin/expressions-misc and squashes the following commits: 35fce15 [Reynold Xin] Removed println. 2647067 [Reynold Xin] Promote to string type. 3c32bbc [Reynold Xin] Fixed if. de827ac [Reynold Xin] Fixed style b201cd4 [Reynold Xin] Removed if. 6b21a9b [Reynold Xin] [SQL] Miscellaneous SQL/DF expression changes. --- .../sql/catalyst/analysis/FunctionRegistry.scala | 20 ++++++--- .../sql/catalyst/analysis/HiveTypeCoercion.scala | 30 +++++++++++++ .../catalyst/analysis/HiveTypeCoercionSuite.scala | 13 ++++++ .../expressions/ConditionalExpressionSuite.scala | 43 +++++++++++++++++- .../apache/spark/sql/ColumnExpressionSuite.scala | 16 +++++++ .../apache/spark/sql/DataFrameFunctionsSuite.scala | 29 ++++++------ .../apache/spark/sql/MathExpressionsSuite.scala | 51 +++++++++++++++++++--- 7 files changed, 175 insertions(+), 27 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index a7816e3275..45bcbf73fa 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -84,43 +84,51 @@ object FunctionRegistry { type FunctionBuilder = Seq[Expression] => Expression val expressions: Map[String, FunctionBuilder] = Map( - // Non aggregate functions + // misc non-aggregate functions expression[Abs]("abs"), expression[CreateArray]("array"), expression[Coalesce]("coalesce"), expression[Explode]("explode"), + expression[If]("if"), + expression[IsNull]("isnull"), + expression[IsNotNull]("isnotnull"), + expression[Coalesce]("nvl"), expression[Rand]("rand"), expression[Randn]("randn"), expression[CreateStruct]("struct"), expression[Sqrt]("sqrt"), - // Math functions + // math functions expression[Acos]("acos"), expression[Asin]("asin"), expression[Atan]("atan"), expression[Atan2]("atan2"), expression[Cbrt]("cbrt"), expression[Ceil]("ceil"), + expression[Ceil]("ceiling"), expression[Cos]("cos"), expression[EulerNumber]("e"), expression[Exp]("exp"), expression[Expm1]("expm1"), expression[Floor]("floor"), expression[Hypot]("hypot"), - expression[Log]("log"), + expression[Log]("ln"), expression[Log10]("log10"), expression[Log1p]("log1p"), + expression[UnaryMinus]("negative"), expression[Pi]("pi"), expression[Log2]("log2"), expression[Pow]("pow"), + expression[Pow]("power"), expression[Rint]("rint"), + expression[Signum]("sign"), expression[Signum]("signum"), expression[Sin]("sin"), expression[Sinh]("sinh"), expression[Tan]("tan"), expression[Tanh]("tanh"), - expression[ToDegrees]("todegrees"), - expression[ToRadians]("toradians"), + expression[ToDegrees]("degrees"), + expression[ToRadians]("radians"), // aggregate functions expression[Average]("avg"), @@ -132,10 +140,12 @@ object FunctionRegistry { expression[Sum]("sum"), // string functions + expression[Lower]("lcase"), expression[Lower]("lower"), expression[StringLength]("length"), expression[Substring]("substr"), expression[Substring]("substring"), + expression[Upper]("ucase"), expression[Upper]("upper") ) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 737905c358..6ed192360d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -58,6 +58,15 @@ object HiveTypeCoercion { case _ => None } + /** Similar to [[findTightestCommonType]], but can promote all the way to StringType. */ + private def findTightestCommonTypeToString(left: DataType, right: DataType): Option[DataType] = { + findTightestCommonTypeOfTwo(left, right).orElse((left, right) match { + case (StringType, t2: AtomicType) if t2 != BinaryType && t2 != BooleanType => Some(StringType) + case (t1: AtomicType, StringType) if t1 != BinaryType && t1 != BooleanType => Some(StringType) + case _ => None + }) + } + /** * Find the tightest common type of a set of types by continuously applying * `findTightestCommonTypeOfTwo` on these types. @@ -91,6 +100,7 @@ trait HiveTypeCoercion { StringToIntegralCasts :: FunctionArgumentConversion :: CaseWhenCoercion :: + IfCoercion :: Division :: PropagateTypes :: ExpectedInputConversion :: @@ -652,6 +662,26 @@ trait HiveTypeCoercion { } } + /** + * Coerces the type of different branches of If statement to a common type. + */ + object IfCoercion extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions { + // Find tightest common type for If, if the true value and false value have different types. + case i @ If(pred, left, right) if left.dataType != right.dataType => + findTightestCommonTypeToString(left.dataType, right.dataType).map { widestType => + val newLeft = if (left.dataType == widestType) left else Cast(left, widestType) + val newRight = if (right.dataType == widestType) right else Cast(right, widestType) + i.makeCopy(Array(pred, newLeft, newRight)) + }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. + + // Convert If(null literal, _, _) into boolean type. + // In the optimizer, we should short-circuit this directly into false value. + case i @ If(pred, left, right) if pred.dataType == NullType => + i.makeCopy(Array(Literal.create(null, BooleanType), left, right)) + } + } + /** * Casts types according to the expected input types for Expressions that have the trait * `ExpectsInputTypes`. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 9977f7af00..f7b8e21bed 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -134,6 +134,19 @@ class HiveTypeCoercionSuite extends PlanTest { :: Nil)) } + test("type coercion for If") { + val rule = new HiveTypeCoercion { }.IfCoercion + ruleTest(rule, + If(Literal(true), Literal(1), Literal(1L)), + If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) + ) + + ruleTest(rule, + If(Literal.create(null, NullType), Literal(1), Literal(1)), + If(Literal.create(null, BooleanType), Literal(1), Literal(1)) + ) + } + test("type coercion for CaseKeyWhen") { val cwc = new HiveTypeCoercion {}.CaseWhenCoercion ruleTest(cwc, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala index 152c4e4111..372848ea9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ConditionalExpressionSuite.scala @@ -19,11 +19,52 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.dsl.expressions._ -import org.apache.spark.sql.types.{IntegerType, BooleanType} +import org.apache.spark.sql.types._ class ConditionalExpressionSuite extends SparkFunSuite with ExpressionEvalHelper { + test("if") { + val testcases = Seq[(java.lang.Boolean, Integer, Integer, Integer)]( + (true, 1, 2, 1), + (false, 1, 2, 2), + (null, 1, 2, 2), + (true, null, 2, null), + (false, 1, null, null), + (null, null, 2, 2), + (null, 1, null, null) + ) + + // dataType must match T. + def testIf(convert: (Integer => Any), dataType: DataType): Unit = { + for ((predicate, trueValue, falseValue, expected) <- testcases) { + val trueValueConverted = if (trueValue == null) null else convert(trueValue) + val falseValueConverted = if (falseValue == null) null else convert(falseValue) + val expectedConverted = if (expected == null) null else convert(expected) + + checkEvaluation( + If(Literal.create(predicate, BooleanType), + Literal.create(trueValueConverted, dataType), + Literal.create(falseValueConverted, dataType)), + expectedConverted) + } + } + + testIf(_ == 1, BooleanType) + testIf(_.toShort, ShortType) + testIf(identity, IntegerType) + testIf(_.toLong, LongType) + + testIf(_.toFloat, FloatType) + testIf(_.toDouble, DoubleType) + testIf(Decimal(_), DecimalType.Unlimited) + + testIf(identity, DateType) + testIf(_.toLong, TimestampType) + + testIf(_.toString, StringType) + } + test("case when") { val row = create_row(null, false, true, "a", "b", "c") val c1 = 'a.boolean.at(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index 4f5484f136..efcdae5bce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -185,12 +185,20 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer( nullStrings.toDF.where($"s".isNull), nullStrings.collect().toSeq.filter(r => r.getString(1) eq null)) + + checkAnswer( + ctx.sql("select isnull(null), isnull(1)"), + Row(true, false)) } test("isNotNull") { checkAnswer( nullStrings.toDF.where($"s".isNotNull), nullStrings.collect().toSeq.filter(r => r.getString(1) ne null)) + + checkAnswer( + ctx.sql("select isnotnull(null), isnotnull('a')"), + Row(false, true)) } test("===") { @@ -393,6 +401,10 @@ class ColumnExpressionSuite extends QueryTest { testData.select(upper(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + ctx.sql("SELECT upper('aB'), ucase('cDe')"), + Row("AB", "CDE")) } test("lower") { @@ -410,6 +422,10 @@ class ColumnExpressionSuite extends QueryTest { testData.select(lower(lit(null))), (1 to 100).map(n => Row(null)) ) + + checkAnswer( + ctx.sql("SELECT lower('aB'), lcase('cDe')"), + Row("ab", "cde")) } test("monotonicallyIncreasingId") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 659b64c185..cfd23867a9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -110,7 +110,20 @@ class DataFrameFunctionsSuite extends QueryTest { testData2.collect().toSeq.map(r => Row(~r.getInt(0)))) } - test("length") { + test("if function") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.selectExpr("if(a = 1, 'one', 'not_one')", "if(b = 1, 'one', 'not_one')"), + Row("one", "not_one")) + } + + test("nvl function") { + checkAnswer( + ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"), + Row("x", "y", null)) + } + + test("string length function") { checkAnswer( nullStrings.select(strlen($"s"), strlen("s")), nullStrings.collect().toSeq.map { r => @@ -127,18 +140,4 @@ class DataFrameFunctionsSuite extends QueryTest { Row(l) }) } - - test("log2 functions test") { - val df = Seq((1, 2)).toDF("a", "b") - checkAnswer( - df.select(log2("b") + log2("a")), - Row(1)) - - checkAnswer( - ctx.sql("SELECT LOG2(8)"), - Row(3)) - checkAnswer( - ctx.sql("SELECT LOG2(null)"), - Row(null)) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index 0a38af2b4c..6561c3b232 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql import org.apache.spark.sql.functions._ +import org.apache.spark.sql.functions.{log => logarithm} private object MathExpressionsTestData { @@ -151,20 +152,31 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(tanh, math.tanh) } - test("toDeg") { + test("toDegrees") { testOneToOneMathFunction(toDegrees, math.toDegrees) + checkAnswer( + ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"), + Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5))) + ) } - test("toRad") { + test("toRadians") { testOneToOneMathFunction(toRadians, math.toRadians) + checkAnswer( + ctx.sql("SELECT radians(0), radians(1), radians(1.5)"), + Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5))) + ) } test("cbrt") { testOneToOneMathFunction(cbrt, math.cbrt) } - test("ceil") { + test("ceil and ceiling") { testOneToOneMathFunction(ceil, math.ceil) + checkAnswer( + ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"), + Row(0.0, 1.0, 2.0)) } test("floor") { @@ -183,12 +195,21 @@ class MathExpressionsSuite extends QueryTest { testOneToOneMathFunction(expm1, math.expm1) } - test("signum") { + test("signum / sign") { testOneToOneMathFunction[Double](signum, math.signum) + + checkAnswer( + ctx.sql("SELECT sign(10), signum(-11)"), + Row(1, -1)) } - test("pow") { + test("pow / power") { testTwoToOneMathFunction(pow, pow, math.pow) + + checkAnswer( + ctx.sql("SELECT pow(1, 2), power(2, 1)"), + Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1))) + ) } test("hypot") { @@ -199,8 +220,12 @@ class MathExpressionsSuite extends QueryTest { testTwoToOneMathFunction(atan2, atan2, math.atan2) } - test("log") { + test("log / ln") { testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log) + checkAnswer( + ctx.sql("SELECT ln(0), ln(1), ln(1.5)"), + Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5))) + ) } test("log10") { @@ -211,4 +236,18 @@ class MathExpressionsSuite extends QueryTest { testOneToOneNonNegativeMathFunction(log1p, math.log1p) } + test("log2") { + val df = Seq((1, 2)).toDF("a", "b") + checkAnswer( + df.select(log2("b") + log2("a")), + Row(1)) + + checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null)) + } + + test("negative") { + checkAnswer( + ctx.sql("SELECT negative(1), negative(0), negative(-1)"), + Row(-1, 0, 1)) + } } -- cgit v1.2.3