aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorzhangjiajin <zhangjiajin@huawei.com>2015-07-30 08:14:09 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-30 08:14:09 -0700
commitd212a314227dec26c0dbec8ed3422d0ec8f818f9 (patch)
tree32775371b13cab56481318e6133bb6e136e63ad0 /mllib
parentc5815930be46a89469440b7c61b59764fb67a54c (diff)
downloadspark-d212a314227dec26c0dbec8ed3422d0ec8f818f9.tar.gz
spark-d212a314227dec26c0dbec8ed3422d0ec8f818f9.tar.bz2
spark-d212a314227dec26c0dbec8ed3422d0ec8f818f9.zip
[SPARK-8998] [MLLIB] Distribute PrefixSpan computation for large projected databases
Continuation of work by zhangjiajin Closes #7412 Author: zhangjiajin <zhangjiajin@huawei.com> Author: Feynman Liang <fliang@databricks.com> Author: zhang jiajin <zhangjiajin@huawei.com> Closes #7783 from feynmanliang/SPARK-8998-improve-distributed and squashes the following commits: a61943d [Feynman Liang] Collect small patterns to local 4ddf479 [Feynman Liang] Parallelize freqItemCounts ad23aa9 [zhang jiajin] Merge pull request #1 from feynmanliang/SPARK-8998-collectBeforeLocal 87fa021 [Feynman Liang] Improve extend prefix readability c2caa5c [Feynman Liang] Readability improvements and comments 1235cfc [Feynman Liang] Use Iterable[Array[_]] over Array[Array[_]] for database da0091b [Feynman Liang] Use lists for prefixes to reuse data cb2a4fc [Feynman Liang] Inline code for readability 01c9ae9 [Feynman Liang] Add getters 6e149fa [Feynman Liang] Fix splitPrefixSuffixPairs 64271b3 [zhangjiajin] Modified codes according to comments. d2250b7 [zhangjiajin] remove minPatternsBeforeLocalProcessing, add maxSuffixesBeforeLocalProcessing. b07e20c [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark into CollectEnoughPrefixes 095aa3a [zhangjiajin] Modified the code according to the review comments. baa2885 [zhangjiajin] Modified the code according to the review comments. 6560c69 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixeSpan a8fde87 [zhangjiajin] Merge branch 'master' of https://github.com/apache/spark 4dd1c8a [zhangjiajin] initialize file before rebase. 078d410 [zhangjiajin] fix a scala style error. 22b0ef4 [zhangjiajin] Add feature: Collect enough frequent prefixes before projection in PrefixSpan. ca9c4c8 [zhangjiajin] Modified the code according to the review comments. 574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization. ba5df34 [zhangjiajin] Fix a Scala style error. 4c60fb3 [zhangjiajin] Fix some Scala style errors. 1dd33ad [zhangjiajin] Modified the code according to the review comments. 89bc368 [zhangjiajin] Fixed a Scala style error. a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala 951fd42 [zhang jiajin] Delete Prefixspan.scala 575995f [zhangjiajin] Modified the code according to the review comments. 91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala203
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala21
3 files changed, 161 insertions, 69 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index 7ead632748..0ea7920810 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -40,7 +40,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
minCount: Long,
maxPatternLength: Int,
prefixes: List[Int],
- database: Array[Array[Int]]): Iterator[(List[Int], Long)] = {
+ database: Iterable[Array[Int]]): Iterator[(List[Int], Long)] = {
if (prefixes.length == maxPatternLength || database.isEmpty) return Iterator.empty
val frequentItemAndCounts = getFreqItemAndCounts(minCount, database)
val filteredDatabase = database.map(x => x.filter(frequentItemAndCounts.contains))
@@ -67,7 +67,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
}
}
- def project(database: Array[Array[Int]], prefix: Int): Array[Array[Int]] = {
+ def project(database: Iterable[Array[Int]], prefix: Int): Iterable[Array[Int]] = {
database
.map(getSuffix(prefix, _))
.filter(_.nonEmpty)
@@ -81,7 +81,7 @@ private[fpm] object LocalPrefixSpan extends Logging with Serializable {
*/
private def getFreqItemAndCounts(
minCount: Long,
- database: Array[Array[Int]]): mutable.Map[Int, Long] = {
+ database: Iterable[Array[Int]]): mutable.Map[Int, Long] = {
// TODO: use PrimitiveKeyOpenHashMap
val counts = mutable.Map[Int, Long]().withDefaultValue(0L)
database.foreach { sequence =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 6f52db7b07..e6752332cd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -17,6 +17,8 @@
package org.apache.spark.mllib.fpm
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.Logging
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
@@ -44,27 +46,44 @@ class PrefixSpan private (
private var maxPatternLength: Int) extends Logging with Serializable {
/**
+ * The maximum number of items allowed in a projected database before local processing. If a
+ * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
+ */
+ // TODO: make configurable with a better default value, 10000 may be too small
+ private val maxLocalProjDBSize: Long = 10000
+
+ /**
* Constructs a default instance with default parameters
* {minSupport: `0.1`, maxPatternLength: `10`}.
*/
def this() = this(0.1, 10)
/**
+ * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
+ * frequent).
+ */
+ def getMinSupport: Double = this.minSupport
+
+ /**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
- require(minSupport >= 0 && minSupport <= 1,
- "The minimum support value must be between 0 and 1, including 0 and 1.")
+ require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
this.minSupport = minSupport
this
}
/**
+ * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
+ */
+ def getMaxPatternLength: Double = this.maxPatternLength
+
+ /**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
- require(maxPatternLength >= 1,
- "The maximum pattern length value must be greater than 0.")
+ // TODO: support unbounded pattern length when maxPatternLength = 0
+ require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
this.maxPatternLength = maxPatternLength
this
}
@@ -78,81 +97,153 @@ class PrefixSpan private (
* the value of pair is the pattern's count.
*/
def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
+ val sc = sequences.sparkContext
+
if (sequences.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
- val minCount = getMinCount(sequences)
- val lengthOnePatternsAndCounts =
- getFreqItemAndCounts(minCount, sequences).collect()
- val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
- lengthOnePatternsAndCounts.map(_._1), sequences)
- val groupedProjectedDatabase = prefixAndProjectedDatabase
- .map(x => (x._1.toSeq, x._2))
- .groupByKey()
- .map(x => (x._1.toArray, x._2.toArray))
- val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
- val lengthOnePatternsAndCountsRdd =
- sequences.sparkContext.parallelize(
- lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
- val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
- allPatterns
+
+ // Convert min support to a min number of transactions for this dataset
+ val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
+
+ // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
+ val freqItemCounts = sequences
+ .flatMap(seq => seq.distinct.map(item => (item, 1L)))
+ .reduceByKey(_ + _)
+ .filter(_._2 >= minCount)
+ .collect()
+
+ // Pairs of (length 1 prefix, suffix consisting of frequent items)
+ val itemSuffixPairs = {
+ val freqItems = freqItemCounts.map(_._1).toSet
+ sequences.flatMap { seq =>
+ val filteredSeq = seq.filter(freqItems.contains(_))
+ freqItems.flatMap { item =>
+ val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
+ candidateSuffix match {
+ case suffix if !suffix.isEmpty => Some((List(item), suffix))
+ case _ => None
+ }
+ }
+ }
+ }
+
+ // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
+ // frequent length-one prefixes)
+ var resultsAccumulator = freqItemCounts.map(x => (List(x._1), x._2))
+
+ // Remaining work to be locally and distributively processed respectfully
+ var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
+
+ // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
+ // projected database sizes <= `maxLocalProjDBSize`)
+ while (pairsForDistributed.count() != 0) {
+ val (nextPatternAndCounts, nextPrefixSuffixPairs) =
+ extendPrefixes(minCount, pairsForDistributed)
+ pairsForDistributed.unpersist()
+ val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
+ pairsForDistributed = largerPairsPart
+ pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
+ pairsForLocal ++= smallerPairsPart
+ resultsAccumulator ++= nextPatternAndCounts.collect()
+ }
+
+ // Process the small projected databases locally
+ val remainingResults = getPatternsInLocal(
+ minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
+
+ (sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
+ .map { case (pattern, count) => (pattern.toArray, count) }
}
+
/**
- * Get the minimum count (sequences count * minSupport).
- * @param sequences input data set, contains a set of sequences,
- * @return minimum count,
+ * Partitions the prefix-suffix pairs by projected database size.
+ * @param prefixSuffixPairs prefix (length n) and suffix pairs,
+ * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
+ * greater than [[maxLocalProjDBSize]]
*/
- private def getMinCount(sequences: RDD[Array[Int]]): Long = {
- if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
+ private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Int], Array[Int])])
+ : (Array[(List[Int], Array[Int])], RDD[(List[Int], Array[Int])]) = {
+ val prefixToSuffixSize = prefixSuffixPairs
+ .aggregateByKey(0)(
+ seqOp = { case (count, suffix) => count + suffix.length },
+ combOp = { _ + _ })
+ val smallPrefixes = prefixToSuffixSize
+ .filter(_._2 <= maxLocalProjDBSize)
+ .keys
+ .collect()
+ .toSet
+ val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
+ val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
+ (small.collect(), large)
}
/**
- * Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the absolute minimum count
- * @param sequences original sequences data
- * @return array of item and count pair
+ * Extends all prefixes by one item from their suffix and computes the resulting frequent prefixes
+ * and remaining work.
+ * @param minCount minimum count
+ * @param prefixSuffixPairs prefix (length N) and suffix pairs,
+ * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
+ * prefix, corresponding suffix) pairs.
*/
- private def getFreqItemAndCounts(
+ private def extendPrefixes(
minCount: Long,
- sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
- sequences.flatMap(_.distinct.map((_, 1L)))
+ prefixSuffixPairs: RDD[(List[Int], Array[Int])])
+ : (RDD[(List[Int], Long)], RDD[(List[Int], Array[Int])]) = {
+
+ // (length N prefix, item from suffix) pairs and their corresponding number of occurrences
+ // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
+ val prefixItemPairAndCounts = prefixSuffixPairs
+ .flatMap { case (prefix, suffix) => suffix.distinct.map(y => ((prefix, y), 1L)) }
.reduceByKey(_ + _)
.filter(_._2 >= minCount)
- }
- /**
- * Get the frequent prefixes' projected database.
- * @param frequentPrefixes frequent prefixes
- * @param sequences sequences data
- * @return prefixes and projected database
- */
- private def getPrefixAndProjectedDatabase(
- frequentPrefixes: Array[Int],
- sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
- val filteredSequences = sequences.map { p =>
- p.filter (frequentPrefixes.contains(_) )
- }
- filteredSequences.flatMap { x =>
- frequentPrefixes.map { y =>
- val sub = LocalPrefixSpan.getSuffix(y, x)
- (Array(y), sub)
- }.filter(_._2.nonEmpty)
- }
+ // Map from prefix to set of possible next items from suffix
+ val prefixToNextItems = prefixItemPairAndCounts
+ .keys
+ .groupByKey()
+ .mapValues(_.toSet)
+ .collect()
+ .toMap
+
+
+ // Frequent patterns with length N+1 and their corresponding counts
+ val extendedPrefixAndCounts = prefixItemPairAndCounts
+ .map { case ((prefix, item), count) => (item :: prefix, count) }
+
+ // Remaining work, all prefixes will have length N+1
+ val extendedPrefixAndSuffix = prefixSuffixPairs
+ .filter(x => prefixToNextItems.contains(x._1))
+ .flatMap { case (prefix, suffix) =>
+ val frequentNextItems = prefixToNextItems(prefix)
+ val filteredSuffix = suffix.filter(frequentNextItems.contains(_))
+ frequentNextItems.flatMap { item =>
+ LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
+ case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
+ case _ => None
+ }
+ }
+ }
+
+ (extendedPrefixAndCounts, extendedPrefixAndSuffix)
}
/**
- * calculate the patterns in local.
+ * Calculate the patterns in local.
* @param minCount the absolute minimum count
- * @param data patterns and projected sequences data data
+ * @param data prefixes and projected sequences data data
* @return patterns
*/
private def getPatternsInLocal(
minCount: Long,
- data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
- data.flatMap { case (prefix, projDB) =>
- LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList, projDB)
- .map { case (pattern: List[Int], count: Long) => (pattern.toArray.reverse, count) }
+ data: RDD[(List[Int], Iterable[Array[Int]])]): RDD[(List[Int], Long)] = {
+ data.flatMap {
+ case (prefix, projDB) =>
+ LocalPrefixSpan.run(minCount, maxPatternLength, prefix.toList.reverse, projDB)
+ .map { case (pattern: List[Int], count: Long) =>
+ (pattern.reverse, count)
+ }
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 9f107c89f6..6dd2dc926a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -44,13 +44,6 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
val rdd = sc.parallelize(sequences, 2).cache()
- def compareResult(
- expectedValue: Array[(Array[Int], Long)],
- actualValue: Array[(Array[Int], Long)]): Boolean = {
- expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
- actualValue.map(x => (x._1.toSeq, x._2)).toSet
- }
-
val prefixspan = new PrefixSpan()
.setMinSupport(0.33)
.setMaxPatternLength(50)
@@ -76,7 +69,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
- assert(compareResult(expectedValue1, result1.collect()))
+ assert(compareResults(expectedValue1, result1.collect()))
prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
val result2 = prefixspan.run(rdd)
@@ -87,7 +80,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4), 4L),
(Array(5), 3L)
)
- assert(compareResult(expectedValue2, result2.collect()))
+ assert(compareResults(expectedValue2, result2.collect()))
prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
val result3 = prefixspan.run(rdd)
@@ -107,6 +100,14 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(4, 5), 2L),
(Array(5), 3L)
)
- assert(compareResult(expectedValue3, result3.collect()))
+ assert(compareResults(expectedValue3, result3.collect()))
+ }
+
+ private def compareResults(
+ expectedValue: Array[(Array[Int], Long)],
+ actualValue: Array[(Array[Int], Long)]): Boolean = {
+ expectedValue.map(x => (x._1.toSeq, x._2)).toSet ==
+ actualValue.map(x => (x._1.toSeq, x._2)).toSet
}
+
}