diff options
Diffstat (limited to 'core')
-rwxr-xr-x | core/source/core/scala/org/hashtree/stringmetric/LevenshteinMetric.scala | 60 | ||||
-rwxr-xr-x | core/source/test/scala/org/hashtree/stringmetric/LevenshteinMetricSpec.scala | 30 |
2 files changed, 90 insertions, 0 deletions
diff --git a/core/source/core/scala/org/hashtree/stringmetric/LevenshteinMetric.scala b/core/source/core/scala/org/hashtree/stringmetric/LevenshteinMetric.scala new file mode 100755 index 0000000..9465c66 --- /dev/null +++ b/core/source/core/scala/org/hashtree/stringmetric/LevenshteinMetric.scala @@ -0,0 +1,60 @@ +package org.hashtree.stringmetric + +import scala.math + +/** An implementation of the Levenshtein [[org.hashtree.stringmetric.StringMetric]]. */ +object LevenshteinMetric extends StringMetric { + implicit val stringCleaner = new StringCleanerDelegate with CaseStringCleaner + + override def compare(charArray1: Array[Char], charArray2: Array[Char])(implicit stringCleaner: StringCleaner): Option[Int] = { + val ca1 = stringCleaner.clean(charArray1) + val ca2 = stringCleaner.clean(charArray2) + + if (ca1.length == 0 && ca2.length == 0) + None + else { + val levenshteinMemoize = Memoize.Y(levenshtein) + + Some(levenshteinMemoize(ca1, ca2)) + } + } + + override def compare(string1: String, string2: String)(implicit stringCleaner: StringCleaner): Option[Int] = { + compare( + stringCleaner.clean(string1.toCharArray), + stringCleaner.clean(string2.toCharArray) + )(new StringCleanerDelegate) + } + + private[this] def levenshtein(f: CompareTuple => Int)(ct: CompareTuple): Int = { + if (ct._1.length == 0) + ct._2.length + else if (ct._2.length == 0) + ct._1.length + else { + math.min( + math.min( + f(ct._1.tail, ct._2) + 1, + f(ct._1, ct._2.tail) + 1 + ), + f(ct._1.tail, ct._2.tail) + (if (ct._1.head != ct._2.head) 1 else 0) + ) + } + } + + private[this] final class Memoize[-T, +R](f: T => R) extends (T => R) { + private[this] val map = scala.collection.mutable.Map[T, R]() + + def apply(k: T): R = map.getOrElseUpdate(k, f(k)) + } + + private[this] object Memoize { + def apply[T, R](f: T => R) = new Memoize(f) + + def Y[T, R](f: (T => R) => T => R): (T => R) = { + lazy val yf: T => R = Memoize(f(yf)(_)) + + yf + } + } +}
\ No newline at end of file diff --git a/core/source/test/scala/org/hashtree/stringmetric/LevenshteinMetricSpec.scala b/core/source/test/scala/org/hashtree/stringmetric/LevenshteinMetricSpec.scala new file mode 100755 index 0000000..01c9082 --- /dev/null +++ b/core/source/test/scala/org/hashtree/stringmetric/LevenshteinMetricSpec.scala @@ -0,0 +1,30 @@ +package org.hashtree.stringmetric + +import org.hashtree.stringmetric.LevenshteinMetric.stringCleaner +import org.junit.runner.RunWith +import org.scalatest.junit.JUnitRunner + +@RunWith(classOf[JUnitRunner]) +final class LevenshteinMetricSpec extends ScalaTest { + "LevenshteinMetric" should provide { + "compare method" when passed { + "valid arguments" should returns { + "Int indicating distance" in { + LevenshteinMetric.compare("", "").isDefined should be (false) + + LevenshteinMetric.compare("abc", "").get should be (3) + LevenshteinMetric.compare("", "xyz").get should be (3) + LevenshteinMetric.compare("abc", "abc").get should be (0) + LevenshteinMetric.compare("abc", "xyz").get should be (3) + LevenshteinMetric.compare("abc", "a").get should be (2) + LevenshteinMetric.compare("a", "abc").get should be (2) + LevenshteinMetric.compare("abc", "c").get should be (2) + LevenshteinMetric.compare("c", "abc").get should be (2) + + LevenshteinMetric.compare("kitten", "sitting").get should be (3) + LevenshteinMetric.compare("drake", "cake").get should be (2) + } + } + } + } +}
\ No newline at end of file |