aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-07-21 00:48:07 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-21 00:48:07 -0700
commit8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369 (patch)
treef4a7c28f662757bc341407642c7f90357f1d4b79
parentd38c5029a2ca845e2782096044a6412b653c9f95 (diff)
downloadspark-8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369.tar.gz
spark-8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369.tar.bz2
spark-8c8f0ef59e12b6f13d5a0bf2d7bf1248b5c1e369.zip
[SPARK-8255] [SPARK-8256] [SQL] Add regex_extract/regex_replace
Add expressions `regex_extract` & `regex_replace` Author: Cheng Hao <hao.cheng@intel.com> Closes #7468 from chenghao-intel/regexp and squashes the following commits: e5ea476 [Cheng Hao] minor update for documentation ef96fd6 [Cheng Hao] update the code gen 72cf28f [Cheng Hao] Add more log for compilation error 4e11381 [Cheng Hao] Add regexp_replace / regexp_extract support
-rw-r--r--python/pyspark/sql/functions.py30
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala5
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala217
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala1
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala35
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala16
8 files changed, 323 insertions, 4 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 031745a1c4..3c134faa0a 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -46,6 +46,8 @@ __all__ = [
'monotonicallyIncreasingId',
'rand',
'randn',
+ 'regexp_extract',
+ 'regexp_replace',
'sha1',
'sha2',
'sparkPartitionId',
@@ -345,6 +347,34 @@ def levenshtein(left, right):
@ignore_unicode_prefix
@since(1.5)
+def regexp_extract(str, pattern, idx):
+ """Extract a specific(idx) group identified by a java regex, from the specified string column.
+
+ >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
+ >>> df.select(regexp_extract('str', '(\d+)-(\d+)', 1).alias('d')).collect()
+ [Row(d=u'100')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.regexp_extract(_to_java_column(str), pattern, idx)
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
+def regexp_replace(str, pattern, replacement):
+ """Replace all substrings of the specified string value that match regexp with rep.
+
+ >>> df = sqlContext.createDataFrame([('100-200',)], ['str'])
+ >>> df.select(regexp_replace('str', '(\\d+)', '##').alias('d')).collect()
+ [Row(d=u'##-##')]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.regexp_replace(_to_java_column(str), pattern, replacement)
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
def md5(col):
"""Calculates the MD5 digest and returns the value as a 32 character hex string.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 71e87b98d8..aec392379c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -161,6 +161,8 @@ object FunctionRegistry {
expression[Lower]("lower"),
expression[Length]("length"),
expression[Levenshtein]("levenshtein"),
+ expression[RegExpExtract]("regexp_extract"),
+ expression[RegExpReplace]("regexp_replace"),
expression[StringInstr]("instr"),
expression[StringLocate]("locate"),
expression[StringLPad]("lpad"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
index 606f770cb4..319dcd1c04 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala
@@ -297,8 +297,9 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
evaluator.cook(code)
} catch {
case e: Exception =>
- logError(s"failed to compile:\n $code", e)
- throw e
+ val msg = s"failed to compile:\n $code"
+ logError(msg, e)
+ throw new Exception(msg, e)
}
evaluator.getClazz().newInstance().asInstanceOf[GeneratedClass]
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 92fefe1585..fe57d17f1e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.text.DecimalFormat
import java.util.Locale
-import java.util.regex.Pattern
+import java.util.regex.{MatchResult, Pattern}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
@@ -877,6 +877,221 @@ case class Encode(value: Expression, charset: Expression)
}
/**
+ * Replace all substrings of str that match regexp with rep.
+ *
+ * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
+ */
+case class RegExpReplace(subject: Expression, regexp: Expression, rep: Expression)
+ extends Expression with ImplicitCastInputTypes {
+
+ // last regex in string, we will update the pattern iff regexp value changed.
+ @transient private var lastRegex: UTF8String = _
+ // last regex pattern, we cache it for performance concern
+ @transient private var pattern: Pattern = _
+ // last replacement string, we don't want to convert a UTF8String => java.langString every time.
+ @transient private var lastReplacement: String = _
+ @transient private var lastReplacementInUTF8: UTF8String = _
+ // result buffer write by Matcher
+ @transient private val result: StringBuffer = new StringBuffer
+
+ override def nullable: Boolean = subject.nullable || regexp.nullable || rep.nullable
+ override def foldable: Boolean = subject.foldable && regexp.foldable && rep.foldable
+
+ override def eval(input: InternalRow): Any = {
+ val s = subject.eval(input)
+ if (null != s) {
+ val p = regexp.eval(input)
+ if (null != p) {
+ val r = rep.eval(input)
+ if (null != r) {
+ if (!p.equals(lastRegex)) {
+ // regex value changed
+ lastRegex = p.asInstanceOf[UTF8String]
+ pattern = Pattern.compile(lastRegex.toString)
+ }
+ if (!r.equals(lastReplacementInUTF8)) {
+ // replacement string changed
+ lastReplacementInUTF8 = r.asInstanceOf[UTF8String]
+ lastReplacement = lastReplacementInUTF8.toString
+ }
+ val m = pattern.matcher(s.toString())
+ result.delete(0, result.length())
+
+ while (m.find) {
+ m.appendReplacement(result, lastReplacement)
+ }
+ m.appendTail(result)
+
+ return UTF8String.fromString(result.toString)
+ }
+ }
+ }
+
+ null
+ }
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, StringType)
+ override def children: Seq[Expression] = subject :: regexp :: rep :: Nil
+ override def prettyName: String = "regexp_replace"
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val termLastRegex = ctx.freshName("lastRegex")
+ val termPattern = ctx.freshName("pattern")
+
+ val termLastReplacement = ctx.freshName("lastReplacement")
+ val termLastReplacementInUTF8 = ctx.freshName("lastReplacementInUTF8")
+
+ val termResult = ctx.freshName("result")
+
+ val classNameUTF8String = classOf[UTF8String].getCanonicalName
+ val classNamePattern = classOf[Pattern].getCanonicalName
+ val classNameString = classOf[java.lang.String].getCanonicalName
+ val classNameStringBuffer = classOf[java.lang.StringBuffer].getCanonicalName
+
+ ctx.addMutableState(classNameUTF8String,
+ termLastRegex, s"${termLastRegex} = null;")
+ ctx.addMutableState(classNamePattern,
+ termPattern, s"${termPattern} = null;")
+ ctx.addMutableState(classNameString,
+ termLastReplacement, s"${termLastReplacement} = null;")
+ ctx.addMutableState(classNameUTF8String,
+ termLastReplacementInUTF8, s"${termLastReplacementInUTF8} = null;")
+ ctx.addMutableState(classNameStringBuffer,
+ termResult, s"${termResult} = new $classNameStringBuffer();")
+
+ val evalSubject = subject.gen(ctx)
+ val evalRegexp = regexp.gen(ctx)
+ val evalRep = rep.gen(ctx)
+
+ s"""
+ ${evalSubject.code}
+ boolean ${ev.isNull} = ${evalSubject.isNull};
+ ${ctx.javaType(dataType)} ${ev.primitive} = ${ctx.defaultValue(dataType)};
+ if (!${evalSubject.isNull}) {
+ ${evalRegexp.code}
+ if (!${evalRegexp.isNull}) {
+ ${evalRep.code}
+ if (!${evalRep.isNull}) {
+ if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
+ // regex value changed
+ ${termLastRegex} = ${evalRegexp.primitive};
+ ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
+ }
+ if (!${evalRep.primitive}.equals(${termLastReplacementInUTF8})) {
+ // replacement string changed
+ ${termLastReplacementInUTF8} = ${evalRep.primitive};
+ ${termLastReplacement} = ${termLastReplacementInUTF8}.toString();
+ }
+ ${termResult}.delete(0, ${termResult}.length());
+ ${classOf[java.util.regex.Matcher].getCanonicalName} m =
+ ${termPattern}.matcher(${evalSubject.primitive}.toString());
+
+ while (m.find()) {
+ m.appendReplacement(${termResult}, ${termLastReplacement});
+ }
+ m.appendTail(${termResult});
+ ${ev.primitive} = ${classNameUTF8String}.fromString(${termResult}.toString());
+ ${ev.isNull} = false;
+ }
+ }
+ }
+ """
+ }
+}
+
+/**
+ * Extract a specific(idx) group identified by a Java regex.
+ *
+ * NOTE: this expression is not THREAD-SAFE, as it has some internal mutable status.
+ */
+case class RegExpExtract(subject: Expression, regexp: Expression, idx: Expression)
+ extends Expression with ImplicitCastInputTypes {
+ def this(s: Expression, r: Expression) = this(s, r, Literal(1))
+
+ // last regex in string, we will update the pattern iff regexp value changed.
+ @transient private var lastRegex: UTF8String = _
+ // last regex pattern, we cache it for performance concern
+ @transient private var pattern: Pattern = _
+
+ override def nullable: Boolean = subject.nullable || regexp.nullable || idx.nullable
+ override def foldable: Boolean = subject.foldable && regexp.foldable && idx.foldable
+
+ override def eval(input: InternalRow): Any = {
+ val s = subject.eval(input)
+ if (null != s) {
+ val p = regexp.eval(input)
+ if (null != p) {
+ val r = idx.eval(input)
+ if (null != r) {
+ if (!p.equals(lastRegex)) {
+ // regex value changed
+ lastRegex = p.asInstanceOf[UTF8String]
+ pattern = Pattern.compile(lastRegex.toString)
+ }
+ val m = pattern.matcher(s.toString())
+ if (m.find) {
+ val mr: MatchResult = m.toMatchResult
+ return UTF8String.fromString(mr.group(r.asInstanceOf[Int]))
+ }
+ return UTF8String.EMPTY_UTF8
+ }
+ }
+ }
+
+ null
+ }
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType, IntegerType)
+ override def children: Seq[Expression] = subject :: regexp :: idx :: Nil
+ override def prettyName: String = "regexp_extract"
+
+ override protected def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val termLastRegex = ctx.freshName("lastRegex")
+ val termPattern = ctx.freshName("pattern")
+ val classNameUTF8String = classOf[UTF8String].getCanonicalName
+ val classNamePattern = classOf[Pattern].getCanonicalName
+
+ ctx.addMutableState(classNameUTF8String, termLastRegex, s"${termLastRegex} = null;")
+ ctx.addMutableState(classNamePattern, termPattern, s"${termPattern} = null;")
+
+ val evalSubject = subject.gen(ctx)
+ val evalRegexp = regexp.gen(ctx)
+ val evalIdx = idx.gen(ctx)
+
+ s"""
+ ${ctx.javaType(dataType)} ${ev.primitive} = null;
+ boolean ${ev.isNull} = true;
+ ${evalSubject.code}
+ if (!${evalSubject.isNull}) {
+ ${evalRegexp.code}
+ if (!${evalRegexp.isNull}) {
+ ${evalIdx.code}
+ if (!${evalIdx.isNull}) {
+ if (!${evalRegexp.primitive}.equals(${termLastRegex})) {
+ // regex value changed
+ ${termLastRegex} = ${evalRegexp.primitive};
+ ${termPattern} = ${classNamePattern}.compile(${termLastRegex}.toString());
+ }
+ ${classOf[java.util.regex.Matcher].getCanonicalName} m =
+ ${termPattern}.matcher(${evalSubject.primitive}.toString());
+ if (m.find()) {
+ ${classOf[java.util.regex.MatchResult].getCanonicalName} mr = m.toMatchResult();
+ ${ev.primitive} = ${classNameUTF8String}.fromString(mr.group(${evalIdx.primitive}));
+ ${ev.isNull} = false;
+ } else {
+ ${ev.primitive} = ${classNameUTF8String}.EMPTY_UTF8;
+ ${ev.isNull} = false;
+ }
+ }
+ }
+ }
+ """
+ }
+}
+
+/**
* Formats the number X to a format like '#,###,###.##', rounded to D decimal places,
* and returns the result as a string. If D is 0, the result has no decimal point or
* fractional part.
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
index 7a96044d35..6e17ffcda9 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala
@@ -79,7 +79,6 @@ trait ExpressionEvalHelper {
fail(
s"""
|Code generation of $expression failed:
- |${evaluated.code}
|$e
""".stripMargin)
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 67d97cd30b..96c540ab36 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -464,6 +464,41 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(StringSpace(s1), null, row2)
}
+ test("RegexReplace") {
+ val row1 = create_row("100-200", "(\\d+)", "num")
+ val row2 = create_row("100-200", "(\\d+)", "###")
+ val row3 = create_row("100-200", "(-)", "###")
+
+ val s = 's.string.at(0)
+ val p = 'p.string.at(1)
+ val r = 'r.string.at(2)
+
+ val expr = RegExpReplace(s, p, r)
+ checkEvaluation(expr, "num-num", row1)
+ checkEvaluation(expr, "###-###", row2)
+ checkEvaluation(expr, "100###200", row3)
+ }
+
+ test("RegexExtract") {
+ val row1 = create_row("100-200", "(\\d+)-(\\d+)", 1)
+ val row2 = create_row("100-200", "(\\d+)-(\\d+)", 2)
+ val row3 = create_row("100-200", "(\\d+).*", 1)
+ val row4 = create_row("100-200", "([a-z])", 1)
+
+ val s = 's.string.at(0)
+ val p = 'p.string.at(1)
+ val r = 'r.int.at(2)
+
+ val expr = RegExpExtract(s, p, r)
+ checkEvaluation(expr, "100", row1)
+ checkEvaluation(expr, "200", row2)
+ checkEvaluation(expr, "100", row3)
+ checkEvaluation(expr, "", row4) // will not match anything, empty string get
+
+ val expr1 = new RegExpExtract(s, p)
+ checkEvaluation(expr1, "100", row1)
+ }
+
test("SPLIT") {
val s1 = 'a.string.at(0)
val s2 = 'b.string.at(1)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 8fa017610b..6d60dae624 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1781,6 +1781,27 @@ object functions {
StringLocate(lit(substr).expr, str.expr, lit(pos).expr)
}
+
+ /**
+ * Extract a specific(idx) group identified by a java regex, from the specified string column.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def regexp_extract(e: Column, exp: String, groupIdx: Int): Column = {
+ RegExpExtract(e.expr, lit(exp).expr, lit(groupIdx).expr)
+ }
+
+ /**
+ * Replace all substrings of the specified string value that match regexp with rep.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def regexp_replace(e: Column, pattern: String, replacement: String): Column = {
+ RegExpReplace(e.expr, lit(pattern).expr, lit(replacement).expr)
+ }
+
/**
* Computes the BASE64 encoding of a binary column and returns it as a string column.
* This is the reverse of unbase64.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index 4551192b15..d1f855903c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -56,6 +56,22 @@ class StringFunctionsSuite extends QueryTest {
checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
}
+ test("string regex_replace / regex_extract") {
+ val df = Seq(("100-200", "")).toDF("a", "b")
+
+ checkAnswer(
+ df.select(
+ regexp_replace($"a", "(\\d+)", "num"),
+ regexp_extract($"a", "(\\d+)-(\\d+)", 1)),
+ Row("num-num", "100"))
+
+ checkAnswer(
+ df.selectExpr(
+ "regexp_replace(a, '(\\d+)', 'num')",
+ "regexp_extract(a, '(\\d+)-(\\d+)', 2)"),
+ Row("num-num", "200"))
+ }
+
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(