aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py16
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala79
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala6
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java16
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java31
9 files changed, 180 insertions, 8 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 9f0d71d796..b5c6a01f18 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1290,6 +1290,22 @@ def length(col):
return Column(sc._jvm.functions.length(_to_java_column(col)))
+@ignore_unicode_prefix
+@since(1.5)
+def translate(srcCol, matching, replace):
+ """A function translate any character in the `srcCol` by a character in `matching`.
+ The characters in `replace` is corresponding to the characters in `matching`.
+ The translate will happen when any character in the string matching with the character
+ in the `matching`.
+
+ >>> sqlContext.createDataFrame([('translate',)], ['a']).select(translate('a', "rnlt", "123")\
+ .alias('r')).collect()
+ [Row(r=u'1a2s3ae')]
+ """
+ sc = SparkContext._active_spark_context
+ return Column(sc._jvm.functions.translate(_to_java_column(srcCol), matching, replace))
+
+
# ---------------------- Collection functions ------------------------------
@since(1.4)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 94c355f838..cd5a90d788 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -203,6 +203,7 @@ object FunctionRegistry {
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[SubstringIndex]("substring_index"),
+ expression[StringTranslate]("translate"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ef2fc2e8c2..0b98f555a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -444,7 +444,7 @@ abstract class TernaryExpression extends Expression {
override def nullable: Boolean = children.exists(_.nullable)
/**
- * Default behavior of evaluation according to the default nullability of BinaryExpression.
+ * Default behavior of evaluation according to the default nullability of TernaryExpression.
* If subclass of BinaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
@@ -463,7 +463,7 @@ abstract class TernaryExpression extends Expression {
}
/**
- * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
+ * Called by default [[eval]] implementation. If subclass of TernaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 0cc785d9f3..76666bd6b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import java.text.DecimalFormat
-import java.util.{Arrays, Locale}
+import java.util.Arrays
+import java.util.{Map => JMap, HashMap}
+import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
@@ -349,6 +351,81 @@ case class EndsWith(left: Expression, right: Expression)
}
}
+object StringTranslate {
+
+ def buildDict(matchingString: UTF8String, replaceString: UTF8String)
+ : JMap[Character, Character] = {
+ val matching = matchingString.toString()
+ val replace = replaceString.toString()
+ val dict = new HashMap[Character, Character]()
+ var i = 0
+ while (i < matching.length()) {
+ val rep = if (i < replace.length()) replace.charAt(i) else '\0'
+ if (null == dict.get(matching.charAt(i))) {
+ dict.put(matching.charAt(i), rep)
+ }
+ i += 1
+ }
+ dict
+ }
+}
+
+/**
+ * A function translate any character in the `srcExpr` by a character in `replaceExpr`.
+ * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`.
+ * The translate will happen when any character in the string matching with the character
+ * in the `matchingExpr`.
+ */
+case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ @transient private var lastMatching: UTF8String = _
+ @transient private var lastReplace: UTF8String = _
+ @transient private var dict: JMap[Character, Character] = _
+
+ override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
+ if (matchingEval != lastMatching || replaceEval != lastReplace) {
+ lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
+ lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
+ dict = StringTranslate.buildDict(lastMatching, lastReplace)
+ }
+ srcEval.asInstanceOf[UTF8String].translate(dict)
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val termLastMatching = ctx.freshName("lastMatching")
+ val termLastReplace = ctx.freshName("lastReplace")
+ val termDict = ctx.freshName("dict")
+ val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
+
+ ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;")
+ ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;")
+ ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;")
+
+ nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
+ val check = if (matchingExpr.foldable && replaceExpr.foldable) {
+ s"${termDict} == null"
+ } else {
+ s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})"
+ }
+ s"""if ($check) {
+ // Not all of them is literal or matching or replace value changed
+ ${termLastMatching} = ${matching}.clone();
+ ${termLastReplace} = ${replace}.clone();
+ ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate
+ .buildDict(${termLastMatching}, ${termLastReplace});
+ }
+ ${ev.primitive} = ${src}.translate(${termDict});
+ """
+ })
+ }
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
+ override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil
+ override def prettyName: String = "translate"
+}
+
/**
* A function that returns the index (1-based) of the given string (left) in the comma-
* delimited list (right). Returns 0, if the string wasn't found or if the given
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 23f36ca43d..426dc27247 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -431,6 +431,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(SoundEx(Literal("!!")), "!!")
}
+ test("translate") {
+ checkEvaluation(
+ StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")
+ checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate")
+ checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")), "asae")
+ // test for multiple mapping
+ checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd")
+ checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd")
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b")
+ // scalastyle:on
+ }
+
test("TRIM/LTRIM/RTRIM") {
val s = 'a.string.at(0)
checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef "))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index 39aa905c85..79c5f59666 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1100,11 +1100,11 @@ object functions {
}
/**
- * Computes hex value of the given column.
- *
- * @group math_funcs
- * @since 1.5.0
- */
+ * Computes hex value of the given column.
+ *
+ * @group math_funcs
+ * @since 1.5.0
+ */
def hex(column: Column): Column = Hex(column.expr)
/**
@@ -1863,6 +1863,17 @@ object functions {
def substring_index(str: Column, delim: String, count: Int): Column =
SubstringIndex(str.expr, lit(delim).expr, lit(count).expr)
+ /* Translate any character in the src by a character in replaceString.
+ * The characters in replaceString is corresponding to the characters in matchingString.
+ * The translate will happen when any character in the string matching with the character
+ * in the matchingString.
+ *
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def translate(src: Column, matchingString: String, replaceString: String): Column =
+ StringTranslate(src.expr, lit(matchingString).expr, lit(replaceString).expr)
+
/**
* Trim the spaces from both ends for the specified string column.
*
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index ab5da6ee79..ca298b2434 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -128,6 +128,12 @@ class StringFunctionsSuite extends QueryTest {
// scalastyle:on
}
+ test("string translate") {
+ val df = Seq(("translate", "")).toDF("a", "b")
+ checkAnswer(df.select(translate($"a", "rnlt", "123")), Row("1a2s3ae"))
+ checkAnswer(df.selectExpr("""translate(a, "rnlt", "")"""), Row("asae"))
+ }
+
test("string trim functions") {
val df = Seq((" example ", "")).toDF("a", "b")
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index febbe3d4e5..d1014426c0 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -22,6 +22,7 @@ import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.nio.ByteOrder;
import java.util.Arrays;
+import java.util.Map;
import org.apache.spark.unsafe.PlatformDependent;
import org.apache.spark.unsafe.array.ByteArrayMethods;
@@ -795,6 +796,21 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
return res;
}
+ // TODO: Need to use `Code Point` here instead of Char in case the character longer than 2 bytes
+ public UTF8String translate(Map<Character, Character> dict) {
+ String srcStr = this.toString();
+
+ StringBuilder sb = new StringBuilder();
+ for(int k = 0; k< srcStr.length(); k++) {
+ if (null == dict.get(srcStr.charAt(k))) {
+ sb.append(srcStr.charAt(k));
+ } else if ('\0' != dict.get(srcStr.charAt(k))){
+ sb.append(dict.get(srcStr.charAt(k)));
+ }
+ }
+ return fromString(sb.toString());
+ }
+
@Override
public String toString() {
try {
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index b30c94c1c1..98aa8a2469 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -19,7 +19,9 @@ package org.apache.spark.unsafe.types;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
+import java.util.HashMap;
+import com.google.common.collect.ImmutableMap;
import org.junit.Test;
import static junit.framework.Assert.*;
@@ -392,6 +394,35 @@ public class UTF8StringSuite {
}
@Test
+ public void translate() {
+ assertEquals(
+ fromString("1a2s3ae"),
+ fromString("translate").translate(ImmutableMap.of(
+ 'r', '1',
+ 'n', '2',
+ 'l', '3',
+ 't', '\0'
+ )));
+ assertEquals(
+ fromString("translate"),
+ fromString("translate").translate(new HashMap<Character, Character>()));
+ assertEquals(
+ fromString("asae"),
+ fromString("translate").translate(ImmutableMap.of(
+ 'r', '\0',
+ 'n', '\0',
+ 'l', '\0',
+ 't', '\0'
+ )));
+ assertEquals(
+ fromString("aa世b"),
+ fromString("花花世界").translate(ImmutableMap.of(
+ '花', 'a',
+ '界', 'b'
+ )));
+ }
+
+ @Test
public void createBlankString() {
assertEquals(fromString(" "), blankString(1));
assertEquals(fromString(" "), blankString(2));