aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala9
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala5
-rw-r--r--unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java66
-rw-r--r--unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java24
4 files changed, 97 insertions, 7 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 47fc7cdaa8..57f436485b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -284,13 +284,12 @@ case class Levenshtein(left: Expression, right: Expression) extends BinaryExpres
override def dataType: DataType = IntegerType
- protected override def nullSafeEval(input1: Any, input2: Any): Any =
- StringUtils.getLevenshteinDistance(input1.toString, input2.toString)
+ protected override def nullSafeEval(leftValue: Any, rightValue: Any): Any =
+ leftValue.asInstanceOf[UTF8String].levenshteinDistance(rightValue.asInstanceOf[UTF8String])
override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
- val stringUtils = classOf[StringUtils].getName
- defineCodeGen(ctx, ev, (left, right) =>
- s"$stringUtils.getLevenshteinDistance($left.toString(), $right.toString())")
+ nullSafeCodeGen(ctx, ev, (left, right) =>
+ s"${ev.primitive} = $left.levenshteinDistance($right);")
}
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index 1efbe1a245..69bef1c63e 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -282,5 +282,10 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
+ // scalastyle:off
+ // non ascii characters are not allowed in the code, so we disable the scalastyle here.
+ checkEvaluation(Levenshtein(Literal("千世"), Literal("fog")), 3)
+ checkEvaluation(Levenshtein(Literal("世界千世"), Literal("大a界b")), 4)
+ // scalastyle:on
}
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
index d2a25096a5..847d80ad58 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java
@@ -99,8 +99,6 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
/**
* Returns the number of code points in it.
- *
- * This is only used by Substring() when `start` is negative.
*/
public int numChars() {
int len = 0;
@@ -254,6 +252,70 @@ public final class UTF8String implements Comparable<UTF8String>, Serializable {
}
}
+ /**
+ * Levenshtein distance is a metric for measuring the distance of two strings. The distance is
+ * defined by the minimum number of single-character edits (i.e. insertions, deletions or
+ * substitutions) that are required to change one of the strings into the other.
+ */
+ public int levenshteinDistance(UTF8String other) {
+ // Implementation adopted from org.apache.common.lang3.StringUtils.getLevenshteinDistance
+
+ int n = numChars();
+ int m = other.numChars();
+
+ if (n == 0) {
+ return m;
+ } else if (m == 0) {
+ return n;
+ }
+
+ UTF8String s, t;
+
+ if (n <= m) {
+ s = this;
+ t = other;
+ } else {
+ s = other;
+ t = this;
+ int swap;
+ swap = n;
+ n = m;
+ m = swap;
+ }
+
+ int p[] = new int[n + 1];
+ int d[] = new int[n + 1];
+ int swap[];
+
+ int i, i_bytes, j, j_bytes, num_bytes_j, cost;
+
+ for (i = 0; i <= n; i++) {
+ p[i] = i;
+ }
+
+ for (j = 0, j_bytes = 0; j < m; j_bytes += num_bytes_j, j++) {
+ num_bytes_j = numBytesForFirstByte(t.getByte(j_bytes));
+ d[0] = j + 1;
+
+ for (i = 0, i_bytes = 0; i < n; i_bytes += numBytesForFirstByte(s.getByte(i_bytes)), i++) {
+ if (s.getByte(i_bytes) != t.getByte(j_bytes) ||
+ num_bytes_j != numBytesForFirstByte(s.getByte(i_bytes))) {
+ cost = 1;
+ } else {
+ cost = (ByteArrayMethods.arrayEquals(t.base, t.offset + j_bytes, s.base,
+ s.offset + i_bytes, num_bytes_j)) ? 0 : 1;
+ }
+ d[i + 1] = Math.min(Math.min(d[i] + 1, p[i + 1] + 1), p[i] + cost);
+ }
+
+ swap = p;
+ p = d;
+ d = swap;
+ }
+
+ return p[n];
+ }
+
@Override
public int hashCode() {
int result = 1;
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
index 8ec69ebac8..fb463ba17f 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java
@@ -128,4 +128,28 @@ public class UTF8StringSuite {
assertEquals(fromString("数据砖头").substring(3, 5), fromString("头"));
assertEquals(fromString("ߵ梷").substring(0, 2), fromString("ߵ梷"));
}
+
+ @Test
+ public void levenshteinDistance() {
+ assertEquals(
+ UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("")), 0);
+ assertEquals(
+ UTF8String.fromString("").levenshteinDistance(UTF8String.fromString("a")), 1);
+ assertEquals(
+ UTF8String.fromString("aaapppp").levenshteinDistance(UTF8String.fromString("")), 7);
+ assertEquals(
+ UTF8String.fromString("frog").levenshteinDistance(UTF8String.fromString("fog")), 1);
+ assertEquals(
+ UTF8String.fromString("fly").levenshteinDistance(UTF8String.fromString("ant")),3);
+ assertEquals(
+ UTF8String.fromString("elephant").levenshteinDistance(UTF8String.fromString("hippo")), 7);
+ assertEquals(
+ UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("elephant")), 7);
+ assertEquals(
+ UTF8String.fromString("hippo").levenshteinDistance(UTF8String.fromString("zzzzzzzz")), 8);
+ assertEquals(
+ UTF8String.fromString("hello").levenshteinDistance(UTF8String.fromString("hallo")),1);
+ assertEquals(
+ UTF8String.fromString("世界千世").levenshteinDistance(UTF8String.fromString("千a世b")),4);
+ }
}