aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala12
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala13
3 files changed, 29 insertions, 8 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
index 369207587d..92f8fb85fc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala
@@ -477,10 +477,13 @@ case class PrintToStderr(child: Expression) extends UnaryExpression {
protected override def nullSafeEval(input: Any): Any = input
+ private val outputPrefix = s"Result of ${child.simpleString} is "
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val outputPrefixField = ctx.addReferenceObj("outputPrefix", outputPrefix)
nullSafeCodeGen(ctx, ev, c =>
s"""
- | System.err.println("Result of ${child.simpleString} is " + $c);
+ | System.err.println($outputPrefixField + $c);
| ${ev.value} = $c;
""".stripMargin)
}
@@ -501,10 +504,12 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
override def prettyName: String = "assert_true"
+ private val errMsg = s"'${child.simpleString}' is not true!"
+
override def eval(input: InternalRow) : Any = {
val v = child.eval(input)
if (v == null || java.lang.Boolean.FALSE.equals(v)) {
- throw new RuntimeException(s"'${child.simpleString}' is not true!")
+ throw new RuntimeException(errMsg)
} else {
null
}
@@ -512,9 +517,10 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val eval = child.genCode(ctx)
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
ExprCode(code = s"""${eval.code}
|if (${eval.isNull} || !${eval.value}) {
- | throw new RuntimeException("'${child.simpleString}' is not true.");
+ | throw new RuntimeException($errMsgField);
|}""".stripMargin, isNull = "true", value = "null")
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 4da74a0a27..faf8fecd79 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -938,7 +938,10 @@ case class GetExternalRowField(
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ private val errMsg = s"The ${index}th field '$fieldName' of input row cannot be null."
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val row = child.genCode(ctx)
val code = s"""
${row.code}
@@ -948,8 +951,7 @@ case class GetExternalRowField(
}
if (${row.value}.isNullAt($index)) {
- throw new RuntimeException("The ${index}th field '$fieldName' of input row " +
- "cannot be null.");
+ throw new RuntimeException($errMsgField);
}
final Object ${ev.value} = ${row.value}.get($index);
@@ -974,7 +976,10 @@ case class ValidateExternalType(child: Expression, expected: DataType)
override def eval(input: InternalRow): Any =
throw new UnsupportedOperationException("Only code-generated evaluation is supported")
+ private val errMsg = s" is not a valid external type for schema of ${expected.simpleString}"
+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+ val errMsgField = ctx.addReferenceObj("errMsg", errMsg)
val input = child.genCode(ctx)
val obj = input.value
@@ -995,8 +1000,7 @@ case class ValidateExternalType(child: Expression, expected: DataType)
if ($typeCheck) {
${ev.value} = (${ctx.boxedType(dataType)}) $obj;
} else {
- throw new RuntimeException($obj.getClass().getName() + " is not a valid " +
- "external type for schema of ${expected.simpleString}");
+ throw new RuntimeException($obj.getClass().getName() + $errMsgField);
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
index 0532cf5113..45dcfcaf23 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala
@@ -23,7 +23,7 @@ import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.expressions.objects.CreateExternalRow
+import org.apache.spark.sql.catalyst.expressions.objects.{CreateExternalRow, GetExternalRowField, ValidateExternalType}
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -265,4 +265,15 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
Literal.create("\\\\u001/Compilation error occurs", StringType) :: Nil)
}
+
+ test("SPARK-17160: field names are properly escaped by GetExternalRowField") {
+ val inputObject = BoundReference(0, ObjectType(classOf[Row]), nullable = true)
+ GenerateUnsafeProjection.generate(
+ ValidateExternalType(
+ GetExternalRowField(inputObject, index = 0, fieldName = "\"quote"), IntegerType) :: Nil)
+ }
+
+ test("SPARK-17160: field names are properly escaped by AssertTrue") {
+ GenerateUnsafeProjection.generate(AssertTrue(Cast(Literal("\""), BooleanType)) :: Nil)
+ }
}