aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-07-14 23:50:57 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-14 23:50:57 -0700
commit1bb8accbc95a0f0856a8bb715f1e94c3ff96a8c7 (patch)
tree9fcababb455b12e3f02773f6b723dfbe2a73ebcf /mllib
parentf0e129740dc2442a21dfa7fbd97360df87291095 (diff)
downloadspark-1bb8accbc95a0f0856a8bb715f1e94c3ff96a8c7.tar.gz
spark-1bb8accbc95a0f0856a8bb715f1e94c3ff96a8c7.tar.bz2
spark-1bb8accbc95a0f0856a8bb715f1e94c3ff96a8c7.zip
[SPARK-8997] [MLLIB] Performance improvements in LocalPrefixSpan
Improves the performance of LocalPrefixSpan by implementing optimizations proposed in [SPARK-8997](https://issues.apache.org/jira/browse/SPARK-8997) Author: Feynman Liang <fliang@databricks.com> Author: Feynman Liang <feynman.liang@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #7360 from feynmanliang/SPARK-8997-improve-prefixspan and squashes the following commits: 59db2f5 [Feynman Liang] Merge pull request #1 from mengxr/SPARK-8997 91e4357 [Xiangrui Meng] update LocalPrefixSpan impl 9212256 [Feynman Liang] MengXR code review comments f055d82 [Feynman Liang] Fix failing scalatest 2e00cba [Feynman Liang] Depth first projections 70b93e3 [Feynman Liang] Performance improvements in LocalPrefixSpan, fix tests
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala95
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala14
3 files changed, 44 insertions, 70 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 39c48b084e..7ead632748 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -17,58 +17,49 @@
package org.apache.spark.mllib.fpm
+import scala.collection.mutable
+
import org.apache.spark.Logging
-import org.apache.spark.annotation.Experimental
/**
- *
- * :: Experimental ::
- *
* Calculate all patterns of a projected database in local.
*/
-@Experimental
private[fpm] object LocalPrefixSpan extends Logging with Serializable {
/**
* Calculate all patterns of a projected database.
* @param minCount minimum count
* @param maxPatternLength maximum pattern length
- * @param prefix prefix
- * @param projectedDatabase the projected dabase
+ * @param prefixes prefixes in reversed order
+ * @param database the projected database
* @return a set of sequential pattern pairs,
- * the key of pair is sequential pattern (a list of items),
+ * the key of pair is sequential pattern (a list of items in reversed order),
* the value of pair is the pattern's count.
*/
def run(
minCount: Long,
maxPatternLength: Int,
- prefix: Array[Int],
- projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
- val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
- val frequentPatternAndCounts = frequentPrefixAndCounts
- .map(x => (prefix ++ Array(x._1), x._2))
- val prefixProjectedDatabases = getPatternAndProjectedDatabase(
- prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
-
- val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
- if (continueProcess) {
- val nextPatterns = prefixProjectedDatabases
- .map(x => run(minCount, maxPatternLength, x._1, x._2))
- .reduce(_ ++ _)
- frequentPatternAndCounts ++ nextPatterns
- } else {
- frequentPatternAndCounts
+ prefixes: List[Int],
+ database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
+ if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
+ val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
+ val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
+ frequentItemAndCounts.iterator.flatMap { case (item, count) =>
+ val newPrefixes = item :: prefixes
+ val newProjected = project(filteredDatabase, item)
+ Iterator.single((newPrefixes, count)) ++
+ run(minCount, maxPatternLength, newPrefixes, newProjected)
}
}
/**
- * calculate suffix sequence following a prefix in a sequence
- * @param prefix prefix
- * @param sequence sequence
+ * Calculate suffix sequence immediately after the first occurrence of an item.
+ * @param item item to get suffix after
+ * @param sequence sequence to extract suffix from
* @return suffix sequence
*/
- def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
- val index = sequence.indexOf(prefix)
+ def getSuffix(item: Int, sequence: Array[Int]): Array[Int] = {
+ val index = sequence.indexOf(item)
if (index == -1) {
Array()
} else {
@@ -76,38 +67,28 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}
+ def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
+ database
+ .map(getSuffix(prefix, _))
+ .filter(_.nonEmpty)
+ }
+
/**
* Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the absolute minimum count
- * @param sequences sequences data
- * @return array of item and count pair
+ * @param minCount the minimum count for an item to be frequent
+ * @param database database of sequences
+ * @return freq item to count map
*/
private def getFreqItemAndCounts(
minCount: Long,
- sequences: Array[Array[Int]]): Array[(Int, Long)] = {
- sequences.flatMap(_.distinct)
- .groupBy(x => x)
- .mapValues(_.length.toLong)
- .filter(_._2 >= minCount)
- .toArray
- }
-
- /**
- * Get the frequent prefixes' projected database.
- * @param prePrefix the frequent prefixes' prefix
- * @param frequentPrefixes frequent prefixes
- * @param sequences sequences data
- * @return prefixes and projected database
- */
- private def getPatternAndProjectedDatabase(
- prePrefix: Array[Int],
- frequentPrefixes: Array[Int],
- sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
- val filteredProjectedDatabase = sequences
- .map(x => x.filter(frequentPrefixes.contains(_)))
- frequentPrefixes.map { x =>
- val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
- (prePrefix ++ Array(x), sub)
- }.filter(x => x._2.nonEmpty)
+ database: Array[Array[Int]]): mutable.Map[Int, Long] = {
+ // TODO: use PrimitiveKeyOpenHashMap
+ val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
+ database.foreach { sequence =>
+ sequence.distinct.foreach { item =>
+ counts(item) += 1L
+ }
+ }
+ counts.filter(_._2 >= minCount)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 9d8c60ef0f..6f52db7b07 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -150,8 +150,9 @@ class PrefixSpan private (
private def getPatternsInLocal(
minCount: Long,
data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
- data.flatMap { x =>
- LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
+ data.flatMap { case (prefix, projDB) =>
+ LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
+ .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 413436d3db..9f107c89f6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,9 +18,8 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.rdd.RDD
-class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
+class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
test("PrefixSpan using Integer type") {
@@ -48,15 +47,8 @@ class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
def compareResult(
expectedValue: Array[(Array[Int], Long)],
actualValue: Array[(Array[Int], Long)]): Boolean = {
- val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
- x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
- }
- val sortedActualValue = actualValue.sortWith{ (x, y) =>
- x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
- }
- sortedExpectedValue.zip(sortedActualValue)
- .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
- .reduce(_&&_)
+ expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
+ actualValue.map(x => (x._1.toSeq, x._2)).toSet
}
val prefixspan = new PrefixSpan()