From 4eae1dbd7c4ec96727f92dd52e3fb9b26b0ec883 Mon Sep 17 00:00:00 2001 From: Sameer Agarwal Date: Mon, 18 Apr 2016 20:28:22 -0700 Subject: [SPARK-14718][SQL] Avoid mutating ExprCode in doGenCode ## What changes were proposed in this pull request? The `doGenCode` method currently takes in an `ExprCode`, mutates it and returns the java code to evaluate the given expression. It should instead just return a new `ExprCode` to avoid passing around mutable objects during code generation. ## How was this patch tested? Existing Tests Author: Sameer Agarwal Closes #12483 from sameeragarwal/new-exprcode-2. --- .../spark/sql/catalyst/analysis/unresolved.scala | 2 +- .../sql/catalyst/expressions/BoundAttribute.scala | 14 +-- .../spark/sql/catalyst/expressions/Cast.scala | 6 +- .../sql/catalyst/expressions/Expression.scala | 54 +++++------ .../sql/catalyst/expressions/InputFileName.scala | 7 +- .../expressions/MonotonicallyIncreasingID.scala | 8 +- .../expressions/ReferenceToExpressions.scala | 8 +- .../spark/sql/catalyst/expressions/ScalaUDF.scala | 7 +- .../spark/sql/catalyst/expressions/SortOrder.scala | 18 ++-- .../catalyst/expressions/SparkPartitionID.scala | 5 +- .../sql/catalyst/expressions/TimeWindow.scala | 6 +- .../sql/catalyst/expressions/arithmetic.scala | 52 +++++----- .../catalyst/expressions/bitwiseExpressions.scala | 2 +- .../expressions/codegen/CodegenFallback.scala | 12 +-- .../expressions/collectionOperations.scala | 4 +- .../catalyst/expressions/complexTypeCreator.scala | 56 +++++------ .../expressions/complexTypeExtractors.scala | 8 +- .../expressions/conditionalExpressions.scala | 28 +++--- .../catalyst/expressions/datetimeExpressions.scala | 106 ++++++++++----------- .../catalyst/expressions/decimalExpressions.scala | 8 +- .../spark/sql/catalyst/expressions/literals.scala | 16 ++-- .../sql/catalyst/expressions/mathExpressions.scala | 46 +++++---- .../spark/sql/catalyst/expressions/misc.scala | 26 +++-- .../catalyst/expressions/namedExpressions.scala | 2 +- .../sql/catalyst/expressions/nullExpressions.scala | 42 ++++---- .../spark/sql/catalyst/expressions/objects.scala | 79 +++++++-------- .../sql/catalyst/expressions/predicates.scala | 49 +++++----- .../catalyst/expressions/randomExpressions.scala | 16 ++-- .../catalyst/expressions/regexpExpressions.scala | 26 ++--- .../catalyst/expressions/stringExpressions.scala | 85 ++++++++--------- .../catalyst/expressions/NonFoldableLiteral.scala | 2 +- .../org/apache/spark/sql/execution/subquery.scala | 2 +- 32 files changed, 362 insertions(+), 440 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala index 90b7b60b1c..e83008e86e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/unresolved.scala @@ -153,7 +153,7 @@ case class UnresolvedGenerator(name: String, children: Seq[Expression]) extends override def eval(input: InternalRow = null): TraversableOnce[InternalRow] = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") override def terminate(): TraversableOnce[InternalRow] = diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index cf23884c44..99f156a935 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -58,7 +58,7 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val value = ctx.getValue(ctx.INPUT_ROW, dataType, ordinal.toString) if (ctx.currentVars != null && ctx.currentVars(ordinal) != null) { @@ -67,17 +67,13 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) ev.value = oev.value val code = oev.code oev.code = "" - code + ev.copy(code = code) } else if (nullable) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = ${ctx.INPUT_ROW}.isNullAt($ordinal); - $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value); - """ + $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($value);""") } else { - ev.isNull = "false" - s""" - $javaType ${ev.value} = $value; - """ + ev.copy(code = s"""$javaType ${ev.value} = $value;""", isNull = "false") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index ffb100ee54..b1e89b5de8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -446,11 +446,11 @@ case class Cast(child: Expression, dataType: DataType) extends UnaryExpression w protected override def nullSafeEval(input: Any): Any = cast(input) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx) - eval.code + - castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast) + ev.copy(code = eval.code + + castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast)) } // three function arguments are: child.primitive, result.primitive and result.isNull diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 11e3fd78d0..7dacdafb71 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -101,9 +101,8 @@ abstract class Expression extends TreeNode[Expression] { }.getOrElse { val isNull = ctx.freshName("isNull") val value = ctx.freshName("value") - val ve = ExprCode("", isNull, value) - ve.code = doGenCode(ctx, ve) - if (ve.code != "") { + val ve = doGenCode(ctx, ExprCode("", isNull, value)) + if (ve.code.nonEmpty) { // Add `this` in the comment. ve.copy(s"/* ${toCommentSafeString(this.toString)} */\n" + ve.code.trim) } else { @@ -119,9 +118,9 @@ abstract class Expression extends TreeNode[Expression] { * * @param ctx a [[CodegenContext]] * @param ev an [[ExprCode]] with unique terms. - * @return Java source code + * @return an [[ExprCode]] containing the Java source code to generate the given expression */ - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode /** * Returns `true` if this expression and all its children have been resolved to a specific schema @@ -216,7 +215,7 @@ trait Unevaluable extends Expression { final override def eval(input: InternalRow = null): Any = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") - final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = + final override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = throw new UnsupportedOperationException(s"Cannot evaluate expression: $this") } @@ -316,7 +315,7 @@ abstract class UnaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: String => String): String = { + f: String => String): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s"${ev.value} = ${f(eval)};" }) @@ -332,25 +331,23 @@ abstract class UnaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: String => String): String = { + f: String => String): ExprCode = { val childGen = child.genCode(ctx) val resultCode = f(childGen.value) if (nullable) { val nullSafeEval = ctx.nullSafeExec(child.nullable, childGen.isNull)(resultCode) - s""" + ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $nullSafeEval - """ + """) } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${childGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + $resultCode""", isNull = "false") } } } @@ -406,7 +403,7 @@ abstract class BinaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String) => String): String = { + f: (String, String) => String): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s"${ev.value} = ${f(eval1, eval2)};" }) @@ -423,7 +420,7 @@ abstract class BinaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String) => String): String = { + f: (String, String) => String): ExprCode = { val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) val resultCode = f(leftGen.value, rightGen.value) @@ -439,19 +436,17 @@ abstract class BinaryExpression extends Expression { } } - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; $nullSafeEval - """ + """) } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${leftGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + $resultCode""", isNull = "false") } } } @@ -548,7 +543,7 @@ abstract class TernaryExpression extends Expression { protected def defineCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String, String) => String): String = { + f: (String, String, String) => String): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { s"${ev.value} = ${f(eval1, eval2, eval3)};" }) @@ -565,7 +560,7 @@ abstract class TernaryExpression extends Expression { protected def nullSafeCodeGen( ctx: CodegenContext, ev: ExprCode, - f: (String, String, String) => String): String = { + f: (String, String, String) => String): ExprCode = { val leftGen = children(0).genCode(ctx) val midGen = children(1).genCode(ctx) val rightGen = children(2).genCode(ctx) @@ -584,20 +579,17 @@ abstract class TernaryExpression extends Expression { } } - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $nullSafeEval - """ + $nullSafeEval""") } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${leftGen.code} ${midGen.code} ${rightGen.code} ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $resultCode - """ + $resultCode""", isNull = "false") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala index 144efb751b..96929ecf56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/InputFileName.scala @@ -43,9 +43,8 @@ case class InputFileName() extends LeafExpression with Nondeterministic { InputFileNameHolder.getInputFileName() } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { - ev.isNull = "false" - s"final ${ctx.javaType(dataType)} ${ev.value} = " + - "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();" + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = " + + "org.apache.spark.rdd.InputFileNameHolder.getInputFileName();", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala index 9d3e80cad6..75c6bb2d84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/MonotonicallyIncreasingID.scala @@ -65,18 +65,16 @@ private[sql] case class MonotonicallyIncreasingID() extends LeafExpression with partitionMask + currentCount } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val countTerm = ctx.freshName("count") val partitionMaskTerm = ctx.freshName("partitionMask") ctx.addMutableState(ctx.JAVA_LONG, countTerm, s"$countTerm = 0L;") ctx.addMutableState(ctx.JAVA_LONG, partitionMaskTerm, s"$partitionMaskTerm = ((long) org.apache.spark.TaskContext.getPartitionId()) << 33;") - ev.isNull = "false" - s""" + ev.copy(code = s""" final ${ctx.javaType(dataType)} ${ev.value} = $partitionMaskTerm + $countTerm; - $countTerm++; - """ + $countTerm++;""", isNull = "false") } override def prettyName: String = "monotonically_increasing_id" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala index 98710f8d78..c4cc6c39b0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ReferenceToExpressions.scala @@ -59,7 +59,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) result.eval(projection(input)) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childrenGen = children.map(_.genCode(ctx)) val childrenVars = childrenGen.zip(children).map { case (childGen, child) => LambdaVariable(childGen.value, childGen.isNull, child.dataType) @@ -69,9 +69,7 @@ case class ReferenceToExpressions(result: Expression, children: Seq[Expression]) case b: BoundReference => childrenVars(b.ordinal) }.genCode(ctx) - ev.value = resultGen.value - ev.isNull = resultGen.isNull - - childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code + ExprCode(code = childrenGen.map(_.code).mkString("\n") + "\n" + resultGen.code, + isNull = resultGen.isNull, value = resultGen.value) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala index 1b19cdbadd..0038cf65e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ScalaUDF.scala @@ -991,7 +991,7 @@ case class ScalaUDF( override def doGenCode( ctx: CodegenContext, - ev: ExprCode): String = { + ev: ExprCode): ExprCode = { ctx.references += this @@ -1042,7 +1042,7 @@ case class ScalaUDF( s"(${ctx.boxedType(dataType)})${catalystConverterTerm}" + s".apply($funcTerm.apply(${funcArguments.mkString(", ")}));" - s""" + ev.copy(code = s""" $evalCode ${converters.mkString("\n")} $callFunc @@ -1051,8 +1051,7 @@ case class ScalaUDF( ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $resultTerm; - } - """ + }""") } private[this] val converter = CatalystTypeConverters.createToCatalystConverter(dataType) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala index beced2c646..e0c3b22a3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SortOrder.scala @@ -70,7 +70,7 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { override def eval(input: InternalRow): Any = throw new UnsupportedOperationException - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childCode = child.child.genCode(ctx) val input = childCode.value val BinaryPrefixCmp = classOf[BinaryPrefixComparator].getName @@ -104,14 +104,14 @@ case class SortPrefix(child: SortOrder) extends UnaryExpression { case _ => (0L, "0L") } - childCode.code + - s""" - |long ${ev.value} = ${nullValue}L; - |boolean ${ev.isNull} = false; - |if (!${childCode.isNull}) { - | ${ev.value} = $prefixCode; - |} - """.stripMargin + ev.copy(code = childCode.code + + s""" + |long ${ev.value} = ${nullValue}L; + |boolean ${ev.isNull} = false; + |if (!${childCode.isNull}) { + | ${ev.value} = $prefixCode; + |} + """.stripMargin) } override def dataType: DataType = LongType diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala index 8ca168a85b..71af59a7a8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/SparkPartitionID.scala @@ -44,11 +44,10 @@ private[sql] case class SparkPartitionID() extends LeafExpression with Nondeterm override protected def evalInternal(input: InternalRow): Int = partitionId - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val idTerm = ctx.freshName("partitionId") ctx.addMutableState(ctx.JAVA_INT, idTerm, s"$idTerm = org.apache.spark.TaskContext.getPartitionId();") - ev.isNull = "false" - s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;" + ev.copy(code = s"final ${ctx.javaType(dataType)} ${ev.value} = $idTerm;", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala index 46cbd12496..83fa447cf8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/TimeWindow.scala @@ -158,11 +158,11 @@ object TimeWindow { case class PreciseTimestamp(child: Expression) extends UnaryExpression with ExpectsInputTypes { override def inputTypes: Seq[AbstractDataType] = Seq(TimestampType) override def dataType: DataType = LongType - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - eval.code + + ev.copy(code = eval.code + s"""boolean ${ev.isNull} = ${eval.isNull}; |${ctx.javaType(dataType)} ${ev.value} = ${eval.value}; - """.stripMargin + """.stripMargin) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala index 25806c547b..b2df79a588 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala @@ -36,7 +36,7 @@ case class UnaryMinus(child: Expression) extends UnaryExpression private lazy val numeric = TypeUtils.getNumeric(dataType) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()") case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => { val originValue = ctx.freshName("origin") @@ -70,7 +70,7 @@ case class UnaryPositive(child: Expression) override def dataType: DataType = child.dataType - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = defineCodeGen(ctx, ev, c => c) protected override def nullSafeEval(input: Any): Any = input @@ -93,7 +93,7 @@ case class Abs(child: Expression) private lazy val numeric = TypeUtils.getNumeric(dataType) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, c => s"$c.abs()") case dt: NumericType => @@ -113,7 +113,7 @@ abstract class BinaryArithmetic extends BinaryOperator { def decimalMethod: String = sys.error("BinaryArithmetics must override either decimalMethod or genCode") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$decimalMethod($eval2)") // byte and short are casted into int when add, minus, times or divide @@ -147,7 +147,7 @@ case class Add(left: Expression, right: Expression) extends BinaryArithmetic wit } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$plus($eval2)") case ByteType | ShortType => @@ -179,7 +179,7 @@ case class Subtract(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = dataType match { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match { case dt: DecimalType => defineCodeGen(ctx, ev, (eval1, eval2) => s"$eval1.$$minus($eval2)") case ByteType | ShortType => @@ -241,7 +241,7 @@ case class Divide(left: Expression, right: Expression) /** * Special case handling due to division by 0 => null. */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -256,7 +256,7 @@ case class Divide(left: Expression, right: Expression) s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -265,10 +265,9 @@ case class Divide(left: Expression, right: Expression) } else { ${eval1.code} ${ev.value} = $divide; - } - """ + }""") } else { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -281,8 +280,7 @@ case class Divide(left: Expression, right: Expression) } else { ${ev.value} = $divide; } - } - """ + }""") } } } @@ -320,7 +318,7 @@ case class Remainder(left: Expression, right: Expression) /** * Special case handling for x % 0 ==> null. */ - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val isZero = if (dataType.isInstanceOf[DecimalType]) { @@ -335,7 +333,7 @@ case class Remainder(left: Expression, right: Expression) s"($javaType)(${eval1.value} $symbol ${eval2.value})" } if (!left.nullable && !right.nullable) { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -344,10 +342,9 @@ case class Remainder(left: Expression, right: Expression) } else { ${eval1.code} ${ev.value} = $remainder; - } - """ + }""") } else { - s""" + ev.copy(code = s""" ${eval2.code} boolean ${ev.isNull} = false; $javaType ${ev.value} = ${ctx.defaultValue(javaType)}; @@ -360,8 +357,7 @@ case class Remainder(left: Expression, right: Expression) } else { ${ev.value} = $remainder; } - } - """ + }""") } } } @@ -393,12 +389,12 @@ case class MaxOf(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) - eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; ${ctx.javaType(left.dataType)} ${ev.value} = ${ctx.defaultValue(left.dataType)}; @@ -415,8 +411,7 @@ case class MaxOf(left: Expression, right: Expression) } else { ${ev.value} = ${eval2.value}; } - } - """ + }""") } override def symbol: String = "max" @@ -449,12 +444,12 @@ case class MinOf(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val compCode = ctx.genComp(dataType, eval1.value, eval2.value) - eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.isNull} = false; ${ctx.javaType(left.dataType)} ${ev.value} = ${ctx.defaultValue(left.dataType)}; @@ -471,8 +466,7 @@ case class MinOf(left: Expression, right: Expression) } else { ${ev.value} = ${eval2.value}; } - } - """ + }""") } override def symbol: String = "min" @@ -503,7 +497,7 @@ case class Pmod(left: Expression, right: Expression) extends BinaryArithmetic wi case _: DecimalType => pmod(left.asInstanceOf[Decimal], right.asInstanceOf[Decimal]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { dataType match { case dt: DecimalType => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala index 8fd8a9bd4e..3a0a882e38 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/bitwiseExpressions.scala @@ -130,7 +130,7 @@ case class BitwiseNot(child: Expression) extends UnaryExpression with ExpectsInp ((evalE: Long) => ~evalE).asInstanceOf[(Any) => Any] } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"(${ctx.javaType(dataType)}) ~($c)") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala index 1e446c498d..2bd77c65c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenFallback.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.util.toCommentSafeString */ trait CodegenFallback extends Expression { - protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { foreach { case n: Nondeterministic => n.setInitialValues() case _ => @@ -37,22 +37,20 @@ trait CodegenFallback extends Expression { ctx.references += this val objectTerm = ctx.freshName("obj") if (nullable) { - s""" + ev.copy(code = s""" /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); boolean ${ev.isNull} = $objectTerm == null; ${ctx.javaType(this.dataType)} ${ev.value} = ${ctx.defaultValue(this.dataType)}; if (!${ev.isNull}) { ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - } - """ + }""") } else { - ev.isNull = "false" - s""" + ev.copy(code = s""" /* expression: ${toCommentSafeString(this.toString)} */ Object $objectTerm = ((Expression) references[$idx]).eval($input); ${ctx.javaType(this.dataType)} ${ev.value} = (${ctx.boxedType(this.dataType)}) $objectTerm; - """ + """, isNull = "false") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 8cb691c9b1..864288394e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -37,7 +37,7 @@ case class Size(child: Expression) extends UnaryExpression with ExpectsInputType case _: MapType => value.asInstanceOf[MapData].numElements() } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).numElements();") } } @@ -180,7 +180,7 @@ case class ArrayContains(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (arr, value) => { val i = ctx.freshName("i") val getValue = ctx.getValue(arr, right.dataType, i) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index a7a59d8784..3d4819c55a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -48,13 +48,12 @@ case class CreateArray(children: Seq[Expression]) extends Expression { new GenericArrayData(children.map(_.eval(input)).toArray) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val values = ctx.freshName("values") - s""" + ev.copy(code = s""" final boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + + final Object[] $values = new Object[${children.size}];""" + children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" @@ -65,7 +64,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { } """ }.mkString("\n") + - s"final ArrayData ${ev.value} = new $arrayClass($values);" + s"final ArrayData ${ev.value} = new $arrayClass($values);") } override def prettyName: String = "array" @@ -115,19 +114,18 @@ case class CreateMap(children: Seq[Expression]) extends Expression { new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName val mapClass = classOf[ArrayBasedMapData].getName val keyArray = ctx.freshName("keyArray") val valueArray = ctx.freshName("valueArray") val keyData = s"new $arrayClass($keyArray)" val valueData = s"new $arrayClass($valueArray)" - s""" + ev.copy(code = s""" final boolean ${ev.isNull} = false; final Object[] $keyArray = new Object[${keys.size}]; - final Object[] $valueArray = new Object[${values.size}]; - """ + keys.zipWithIndex.map { - case (key, i) => + final Object[] $valueArray = new Object[${values.size}];""" + + keys.zipWithIndex.map { case (key, i) => val eval = key.genCode(ctx) s""" ${eval.code} @@ -148,7 +146,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { $valueArray[$i] = ${eval.value}; } """ - }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);" + }.mkString("\n") + s"final MapData ${ev.value} = new $mapClass($keyData, $valueData);") } override def prettyName: String = "map" @@ -181,13 +179,12 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - s""" + ev.copy(code = s""" boolean ${ev.isNull} = false; - final Object[] $values = new Object[${children.size}]; - """ + + final Object[] $values = new Object[${children.size}];""" + children.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" @@ -195,10 +192,9 @@ case class CreateStruct(children: Seq[Expression]) extends Expression { $values[$i] = null; } else { $values[$i] = ${eval.value}; - } - """ + }""" }.mkString("\n") + - s"final InternalRow ${ev.value} = new $rowClass($values);" + s"final InternalRow ${ev.value} = new $rowClass($values);") } override def prettyName: String = "struct" @@ -262,13 +258,12 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { InternalRow(valExprs.map(_.eval(input)): _*) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericInternalRow].getName val values = ctx.freshName("values") - s""" + ev.copy(code = s""" boolean ${ev.isNull} = false; - final Object[] $values = new Object[${valExprs.size}]; - """ + + final Object[] $values = new Object[${valExprs.size}];""" + valExprs.zipWithIndex.map { case (e, i) => val eval = e.genCode(ctx) eval.code + s""" @@ -276,10 +271,9 @@ case class CreateNamedStruct(children: Seq[Expression]) extends Expression { $values[$i] = null; } else { $values[$i] = ${eval.value}; - } - """ + }""" }.mkString("\n") + - s"final InternalRow ${ev.value} = new $rowClass($values);" + s"final InternalRow ${ev.value} = new $rowClass($values);") } override def prettyName: String = "named_struct" @@ -314,11 +308,9 @@ case class CreateStructUnsafe(children: Seq[Expression]) extends Expression { InternalRow(children.map(_.eval(input)): _*) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, children) - ev.isNull = eval.isNull - ev.value = eval.value - eval.code + ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) } override def prettyName: String = "struct_unsafe" @@ -354,11 +346,9 @@ case class CreateNamedStructUnsafe(children: Seq[Expression]) extends Expression InternalRow(valExprs.map(_.eval(input)): _*) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = GenerateUnsafeProjection.createCode(ctx, valExprs) - ev.isNull = eval.isNull - ev.value = eval.value - eval.code + ExprCode(code = eval.code, isNull = eval.isNull, value = eval.value) } override def prettyName: String = "named_struct_unsafe" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala index b5ff9f55d5..3b4468f55c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeExtractors.scala @@ -122,7 +122,7 @@ case class GetStructField(child: Expression, ordinal: Int, name: Option[String] protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[InternalRow].get(ordinal, childSchema(ordinal).dataType) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { if (nullable) { s""" @@ -179,7 +179,7 @@ case class GetArrayStructFields( new GenericArrayData(result) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, eval => { val n = ctx.freshName("n") @@ -239,7 +239,7 @@ case class GetArrayItem(child: Expression, ordinal: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { val index = ctx.freshName("index") s""" @@ -302,7 +302,7 @@ case class GetMapValue(child: Expression, key: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = ctx.freshName("index") val length = ctx.freshName("length") val keys = ctx.freshName("keys") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala index a4c800a26c..336649c0fd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/conditionalExpressions.scala @@ -55,12 +55,12 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val condEval = predicate.genCode(ctx) val trueEval = trueValue.genCode(ctx) val falseEval = falseValue.genCode(ctx) - s""" + ev.copy(code = s""" ${condEval.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -72,8 +72,7 @@ case class If(predicate: Expression, trueValue: Expression, falseValue: Expressi ${falseEval.code} ${ev.isNull} = ${falseEval.isNull}; ${ev.value} = ${falseEval.value}; - } - """ + }""") } override def toString: String = s"if ($predicate) $trueValue else $falseValue" @@ -147,7 +146,7 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E branches.length < CaseWhen.MAX_NUM_CASES_FOR_CODEGEN } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (!shouldCodegen) { // Fallback to interpreted mode if there are too many branches, as it may reach the // 64K limit (limit on bytecode size for a single function). @@ -198,11 +197,10 @@ case class CaseWhen(branches: Seq[(Expression, Expression)], elseValue: Option[E generatedCode += "}\n" * cases.size - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - $generatedCode - """ + $generatedCode""") } override def toString: String = { @@ -298,7 +296,7 @@ case class Least(children: Seq[Expression]) extends Expression { }) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) @@ -312,12 +310,11 @@ case class Least(children: Seq[Expression]) extends Expression { } """ } - s""" + ev.copy(code = s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")} - """ + ${rest.map(updateEval).mkString("\n")}""") } } @@ -359,7 +356,7 @@ case class Greatest(children: Seq[Expression]) extends Expression { }) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evalChildren = children.map(_.genCode(ctx)) val first = evalChildren(0) val rest = evalChildren.drop(1) @@ -373,12 +370,11 @@ case class Greatest(children: Seq[Expression]) extends Expression { } """ } - s""" + ev.copy(code = s""" ${first.code} boolean ${ev.isNull} = ${first.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${first.value}; - ${rest.map(updateEval).mkString("\n")} - """ + ${rest.map(updateEval).mkString("\n")}""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala index 18649a39cb..69c32f447e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/datetimeExpressions.scala @@ -91,7 +91,7 @@ case class DateAdd(startDate: Expression, days: Expression) start.asInstanceOf[Int] + d.asInstanceOf[Int] } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd + $d;""" }) @@ -119,7 +119,7 @@ case class DateSub(startDate: Expression, days: Expression) start.asInstanceOf[Int] - d.asInstanceOf[Int] } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, d) => { s"""${ev.value} = $sd - $d;""" }) @@ -141,7 +141,7 @@ case class Hour(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getHours(timestamp.asInstanceOf[Long]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getHours($c)") } @@ -160,7 +160,7 @@ case class Minute(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getMinutes(timestamp.asInstanceOf[Long]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMinutes($c)") } @@ -179,7 +179,7 @@ case class Second(child: Expression) extends UnaryExpression with ImplicitCastIn DateTimeUtils.getSeconds(timestamp.asInstanceOf[Long]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getSeconds($c)") } @@ -198,7 +198,7 @@ case class DayOfYear(child: Expression) extends UnaryExpression with ImplicitCas DateTimeUtils.getDayInYear(date.asInstanceOf[Int]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayInYear($c)") } @@ -217,7 +217,7 @@ case class Year(child: Expression) extends UnaryExpression with ImplicitCastInpu DateTimeUtils.getYear(date.asInstanceOf[Int]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getYear($c)") } @@ -235,7 +235,7 @@ case class Quarter(child: Expression) extends UnaryExpression with ImplicitCastI DateTimeUtils.getQuarter(date.asInstanceOf[Int]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getQuarter($c)") } @@ -254,7 +254,7 @@ case class Month(child: Expression) extends UnaryExpression with ImplicitCastInp DateTimeUtils.getMonth(date.asInstanceOf[Int]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getMonth($c)") } @@ -273,7 +273,7 @@ case class DayOfMonth(child: Expression) extends UnaryExpression with ImplicitCa DateTimeUtils.getDayOfMonth(date.asInstanceOf[Int]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, c => s"$dtu.getDayOfMonth($c)") } @@ -300,7 +300,7 @@ case class WeekOfYear(child: Expression) extends UnaryExpression with ImplicitCa c.get(Calendar.WEEK_OF_YEAR) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, time => { val cal = classOf[Calendar].getName val c = ctx.freshName("cal") @@ -335,7 +335,7 @@ case class DateFormatClass(left: Expression, right: Expression) extends BinaryEx UTF8String.fromString(sdf.format(new java.util.Date(timestamp.asInstanceOf[Long] / 1000))) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val sdf = classOf[SimpleDateFormat].getName defineCodeGen(ctx, ev, (timestamp, format) => { s"""UTF8String.fromString((new $sdf($format.toString())) @@ -430,20 +430,19 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { left.dataType match { case StringType if right.foldable => val sdf = classOf[SimpleDateFormat].getName val fString = if (constFormat == null) null else constFormat.toString val formatter = ctx.freshName("formatter") if (fString == null) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { val eval1 = left.genCode(ctx) - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -455,8 +454,7 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { } catch (java.lang.Throwable e) { ${ev.isNull} = true; } - } - """ + }""") } case StringType => val sdf = classOf[SimpleDateFormat].getName @@ -472,25 +470,23 @@ abstract class UnixTime extends BinaryExpression with ExpectsInputTypes { }) case TimestampType => val eval1 = left.genCode(ctx) - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = ${eval1.value} / 1000000L; - } - """ + }""") case DateType => val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") val eval1 = left.genCode(ctx) - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = ${eval1.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.daysToMillis(${eval1.value}) / 1000L; - } - """ + }""") } } @@ -550,17 +546,16 @@ case class FromUnixTime(sec: Expression, format: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val sdf = classOf[SimpleDateFormat].getName if (format.foldable) { if (constFormat == null) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { val t = left.genCode(ctx) - s""" + ev.copy(code = s""" ${t.code} boolean ${ev.isNull} = ${t.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -571,8 +566,7 @@ case class FromUnixTime(sec: Expression, format: Expression) } catch (java.lang.Throwable e) { ${ev.isNull} = true; } - } - """ + }""") } } else { nullSafeCodeGen(ctx, ev, (seconds, f) => { @@ -605,7 +599,7 @@ case class LastDay(startDate: Expression) extends UnaryExpression with ImplicitC DateTimeUtils.getLastDayOfMonth(date.asInstanceOf[Int]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, sd => s"$dtu.getLastDayOfMonth($sd)") } @@ -646,7 +640,7 @@ case class NextDay(startDate: Expression, dayOfWeek: Expression) } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (sd, dowS) => { val dateTimeUtilClass = DateTimeUtils.getClass.getName.stripSuffix("$") val dayOfWeekTerm = ctx.freshName("dayOfWeek") @@ -698,7 +692,7 @@ case class TimeAdd(start: Expression, interval: Expression) start.asInstanceOf[Long], itvl.months, itvl.microseconds) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, $i.months, $i.microseconds)""" @@ -725,21 +719,21 @@ case class FromUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() if (tz == null) { - s""" + ev.copy(code = s""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; - """.stripMargin + """.stripMargin) } else { val tzTerm = ctx.freshName("tz") val tzClass = classOf[TimeZone].getName ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") val eval = left.genCode(ctx) - s""" + ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -747,7 +741,7 @@ case class FromUTCTimestamp(left: Expression, right: Expression) | ${ev.value} = ${eval.value} + | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; |} - """.stripMargin + """.stripMargin) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { @@ -777,7 +771,7 @@ case class TimeSub(start: Expression, interval: Expression) start.asInstanceOf[Long], 0 - itvl.months, 0 - itvl.microseconds) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, i) => { s"""$dtu.timestampAddInterval($sd, 0 - $i.months, 0 - $i.microseconds)""" @@ -805,7 +799,7 @@ case class AddMonths(startDate: Expression, numMonths: Expression) DateTimeUtils.dateAddMonths(start.asInstanceOf[Int], months.asInstanceOf[Int]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (sd, m) => { s"""$dtu.dateAddMonths($sd, $m)""" @@ -835,7 +829,7 @@ case class MonthsBetween(date1: Expression, date2: Expression) DateTimeUtils.monthsBetween(t1.asInstanceOf[Long], t2.asInstanceOf[Long]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") defineCodeGen(ctx, ev, (l, r) => { s"""$dtu.monthsBetween($l, $r)""" @@ -864,21 +858,21 @@ case class ToUTCTimestamp(left: Expression, right: Expression) timezone.asInstanceOf[UTF8String].toString) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (right.foldable) { val tz = right.eval() if (tz == null) { - s""" + ev.copy(code = s""" |boolean ${ev.isNull} = true; |long ${ev.value} = 0; - """.stripMargin + """.stripMargin) } else { val tzTerm = ctx.freshName("tz") val tzClass = classOf[TimeZone].getName ctx.addMutableState(tzClass, tzTerm, s"""$tzTerm = $tzClass.getTimeZone("$tz");""") val eval = left.genCode(ctx) - s""" + ev.copy(code = s""" |${eval.code} |boolean ${ev.isNull} = ${eval.isNull}; |long ${ev.value} = 0; @@ -886,7 +880,7 @@ case class ToUTCTimestamp(left: Expression, right: Expression) | ${ev.value} = ${eval.value} - | ${tzTerm}.getOffset(${eval.value} / 1000) * 1000L; |} - """.stripMargin + """.stripMargin) } } else { defineCodeGen(ctx, ev, (timestamp, format) => { @@ -912,7 +906,7 @@ case class ToDate(child: Expression) extends UnaryExpression with ImplicitCastIn override def eval(input: InternalRow): Any = child.eval(input) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, d => d) } @@ -959,25 +953,23 @@ case class TruncDate(date: Expression, format: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val dtu = DateTimeUtils.getClass.getName.stripSuffix("$") if (format.foldable) { if (truncLevel == -1) { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { val d = date.genCode(ctx) - s""" + ev.copy(code = s""" ${d.code} boolean ${ev.isNull} = ${d.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $dtu.truncDate(${d.value}, $truncLevel); - } - """ + }""") } } else { nullSafeCodeGen(ctx, ev, (dateVal, fmt) => { @@ -1013,7 +1005,7 @@ case class DateDiff(endDate: Expression, startDate: Expression) end.asInstanceOf[Int] - start.asInstanceOf[Int] } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (end, start) => s"$end - $start") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala index 5629ee1a14..fa5dea6841 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/decimalExpressions.scala @@ -34,7 +34,7 @@ case class UnscaledValue(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input.asInstanceOf[Decimal].toUnscaledLong - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.toUnscaledLong()") } } @@ -53,7 +53,7 @@ case class MakeDecimal(child: Expression, precision: Int, scale: Int) extends Un protected override def nullSafeEval(input: Any): Any = Decimal(input.asInstanceOf[Long], precision, scale) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s""" ${ev.value} = (new Decimal()).setOrNull($eval, $precision, $scale); @@ -71,7 +71,7 @@ case class PromotePrecision(child: Expression) extends UnaryExpression { override def dataType: DataType = child.dataType override def eval(input: InternalRow): Any = child.eval(input) override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = "" + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") override def prettyName: String = "promote_precision" override def sql: String = child.sql } @@ -93,7 +93,7 @@ case class CheckOverflow(child: Expression, dataType: DecimalType) extends Unary } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { val tmp = ctx.freshName("tmp") s""" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala index bdadbfbbb0..e9dda588de 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala @@ -191,17 +191,17 @@ case class Literal protected (value: Any, dataType: DataType) override def eval(input: InternalRow): Any = value - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // change the isNull and primitive to consts, to inline them if (value == null) { ev.isNull = "true" - s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};" + ev.copy(s"final ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};") } else { dataType match { case BooleanType => ev.isNull = "false" ev.value = value.toString - "" + ev.copy("") case FloatType => val v = value.asInstanceOf[Float] if (v.isNaN || v.isInfinite) { @@ -209,7 +209,7 @@ case class Literal protected (value: Any, dataType: DataType) } else { ev.isNull = "false" ev.value = s"${value}f" - "" + ev.copy("") } case DoubleType => val v = value.asInstanceOf[Double] @@ -218,20 +218,20 @@ case class Literal protected (value: Any, dataType: DataType) } else { ev.isNull = "false" ev.value = s"${value}D" - "" + ev.copy("") } case ByteType | ShortType => ev.isNull = "false" ev.value = s"(${ctx.javaType(dataType)})$value" - "" + ev.copy("") case IntegerType | DateType => ev.isNull = "false" ev.value = value.toString - "" + ev.copy("") case TimestampType | LongType => ev.isNull = "false" ev.value = s"${value}L" - "" + ev.copy("") // eval() version may be faster for non-primitive types case other => super[CodegenFallback].doGenCode(ctx, ev) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala index 231382e6bb..5152265152 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/mathExpressions.scala @@ -70,7 +70,7 @@ abstract class UnaryMathExpression(val f: Double => Double, name: String) // name of function in java.lang.Math def funcName: String = name.toLowerCase - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"java.lang.Math.${funcName}($c)") } } @@ -88,7 +88,7 @@ abstract class UnaryLogExpression(f: Double => Double, name: String) if (d <= yAsymptote) null else f(d) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -123,7 +123,7 @@ abstract class BinaryMathExpression(f: (Double, Double) => Double, name: String) f(input1.asInstanceOf[Double], input2.asInstanceOf[Double]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.${name.toLowerCase}($c1, $c2)") } } @@ -197,7 +197,7 @@ case class Ceil(child: Expression) extends UnaryMathExpression(math.ceil, "CEIL" case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].ceil } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -242,7 +242,7 @@ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expre toBase.asInstanceOf[Int]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val numconv = NumberConverter.getClass.getName.stripSuffix("$") nullSafeCodeGen(ctx, ev, (num, from, to) => s""" @@ -284,7 +284,7 @@ case class Floor(child: Expression) extends UnaryMathExpression(math.floor, "FLO case DecimalType.Fixed(precision, scale) => input.asInstanceOf[Decimal].floor } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case DecimalType.Fixed(_, 0) => defineCodeGen(ctx, ev, c => s"$c") case DecimalType.Fixed(precision, scale) => @@ -346,7 +346,7 @@ case class Factorial(child: Expression) extends UnaryExpression with ImplicitCas } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, eval => { s""" if ($eval > 20 || $eval < 0) { @@ -370,7 +370,7 @@ case class Log(child: Expression) extends UnaryLogExpression(math.log, "LOG") extended = "> SELECT _FUNC_(2);\n 1.0") case class Log2(child: Expression) extends UnaryLogExpression((x: Double) => math.log(x) / math.log(2), "LOG2") { - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s""" if ($c <= $yAsymptote) { @@ -458,7 +458,7 @@ case class Bin(child: Expression) protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(jl.Long.toBinaryString(input.asInstanceOf[Long])) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c) => s"UTF8String.fromString(java.lang.Long.toBinaryString($c))") } @@ -556,7 +556,7 @@ case class Hex(child: Expression) extends UnaryExpression with ImplicitCastInput case StringType => Hex.hex(num.asInstanceOf[UTF8String].getBytes) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s"${ev.value} = " + (child.dataType match { @@ -584,7 +584,7 @@ case class Unhex(child: Expression) extends UnaryExpression with ImplicitCastInp protected override def nullSafeEval(num: Any): Any = Hex.unhex(num.asInstanceOf[UTF8String].getBytes) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (c) => { val hex = Hex.getClass.getName.stripSuffix("$") s""" @@ -613,7 +613,7 @@ case class Atan2(left: Expression, right: Expression) math.atan2(input1.asInstanceOf[Double] + 0.0, input2.asInstanceOf[Double] + 0.0) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.atan2($c1 + 0.0, $c2 + 0.0)") } } @@ -623,7 +623,7 @@ case class Atan2(left: Expression, right: Expression) extended = "> SELECT _FUNC_(2, 3);\n 8.0") case class Pow(left: Expression, right: Expression) extends BinaryMathExpression(math.pow, "POWER") { - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"java.lang.Math.pow($c1, $c2)") } } @@ -653,7 +653,7 @@ case class ShiftLeft(left: Expression, right: Expression) } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left << $right") } } @@ -683,7 +683,7 @@ case class ShiftRight(left: Expression, right: Expression) } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >> $right") } } @@ -713,7 +713,7 @@ case class ShiftRightUnsigned(left: Expression, right: Expression) } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (left, right) => s"$left >>> $right") } } @@ -753,7 +753,7 @@ case class Logarithm(left: Expression, right: Expression) if (dLeft <= 0.0 || dRight <= 0.0) null else math.log(dRight) / math.log(dLeft) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (left.isInstanceOf[EulerNumber]) { nullSafeCodeGen(ctx, ev, (c1, c2) => s""" @@ -874,7 +874,7 @@ abstract class RoundBase(child: Expression, scale: Expression, } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val ce = child.genCode(ctx) val evaluationCode = child.dataType match { @@ -937,19 +937,17 @@ abstract class RoundBase(child: Expression, scale: Expression, } if (scaleV == null) { // if scale is null, no need to eval its child at all - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};""") } else { - s""" + ev.copy(code = s""" ${ce.code} boolean ${ev.isNull} = ${ce.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { $evaluationCode - } - """ + }""") } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala index 8bef2524cc..1c0787bf92 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/misc.scala @@ -49,7 +49,7 @@ case class Md5(child: Expression) extends UnaryExpression with ImplicitCastInput protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.md5Hex(input.asInstanceOf[Array[Byte]])) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.md5Hex($c))") } @@ -102,7 +102,7 @@ case class Sha2(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val digestUtils = "org.apache.commons.codec.digest.DigestUtils" nullSafeCodeGen(ctx, ev, (eval1, eval2) => { s""" @@ -147,7 +147,7 @@ case class Sha1(child: Expression) extends UnaryExpression with ImplicitCastInpu protected override def nullSafeEval(input: Any): Any = UTF8String.fromString(DigestUtils.sha1Hex(input.asInstanceOf[Array[Byte]])) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"UTF8String.fromString(org.apache.commons.codec.digest.DigestUtils.sha1Hex($c))" ) @@ -173,7 +173,7 @@ case class Crc32(child: Expression) extends UnaryExpression with ImplicitCastInp checksum.getValue } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val CRC32 = "java.util.zip.CRC32" nullSafeCodeGen(ctx, ev, value => { s""" @@ -244,7 +244,7 @@ abstract class HashExpression[E] extends Expression { protected def computeHash(value: Any, dataType: DataType, seed: E): E - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { ev.isNull = "false" val childrenHash = children.map { child => val childGen = child.genCode(ctx) @@ -253,10 +253,9 @@ abstract class HashExpression[E] extends Expression { } }.mkString("\n") - s""" + ev.copy(code = s""" ${ctx.javaType(dataType)} ${ev.value} = $seed; - $childrenHash - """ + $childrenHash""") } private def nullSafeElementHash( @@ -477,7 +476,7 @@ case class PrintToStderr(child: Expression) extends UnaryExpression { protected override def nullSafeEval(input: Any): Any = input - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => s""" | System.err.println("Result of ${child.simpleString} is " + $c); @@ -510,15 +509,12 @@ case class AssertTrue(child: Expression) extends UnaryExpression with ImplicitCa } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ev.isNull = "true" - ev.value = "null" - s"""${eval.code} + ExprCode(code = s"""${eval.code} |if (${eval.isNull} || !${eval.value}) { | throw new RuntimeException("'${child.simpleString}' is not true."); - |} - """.stripMargin + |}""".stripMargin, isNull = "true", value = "null") } override def sql: String = s"assert_true(${child.sql})" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala index b0434674c6..c083f12724 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala @@ -143,7 +143,7 @@ case class Alias(child: Expression, name: String)( /** Just a simple passthrough for code generation. */ override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx) - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = "" + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = ev.copy("") override def dataType: DataType = child.dataType override def nullable: Boolean = child.nullable diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala index d9c06e3b99..421200e147 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/nullExpressions.scala @@ -64,15 +64,14 @@ case class Coalesce(children: Seq[Expression]) extends Expression { result } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val first = children(0) val rest = children.drop(1) val firstEval = first.genCode(ctx) - s""" + ev.copy(code = s""" ${firstEval.code} boolean ${ev.isNull} = ${firstEval.isNull}; - ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value}; - """ + + ${ctx.javaType(dataType)} ${ev.value} = ${firstEval.value};""" + rest.map { e => val eval = e.genCode(ctx) s""" @@ -84,7 +83,7 @@ case class Coalesce(children: Seq[Expression]) extends Expression { } } """ - }.mkString("\n") + }.mkString("\n")) } } @@ -113,16 +112,15 @@ case class IsNaN(child: Expression) extends UnaryExpression } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) child.dataType match { case DoubleType | FloatType => - s""" + ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value}); - """ + ${ev.value} = !${eval.isNull} && Double.isNaN(${eval.value});""") } } } @@ -155,12 +153,12 @@ case class NaNvl(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val leftGen = left.genCode(ctx) val rightGen = right.genCode(ctx) left.dataType match { case DoubleType | FloatType => - s""" + ev.copy(code = s""" ${leftGen.code} boolean ${ev.isNull} = false; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -177,8 +175,7 @@ case class NaNvl(left: Expression, right: Expression) ${ev.value} = ${rightGen.value}; } } - } - """ + }""") } } } @@ -196,11 +193,9 @@ case class IsNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) == null } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ev.isNull = "false" - ev.value = eval.isNull - eval.code + ExprCode(code = eval.code, isNull = "false", value = eval.isNull) } override def sql: String = s"(${child.sql} IS NULL)" @@ -219,11 +214,9 @@ case class IsNotNull(child: Expression) extends UnaryExpression with Predicate { child.eval(input) != null } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval = child.genCode(ctx) - ev.isNull = "false" - ev.value = s"(!(${eval.isNull}))" - eval.code + ExprCode(code = eval.code, isNull = "false", value = s"(!(${eval.isNull}))") } override def sql: String = s"(${child.sql} IS NOT NULL)" @@ -259,7 +252,7 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate numNonNulls >= n } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val nonnull = ctx.freshName("nonnull") val code = children.map { e => val eval = e.genCode(ctx) @@ -284,11 +277,10 @@ case class AtLeastNNonNulls(n: Int, children: Seq[Expression]) extends Predicate """ } }.mkString("\n") - s""" + ev.copy(code = s""" int $nonnull = 0; $code boolean ${ev.isNull} = false; - boolean ${ev.value} = $nonnull >= $n; - """ + boolean ${ev.value} = $nonnull >= $n;""") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala index f5f102a578..1e418540a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects.scala @@ -59,7 +59,7 @@ case class StaticInvoke( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -72,7 +72,7 @@ case class StaticInvoke( } val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - s""" + ev.copy(code = s""" ${argGen.map(_.code).mkString("\n")} boolean ${ev.isNull} = !$argsNonNull; @@ -82,14 +82,14 @@ case class StaticInvoke( ${ev.value} = $objectName.$functionName($argString); $objNullCheck } - """ + """) } else { - s""" + ev.copy(code = s""" ${argGen.map(_.code).mkString("\n")} $javaType ${ev.value} = $objectName.$functionName($argString); final boolean ${ev.isNull} = ${ev.value} == null; - """ + """) } } } @@ -148,7 +148,7 @@ case class Invoke( case _ => identity[String] _ } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val obj = targetObject.genCode(ctx) val argGen = arguments.map(_.genCode(ctx)) @@ -178,12 +178,12 @@ case class Invoke( """ } - s""" + ev.copy(code = s""" ${obj.code} ${argGen.map(_.code).mkString("\n")} $evaluate $objNullCheck - """ + """) } override def toString: String = s"$targetObject.$functionName" @@ -239,7 +239,7 @@ case class NewInstance( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val argGen = arguments.map(_.genCode(ctx)) val argString = argGen.map(_.value).mkString(", ") @@ -261,7 +261,7 @@ case class NewInstance( if (propagateNull && argGen.nonEmpty) { val argsNonNull = s"!(${argGen.map(_.isNull).mkString(" || ")})" - s""" + ev.copy(code = s""" $setup boolean ${ev.isNull} = true; @@ -270,14 +270,14 @@ case class NewInstance( ${ev.value} = $constructorCall; ${ev.isNull} = false; } - """ + """) } else { - s""" + ev.copy(code = s""" $setup final $javaType ${ev.value} = $constructorCall; final boolean ${ev.isNull} = false; - """ + """) } } @@ -302,17 +302,17 @@ case class UnwrapOption( override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val inputObject = child.genCode(ctx) - s""" + ev.copy(code = s""" ${inputObject.code} boolean ${ev.isNull} = ${inputObject.value} == null || ${inputObject.value}.isEmpty(); $javaType ${ev.value} = ${ev.isNull} ? ${ctx.defaultValue(dataType)} : ($javaType)${inputObject.value}.get(); - """ + """) } } @@ -335,17 +335,17 @@ case class WrapOption(child: Expression, optType: DataType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val inputObject = child.genCode(ctx) - s""" + ev.copy(code = s""" ${inputObject.code} boolean ${ev.isNull} = false; scala.Option ${ev.value} = ${inputObject.isNull} ? scala.Option$$.MODULE$$.apply(null) : new scala.Some(${inputObject.value}); - """ + """) } } @@ -444,7 +444,7 @@ case class MapObjects private( override def dataType: DataType = ArrayType(lambdaFunction.dataType) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val javaType = ctx.javaType(dataType) val elementJavaType = ctx.javaType(loopVar.dataType) ctx.addMutableState("boolean", loopVar.isNull, "") @@ -474,7 +474,7 @@ case class MapObjects private( s"${loopVar.isNull} = ${genInputData.isNull} || ${loopVar.value} == null;" } - s""" + ev.copy(code = s""" ${genInputData.code} boolean ${ev.isNull} = ${genInputData.value} == null; @@ -504,7 +504,7 @@ case class MapObjects private( ${ev.isNull} = false; ${ev.value} = new ${classOf[GenericArrayData].getName}($convertedArray); } - """ + """) } } @@ -524,7 +524,7 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rowClass = classOf[GenericRowWithSchema].getName val values = ctx.freshName("values") ctx.addMutableState("Object[]", values, "") @@ -541,12 +541,12 @@ case class CreateExternalRow(children: Seq[Expression], schema: StructType) } val childrenCode = ctx.splitExpressions(ctx.INPUT_ROW, childrenCodes) val schemaField = ctx.addReferenceObj("schema", schema) - s""" + ev.copy(code = s""" boolean ${ev.isNull} = false; $values = new Object[${children.size}]; $childrenCode final ${classOf[Row].getName} ${ev.value} = new $rowClass($values, this.$schemaField); - """ + """) } } @@ -561,7 +561,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported") - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -579,14 +579,14 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) // Code to serialize. val input = child.genCode(ctx) - s""" + ev.copy(code = s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $serializer.serialize(${input.value}, null).array(); } - """ + """) } override def dataType: DataType = BinaryType @@ -601,7 +601,7 @@ case class EncodeUsingSerializer(child: Expression, kryo: Boolean) case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: Boolean) extends UnaryExpression with NonSQLExpression { - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { // Code to initialize the serializer. val serializer = ctx.freshName("serializer") val (serializerClass, serializerInstanceClass) = { @@ -619,7 +619,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B // Code to serialize. val input = child.genCode(ctx) - s""" + ev.copy(code = s""" ${input.code} final boolean ${ev.isNull} = ${input.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -627,7 +627,7 @@ case class DecodeUsingSerializer[T](child: Expression, tag: ClassTag[T], kryo: B ${ev.value} = (${ctx.javaType(dataType)}) $serializer.deserialize(java.nio.ByteBuffer.wrap(${input.value}), null); } - """ + """) } override def dataType: DataType = ObjectType(tag.runtimeClass) @@ -646,7 +646,7 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val instanceGen = beanInstance.genCode(ctx) val initialize = setters.map { @@ -661,12 +661,12 @@ case class InitializeJavaBean(beanInstance: Expression, setters: Map[String, Exp ev.isNull = instanceGen.isNull ev.value = instanceGen.value - s""" + ev.copy(code = s""" ${instanceGen.code} if (!${instanceGen.isNull}) { ${initialize.mkString("\n")} } - """ + """) } } @@ -688,7 +688,7 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) override def eval(input: InternalRow): Any = throw new UnsupportedOperationException("Only code-generated evaluation is supported.") - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val childGen = child.genCode(ctx) val errMsg = "Null value appeared in non-nullable field:" + @@ -698,16 +698,11 @@ case class AssertNotNull(child: Expression, walkedTypePath: Seq[String]) "(e.g. java.lang.Integer instead of int/scala.Int)." val idx = ctx.references.length ctx.references += errMsg - - ev.isNull = "false" - ev.value = childGen.value - - s""" + ExprCode(code = s""" ${childGen.code} if (${childGen.isNull}) { throw new RuntimeException((String) references[$idx]); - } - """ + }""", isNull = "false", value = childGen.value) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala index b15a77a8e7..057c6545ef 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/predicates.scala @@ -99,7 +99,7 @@ case class Not(child: Expression) protected override def nullSafeEval(input: Any): Any = !input.asInstanceOf[Boolean] - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"!($c)") } @@ -157,7 +157,7 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val valueGen = value.genCode(ctx) val listGen = list.map(_.genCode(ctx)) val listCode = listGen.map(x => @@ -172,14 +172,14 @@ case class In(value: Expression, list: Seq[Expression]) extends Predicate } } """).mkString("\n") - s""" + ev.copy(code = s""" ${valueGen.code} boolean ${ev.value} = false; boolean ${ev.isNull} = ${valueGen.isNull}; if (!${ev.isNull}) { $listCode } - """ + """) } override def sql: String = { @@ -216,7 +216,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with def getHSet(): Set[Any] = hset - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val setName = classOf[Set[Any]].getName val InSetName = classOf[InSet].getName val childGen = child.genCode(ctx) @@ -226,7 +226,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ctx.addMutableState(setName, hsetTerm, s"$hsetTerm = (($InSetName)references[${ctx.references.size - 1}]).getHSet();") ctx.addMutableState("boolean", hasNullTerm, s"$hasNullTerm = $hsetTerm.contains(null);") - s""" + ev.copy(code = s""" ${childGen.code} boolean ${ev.isNull} = ${childGen.isNull}; boolean ${ev.value} = false; @@ -236,7 +236,7 @@ case class InSet(child: Expression, hset: Set[Any]) extends UnaryExpression with ${ev.isNull} = true; } } - """ + """) } override def sql: String = { @@ -274,24 +274,22 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) // The result should be `false`, if any of them is `false` whenever the other is null or not. if (!left.nullable && !right.nullable) { - ev.isNull = "false" - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = false; if (${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - } - """ + }""", isNull = "false") } else { - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = false; @@ -306,7 +304,7 @@ case class And(left: Expression, right: Expression) extends BinaryOperator with ${ev.isNull} = true; } } - """ + """) } } } @@ -339,24 +337,23 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) // The result should be `true`, if any of them is `true` whenever the other is null or not. if (!left.nullable && !right.nullable) { ev.isNull = "false" - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.value} = true; if (!${eval1.value}) { ${eval2.code} ${ev.value} = ${eval2.value}; - } - """ + }""", isNull = "false") } else { - s""" + ev.copy(code = s""" ${eval1.code} boolean ${ev.isNull} = false; boolean ${ev.value} = true; @@ -371,7 +368,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P ${ev.isNull} = true; } } - """ + """) } } } @@ -379,7 +376,7 @@ case class Or(left: Expression, right: Expression) extends BinaryOperator with P abstract class BinaryComparison extends BinaryOperator with Predicate { - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (ctx.isPrimitiveType(left.dataType) && left.dataType != BooleanType // java boolean doesn't support > or < operator && left.dataType != FloatType @@ -428,7 +425,7 @@ case class EqualTo(left: Expression, right: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => ctx.genEqual(left.dataType, c1, c2)) } } @@ -464,15 +461,13 @@ case class EqualNullSafe(left: Expression, right: Expression) extends BinaryComp } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val eval1 = left.genCode(ctx) val eval2 = right.genCode(ctx) val equalCode = ctx.genEqual(left.dataType, eval1.value, eval2.value) - ev.isNull = "false" - eval1.code + eval2.code + s""" + ev.copy(code = eval1.code + eval2.code + s""" boolean ${ev.value} = (${eval1.isNull} && ${eval2.isNull}) || - (!${eval1.isNull} && $equalCode); - """ + (!${eval1.isNull} && $equalCode);""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala index 1eed24dd1e..ca200768b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala @@ -67,15 +67,13 @@ case class Rand(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to rand must be an integer literal.") }) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") - ev.isNull = "false" - s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble(); - """ + ev.copy(code = s""" + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""", isNull = "false") } } @@ -92,14 +90,12 @@ case class Randn(seed: Long) extends RDG { case _ => throw new AnalysisException("Input argument to randn must be an integer literal.") }) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val rngTerm = ctx.freshName("rng") val className = classOf[XORShiftRandom].getName ctx.addMutableState(className, rngTerm, s"$rngTerm = new $className(${seed}L + org.apache.spark.TaskContext.getPartitionId());") - ev.isNull = "false" - s""" - final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian(); - """ + ev.copy(code = s""" + final ${ctx.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""", isNull = "false") } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala index 4f5b85d7f4..541b8601a3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/regexpExpressions.scala @@ -78,7 +78,7 @@ case class Like(left: Expression, right: Expression) override def toString: String = s"$left LIKE $right" - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex" val pattern = ctx.freshName("pattern") @@ -93,19 +93,19 @@ case class Like(left: Expression, right: Expression) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - s""" + ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).matches(); } - """ + """) } else { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + """) } } else { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -128,7 +128,7 @@ case class RLike(left: Expression, right: Expression) override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0) override def toString: String = s"$left RLIKE $right" - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val patternClass = classOf[Pattern].getName val pattern = ctx.freshName("pattern") @@ -142,19 +142,19 @@ case class RLike(left: Expression, right: Expression) // We don't use nullSafeCodeGen here because we don't want to re-evaluate right again. val eval = left.genCode(ctx) - s""" + ev.copy(code = s""" ${eval.code} boolean ${ev.isNull} = ${eval.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; if (!${ev.isNull}) { ${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0); } - """ + """) } else { - s""" + ev.copy(code = s""" boolean ${ev.isNull} = true; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; - """ + """) } } else { nullSafeCodeGen(ctx, ev, (eval1, eval2) => { @@ -188,7 +188,7 @@ case class StringSplit(str: Expression, pattern: Expression) new GenericArrayData(strings.asInstanceOf[Array[Any]]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val arrayClass = classOf[GenericArrayData].getName nullSafeCodeGen(ctx, ev, (str, pattern) => // Array in java is covariant, so we don't need to cast UTF8String[] to Object[]. @@ -247,7 +247,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio override def children: Seq[Expression] = subject :: regexp :: rep :: Nil override def prettyName: String = "regexp_replace" - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") @@ -330,7 +330,7 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override def children: Seq[Expression] = subject :: regexp :: idx :: Nil override def prettyName: String = "regexp_extract" - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") val classNamePattern = classOf[Pattern].getCanonicalName diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 8c15357360..78e846d3f5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -51,18 +51,18 @@ case class Concat(children: Seq[Expression]) extends Expression with ImplicitCas UTF8String.concat(inputs : _*) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val evals = children.map(_.genCode(ctx)) val inputs = evals.map { eval => s"${eval.isNull} ? null : ${eval.value}" }.mkString(", ") - evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString("\n") + s""" boolean ${ev.isNull} = false; UTF8String ${ev.value} = UTF8String.concat($inputs); if (${ev.value} == null) { ${ev.isNull} = true; } - """ + """) } } @@ -106,7 +106,7 @@ case class ConcatWs(children: Seq[Expression]) UTF8String.concatWs(flatInputs.head, flatInputs.tail : _*) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { if (children.forall(_.dataType == StringType)) { // All children are strings. In that case we can construct a fixed size array. val evals = children.map(_.genCode(ctx)) @@ -115,10 +115,10 @@ case class ConcatWs(children: Seq[Expression]) s"${eval.isNull} ? (UTF8String) null : ${eval.value}" }.mkString(", ") - evals.map(_.code).mkString("\n") + s""" + ev.copy(evals.map(_.code).mkString("\n") + s""" UTF8String ${ev.value} = UTF8String.concatWs($inputs); boolean ${ev.isNull} = ${ev.value} == null; - """ + """) } else { val array = ctx.freshName("array") val varargNum = ctx.freshName("varargNum") @@ -148,7 +148,7 @@ case class ConcatWs(children: Seq[Expression]) } }.unzip - evals.map(_.code).mkString("\n") + + ev.copy(evals.map(_.code).mkString("\n") + s""" int $varargNum = ${children.count(_.dataType == StringType) - 1}; int $idxInVararg = 0; @@ -157,7 +157,7 @@ case class ConcatWs(children: Seq[Expression]) ${varargBuild.mkString("\n")} UTF8String ${ev.value} = UTF8String.concatWs(${evals.head.value}, $array); boolean ${ev.isNull} = ${ev.value} == null; - """ + """) } } } @@ -185,7 +185,7 @@ case class Upper(child: Expression) override def convert(v: UTF8String): UTF8String = v.toUpperCase - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toUpperCase()") } } @@ -200,7 +200,7 @@ case class Lower(child: Expression) extends UnaryExpression with String2StringEx override def convert(v: UTF8String): UTF8String = v.toLowerCase - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).toLowerCase()") } } @@ -225,7 +225,7 @@ trait StringPredicate extends Predicate with ImplicitCastInputTypes { case class Contains(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.contains(r) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).contains($c2)") } } @@ -236,7 +236,7 @@ case class Contains(left: Expression, right: Expression) case class StartsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.startsWith(r) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).startsWith($c2)") } } @@ -247,7 +247,7 @@ case class StartsWith(left: Expression, right: Expression) case class EndsWith(left: Expression, right: Expression) extends BinaryExpression with StringPredicate { override def compare(l: UTF8String, r: UTF8String): Boolean = l.endsWith(r) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (c1, c2) => s"($c1).endsWith($c2)") } } @@ -298,7 +298,7 @@ case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replac srcEval.asInstanceOf[UTF8String].translate(dict) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val termLastMatching = ctx.freshName("lastMatching") val termLastReplace = ctx.freshName("lastReplace") val termDict = ctx.freshName("dict") @@ -351,7 +351,7 @@ case class FindInSet(left: Expression, right: Expression) extends BinaryExpressi override protected def nullSafeEval(word: Any, set: Any): Any = set.asInstanceOf[UTF8String].findInSet(word.asInstanceOf[UTF8String]) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (word, set) => s"${ev.value} = $set.findInSet($word);" ) @@ -375,7 +375,7 @@ case class StringTrim(child: Expression) override def prettyName: String = "trim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).trim()") } } @@ -393,7 +393,7 @@ case class StringTrimLeft(child: Expression) override def prettyName: String = "ltrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).trimLeft()") } } @@ -411,7 +411,7 @@ case class StringTrimRight(child: Expression) override def prettyName: String = "rtrim" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).trimRight()") } } @@ -440,7 +440,7 @@ case class StringInstr(str: Expression, substr: Expression) override def prettyName: String = "instr" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => s"($l).indexOf($r, 0) + 1") } @@ -475,7 +475,7 @@ case class SubstringIndex(strExpr: Expression, delimExpr: Expression, countExpr: count.asInstanceOf[Int]) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, delim, count) => s"$str.subStringIndex($delim, $count)") } } @@ -524,11 +524,11 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) } } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val substrGen = substr.genCode(ctx) val strGen = str.genCode(ctx) val startGen = start.genCode(ctx) - s""" + ev.copy(code = s""" int ${ev.value} = 0; boolean ${ev.isNull} = false; ${startGen.code} @@ -546,7 +546,7 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) ${ev.isNull} = true; } } - """ + """) } override def prettyName: String = "locate" @@ -571,7 +571,7 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } @@ -597,7 +597,7 @@ case class StringRPad(str: Expression, len: Expression, pad: Expression) str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } - override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } @@ -638,7 +638,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val pattern = children.head.genCode(ctx) val argListGen = children.tail.map(x => (x.dataType, x.genCode(ctx))) @@ -660,7 +660,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC val formatter = classOf[java.util.Formatter].getName val sb = ctx.freshName("sb") val stringBuffer = classOf[StringBuffer].getName - s""" + ev.copy(code = s""" ${pattern.code} boolean ${ev.isNull} = ${pattern.isNull}; ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)}; @@ -670,8 +670,7 @@ case class FormatString(children: Expression*) extends Expression with ImplicitC $formatter $form = new $formatter($sb, ${classOf[Locale].getName}.US); $form.format(${pattern.value}.toString() $argListString); ${ev.value} = UTF8String.fromString($sb.toString()); - } - """ + }""") } override def prettyName: String = "format_string" @@ -694,7 +693,7 @@ case class InitCap(child: Expression) extends UnaryExpression with ImplicitCastI override def nullSafeEval(string: Any): Any = { string.asInstanceOf[UTF8String].toLowerCase.toTitleCase } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, str => s"$str.toLowerCase().toTitleCase()") } } @@ -719,7 +718,7 @@ case class StringRepeat(str: Expression, times: Expression) override def prettyName: String = "repeat" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (l, r) => s"($l).repeat($r)") } } @@ -735,7 +734,7 @@ case class StringReverse(child: Expression) extends UnaryExpression with String2 override def prettyName: String = "reverse" - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"($c).reverse()") } } @@ -757,7 +756,7 @@ case class StringSpace(child: Expression) UTF8String.blankString(if (length < 0) 0 else length) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (length) => s"""${ev.value} = UTF8String.blankString(($length < 0) ? 0 : $length);""") } @@ -799,7 +798,7 @@ case class Substring(str: Expression, pos: Expression, len: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, (string, pos, len) => { str.dataType match { @@ -825,7 +824,7 @@ case class Length(child: Expression) extends UnaryExpression with ExpectsInputTy case BinaryType => value.asInstanceOf[Array[Byte]].length } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { child.dataType match { case StringType => defineCodeGen(ctx, ev, c => s"($c).numChars()") case BinaryType => defineCodeGen(ctx, ev, c => s"($c).length") @@ -848,7 +847,7 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any = leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String]) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (left, right) => s"${ev.value} = $left.levenshteinDistance($right);") } @@ -868,7 +867,7 @@ case class SoundEx(child: Expression) extends UnaryExpression with ExpectsInputT override def nullSafeEval(input: Any): Any = input.asInstanceOf[UTF8String].soundex() - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { defineCodeGen(ctx, ev, c => s"$c.soundex()") } } @@ -894,7 +893,7 @@ case class Ascii(child: Expression) extends UnaryExpression with ImplicitCastInp } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { val bytes = ctx.freshName("bytes") s""" @@ -924,7 +923,7 @@ case class Base64(child: Expression) extends UnaryExpression with ImplicitCastIn bytes.asInstanceOf[Array[Byte]])) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { s"""${ev.value} = UTF8String.fromBytes( org.apache.commons.codec.binary.Base64.encodeBase64($child)); @@ -945,7 +944,7 @@ case class UnBase64(child: Expression) extends UnaryExpression with ImplicitCast protected override def nullSafeEval(string: Any): Any = org.apache.commons.codec.binary.Base64.decodeBase64(string.asInstanceOf[UTF8String].toString) - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (child) => { s""" ${ev.value} = org.apache.commons.codec.binary.Base64.decodeBase64($child.toString()); @@ -973,7 +972,7 @@ case class Decode(bin: Expression, charset: Expression) UTF8String.fromString(new String(input1.asInstanceOf[Array[Byte]], fromCharset)) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (bytes, charset) => s""" try { @@ -1005,7 +1004,7 @@ case class Encode(value: Expression, charset: Expression) input1.asInstanceOf[UTF8String].toString.getBytes(toCharset) } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (string, charset) => s""" try { @@ -1088,7 +1087,7 @@ case class FormatNumber(x: Expression, d: Expression) } } - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (num, d) => { def typeHelper(p: String): String = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala index de410b86ea..3a24b4d7d5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/NonFoldableLiteral.scala @@ -35,7 +35,7 @@ case class NonFoldableLiteral(value: Any, dataType: DataType) extends LeafExpres override def eval(input: InternalRow): Any = value - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { Literal.create(value, dataType).doGenCode(ctx, ev) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala index 03defc121c..b3e8b37a2e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala @@ -54,7 +54,7 @@ case class ScalarSubquery( override def eval(input: InternalRow): Any = result - override def doGenCode(ctx: CodegenContext, ev: ExprCode): String = { + override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { Literal.create(result, dataType).doGenCode(ctx, ev) } } -- cgit v1.2.3