aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala163
1 files changed, 101 insertions, 62 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 0e3d99127e..0b36091ece 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -33,6 +33,78 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
import org.apache.spark.sql.types._
/**
+ * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]].
+ */
+trait InvokeLike extends Expression with NonSQLExpression {
+
+ def arguments: Seq[Expression]
+
+ def propagateNull: Boolean
+
+ protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable)
+
+ /**
+ * Prepares codes for arguments.
+ *
+ * - generate codes for argument.
+ * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments.
+ * - avoid some of nullabilty checking which are not needed because the expression is not
+ * nullable.
+ * - when needNullCheck == true, short circuit if we found one of arguments is null because
+ * preparing rest of arguments can be skipped in the case.
+ *
+ * @param ctx a [[CodegenContext]]
+ * @return (code to prepare arguments, argument string, result of argument null check)
+ */
+ def prepareArguments(ctx: CodegenContext): (String, String, String) = {
+
+ val resultIsNull = if (needNullCheck) {
+ val resultIsNull = ctx.freshName("resultIsNull")
+ ctx.addMutableState("boolean", resultIsNull, "")
+ resultIsNull
+ } else {
+ "false"
+ }
+ val argValues = arguments.map { e =>
+ val argValue = ctx.freshName("argValue")
+ ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
+ argValue
+ }
+
+ val argCodes = if (needNullCheck) {
+ val reset = s"$resultIsNull = false;"
+ val argCodes = arguments.zipWithIndex.map { case (e, i) =>
+ val expr = e.genCode(ctx)
+ val updateResultIsNull = if (e.nullable) {
+ s"$resultIsNull = ${expr.isNull};"
+ } else {
+ ""
+ }
+ s"""
+ if (!$resultIsNull) {
+ ${expr.code}
+ $updateResultIsNull
+ ${argValues(i)} = ${expr.value};
+ }
+ """
+ }
+ reset +: argCodes
+ } else {
+ arguments.zipWithIndex.map { case (e, i) =>
+ val expr = e.genCode(ctx)
+ s"""
+ ${expr.code}
+ ${argValues(i)} = ${expr.value};
+ """
+ }
+ }
+ val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
+
+ (argCode, argValues.mkString(", "), resultIsNull)
+ }
+}
+
+/**
* Invokes a static function, returning the result. By default, any of the arguments being null
* will result in returning null instead of calling the function.
*
@@ -50,7 +122,7 @@ case class StaticInvoke(
dataType: DataType,
functionName: String,
arguments: Seq[Expression] = Nil,
- propagateNull: Boolean = true) extends Expression with NonSQLExpression {
+ propagateNull: Boolean = true) extends InvokeLike {
val objectName = staticObject.getName.stripSuffix("$")
@@ -62,16 +134,10 @@ case class StaticInvoke(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
- val argGen = arguments.map(_.genCode(ctx))
- val argString = argGen.map(_.value).mkString(", ")
- val callFunc = s"$objectName.$functionName($argString)"
+ val (argCode, argString, resultIsNull) = prepareArguments(ctx)
- val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};"
- } else {
- s"boolean ${ev.isNull} = false;"
- }
+ val callFunc = s"$objectName.$functionName($argString)"
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
@@ -82,9 +148,9 @@ case class StaticInvoke(
}
val code = s"""
- ${argGen.map(_.code).mkString("\n")}
- $setIsNull
- final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc;
+ $argCode
+ boolean ${ev.isNull} = $resultIsNull;
+ final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc;
$postNullCheck
"""
ev.copy(code = code)
@@ -103,13 +169,15 @@ case class StaticInvoke(
* @param functionName The name of the method to call.
* @param dataType The expected return type of the function.
* @param arguments An optional list of expressions, whos evaluation will be passed to the function.
+ * @param propagateNull When true, and any of the arguments is null, null will be returned instead
+ * of calling the function.
*/
case class Invoke(
targetObject: Expression,
functionName: String,
dataType: DataType,
arguments: Seq[Expression] = Nil,
- propagateNull: Boolean = true) extends Expression with NonSQLExpression {
+ propagateNull: Boolean = true) extends InvokeLike {
override def nullable: Boolean = true
override def children: Seq[Expression] = targetObject +: arguments
@@ -131,8 +199,8 @@ case class Invoke(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
val obj = targetObject.genCode(ctx)
- val argGen = arguments.map(_.genCode(ctx))
- val argString = argGen.map(_.value).mkString(", ")
+
+ val (argCode, argString, resultIsNull) = prepareArguments(ctx)
val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive
val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty
@@ -164,12 +232,6 @@ case class Invoke(
"""
}
- val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};"
- } else {
- s"boolean ${ev.isNull} = ${obj.isNull};"
- }
-
// If the function can return null, we do an extra check to make sure our null bit is still set
// correctly.
val postNullCheck = if (ctx.defaultValue(dataType) == "null") {
@@ -177,15 +239,19 @@ case class Invoke(
} else {
""
}
+
val code = s"""
${obj.code}
- ${argGen.map(_.code).mkString("\n")}
- $setIsNull
+ boolean ${ev.isNull} = true;
$javaType ${ev.value} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- $evaluate
+ if (!${obj.isNull}) {
+ $argCode
+ ${ev.isNull} = $resultIsNull;
+ if (!${ev.isNull}) {
+ $evaluate
+ }
+ $postNullCheck
}
- $postNullCheck
"""
ev.copy(code = code)
}
@@ -223,10 +289,10 @@ case class NewInstance(
arguments: Seq[Expression],
propagateNull: Boolean,
dataType: DataType,
- outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression {
+ outerPointer: Option[() => AnyRef]) extends InvokeLike {
private val className = cls.getName
- override def nullable: Boolean = propagateNull
+ override def nullable: Boolean = needNullCheck
override def children: Seq[Expression] = arguments
@@ -245,52 +311,25 @@ case class NewInstance(
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val javaType = ctx.javaType(dataType)
- val argIsNulls = ctx.freshName("argIsNulls")
- ctx.addMutableState("boolean[]", argIsNulls,
- s"$argIsNulls = new boolean[${arguments.size}];")
- val argValues = arguments.zipWithIndex.map { case (e, i) =>
- val argValue = ctx.freshName("argValue")
- ctx.addMutableState(ctx.javaType(e.dataType), argValue, "")
- argValue
- }
- val argCodes = arguments.zipWithIndex.map { case (e, i) =>
- val expr = e.genCode(ctx)
- expr.code + s"""
- $argIsNulls[$i] = ${expr.isNull};
- ${argValues(i)} = ${expr.value};
- """
- }
- val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes)
+ val (argCode, argString, resultIsNull) = prepareArguments(ctx)
val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx))
- var isNull = ev.isNull
- val setIsNull = if (propagateNull && arguments.nonEmpty) {
- s"""
- boolean $isNull = false;
- for (int idx = 0; idx < ${arguments.length}; idx++) {
- if ($argIsNulls[idx]) { $isNull = true; break; }
- }
- """
- } else {
- isNull = "false"
- ""
- }
+ ev.isNull = resultIsNull
val constructorCall = outer.map { gen =>
- s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})"""
+ s"${gen.value}.new ${cls.getSimpleName}($argString)"
}.getOrElse {
- s"new $className(${argValues.mkString(", ")})"
+ s"new $className($argString)"
}
val code = s"""
$argCode
${outer.map(_.code).getOrElse("")}
- $setIsNull
- final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall;
- """
- ev.copy(code = code, isNull = isNull)
+ final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall;
+ """
+ ev.copy(code = code)
}
override def toString: String = s"newInstance($cls)"