From 6bba7509a932aa4d39266df2d15b1370b7aabbec Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 31 Jul 2015 08:28:05 -0700 Subject: [SPARK-9500] add TernaryExpression to simplify ternary expressions There lots of duplicated code in ternary expressions, create a TernaryExpression for them to reduce duplicated code. cc chenghao-intel Author: Davies Liu Closes #7816 from davies/ternary and squashes the following commits: ed2bf76 [Davies Liu] add TernaryExpression --- .../sql/catalyst/expressions/Expression.scala | 85 +++++ .../expressions/codegen/CodeGenerator.scala | 2 +- .../spark/sql/catalyst/expressions/math.scala | 66 +--- .../catalyst/expressions/stringOperations.scala | 356 +++++---------------- 4 files changed, 183 insertions(+), 326 deletions(-) (limited to 'sql/catalyst') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala index 8fc182607c..2842b3ec5a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala @@ -432,3 +432,88 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes { private[sql] object BinaryOperator { def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right)) } + +/** + * An expression with three inputs and one output. The output is by default evaluated to null + * if any input is evaluated to null. + */ +abstract class TernaryExpression extends Expression { + + override def foldable: Boolean = children.forall(_.foldable) + + override def nullable: Boolean = children.exists(_.nullable) + + /** + * Default behavior of evaluation according to the default nullability of BinaryExpression. + * If subclass of BinaryExpression override nullable, probably should also override this. + */ + override def eval(input: InternalRow): Any = { + val exprs = children + val value1 = exprs(0).eval(input) + if (value1 != null) { + val value2 = exprs(1).eval(input) + if (value2 != null) { + val value3 = exprs(2).eval(input) + if (value3 != null) { + return nullSafeEval(value1, value2, value3) + } + } + } + null + } + + /** + * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default + * nullability, they can override this method to save null-check code. If we need full control + * of evaluation process, we should override [[eval]]. + */ + protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = + sys.error(s"BinaryExpressions must override either eval or nullSafeEval") + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f accepts two variable names and returns Java code to compute the output. + */ + protected def defineCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => { + s"${ev.primitive} = ${f(eval1, eval2, eval3)};" + }) + } + + /** + * Short hand for generating binary evaluation code. + * If either of the sub-expressions is null, the result of this computation + * is assumed to be null. + * + * @param f function that accepts the 2 non-null evaluation result names of children + * and returns Java code to compute the output. + */ + protected def nullSafeCodeGen( + ctx: CodeGenContext, + ev: GeneratedExpressionCode, + f: (String, String, String) => String): String = { + val evals = children.map(_.gen(ctx)) + val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive) + s""" + ${evals(0).code} + boolean ${ev.isNull} = true; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evals(0).isNull}) { + ${evals(1).code} + if (!${evals(1).isNull}) { + ${evals(2).code} + if (!${evals(2).isNull}) { + ${ev.isNull} = false; // resultCode could change nullability + $resultCode + } + } + } + """ + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 60e2863f7b..e50ec27fc2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -305,7 +305,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - val msg = "failed to compile:\n " + CodeFormatter.format(code) + val msg = s"failed to compile: $e\n" + CodeFormatter.format(code) logError(msg, e) throw new Exception(msg, e) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala index e6d807f6d8..15ceb9193a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala @@ -165,69 +165,29 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH" * @param toBaseExpr to which base */ case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression) - extends Expression with ImplicitCastInputTypes { - - override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable - - override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr) - override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType) - override def dataType: DataType = StringType - /** Returns the result of evaluating this expression on a given input Row */ - override def eval(input: InternalRow): Any = { - val num = numExpr.eval(input) - if (num != null) { - val fromBase = fromBaseExpr.eval(input) - if (fromBase != null) { - val toBase = toBaseExpr.eval(input) - if (toBase != null) { - NumberConverter.convert( - num.asInstanceOf[UTF8String].getBytes, - fromBase.asInstanceOf[Int], - toBase.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = { + NumberConverter.convert( + num.asInstanceOf[UTF8String].getBytes, + fromBase.asInstanceOf[Int], + toBase.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val numGen = numExpr.gen(ctx) - val from = fromBaseExpr.gen(ctx) - val to = toBaseExpr.gen(ctx) - val numconv = NumberConverter.getClass.getName.stripSuffix("$") - s""" - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - ${numGen.code} - boolean ${ev.isNull} = ${numGen.isNull}; - if (!${ev.isNull}) { - ${from.code} - if (!${from.isNull}) { - ${to.code} - if (!${to.isNull}) { - ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(), - ${from.primitive}, ${to.primitive}); - if (${ev.primitive} == null) { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } + nullSafeCodeGen(ctx, ev, (num, from, to) => + s""" + ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to); + if (${ev.primitive} == null) { + ${ev.isNull} = true; } - """ + """ + ) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 99a62343f1..684eac12bd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -426,15 +426,13 @@ case class StringInstr(str: Expression, substr: Expression) * in given string after position pos. */ case class StringLocate(substr: Expression, str: Expression, start: Expression) - extends Expression with ImplicitCastInputTypes with CodegenFallback { + extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback { def this(substr: Expression, str: Expression) = { this(substr, str, Literal(0)) } override def children: Seq[Expression] = substr :: str :: start :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = substr.nullable || str.nullable override def dataType: DataType = IntegerType override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType) @@ -467,60 +465,18 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression) * Returns str, left-padded with pad to a length of len. */ case class StringLPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.lpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)") } override def prettyName: String = "lpad" @@ -530,60 +486,18 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression) * Returns str, right-padded with pad to a length of len. */ case class StringRPad(str: Expression, len: Expression, pad: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { override def children: Seq[Expression] = str :: len :: pad :: Nil - override def foldable: Boolean = children.forall(_.foldable) - override def nullable: Boolean = children.exists(_.nullable) override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType) - override def eval(input: InternalRow): Any = { - val s = str.eval(input) - if (s == null) { - null - } else { - val l = len.eval(input) - if (l == null) { - null - } else { - val p = pad.eval(input) - if (p == null) { - null - } else { - val len = l.asInstanceOf[Int] - val str = s.asInstanceOf[UTF8String] - val pad = p.asInstanceOf[UTF8String] - - str.rpad(len, pad) - } - } - } + override def nullSafeEval(str: Any, len: Any, pad: Any): Any = { + str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String]) } override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val lenGen = len.gen(ctx) - val strGen = str.gen(ctx) - val padGen = pad.gen(ctx) - - s""" - ${lenGen.code} - boolean ${ev.isNull} = ${lenGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${strGen.code} - if (!${strGen.isNull}) { - ${padGen.code} - if (!${padGen.isNull}) { - ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)") } override def prettyName: String = "rpad" @@ -745,68 +659,24 @@ case class StringSplit(str: Expression, pattern: Expression) * Defined for String and Binary types. */ case class Substring(str: Expression, pos: Expression, len: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(str: Expression, pos: Expression) = { this(str, pos, Literal(Integer.MAX_VALUE)) } - override def foldable: Boolean = str.foldable && pos.foldable && len.foldable - override def nullable: Boolean = str.nullable || pos.nullable || len.nullable - override def dataType: DataType = StringType override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType) override def children: Seq[Expression] = str :: pos :: len :: Nil - override def eval(input: InternalRow): Any = { - val stringEval = str.eval(input) - if (stringEval != null) { - val posEval = pos.eval(input) - if (posEval != null) { - val lenEval = len.eval(input) - if (lenEval != null) { - stringEval.asInstanceOf[UTF8String] - .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int]) - } else { - null - } - } else { - null - } - } else { - null - } + override def nullSafeEval(string: Any, pos: Any, len: Any): Any = { + string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int]) } override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { - val strGen = str.gen(ctx) - val posGen = pos.gen(ctx) - val lenGen = len.gen(ctx) - - val start = ctx.freshName("start") - val end = ctx.freshName("end") - - s""" - ${strGen.code} - boolean ${ev.isNull} = ${strGen.isNull}; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${ev.isNull}) { - ${posGen.code} - if (!${posGen.isNull}) { - ${lenGen.code} - if (!${lenGen.isNull}) { - ${ev.primitive} = ${strGen.primitive} - .substringSQL(${posGen.primitive}, ${lenGen.primitive}); - } else { - ${ev.isNull} = true; - } - } else { - ${ev.isNull} = true; - } - } - """ + defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)") } } @@ -986,7 +856,7 @@ case class Encode(value: Expression, charset: Expression) * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { // last regex in string, we will update the pattern iff regexp value changed. @transient private var lastRegex: UTF8String = _ @@ -998,40 +868,26 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio // result buffer write by Matcher @transient private val result: StringBuffer = new StringBuffer - override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = rep.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - if (!r.equals(lastReplacementInUTF8)) { - // replacement string changed - lastReplacementInUTF8 = r.asInstanceOf[UTF8String] - lastReplacement = lastReplacementInUTF8.toString - } - val m = pattern.matcher(s.toString()) - result.delete(0, result.length()) - - while (m.find) { - m.appendReplacement(result, lastReplacement) - } - m.appendTail(result) + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) - return UTF8String.fromString(result.toString) - } - } + while (m.find) { + m.appendReplacement(result, lastReplacement) } + m.appendTail(result) - null + UTF8String.fromString(result.toString) } override def dataType: DataType = StringType @@ -1048,59 +904,43 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio val termResult = ctx.freshName("result") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - val classNameString = classOf[java.lang.String].getCanonicalName val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - ctx.addMutableState(classNameString, + ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;") - ctx.addMutableState(classNameUTF8String, + ctx.addMutableState("UTF8String", termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") ctx.addMutableState(classNameStringBuffer, termResult, s"${termResult} = new $classNameStringBuffer();") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalRep = rep.gen(ctx) - + nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => { s""" - ${evalSubject.code} - boolean ${ev.isNull} = true; - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalRep.code} - if (!${evalRep.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { - // replacement string changed - ${termLastReplacementInUTF8} = ${evalRep.primitive}; - ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); - } - ${termResult}.delete(0, ${termResult}.length()); - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!$rep.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = $rep; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString()); - while (m.find()) { - m.appendReplacement(${termResult}, ${termLastReplacement}); - } - m.appendTail(${termResult}); - ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); - ${ev.isNull} = false; - } - } + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); } + m.appendTail(${termResult}); + ${ev.primitive} = UTF8String.fromString(${termResult}.toString()); + ${ev.isNull} = false; """ + }) } } @@ -1110,7 +950,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. */ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) - extends Expression with ImplicitCastInputTypes { + extends TernaryExpression with ImplicitCastInputTypes { def this(s: Expression, r: Expression) = this(s, r, Literal(1)) // last regex in string, we will update the pattern iff regexp value changed. @@ -1118,32 +958,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio // last regex pattern, we cache it for performance concern @transient private var pattern: Pattern = _ - override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable - override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable - - override def eval(input: InternalRow): Any = { - val s = subject.eval(input) - if (null != s) { - val p = regexp.eval(input) - if (null != p) { - val r = idx.eval(input) - if (null != r) { - if (!p.equals(lastRegex)) { - // regex value changed - lastRegex = p.asInstanceOf[UTF8String] - pattern = Pattern.compile(lastRegex.toString) - } - val m = pattern.matcher(s.toString()) - if (m.find) { - val mr: MatchResult = m.toMatchResult - return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) - } - return UTF8String.EMPTY_UTF8 - } - } + override def nullSafeEval(s: Any, p: Any, r: Any): Any = { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } else { + UTF8String.EMPTY_UTF8 } - - null } override def dataType: DataType = StringType @@ -1154,44 +981,29 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { val termLastRegex = ctx.freshName("lastRegex") val termPattern = ctx.freshName("pattern") - val classNameUTF8String = classOf[UTF8String].getCanonicalName val classNamePattern = classOf[Pattern].getCanonicalName - ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;") ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") - val evalSubject = subject.gen(ctx) - val evalRegexp = regexp.gen(ctx) - val evalIdx = idx.gen(ctx) - - s""" - ${evalSubject.code} - ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; - boolean ${ev.isNull} = true; - if (!${evalSubject.isNull}) { - ${evalRegexp.code} - if (!${evalRegexp.isNull}) { - ${evalIdx.code} - if (!${evalIdx.isNull}) { - if (!${evalRegexp.primitive}.equals(${termLastRegex})) { - // regex value changed - ${termLastRegex} = ${evalRegexp.primitive}; - ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); - } - ${classOf[java.util.regex.Matcher].getCanonicalName} m = - ${termPattern}.matcher(${evalSubject.primitive}.toString()); - if (m.find()) { - ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); - ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); - ${ev.isNull} = false; - } else { - ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; - ${ev.isNull} = false; - } - } - } + nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => { + s""" + if (!$regexp.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = $regexp; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); } - """ + java.util.regex.Matcher m = + ${termPattern}.matcher($subject.toString()); + if (m.find()) { + java.util.regex.MatchResult mr = m.toMatchResult(); + ${ev.primitive} = UTF8String.fromString(mr.group($idx)); + ${ev.isNull} = false; + } else { + ${ev.primitive} = UTF8String.EMPTY_UTF8; + ${ev.isNull} = false; + }""" + }) } } -- cgit v1.2.3