aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/functions.py14
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala32
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/functions.scala23
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala6
6 files changed, 81 insertions, 4 deletions
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 69e563ef36..49dd0332af 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -325,6 +325,20 @@ def explode(col):
@ignore_unicode_prefix
@since(1.5)
+def levenshtein(left, right):
+ """Computes the Levenshtein distance of the two given strings.
+
+ >>> df0 = sqlContext.createDataFrame([('kitten', 'sitting',)], ['l', 'r'])
+ >>> df0.select(levenshtein('l', 'r').alias('d')).collect()
+ [Row(d=3)]
+ """
+ sc = SparkContext._active_spark_context
+ jc = sc._jvm.functions.levenshtein(_to_java_column(left), _to_java_column(right))
+ return Column(jc)
+
+
+@ignore_unicode_prefix
+@since(1.5)
def md5(col):
"""Calculates the MD5 digest and returns the value as a 32 character hex string.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
index e249b58927..92a50e7092 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala
@@ -163,6 +163,7 @@ object FunctionRegistry {
expression[Lower]("lcase"),
expression[Lower]("lower"),
expression[StringLength]("length"),
+ expression[Levenshtein]("levenshtein"),
expression[Substring]("substr"),
expression[Substring]("substring"),
expression[UnBase64]("unbase64"),
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
index 154ac3508c..6de40629ff 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringOperations.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.expressions
import java.util.regex.Pattern
+import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.catalyst.analysis.UnresolvedException
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.types._
@@ -300,6 +301,37 @@ case class StringLength(child: Expression) extends UnaryExpression with ExpectsI
}
/**
+ * A function that return the Levenshtein distance between the two given strings.
+ */
+case class Levenshtein(left: Expression, right: Expression) extends BinaryExpression
+ with ExpectsInputTypes {
+
+ override def inputTypes: Seq[AbstractDataType] = Seq(StringType, StringType)
+
+ override def dataType: DataType = IntegerType
+
+ override def eval(input: InternalRow): Any = {
+ val leftValue = left.eval(input)
+ if (leftValue == null) {
+ null
+ } else {
+ val rightValue = right.eval(input)
+ if(rightValue == null) {
+ null
+ } else {
+ StringUtils.getLevenshteinDistance(leftValue.toString, rightValue.toString)
+ }
+ }
+ }
+
+ override def genCode(ctx: CodeGenContext, ev: GeneratedExpressionCode): String = {
+ val stringUtils = classOf[StringUtils].getName
+ nullSafeCodeGen(ctx, ev, (res, left, right) =>
+ s"$res = $stringUtils.getLevenshteinDistance($left.toString(), $right.toString());")
+ }
+}
+
+/**
* Returns the numeric value of the first character of str.
*/
case class Ascii(child: Expression) extends UnaryExpression with ExpectsInputTypes {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
index 468df20442..1efbe1a245 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringFunctionsSuite.scala
@@ -274,4 +274,13 @@ class StringFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(Decode(Literal.create(null, BinaryType), Literal("utf-8")), null)
checkEvaluation(Decode(b, Literal.create(null, StringType)), null, create_row(null))
}
+
+ test("Levenshtein distance") {
+ checkEvaluation(Levenshtein(Literal.create(null, StringType), Literal("")), null)
+ checkEvaluation(Levenshtein(Literal(""), Literal.create(null, StringType)), null)
+ checkEvaluation(Levenshtein(Literal(""), Literal("")), 0)
+ checkEvaluation(Levenshtein(Literal("abc"), Literal("abc")), 0)
+ checkEvaluation(Levenshtein(Literal("kitten"), Literal("sitting")), 3)
+ checkEvaluation(Levenshtein(Literal("frog"), Literal("fog")), 1)
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
index b63c6ee8ab..e4109da08e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala
@@ -1580,22 +1580,37 @@ object functions {
//////////////////////////////////////////////////////////////////////////////////////////////
/**
- * Computes the length of a given string value
- *
+ * Computes the length of a given string value.
+ *
* @group string_funcs
* @since 1.5.0
*/
def strlen(e: Column): Column = StringLength(e.expr)
/**
- * Computes the length of a given string column
- *
+ * Computes the length of a given string column.
+ *
* @group string_funcs
* @since 1.5.0
*/
def strlen(columnName: String): Column = strlen(Column(columnName))
/**
+ * Computes the Levenshtein distance of the two given strings.
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def levenshtein(l: Column, r: Column): Column = Levenshtein(l.expr, r.expr)
+
+ /**
+ * Computes the Levenshtein distance of the two given strings.
+ * @group string_funcs
+ * @since 1.5.0
+ */
+ def levenshtein(leftColumnName: String, rightColumnName: String): Column =
+ levenshtein(Column(leftColumnName), Column(rightColumnName))
+
+ /**
* Computes the numeric value of the first character of the specified string value.
*
* @group string_funcs
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index bd9fa400e5..bc455a922d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -226,6 +226,12 @@ class DataFrameFunctionsSuite extends QueryTest {
})
}
+ test("Levenshtein distance") {
+ val df = Seq(("kitten", "sitting"), ("frog", "fog")).toDF("l", "r")
+ checkAnswer(df.select(levenshtein("l", "r")), Seq(Row(3), Row(1)))
+ checkAnswer(df.selectExpr("levenshtein(l, r)"), Seq(Row(3), Row(1)))
+ }
+
test("string ascii function") {
val df = Seq(("abc", "")).toDF("a", "b")
checkAnswer(