aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-31 08:28:05 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-31 08:28:05 -0700
commit6bba7509a932aa4d39266df2d15b1370b7aabbec (patch)
tree46a3eb3da7ba11fa6b0319d9a9d90a5970d6e999
parenta3a85d73da053c8e2830759fbc68b734081fa4f3 (diff)
downloadspark-6bba7509a932aa4d39266df2d15b1370b7aabbec.tar.gz
spark-6bba7509a932aa4d39266df2d15b1370b7aabbec.tar.bz2
spark-6bba7509a932aa4d39266df2d15b1370b7aabbec.zip
[SPARK-9500] add TernaryExpression to simplify ternary expressions
There lots of duplicated code in ternary expressions, create a TernaryExpression for them to reduce duplicated code. cc chenghao-intel Author: Davies Liu <davies@databricks.com> Closes #7816 from davies/ternary and squashes the following commits: ed2bf76 [Davies Liu] add TernaryExpression
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala85
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala66
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala356
4 files changed, 183 insertions, 326 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index 8fc182607c..2842b3ec5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -432,3 +432,88 @@ abstract class BinaryOperator extends BinaryExpression with ExpectsInputTypes {
private[sql] object BinaryOperator {
def unapply(e: BinaryOperator): Option[(Expression, Expression)] = Some((e.left, e.right))
}
+
+/**
+ * An expression with three inputs and one output. The output is by default evaluated to null
+ * if any input is evaluated to null.
+ */
+abstract class TernaryExpression extends Expression {
+
+ override def foldable: Boolean = children.forall(_.foldable)
+
+ override def nullable: Boolean = children.exists(_.nullable)
+
+ /**
+ * Default behavior of evaluation according to the default nullability of BinaryExpression.
+ * If subclass of BinaryExpression override nullable, probably should also override this.
+ */
+ override def eval(input: InternalRow): Any = {
+ val exprs = children
+ val value1 = exprs(0).eval(input)
+ if (value1 != null) {
+ val value2 = exprs(1).eval(input)
+ if (value2 != null) {
+ val value3 = exprs(2).eval(input)
+ if (value3 != null) {
+ return nullSafeEval(value1, value2, value3)
+ }
+ }
+ }
+ null
+ }
+
+ /**
+ * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
+ * nullability, they can override this method to save null-check code. If we need full control
+ * of evaluation process, we should override [[eval]].
+ */
+ protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any =
+ sys.error(s"BinaryExpressions must override either eval or nullSafeEval")
+
+ /**
+ * Short hand for generating binary evaluation code.
+ * If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ *
+ * @param f accepts two variable names and returns Java code to compute the output.
+ */
+ protected def defineCodeGen(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (String, String, String) => String): String = {
+ nullSafeCodeGen(ctx, ev, (eval1, eval2, eval3) => {
+ s"${ev.primitive} = ${f(eval1, eval2, eval3)};"
+ })
+ }
+
+ /**
+ * Short hand for generating binary evaluation code.
+ * If either of the sub-expressions is null, the result of this computation
+ * is assumed to be null.
+ *
+ * @param f function that accepts the 2 non-null evaluation result names of children
+ * and returns Java code to compute the output.
+ */
+ protected def nullSafeCodeGen(
+ ctx: CodeGenContext,
+ ev: GeneratedExpressionCode,
+ f: (String, String, String) => String): String = {
+ val evals = children.map(_.gen(ctx))
+ val resultCode = f(evals(0).primitive, evals(1).primitive, evals(2).primitive)
+ s"""
+ ${evals(0).code}
+ boolean ${ev.isNull} = true;
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${evals(0).isNull}) {
+ ${evals(1).code}
+ if (!${evals(1).isNull}) {
+ ${evals(2).code}
+ if (!${evals(2).isNull}) {
+ ${ev.isNull} = false; // resultCode could change nullability
+ $resultCode
+ }
+ }
+ }
+ """
+ }
+}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 60e2863f7b..e50ec27fc2 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -305,7 +305,7 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
evaluator.cook(code)
} catch {
case e: Exception =>
- val msg = "failed to compile:\n " + CodeFormatter.format(code)
+ val msg = s"failed to compile: $e\n" + CodeFormatter.format(code)
logError(msg, e)
throw new Exception(msg, e)
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
index e6d807f6d8..15ceb9193a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/math.scala
@@ -165,69 +165,29 @@ case class Cosh(child: Expression) extends UnaryMathExpression(math.cosh, "COSH"
* @param toBaseExpr to which base
*/
case class Conv(numExpr: Expression, fromBaseExpr: Expression, toBaseExpr: Expression)
- extends Expression with ImplicitCastInputTypes {
-
- override def foldable: Boolean = numExpr.foldable && fromBaseExpr.foldable && toBaseExpr.foldable
-
- override def nullable: Boolean = numExpr.nullable || fromBaseExpr.nullable || toBaseExpr.nullable
+ extends TernaryExpression with ImplicitCastInputTypes {
override def children: Seq[Expression] = Seq(numExpr, fromBaseExpr, toBaseExpr)
-
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, IntegerType, IntegerType)
-
override def dataType: DataType = StringType
- /** Returns the result of evaluating this expression on a given input Row */
- override def eval(input: InternalRow): Any = {
- val num = numExpr.eval(input)
- if (num != null) {
- val fromBase = fromBaseExpr.eval(input)
- if (fromBase != null) {
- val toBase = toBaseExpr.eval(input)
- if (toBase != null) {
- NumberConverter.convert(
- num.asInstanceOf[UTF8String].getBytes,
- fromBase.asInstanceOf[Int],
- toBase.asInstanceOf[Int])
- } else {
- null
- }
- } else {
- null
- }
- } else {
- null
- }
+ override def nullSafeEval(num: Any, fromBase: Any, toBase: Any): Any = {
+ NumberConverter.convert(
+ num.asInstanceOf[UTF8String].getBytes,
+ fromBase.asInstanceOf[Int],
+ toBase.asInstanceOf[Int])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val numGen = numExpr.gen(ctx)
- val from = fromBaseExpr.gen(ctx)
- val to = toBaseExpr.gen(ctx)
-
val numconv = NumberConverter.getClass.getName.stripSuffix("$")
- s"""
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- ${numGen.code}
- boolean ${ev.isNull} = ${numGen.isNull};
- if (!${ev.isNull}) {
- ${from.code}
- if (!${from.isNull}) {
- ${to.code}
- if (!${to.isNull}) {
- ${ev.primitive} = $numconv.convert(${numGen.primitive}.getBytes(),
- ${from.primitive}, ${to.primitive});
- if (${ev.primitive} == null) {
- ${ev.isNull} = true;
- }
- } else {
- ${ev.isNull} = true;
- }
- } else {
- ${ev.isNull} = true;
- }
+ nullSafeCodeGen(ctx, ev, (num, from, to) =>
+ s"""
+ ${ev.primitive} = $numconv.convert($num.getBytes(), $from, $to);
+ if (${ev.primitive} == null) {
+ ${ev.isNull} = true;
}
- """
+ """
+ )
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 99a62343f1..684eac12bd 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -426,15 +426,13 @@ case class StringInstr(str: Expression, substr: Expression)
* in given string after position pos.
*/
case class StringLocate(substr: Expression, str: Expression, start: Expression)
- extends Expression with ImplicitCastInputTypes with CodegenFallback {
+ extends TernaryExpression with ImplicitCastInputTypes with CodegenFallback {
def this(substr: Expression, str: Expression) = {
this(substr, str, Literal(0))
}
override def children: Seq[Expression] = substr :: str :: start :: Nil
- override def foldable: Boolean = children.forall(_.foldable)
- override def nullable: Boolean = substr.nullable || str.nullable
override def dataType: DataType = IntegerType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType, IntegerType)
@@ -467,60 +465,18 @@ case class StringLocate(substr: Expression, str: Expression, start: Expression)
* Returns str, left-padded with pad to a length of len.
*/
case class StringLPad(str: Expression, len: Expression, pad: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes {
override def children: Seq[Expression] = str :: len :: pad :: Nil
- override def foldable: Boolean = children.forall(_.foldable)
- override def nullable: Boolean = children.exists(_.nullable)
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
- override def eval(input: InternalRow): Any = {
- val s = str.eval(input)
- if (s == null) {
- null
- } else {
- val l = len.eval(input)
- if (l == null) {
- null
- } else {
- val p = pad.eval(input)
- if (p == null) {
- null
- } else {
- val len = l.asInstanceOf[Int]
- val str = s.asInstanceOf[UTF8String]
- val pad = p.asInstanceOf[UTF8String]
-
- str.lpad(len, pad)
- }
- }
- }
+ override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
+ str.asInstanceOf[UTF8String].lpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val lenGen = len.gen(ctx)
- val strGen = str.gen(ctx)
- val padGen = pad.gen(ctx)
-
- s"""
- ${lenGen.code}
- boolean ${ev.isNull} = ${lenGen.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${strGen.code}
- if (!${strGen.isNull}) {
- ${padGen.code}
- if (!${padGen.isNull}) {
- ${ev.primitive} = ${strGen.primitive}.lpad(${lenGen.primitive}, ${padGen.primitive});
- } else {
- ${ev.isNull} = true;
- }
- } else {
- ${ev.isNull} = true;
- }
- }
- """
+ defineCodeGen(ctx, ev, (str, len, pad) => s"$str.lpad($len, $pad)")
}
override def prettyName: String = "lpad"
@@ -530,60 +486,18 @@ case class StringLPad(str: Expression, len: Expression, pad: Expression)
* Returns str, right-padded with pad to a length of len.
*/
case class StringRPad(str: Expression, len: Expression, pad: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes {
override def children: Seq[Expression] = str :: len :: pad :: Nil
- override def foldable: Boolean = children.forall(_.foldable)
- override def nullable: Boolean = children.exists(_.nullable)
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, StringType)
- override def eval(input: InternalRow): Any = {
- val s = str.eval(input)
- if (s == null) {
- null
- } else {
- val l = len.eval(input)
- if (l == null) {
- null
- } else {
- val p = pad.eval(input)
- if (p == null) {
- null
- } else {
- val len = l.asInstanceOf[Int]
- val str = s.asInstanceOf[UTF8String]
- val pad = p.asInstanceOf[UTF8String]
-
- str.rpad(len, pad)
- }
- }
- }
+ override def nullSafeEval(str: Any, len: Any, pad: Any): Any = {
+ str.asInstanceOf[UTF8String].rpad(len.asInstanceOf[Int], pad.asInstanceOf[UTF8String])
}
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val lenGen = len.gen(ctx)
- val strGen = str.gen(ctx)
- val padGen = pad.gen(ctx)
-
- s"""
- ${lenGen.code}
- boolean ${ev.isNull} = ${lenGen.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${strGen.code}
- if (!${strGen.isNull}) {
- ${padGen.code}
- if (!${padGen.isNull}) {
- ${ev.primitive} = ${strGen.primitive}.rpad(${lenGen.primitive}, ${padGen.primitive});
- } else {
- ${ev.isNull} = true;
- }
- } else {
- ${ev.isNull} = true;
- }
- }
- """
+ defineCodeGen(ctx, ev, (str, len, pad) => s"$str.rpad($len, $pad)")
}
override def prettyName: String = "rpad"
@@ -745,68 +659,24 @@ case class StringSplit(str: Expression, pattern: Expression)
* Defined for String and Binary types.
*/
case class Substring(str: Expression, pos: Expression, len: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes {
def this(str: Expression, pos: Expression) = {
this(str, pos, Literal(Integer.MAX_VALUE))
}
- override def foldable: Boolean = str.foldable && pos.foldable && len.foldable
- override def nullable: Boolean = str.nullable || pos.nullable || len.nullable
-
override def dataType: DataType = StringType
override def inputTypes: Seq[DataType] = Seq(StringType, IntegerType, IntegerType)
override def children: Seq[Expression] = str :: pos :: len :: Nil
- override def eval(input: InternalRow): Any = {
- val stringEval = str.eval(input)
- if (stringEval != null) {
- val posEval = pos.eval(input)
- if (posEval != null) {
- val lenEval = len.eval(input)
- if (lenEval != null) {
- stringEval.asInstanceOf[UTF8String]
- .substringSQL(posEval.asInstanceOf[Int], lenEval.asInstanceOf[Int])
- } else {
- null
- }
- } else {
- null
- }
- } else {
- null
- }
+ override def nullSafeEval(string: Any, pos: Any, len: Any): Any = {
+ string.asInstanceOf[UTF8String].substringSQL(pos.asInstanceOf[Int], len.asInstanceOf[Int])
}
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val strGen = str.gen(ctx)
- val posGen = pos.gen(ctx)
- val lenGen = len.gen(ctx)
-
- val start = ctx.freshName("start")
- val end = ctx.freshName("end")
-
- s"""
- ${strGen.code}
- boolean ${ev.isNull} = ${strGen.isNull};
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${ev.isNull}) {
- ${posGen.code}
- if (!${posGen.isNull}) {
- ${lenGen.code}
- if (!${lenGen.isNull}) {
- ${ev.primitive} = ${strGen.primitive}
- .substringSQL(${posGen.primitive}, ${lenGen.primitive});
- } else {
- ${ev.isNull} = true;
- }
- } else {
- ${ev.isNull} = true;
- }
- }
- """
+ defineCodeGen(ctx, ev, (str, pos, len) => s"$str.substringSQL($pos, $len)")
}
}
@@ -986,7 +856,7 @@ case class Encode(value: Expression, charset: Expression)
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes {
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
@@ -998,40 +868,26 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
// result buffer write by Matcher
@transient private val result: StringBuffer = new StringBuffer
- override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable
- override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable
-
- override def eval(input: InternalRow): Any = {
- val s = subject.eval(input)
- if (null != s) {
- val p = regexp.eval(input)
- if (null != p) {
- val r = rep.eval(input)
- if (null != r) {
- if (!p.equals(lastRegex)) {
- // regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
- pattern = Pattern.compile(lastRegex.toString)
- }
- if (!r.equals(lastReplacementInUTF8)) {
- // replacement string changed
- lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
- lastReplacement = lastReplacementInUTF8.toString
- }
- val m = pattern.matcher(s.toString())
- result.delete(0, result.length())
-
- while (m.find) {
- m.appendReplacement(result, lastReplacement)
- }
- m.appendTail(result)
+ override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
+ if (!p.equals(lastRegex)) {
+ // regex value changed
+ lastRegex = p.asInstanceOf[UTF8String]
+ pattern = Pattern.compile(lastRegex.toString)
+ }
+ if (!r.equals(lastReplacementInUTF8)) {
+ // replacement string changed
+ lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
+ lastReplacement = lastReplacementInUTF8.toString
+ }
+ val m = pattern.matcher(s.toString())
+ result.delete(0, result.length())
- return UTF8String.fromString(result.toString)
- }
- }
+ while (m.find) {
+ m.appendReplacement(result, lastReplacement)
}
+ m.appendTail(result)
- null
+ UTF8String.fromString(result.toString)
}
override def dataType: DataType = StringType
@@ -1048,59 +904,43 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
val termResult = ctx.freshName("result")
- val classNameUTF8String = classOf[UTF8String].getCanonicalName
val classNamePattern = classOf[Pattern].getCanonicalName
- val classNameString = classOf[java.lang.String].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
- ctx.addMutableState(classNameUTF8String,
+ ctx.addMutableState("UTF8String",
termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern,
termPattern, s"${termPattern} = null;")
- ctx.addMutableState(classNameString,
+ ctx.addMutableState("String",
termLastReplacement, s"${termLastReplacement} = null;")
- ctx.addMutableState(classNameUTF8String,
+ ctx.addMutableState("UTF8String",
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
ctx.addMutableState(classNameStringBuffer,
termResult, s"${termResult} = new $classNameStringBuffer();")
- val evalSubject = subject.gen(ctx)
- val evalRegexp = regexp.gen(ctx)
- val evalRep = rep.gen(ctx)
-
+ nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
s"""
- ${evalSubject.code}
- boolean ${ev.isNull} = true;
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- if (!${evalSubject.isNull}) {
- ${evalRegexp.code}
- if (!${evalRegexp.isNull}) {
- ${evalRep.code}
- if (!${evalRep.isNull}) {
- if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
- // regex value changed
- ${termLastRegex} = ${evalRegexp.primitive};
- ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
- }
- if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) {
- // replacement string changed
- ${termLastReplacementInUTF8} = ${evalRep.primitive};
- ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
- }
- ${termResult}.delete(0, ${termResult}.length());
- ${classOf[java.util.regex.Matcher].getCanonicalName} m =
- ${termPattern}.matcher(${evalSubject.primitive}.toString());
+ if (!$regexp.equals(${termLastRegex})) {
+ // regex value changed
+ ${termLastRegex} = $regexp;
+ ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
+ }
+ if (!$rep.equals(${termLastReplacementInUTF8})) {
+ // replacement string changed
+ ${termLastReplacementInUTF8} = $rep;
+ ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
+ }
+ ${termResult}.delete(0, ${termResult}.length());
+ java.util.regex.Matcher m = ${termPattern}.matcher($subject.toString());
- while (m.find()) {
- m.appendReplacement(${termResult}, ${termLastReplacement});
- }
- m.appendTail(${termResult});
- ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString());
- ${ev.isNull} = false;
- }
- }
+ while (m.find()) {
+ m.appendReplacement(${termResult}, ${termLastReplacement});
}
+ m.appendTail(${termResult});
+ ${ev.primitive} = UTF8String.fromString(${termResult}.toString());
+ ${ev.isNull} = false;
"""
+ })
}
}
@@ -1110,7 +950,7 @@ case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expressio
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
- extends Expression with ImplicitCastInputTypes {
+ extends TernaryExpression with ImplicitCastInputTypes {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
// last regex in string, we will update the pattern iff regexp value changed.
@@ -1118,32 +958,19 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _
- override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable
- override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable
-
- override def eval(input: InternalRow): Any = {
- val s = subject.eval(input)
- if (null != s) {
- val p = regexp.eval(input)
- if (null != p) {
- val r = idx.eval(input)
- if (null != r) {
- if (!p.equals(lastRegex)) {
- // regex value changed
- lastRegex = p.asInstanceOf[UTF8String]
- pattern = Pattern.compile(lastRegex.toString)
- }
- val m = pattern.matcher(s.toString())
- if (m.find) {
- val mr: MatchResult = m.toMatchResult
- return UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
- }
- return UTF8String.EMPTY_UTF8
- }
- }
+ override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
+ if (!p.equals(lastRegex)) {
+ // regex value changed
+ lastRegex = p.asInstanceOf[UTF8String]
+ pattern = Pattern.compile(lastRegex.toString)
+ }
+ val m = pattern.matcher(s.toString())
+ if (m.find) {
+ val mr: MatchResult = m.toMatchResult
+ UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
+ } else {
+ UTF8String.EMPTY_UTF8
}
-
- null
}
override def dataType: DataType = StringType
@@ -1154,44 +981,29 @@ case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expressio
override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
val termLastRegex = ctx.freshName("lastRegex")
val termPattern = ctx.freshName("pattern")
- val classNameUTF8String = classOf[UTF8String].getCanonicalName
val classNamePattern = classOf[Pattern].getCanonicalName
- ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;")
+ ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
- val evalSubject = subject.gen(ctx)
- val evalRegexp = regexp.gen(ctx)
- val evalIdx = idx.gen(ctx)
-
- s"""
- ${evalSubject.code}
- ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
- boolean ${ev.isNull} = true;
- if (!${evalSubject.isNull}) {
- ${evalRegexp.code}
- if (!${evalRegexp.isNull}) {
- ${evalIdx.code}
- if (!${evalIdx.isNull}) {
- if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
- // regex value changed
- ${termLastRegex} = ${evalRegexp.primitive};
- ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
- }
- ${classOf[java.util.regex.Matcher].getCanonicalName} m =
- ${termPattern}.matcher(${evalSubject.primitive}.toString());
- if (m.find()) {
- ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
- ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
- ${ev.isNull} = false;
- } else {
- ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
- ${ev.isNull} = false;
- }
- }
- }
+ nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
+ s"""
+ if (!$regexp.equals(${termLastRegex})) {
+ // regex value changed
+ ${termLastRegex} = $regexp;
+ ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
- """
+ java.util.regex.Matcher m =
+ ${termPattern}.matcher($subject.toString());
+ if (m.find()) {
+ java.util.regex.MatchResult mr = m.toMatchResult();
+ ${ev.primitive} = UTF8String.fromString(mr.group($idx));
+ ${ev.isNull} = false;
+ } else {
+ ${ev.primitive} = UTF8String.EMPTY_UTF8;
+ ${ev.isNull} = false;
+ }"""
+ })
}
}