aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala38
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala15
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala17
5 files changed, 70 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index ed19191b72..a44430059d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -329,6 +329,7 @@ object FunctionRegistry {
expression[SortArray]("sort_array"),
// misc functions
+ expression[AssertTrue]("assert_true"),
expression[Crc32]("crc32"),
expression[Md5]("md5"),
expression[Murmur3Hash]("hash"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
index 823d2495fa..5323b79c57 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercion.scala
@@ -584,10 +584,10 @@ object HiveTypeCoercion {
val newRight = if (right.dataType == widestType) right else Cast(right, widestType)
If(pred, newLeft, newRight)
}.getOrElse(i) // If there is no applicable conversion, leave expression unchanged.
- // Convert If(null literal, _, _) into boolean type.
- // In the optimizer, we should short-circuit this directly into false value.
- case If(pred, left, right) if pred.dataType == NullType =>
+ case If(Literal(null, NullType), left, right) =>
If(Literal.create(null, BooleanType), left, right)
+ case If(pred, left, right) if pred.dataType == NullType =>
+ If(Cast(pred, BooleanType), left, right)
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 113fc862c7..f2f0c2d698 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -487,6 +487,44 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
}
/**
+ * A function throws an exception if 'condition' is not true.
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(condition) - Throw an exception if 'condition' is not true.")
+case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
+
+ override def nullable: Boolean = true
+
+ override def inputTypes: Seq[DataType] = Seq(BooleanType)
+
+ override def dataType: DataType = NullType
+
+ override def prettyName: String = "assert_true"
+
+ override def eval(input: InternalRow) : Any = {
+ val v = child.eval(input)
+ if (v == null || java.lang.Boolean.FALSE.equals(v)) {
+ throw new RuntimeException(s"'${child.simpleString}' is not true!")
+ } else {
+ null
+ }
+ }
+
+ override def genCode(ctx: CodegenContext, ev: ExprCode): String = {
+ val eval = child.gen(ctx)
+ ev.isNull = "true"
+ ev.value = "null"
+ s"""${eval.code}
+ |if (${eval.isNull} || !${eval.value}) {
+ | throw new RuntimeException("'${child.simpleString}' is not true.");
+ |}
+ """.stripMargin
+ }
+
+ override def sql: String = s"assert_true(${child.sql})"
+}
+
+/**
* A xxHash64 64-bit hash expression.
*/
case class XxHash64(children: Seq[Expression], seed: Long) extends HashExpression[Long] {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
index 883ef48984..18de8b152b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/HiveTypeCoercionSuite.scala
@@ -348,15 +348,22 @@ class HiveTypeCoercionSuite extends PlanTest {
test("type coercion for If") {
val rule = HiveTypeCoercion.IfCoercion
+
ruleTest(rule,
If(Literal(true), Literal(1), Literal(1L)),
- If(Literal(true), Cast(Literal(1), LongType), Literal(1L))
- )
+ If(Literal(true), Cast(Literal(1), LongType), Literal(1L)))
ruleTest(rule,
If(Literal.create(null, NullType), Literal(1), Literal(1)),
- If(Literal.create(null, BooleanType), Literal(1), Literal(1))
- )
+ If(Literal.create(null, BooleanType), Literal(1), Literal(1)))
+
+ ruleTest(rule,
+ If(AssertTrue(Literal.create(true, BooleanType)), Literal(1), Literal(2)),
+ If(Cast(AssertTrue(Literal.create(true, BooleanType)), BooleanType), Literal(1), Literal(2)))
+
+ ruleTest(rule,
+ If(AssertTrue(Literal.create(false, BooleanType)), Literal(1), Literal(2)),
+ If(Cast(AssertTrue(Literal.create(false, BooleanType)), BooleanType), Literal(1), Literal(2)))
}
test("type coercion for CaseKeyWhen") {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
index f5bafcc6a7..56de82237b 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MiscFunctionsSuite.scala
@@ -69,6 +69,23 @@ class MiscFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkConsistencyBetweenInterpretedAndCodegen(Crc32, BinaryType)
}
+ test("assert_true") {
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Literal(false, BooleanType)), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Cast(Literal(0), BooleanType)), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Literal.create(null, NullType)), null)
+ }
+ intercept[RuntimeException] {
+ checkEvaluation(AssertTrue(Literal.create(null, BooleanType)), null)
+ }
+ checkEvaluation(AssertTrue(Literal(true, BooleanType)), null)
+ checkEvaluation(AssertTrue(Cast(Literal(1), BooleanType)), null)
+ }
+
private val structOfString = new StructType().add("str", StringType)
private val structOfUDT = new StructType().add("udt", new ExamplePointUDT, false)
private val arrayOfString = ArrayType(StringType)