diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-04-18 12:26:56 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-18 12:26:56 -0700 |
commit | d280d1da1aec925687a0bfb496f3a6e0979e896f (patch) | |
tree | dbeb76ebe45f519f68b624e0cda78d3e719d2dcc /sql | |
parent | b64482f49f6b9c7ff0ba64bd3202fe9cc6ad119a (diff) | |
download | spark-d280d1da1aec925687a0bfb496f3a6e0979e896f.tar.gz spark-d280d1da1aec925687a0bfb496f3a6e0979e896f.tar.bz2 spark-d280d1da1aec925687a0bfb496f3a6e0979e896f.zip |
[SPARK-14580][SPARK-14655][SQL] Hive IfCoercion should preserve predicate.
## What changes were proposed in this pull request?
Currently, `HiveTypeCoercion.IfCoercion` removes all predicates whose return-type are null. However, some UDFs need evaluations because they are designed to throw exceptions. This PR fixes that to preserve the predicates. Also, `assert_true` is implemented as Spark SQL function.
**Before**
```
scala> sql("select if(assert_true(false),2,3)").head
res2: org.apache.spark.sql.Row = [3]
```
**After**
```
scala> sql("select if(assert_true(false),2,3)").head
... ASSERT_TRUE ...
```
**Hive**
```
hive> select if(assert_true(false),2,3);
OK
Failed with exception java.io.IOException:org.apache.hadoop.hive.ql.metadata.HiveException: ASSERT_TRUE(): assertion failed.
```
## How was this patch tested?
Pass the Jenkins tests (including a new testcase in `HivePlanTest`)
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #12340 from dongjoon-hyun/SPARK-14580.
Diffstat (limited to 'sql')
5 files changed, 70 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index ed19191b72..a44430059d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -329,6 +329,7 @@ object FunctionRegistry { expression[SortArray]("sort_array"), // misc functions + expression[AssertTrue]("assert_true"), expression[Crc32]("crc32"), expression[Md5]("md5"), expression[Murmur3Hash]("hash"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala index 823d2495fa..5323b79c57 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala @@ -584,10 +584,10 @@ object HiveTypeCoercion { val newRight = if (right.dataType == widestType) right else Cast(right, widestType) If(pred, newLeft, newRight) }.getOrElse(i) // If there is no applicable conversion, leave expression unchanged. - // Convert If(null literal, _, _) into boolean type. - // In the optimizer, we should short-circuit this directly into false value. - case If(pred, left, right) if pred.dataType == NullType => + case If(Literal(null, NullType), left, right) => If(Literal.create(null, BooleanType), left, right) + case If(pred, left, right) if pred.dataType == NullType => + If(Cast(pred, BooleanType), left, right) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 113fc862c7..f2f0c2d698 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -487,6 +487,44 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { } /** + * A function throws an exception if 'condition' is not true. + */ +@ExpressionDescription( + usage = "_FUNC_(condition) - Throw an exception if 'condition' is not true.") +case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes { + + override def nullable: Boolean = true + + override def inputTypes: Seq[DataType] = Seq(BooleanType) + + override def dataType: DataType = NullType + + override def prettyName: String = "assert_true" + + override def eval(input: InternalRow) : Any = { + val v = child.eval(input) + if (v == null || java.lang.Boolean.FALSE.equals(v)) { + throw new RuntimeException(s"'${child.simpleString}' is not true!") + } else { + null + } + } + + override def genCode(ctx: CodegenContext, ev: ExprCode): String = { + val eval = child.gen(ctx) + ev.isNull = "true" + ev.value = "null" + s"""${eval.code} + |if (${eval.isNull} || !${eval.value}) { + | throw new RuntimeException("'${child.simpleString}' is not true."); + |} + """.stripMargin + } + + override def sql: String = s"assert_true(${child.sql})" +} + +/** * A xxHash64 64-bit hash expression. */ case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala index 883ef48984..18de8b152b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala @@ -348,15 +348,22 @@ class HiveTypeCoercionSuite extends PlanTest { test("type coercion for If") { val rule = HiveTypeCoercion.IfCoercion + ruleTest(rule, If(Literal(true), Literal(1), Literal(1L)), - If(Literal(true), Cast(Literal(1), LongType), Literal(1L)) - ) + If(Literal(true), Cast(Literal(1), LongType), Literal(1L))) ruleTest(rule, If(Literal.create(null, NullType), Literal(1), Literal(1)), - If(Literal.create(null, BooleanType), Literal(1), Literal(1)) - ) + If(Literal.create(null, BooleanType), Literal(1), Literal(1))) + + ruleTest(rule, + If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)), + If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2))) + + ruleTest(rule, + If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)), + If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2))) } test("type coercion for CaseKeyWhen") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala index f5bafcc6a7..56de82237b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala @@ -69,6 +69,23 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType) } + test("assert_true") { + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal(false, BooleanType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(null, NullType)), null) + } + intercept[RuntimeException] { + checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null) + } + checkEvaluation(AssertTrue(Literal(true, BooleanType)), null) + checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null) + } + private val structOfString = new StructType().add("str", StringType) private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false) private val arrayOfString = ArrayType(StringType) |