/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.util.{GenericArrayData, StringUtils}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
abstract class StringRegexExpression extends BinaryExpression
with ImplicitCastInputTypes with NullIntolerant {
def escape(v: String): String
def matches(regex: Pattern, str: String): Boolean
override def dataType: DataType = BooleanType
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
// try cache the pattern for Literal
private lazy val cache: Pattern = right match {
case x @ Literal(value: String, StringType) => compile(value)
case _ => null
}
protected def compile(str: String): Pattern = if (str == null) {
null
} else {
// Let it raise exception if couldn't compile the regex string
Pattern.compile(escape(str))
}
protected def pattern(str: String) = if (cache == null) compile(str) else cache
protected override def nullSafeEval(input1: Any, input2: Any): Any = {
val regex = pattern(input2.asInstanceOf[UTF8String].toString)
if(regex == null) {
null
} else {
matches(regex, input1.asInstanceOf[UTF8String].toString)
}
}
override def sql: String = s"${left.sql} ${prettyName.toUpperCase(Locale.ROOT)} ${right.sql}"
}
/**
* Simple RegEx pattern matching function
*/
@ExpressionDescription(
usage = "str _FUNC_ pattern - Returns true if str matches pattern, " +
"null if any arguments are null, false otherwise.",
extended = """
Arguments:
str - a string expression
pattern - a string expression. The pattern is a string which is matched literally, with
exception to the following special symbols:
_ matches any one character in the input (similar to . in posix regular expressions)
% matches zero ore more characters in the input (similar to .* in posix regular
expressions)
The escape character is '\'. If an escape character precedes a special symbol or another
escape character, the following character is matched literally. It is invalid to escape
any other character.
Examples:
> SELECT '%SystemDrive%\Users\John' _FUNC_ '\%SystemDrive\%\\Users%'
true
See also:
Use RLIKE to match with standard regular expressions.
""")
case class Like(left: Expression, right: Expression) extends StringRegexExpression {
override def escape(v: String): String = StringUtils.escapeLikeRegex(v)
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).matches()
override def toString: String = s"$left LIKE $right"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName
val escapeFunc = StringUtils.getClass.getName.stripSuffix("$") + ".escapeLikeRegex"
val pattern = ctx.freshName("pattern")
if (right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(escape(rVal.asInstanceOf[UTF8String].toString()))
ctx.addMutableState(patternClass, pattern,
s"""$pattern = ${patternClass}.compile("$regexStr");""")
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
ev.copy(code = s"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).matches();
}
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
""")
}
} else {
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($escapeFunc($rightStr));
${ev.value} = $pattern.matcher(${eval1}.toString()).matches();
"""
})
}
}
}
@ExpressionDescription(
usage = "str _FUNC_ regexp - Returns true if `str` matches `regexp`, or false otherwise.")
case class RLike(left: Expression, right: Expression) extends StringRegexExpression {
override def escape(v: String): String = v
override def matches(regex: Pattern, str: String): Boolean = regex.matcher(str).find(0)
override def toString: String = s"$left RLIKE $right"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val patternClass = classOf[Pattern].getName
val pattern = ctx.freshName("pattern")
if (right.foldable) {
val rVal = right.eval()
if (rVal != null) {
val regexStr =
StringEscapeUtils.escapeJava(rVal.asInstanceOf[UTF8String].toString())
ctx.addMutableState(patternClass, pattern,
s"""$pattern = ${patternClass}.compile("$regexStr");""")
// We don't use nullSafeCodeGen here because we don't want to re-evaluate right again.
val eval = left.genCode(ctx)
ev.copy(code = s"""
${eval.code}
boolean ${ev.isNull} = ${eval.isNull};
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
if (!${ev.isNull}) {
${ev.value} = $pattern.matcher(${eval.value}.toString()).find(0);
}
""")
} else {
ev.copy(code = s"""
boolean ${ev.isNull} = true;
${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
""")
}
} else {
val rightStr = ctx.freshName("rightStr")
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
s"""
String $rightStr = ${eval2}.toString();
${patternClass} $pattern = ${patternClass}.compile($rightStr);
${ev.value} = $pattern.matcher(${eval1}.toString()).find(0);
"""
})
}
}
}
/**
* Splits str around pat (pattern is a regular expression).
*/
@ExpressionDescription(
usage = "_FUNC_(str, regex) - Splits `str` around occurrences that match `regex`.",
extended = """
Examples:
> SELECT _FUNC_('oneAtwoBthreeC', '[ABC]');
["one","two","three",""]
""")
case class StringSplit(str: Expression, pattern: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def left: Expression = str
override def right: Expression = pattern
override def dataType: DataType = ArrayType(StringType)
override def inputTypes: Seq[DataType] = Seq(StringType, StringType)
override def nullSafeEval(string: Any, regex: Any): Any = {
val strings = string.asInstanceOf[UTF8String].split(regex.asInstanceOf[UTF8String], -1)
new GenericArrayData(strings.asInstanceOf[Array[Any]])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val arrayClass = classOf[GenericArrayData].getName
nullSafeCodeGen(ctx, ev, (str, pattern) =>
// Array in java is covariant, so we don't need to cast UTF8String[] to Object[].
s"""${ev.value} = new $arrayClass($str.split($pattern, -1));""")
}
override def prettyName: String = "split"
}
/**
* Replace all substrings of str that match regexp with rep.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(str, regexp, rep) - Replaces all substrings of `str` that match `regexp` with `rep`.",
extended = """
Examples:
> SELECT _FUNC_('100-200', '(\d+)', 'num');
num-num
""")
// scalastyle:on line.size.limit
case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _
// last replacement string, we don't want to convert a UTF8String => java.langString every time.
@transient private var lastReplacement: String = _
@transient private var lastReplacementInUTF8: UTF8String = _
// result buffer write by Matcher
@transient private lazy val result: StringBuffer = new StringBuffer
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
if (!r.equals(lastReplacementInUTF8)) {
// replacement string changed
lastReplacementInUTF8 = r.asInstanceOf[UTF8String].clone()
lastReplacement = lastReplacementInUTF8.toString
}
val m = pattern.matcher(s.toString())
result.delete(0, result.length())
while (m.find) {
m.appendReplacement(result, lastReplacement)
}
m.appendTail(result)
UTF8String.fromString(result.toString)
}
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
override def children: Seq[Expression] = subject :: regexp :: rep :: Nil
override def prettyName: String = "regexp_replace"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val termLastRegex = ctx.freshName("lastRegex")
val termPattern = ctx.freshName("pattern")
val termLastReplacement = ctx.freshName("lastReplacement")
val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
val termResult = ctx.freshName("result")
val classNamePattern = classOf[Pattern].getCanonicalName
val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
val matcher = ctx.freshName("matcher")
ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
ctx.addMutableState("String", termLastReplacement, s"${termLastReplacement} = null;")
ctx.addMutableState("UTF8String",
termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
ctx.addMutableState(classNameStringBuffer,
termResult, s"${termResult} = new $classNameStringBuffer();")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}
nullSafeCodeGen(ctx, ev, (subject, regexp, rep) => {
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
if (!$rep.equals(${termLastReplacementInUTF8})) {
// replacement string changed
${termLastReplacementInUTF8} = $rep.clone();
${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
}
${termResult}.delete(0, ${termResult}.length());
java.util.regex.Matcher ${matcher} = ${termPattern}.matcher($subject.toString());
while (${matcher}.find()) {
${matcher}.appendReplacement(${termResult}, ${termLastReplacement});
}
${matcher}.appendTail(${termResult});
${ev.value} = UTF8String.fromString(${termResult}.toString());
$setEvNotNull
"""
})
}
}
/**
* Extract a specific(idx) group identified by a Java regex.
*
* NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
*/
@ExpressionDescription(
usage = "_FUNC_(str, regexp[, idx]) - Extracts a group that matches `regexp`.",
extended = """
Examples:
> SELECT _FUNC_('100-200', '(\d+)-(\d+)', 1);
100
""")
case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
def this(s: Expression, r: Expression) = this(s, r, Literal(1))
// last regex in string, we will update the pattern iff regexp value changed.
@transient private var lastRegex: UTF8String = _
// last regex pattern, we cache it for performance concern
@transient private var pattern: Pattern = _
override def nullSafeEval(s: Any, p: Any, r: Any): Any = {
if (!p.equals(lastRegex)) {
// regex value changed
lastRegex = p.asInstanceOf[UTF8String].clone()
pattern = Pattern.compile(lastRegex.toString)
}
val m = pattern.matcher(s.toString)
if (m.find) {
val mr: MatchResult = m.toMatchResult
val group = mr.group(r.asInstanceOf[Int])
if (group == null) { // Pattern matched, but not optional group
UTF8String.EMPTY_UTF8
} else {
UTF8String.fromString(group)
}
} else {
UTF8String.EMPTY_UTF8
}
}
override def dataType: DataType = StringType
override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
override def prettyName: String = "regexp_extract"
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val termLastRegex = ctx.freshName("lastRegex")
val termPattern = ctx.freshName("pattern")
val classNamePattern = classOf[Pattern].getCanonicalName
val matcher = ctx.freshName("matcher")
val matchResult = ctx.freshName("matchResult")
ctx.addMutableState("UTF8String", termLastRegex, s"${termLastRegex} = null;")
ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
val setEvNotNull = if (nullable) {
s"${ev.isNull} = false;"
} else {
""
}
nullSafeCodeGen(ctx, ev, (subject, regexp, idx) => {
s"""
if (!$regexp.equals(${termLastRegex})) {
// regex value changed
${termLastRegex} = $regexp.clone();
${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
}
java.util.regex.Matcher ${matcher} =
${termPattern}.matcher($subject.toString());
if (${matcher}.find()) {
java.util.regex.MatchResult ${matchResult} = ${matcher}.toMatchResult();
if (${matchResult}.group($idx) == null) {
${ev.value} = UTF8String.EMPTY_UTF8;
} else {
${ev.value} = UTF8String.fromString(${matchResult}.group($idx));
}
$setEvNotNull
} else {
${ev.value} = UTF8String.EMPTY_UTF8;
$setEvNotNull
}"""
})
}
}