aboutsummaryrefslogtreecommitdiff
path: root/sql/catalyst
diff options
context:
space:
mode:
authorzhichao.li <zhichao.li@intel.com>2015-08-06 09:02:30 -0700
committerDavies Liu <davies.liu@gmail.com>2015-08-06 09:02:30 -0700
commitaead18ffca36830e854fba32a1cac11a0b2e31d5 (patch)
tree172fe1d83e6c691abad40c4114662db84fe8fbf2 /sql/catalyst
parentd5a9af3230925c347d0904fe7f2402e468e80bc8 (diff)
downloadspark-aead18ffca36830e854fba32a1cac11a0b2e31d5.tar.gz
spark-aead18ffca36830e854fba32a1cac11a0b2e31d5.tar.bz2
spark-aead18ffca36830e854fba32a1cac11a0b2e31d5.zip
[SPARK-8266] [SQL] add function translate
![translate](http://www.w3resource.com/PostgreSQL/postgresql-translate-function.png) Author: zhichao.li <zhichao.li@intel.com> Closes #7709 from zhichao-li/translate and squashes the following commits: 9418088 [zhichao.li] refine checking condition f2ab77a [zhichao.li] clone string 9d88f2d [zhichao.li] fix indent 6aa2962 [zhichao.li] style e575ead [zhichao.li] add python api 9d4bab0 [zhichao.li] add special case for fodable and refactor unittest eda7ad6 [zhichao.li] update to use TernaryExpression cdfd4be [zhichao.li] add function translate
Diffstat (limited to 'sql/catalyst')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala79
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala14
4 files changed, 95 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index 94c355f838..cd5a90d788 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -203,6 +203,7 @@ object FunctionRegistry {
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[SubstringIndex]("substring_index"),
+ expression[StringTranslate]("translate"),
expression[StringTrim]("trim"),
expression[UnBase64]("unbase64"),
expression[Upper]("ucase"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
index ef2fc2e8c2..0b98f555a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala
@@ -444,7 +444,7 @@ abstract class TernaryExpression extends Expression {
override def nullable: Boolean = children.exists(_.nullable)
/**
- * Default behavior of evaluation according to the default nullability of BinaryExpression.
+ * Default behavior of evaluation according to the default nullability of TernaryExpression.
* If subclass of BinaryExpression override nullable, probably should also override this.
*/
override def eval(input: InternalRow): Any = {
@@ -463,7 +463,7 @@ abstract class TernaryExpression extends Expression {
}
/**
- * Called by default [[eval]] implementation. If subclass of BinaryExpression keep the default
+ * Called by default [[eval]] implementation. If subclass of TernaryExpression keep the default
* nullability, they can override this method to save null-check code. If we need full control
* of evaluation process, we should override [[eval]].
*/
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 0cc785d9f3..76666bd6b3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -18,7 +18,9 @@
package org.apache.spark.sql.catalyst.expressions
import java.text.DecimalFormat
-import java.util.{Arrays, Locale}
+import java.util.Arrays
+import java.util.{Map => JMap, HashMap}
+import java.util.Locale
import java.util.regex.{MatchResult, Pattern}
import org.apache.commons.lang3.StringEscapeUtils
@@ -349,6 +351,81 @@ case class EndsWith(left: Expression, right: Expression)
}
}
+object StringTranslate {
+
+ def buildDict(matchingString: UTF8String, replaceString: UTF8String)
+ : JMap[Character, Character] = {
+ val matching = matchingString.toString()
+ val replace = replaceString.toString()
+ val dict = new HashMap[Character, Character]()
+ var i = 0
+ while (i < matching.length()) {
+ val rep = if (i < replace.length()) replace.charAt(i) else '\0'
+ if (null == dict.get(matching.charAt(i))) {
+ dict.put(matching.charAt(i), rep)
+ }
+ i += 1
+ }
+ dict
+ }
+}
+
+/**
+ * A function translate any character in the `srcExpr` by a character in `replaceExpr`.
+ * The characters in `replaceExpr` is corresponding to the characters in `matchingExpr`.
+ * The translate will happen when any character in the string matching with the character
+ * in the `matchingExpr`.
+ */
+case class StringTranslate(srcExpr: Expression, matchingExpr: Expression, replaceExpr: Expression)
+ extends TernaryExpression with ImplicitCastInputTypes {
+
+ @transient private var lastMatching: UTF8String = _
+ @transient private var lastReplace: UTF8String = _
+ @transient private var dict: JMap[Character, Character] = _
+
+ override def nullSafeEval(srcEval: Any, matchingEval: Any, replaceEval: Any): Any = {
+ if (matchingEval != lastMatching || replaceEval != lastReplace) {
+ lastMatching = matchingEval.asInstanceOf[UTF8String].clone()
+ lastReplace = replaceEval.asInstanceOf[UTF8String].clone()
+ dict = StringTranslate.buildDict(lastMatching, lastReplace)
+ }
+ srcEval.asInstanceOf[UTF8String].translate(dict)
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val termLastMatching = ctx.freshName("lastMatching")
+ val termLastReplace = ctx.freshName("lastReplace")
+ val termDict = ctx.freshName("dict")
+ val classNameDict = classOf[JMap[Character, Character]].getCanonicalName
+
+ ctx.addMutableState("UTF8String", termLastMatching, s"${termLastMatching} = null;")
+ ctx.addMutableState("UTF8String", termLastReplace, s"${termLastReplace} = null;")
+ ctx.addMutableState(classNameDict, termDict, s"${termDict} = null;")
+
+ nullSafeCodeGen(ctx, ev, (src, matching, replace) => {
+ val check = if (matchingExpr.foldable && replaceExpr.foldable) {
+ s"${termDict} == null"
+ } else {
+ s"!${matching}.equals(${termLastMatching}) || !${replace}.equals(${termLastReplace})"
+ }
+ s"""if ($check) {
+ // Not all of them is literal or matching or replace value changed
+ ${termLastMatching} = ${matching}.clone();
+ ${termLastReplace} = ${replace}.clone();
+ ${termDict} = org.apache.spark.sql.catalyst.expressions.StringTranslate
+ .buildDict(${termLastMatching}, ${termLastReplace});
+ }
+ ${ev.primitive} = ${src}.translate(${termDict});
+ """
+ })
+ }
+
+ override def dataType: DataType = StringType
+ override def inputTypes: Seq[DataType] = Seq(StringType, StringType, StringType)
+ override def children: Seq[Expression] = srcExpr :: matchingExpr :: replaceExpr :: Nil
+ override def prettyName: String = "translate"
+}
+
/**
* A function that returns the index (1-based) of the given string (left) in the comma-
* delimited list (right). Returns 0, if the string wasn't found or if the given
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index 23f36ca43d..426dc27247 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -431,6 +431,20 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(SoundEx(Literal("!!")), "!!")
}
+ test("translate") {
+ checkEvaluation(
+ StringTranslate(Literal("translate"), Literal("rnlt"), Literal("123")), "1a2s3ae")
+ checkEvaluation(StringTranslate(Literal("translate"), Literal(""), Literal("123")), "translate")
+ checkEvaluation(StringTranslate(Literal("translate"), Literal("rnlt"), Literal("")), "asae")
+ // test for multiple mapping
+ checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("123")), "12cd")
+ checkEvaluation(StringTranslate(Literal("abcd"), Literal("aba"), Literal("12")), "12cd")
+ // scalastyle:off
+ // non ascii characters are not allowed in the source code, so we disable the scalastyle.
+ checkEvaluation(StringTranslate(Literal("花花世界"), Literal("花界"), Literal("ab")), "aa世b")
+ // scalastyle:on
+ }
+
test("TRIM/LTRIM/RTRIM") {
val s = 'a.string.at(0)
checkEvaluation(StringTrim(Literal(" aa ")), "aa", create_row(" abdef "))