diff options
author | Cheng Hao <hao.cheng@intel.com> | 2015-07-21 00:48:07 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2015-07-21 00:48:07 -0700 |
commit | 8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369 (patch) | |
tree | f4a7c28f662757bc341407642c7f90357f1d4b79 | |
parent | d38c5029a2ca845e2782096044a6412b653c9f95 (diff) | |
download | spark-8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369.tar.gz spark-8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369.tar.bz2 spark-8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369.zip |
[SPARK-8255] [SPARK-8256] [SQL] Add regex_extract/regex_replace
Add expressions `regex_extract` & `regex_replace`
Author: Cheng Hao <hao.cheng@intel.com>
Closes #7468 from chenghao-intel/regexp and squashes the following commits:
e5ea476 [Cheng Hao] minor update for documentation
ef96fd6 [Cheng Hao] update the code gen
72cf28f [Cheng Hao] Add more log for compilation error
4e11381 [Cheng Hao] Add regexp_replace / regexp_extract support
8 files changed, 323 insertions, 4 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 031745a1c4..3c134faa0a 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -46,6 +46,8 @@ __all__ = [ 'monotonicallyIncreasingId', 'rand', 'randn', + 'regexp_extract', + 'regexp_replace', 'sha1', 'sha2', 'sparkPartitionId', @@ -345,6 +347,34 @@ def levenshtein(left, right): @ignore_unicode_prefix @since(1.5) +def regexp_extract(str, pattern, idx): + """Extract a specific(idx) group identified by a java regex, from the specified string column. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect() + [Row(d=u'100')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) +def regexp_replace(str, pattern, replacement): + """Replace all substrings of the specified string value that match regexp with rep. + + >>> df = sqlContext.createDataFrame([('100-200',)], ['str']) + >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect() + [Row(d=u'##-##')] + """ + sc = SparkContext._active_spark_context + jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement) + return Column(jc) + + +@ignore_unicode_prefix +@since(1.5) def md5(col): """Calculates the MD5 digest and returns the value as a 32 character hex string. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index 71e87b98d8..aec392379c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -161,6 +161,8 @@ object FunctionRegistry { expression[Lower]("lower"), expression[Length]("length"), expression[Levenshtein]("levenshtein"), + expression[RegExpExtract]("regexp_extract"), + expression[RegExpReplace]("regexp_replace"), expression[StringInstr]("instr"), expression[StringLocate]("locate"), expression[StringLPad]("lpad"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 606f770cb4..319dcd1c04 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -297,8 +297,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin evaluator.cook(code) } catch { case e: Exception => - logError(s"failed to compile:\n $code", e) - throw e + val msg = s"failed to compile:\n $code" + logError(msg, e) + throw new Exception(msg, e) } evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass] } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala index 92fefe1585..fe57d17f1e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions import java.text.DecimalFormat import java.util.Locale -import java.util.regex.Pattern +import java.util.regex.{MatchResult, Pattern} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.UnresolvedException @@ -877,6 +877,221 @@ case class Encode(value: Expression, charset: Expression) } /** + * Replace all substrings of str that match regexp with rep. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression) + extends Expression with ImplicitCastInputTypes { + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + // last replacement string, we don't want to convert a UTF8String => java.langString every time. + @transient private var lastReplacement: String = _ + @transient private var lastReplacementInUTF8: UTF8String = _ + // result buffer write by Matcher + @transient private val result: StringBuffer = new StringBuffer + + override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable + override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable + + override def eval(input: InternalRow): Any = { + val s = subject.eval(input) + if (null != s) { + val p = regexp.eval(input) + if (null != p) { + val r = rep.eval(input) + if (null != r) { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + if (!r.equals(lastReplacementInUTF8)) { + // replacement string changed + lastReplacementInUTF8 = r.asInstanceOf[UTF8String] + lastReplacement = lastReplacementInUTF8.toString + } + val m = pattern.matcher(s.toString()) + result.delete(0, result.length()) + + while (m.find) { + m.appendReplacement(result, lastReplacement) + } + m.appendTail(result) + + return UTF8String.fromString(result.toString) + } + } + } + + null + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType) + override def children: Seq[Expression] = subject :: regexp :: rep :: Nil + override def prettyName: String = "regexp_replace" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + + val termLastReplacement = ctx.freshName("lastReplacement") + val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8") + + val termResult = ctx.freshName("result") + + val classNameUTF8String = classOf[UTF8String].getCanonicalName + val classNamePattern = classOf[Pattern].getCanonicalName + val classNameString = classOf[java.lang.String].getCanonicalName + val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName + + ctx.addMutableState(classNameUTF8String, + termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, + termPattern, s"${termPattern} = null;") + ctx.addMutableState(classNameString, + termLastReplacement, s"${termLastReplacement} = null;") + ctx.addMutableState(classNameUTF8String, + termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;") + ctx.addMutableState(classNameStringBuffer, + termResult, s"${termResult} = new $classNameStringBuffer();") + + val evalSubject = subject.gen(ctx) + val evalRegexp = regexp.gen(ctx) + val evalRep = rep.gen(ctx) + + s""" + ${evalSubject.code} + boolean ${ev.isNull} = ${evalSubject.isNull}; + ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)}; + if (!${evalSubject.isNull}) { + ${evalRegexp.code} + if (!${evalRegexp.isNull}) { + ${evalRep.code} + if (!${evalRep.isNull}) { + if (!${evalRegexp.primitive}.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = ${evalRegexp.primitive}; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) { + // replacement string changed + ${termLastReplacementInUTF8} = ${evalRep.primitive}; + ${termLastReplacement} = ${termLastReplacementInUTF8}.toString(); + } + ${termResult}.delete(0, ${termResult}.length()); + ${classOf[java.util.regex.Matcher].getCanonicalName} m = + ${termPattern}.matcher(${evalSubject.primitive}.toString()); + + while (m.find()) { + m.appendReplacement(${termResult}, ${termLastReplacement}); + } + m.appendTail(${termResult}); + ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString()); + ${ev.isNull} = false; + } + } + } + """ + } +} + +/** + * Extract a specific(idx) group identified by a Java regex. + * + * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status. + */ +case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression) + extends Expression with ImplicitCastInputTypes { + def this(s: Expression, r: Expression) = this(s, r, Literal(1)) + + // last regex in string, we will update the pattern iff regexp value changed. + @transient private var lastRegex: UTF8String = _ + // last regex pattern, we cache it for performance concern + @transient private var pattern: Pattern = _ + + override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable + override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable + + override def eval(input: InternalRow): Any = { + val s = subject.eval(input) + if (null != s) { + val p = regexp.eval(input) + if (null != p) { + val r = idx.eval(input) + if (null != r) { + if (!p.equals(lastRegex)) { + // regex value changed + lastRegex = p.asInstanceOf[UTF8String] + pattern = Pattern.compile(lastRegex.toString) + } + val m = pattern.matcher(s.toString()) + if (m.find) { + val mr: MatchResult = m.toMatchResult + return UTF8String.fromString(mr.group(r.asInstanceOf[Int])) + } + return UTF8String.EMPTY_UTF8 + } + } + } + + null + } + + override def dataType: DataType = StringType + override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType) + override def children: Seq[Expression] = subject :: regexp :: idx :: Nil + override def prettyName: String = "regexp_extract" + + override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = { + val termLastRegex = ctx.freshName("lastRegex") + val termPattern = ctx.freshName("pattern") + val classNameUTF8String = classOf[UTF8String].getCanonicalName + val classNamePattern = classOf[Pattern].getCanonicalName + + ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;") + ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;") + + val evalSubject = subject.gen(ctx) + val evalRegexp = regexp.gen(ctx) + val evalIdx = idx.gen(ctx) + + s""" + ${ctx.javaType(dataType)} ${ev.primitive} = null; + boolean ${ev.isNull} = true; + ${evalSubject.code} + if (!${evalSubject.isNull}) { + ${evalRegexp.code} + if (!${evalRegexp.isNull}) { + ${evalIdx.code} + if (!${evalIdx.isNull}) { + if (!${evalRegexp.primitive}.equals(${termLastRegex})) { + // regex value changed + ${termLastRegex} = ${evalRegexp.primitive}; + ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString()); + } + ${classOf[java.util.regex.Matcher].getCanonicalName} m = + ${termPattern}.matcher(${evalSubject.primitive}.toString()); + if (m.find()) { + ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult(); + ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive})); + ${ev.isNull} = false; + } else { + ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8; + ${ev.isNull} = false; + } + } + } + } + """ + } +} + +/** * Formats the number X to a format like '#,###,###.##', rounded to D decimal places, * and returns the result as a string. If D is 0, the result has no decimal point or * fractional part. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index 7a96044d35..6e17ffcda9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -79,7 +79,6 @@ trait ExpressionEvalHelper { fail( s""" |Code generation of $expression failed: - |${evaluated.code} |$e """.stripMargin) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 67d97cd30b..96c540ab36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -464,6 +464,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(StringSpace(s1), null, row2) } + test("RegexReplace") { + val row1 = create_row("100-200", "(\\d+)", "num") + val row2 = create_row("100-200", "(\\d+)", "###") + val row3 = create_row("100-200", "(-)", "###") + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.string.at(2) + + val expr = RegExpReplace(s, p, r) + checkEvaluation(expr, "num-num", row1) + checkEvaluation(expr, "###-###", row2) + checkEvaluation(expr, "100###200", row3) + } + + test("RegexExtract") { + val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1) + val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2) + val row3 = create_row("100-200", "(\\d+).*", 1) + val row4 = create_row("100-200", "([a-z])", 1) + + val s = 's.string.at(0) + val p = 'p.string.at(1) + val r = 'r.int.at(2) + + val expr = RegExpExtract(s, p, r) + checkEvaluation(expr, "100", row1) + checkEvaluation(expr, "200", row2) + checkEvaluation(expr, "100", row3) + checkEvaluation(expr, "", row4) // will not match anything, empty string get + + val expr1 = new RegExpExtract(s, p) + checkEvaluation(expr1, "100", row1) + } + test("SPLIT") { val s1 = 'a.string.at(0) val s2 = 'b.string.at(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 8fa017610b..6d60dae624 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1781,6 +1781,27 @@ object functions { StringLocate(lit(substr).expr, str.expr, lit(pos).expr) } + + /** + * Extract a specific(idx) group identified by a java regex, from the specified string column. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = { + RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr) + } + + /** + * Replace all substrings of the specified string value that match regexp with rep. + * + * @group string_funcs + * @since 1.5.0 + */ + def regexp_replace(e: Column, pattern: String, replacement: String): Column = { + RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr) + } + /** * Computes the BASE64 encoding of a binary column and returns it as a string column. * This is the reverse of unbase64. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala index 4551192b15..d1f855903c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala @@ -56,6 +56,22 @@ class StringFunctionsSuite extends QueryTest { checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1))) } + test("string regex_replace / regex_extract") { + val df = Seq(("100-200", "")).toDF("a", "b") + + checkAnswer( + df.select( + regexp_replace($"a", "(\\d+)", "num"), + regexp_extract($"a", "(\\d+)-(\\d+)", 1)), + Row("num-num", "100")) + + checkAnswer( + df.selectExpr( + "regexp_replace(a, '(\\d+)', 'num')", + "regexp_extract(a, '(\\d+)-(\\d+)', 2)"), + Row("num-num", "200")) + } + test("string ascii function") { val df = Seq(("abc", "")).toDF("a", "b") checkAnswer( |