diff options
-rw-r--r-- | sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala | 163 |
1 files changed, 101 insertions, 62 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 0e3d99127e..0b36091ece 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -33,6 +33,78 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData} import org.apache.spark.sql.types._ /** + * Common base class for [[StaticInvoke]], [[Invoke]], and [[NewInstance]]. + */ +trait InvokeLike extends Expression with NonSQLExpression { + + def arguments: Seq[Expression] + + def propagateNull: Boolean + + protected lazy val needNullCheck: Boolean = propagateNull && arguments.exists(_.nullable) + + /** + * Prepares codes for arguments. + * + * - generate codes for argument. + * - use ctx.splitExpressions() to not exceed 64kb JVM limit while preparing arguments. + * - avoid some of nullabilty checking which are not needed because the expression is not + * nullable. + * - when needNullCheck == true, short circuit if we found one of arguments is null because + * preparing rest of arguments can be skipped in the case. + * + * @param ctx a [[CodegenContext]] + * @return (code to prepare arguments, argument string, result of argument null check) + */ + def prepareArguments(ctx: CodegenContext): (String, String, String) = { + + val resultIsNull = if (needNullCheck) { + val resultIsNull = ctx.freshName("resultIsNull") + ctx.addMutableState("boolean", resultIsNull, "") + resultIsNull + } else { + "false" + } + val argValues = arguments.map { e => + val argValue = ctx.freshName("argValue") + ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") + argValue + } + + val argCodes = if (needNullCheck) { + val reset = s"$resultIsNull = false;" + val argCodes = arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + val updateResultIsNull = if (e.nullable) { + s"$resultIsNull = ${expr.isNull};" + } else { + "" + } + s""" + if (!$resultIsNull) { + ${expr.code} + $updateResultIsNull + ${argValues(i)} = ${expr.value}; + } + """ + } + reset +: argCodes + } else { + arguments.zipWithIndex.map { case (e, i) => + val expr = e.genCode(ctx) + s""" + ${expr.code} + ${argValues(i)} = ${expr.value}; + """ + } + } + val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + + (argCode, argValues.mkString(", "), resultIsNull) + } +} + +/** * Invokes a static function, returning the result. By default, any of the arguments being null * will result in returning null instead of calling the function. * @@ -50,7 +122,7 @@ case class StaticInvoke( dataType: DataType, functionName: String, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { val objectName = staticObject.getName.stripSuffix("$") @@ -62,16 +134,10 @@ case class StaticInvoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") - val callFunc = s"$objectName.$functionName($argString)" + val (argCode, argString, resultIsNull) = prepareArguments(ctx) - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = false;" - } + val callFunc = s"$objectName.$functionName($argString)" // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. @@ -82,9 +148,9 @@ case class StaticInvoke( } val code = s""" - ${argGen.map(_.code).mkString("\n")} - $setIsNull - final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : $callFunc; + $argCode + boolean ${ev.isNull} = $resultIsNull; + final $javaType ${ev.value} = $resultIsNull ? ${ctx.defaultValue(dataType)} : $callFunc; $postNullCheck """ ev.copy(code = code) @@ -103,13 +169,15 @@ case class StaticInvoke( * @param functionName The name of the method to call. * @param dataType The expected return type of the function. * @param arguments An optional list of expressions, whos evaluation will be passed to the function. + * @param propagateNull When true, and any of the arguments is null, null will be returned instead + * of calling the function. */ case class Invoke( targetObject: Expression, functionName: String, dataType: DataType, arguments: Seq[Expression] = Nil, - propagateNull: Boolean = true) extends Expression with NonSQLExpression { + propagateNull: Boolean = true) extends InvokeLike { override def nullable: Boolean = true override def children: Seq[Expression] = targetObject +: arguments @@ -131,8 +199,8 @@ case class Invoke( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) - val argGen = arguments.map(_.genCode(ctx)) - val argString = argGen.map(_.value).mkString(", ") + + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val returnPrimitive = method.isDefined && method.get.getReturnType.isPrimitive val needTryCatch = method.isDefined && method.get.getExceptionTypes.nonEmpty @@ -164,12 +232,6 @@ case class Invoke( """ } - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s"boolean ${ev.isNull} = ${obj.isNull} || ${argGen.map(_.isNull).mkString(" || ")};" - } else { - s"boolean ${ev.isNull} = ${obj.isNull};" - } - // If the function can return null, we do an extra check to make sure our null bit is still set // correctly. val postNullCheck = if (ctx.defaultValue(dataType) == "null") { @@ -177,15 +239,19 @@ case class Invoke( } else { "" } + val code = s""" ${obj.code} - ${argGen.map(_.code).mkString("\n")} - $setIsNull + boolean ${ev.isNull} = true; $javaType ${ev.value} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - $evaluate + if (!${obj.isNull}) { + $argCode + ${ev.isNull} = $resultIsNull; + if (!${ev.isNull}) { + $evaluate + } + $postNullCheck } - $postNullCheck """ ev.copy(code = code) } @@ -223,10 +289,10 @@ case class NewInstance( arguments: Seq[Expression], propagateNull: Boolean, dataType: DataType, - outerPointer: Option[() => AnyRef]) extends Expression with NonSQLExpression { + outerPointer: Option[() => AnyRef]) extends InvokeLike { private val className = cls.getName - override def nullable: Boolean = propagateNull + override def nullable: Boolean = needNullCheck override def children: Seq[Expression] = arguments @@ -245,52 +311,25 @@ case class NewInstance( override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) - val argIsNulls = ctx.freshName("argIsNulls") - ctx.addMutableState("boolean[]", argIsNulls, - s"$argIsNulls = new boolean[${arguments.size}];") - val argValues = arguments.zipWithIndex.map { case (e, i) => - val argValue = ctx.freshName("argValue") - ctx.addMutableState(ctx.javaType(e.dataType), argValue, "") - argValue - } - val argCodes = arguments.zipWithIndex.map { case (e, i) => - val expr = e.genCode(ctx) - expr.code + s""" - $argIsNulls[$i] = ${expr.isNull}; - ${argValues(i)} = ${expr.value}; - """ - } - val argCode = ctx.splitExpressions(ctx.INPUT_ROW, argCodes) + val (argCode, argString, resultIsNull) = prepareArguments(ctx) val outer = outerPointer.map(func => Literal.fromObject(func()).genCode(ctx)) - var isNull = ev.isNull - val setIsNull = if (propagateNull && arguments.nonEmpty) { - s""" - boolean $isNull = false; - for (int idx = 0; idx < ${arguments.length}; idx++) { - if ($argIsNulls[idx]) { $isNull = true; break; } - } - """ - } else { - isNull = "false" - "" - } + ev.isNull = resultIsNull val constructorCall = outer.map { gen => - s"""${gen.value}.new ${cls.getSimpleName}(${argValues.mkString(", ")})""" + s"${gen.value}.new ${cls.getSimpleName}($argString)" }.getOrElse { - s"new $className(${argValues.mkString(", ")})" + s"new $className($argString)" } val code = s""" $argCode ${outer.map(_.code).getOrElse("")} - $setIsNull - final $javaType ${ev.value} = $isNull ? ${ctx.defaultValue(javaType)} : $constructorCall; - """ - ev.copy(code = code, isNull = isNull) + final $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(javaType)} : $constructorCall; + """ + ev.copy(code = code) } override def toString: String = s"newInstance($cls)" |