aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-08-04 22:28:49 -0700
committerXiangrui Meng <meng@databricks.com>2015-08-04 22:28:49 -0700
commita02bcf20c4fc9e2e182630d197221729e996afc2 (patch)
treeaddf5a311acafad2849dd32c7a8c47f88f1f702f /mllib/src
parentf7abd6bec9d51ed4ab6359e50eac853e64ecae86 (diff)
downloadspark-a02bcf20c4fc9e2e182630d197221729e996afc2.tar.gz
spark-a02bcf20c4fc9e2e182630d197221729e996afc2.tar.bz2
spark-a02bcf20c4fc9e2e182630d197221729e996afc2.zip
[SPARK-9540] [MLLIB] optimize PrefixSpan implementation
This is a major refactoring of the PrefixSpan implementation. It contains the following changes: 1. Expand prefix with one item at a time. The existing implementation generates all subsets for each itemset, which might have scalability issue when the itemset is large. 2. Use a new internal format. `<(12)(31)>` is represented by `[0, 1, 2, 0, 1, 3, 0]` internally. We use `0` because negative numbers are used to indicates partial prefix items, e.g., `_2` is represented by `-2`. 3. Remember the start indices of all partial projections in the projected postfix to help next projection. 4. Reuse the original sequence array for projected postfixes. 5. Use `Prefix` IDs in aggregation rather than its content. 6. Use `ArrayBuilder` for building primitive arrays. 7. Expose `maxLocalProjDBSize`. 8. Tests are not changed except using `0` instead of `-1` as the delimiter. `Postfix`'s API doc should be a good place to start. Closes #7594 feynmanliang zhangjiajin Author: Xiangrui Meng <meng@databricks.com> Closes #7937 from mengxr/SPARK-9540 and squashes the following commits: 2d0ec31 [Xiangrui Meng] address more comments 48f450c [Xiangrui Meng] address comments from Feynman; fixed a bug in project and added a test 65f90e8 [Xiangrui Meng] naming and documentation 8afc86a [Xiangrui Meng] refactor impl
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala132
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala587
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala271
3 files changed, 599 insertions, 391 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
index ccebf951c8..3ea10779a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -22,85 +22,89 @@ import scala.collection.mutable
import org.apache.spark.Logging
/**
- * Calculate all patterns of a projected database in local.
+ * Calculate all patterns of a projected database in local mode.
+ *
+ * @param minCount minimal count for a frequent pattern
+ * @param maxPatternLength max pattern length for a frequent pattern
*/
-private[fpm] object LocalPrefixSpan extends Logging with Serializable {
- import PrefixSpan._
+private[fpm] class LocalPrefixSpan(
+ val minCount: Long,
+ val maxPatternLength: Int) extends Logging with Serializable {
+ import PrefixSpan.Postfix
+ import LocalPrefixSpan.ReversedPrefix
+
/**
- * Calculate all patterns of a projected database.
- * @param minCount minimum count
- * @param maxPatternLength maximum pattern length
- * @param prefixes prefixes in reversed order
- * @param database the projected database
- * @return a set of sequential pattern pairs,
- * the key of pair is sequential pattern (a list of items in reversed order),
- * the value of pair is the pattern's count.
+ * Generates frequent patterns on the input array of postfixes.
+ * @param postfixes an array of postfixes
+ * @return an iterator of (frequent pattern, count)
*/
- def run(
- minCount: Long,
- maxPatternLength: Int,
- prefixes: List[Set[Int]],
- database: Iterable[List[Set[Int]]]): Iterator[(List[Set[Int]], Long)] = {
- if (prefixes.length == maxPatternLength || database.isEmpty) {
- return Iterator.empty
- }
- val freqItemSetsAndCounts = getFreqItemAndCounts(minCount, database)
- val freqItems = freqItemSetsAndCounts.keys.flatten.toSet
- val filteredDatabase = database.map { suffix =>
- suffix
- .map(item => freqItems.intersect(item))
- .filter(_.nonEmpty)
- }
- freqItemSetsAndCounts.iterator.flatMap { case (item, count) =>
- val newPrefixes = item :: prefixes
- val newProjected = project(filteredDatabase, item)
- Iterator.single((newPrefixes, count)) ++
- run(minCount, maxPatternLength, newPrefixes, newProjected)
+ def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = {
+ genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix, count) =>
+ (prefix.toSequence, count)
}
}
/**
- * Calculate suffix sequence immediately after the first occurrence of an item.
- * @param item itemset to get suffix after
- * @param sequence sequence to extract suffix from
- * @return suffix sequence
+ * Recursively generates frequent patterns.
+ * @param prefix current prefix
+ * @param postfixes projected postfixes w.r.t. the prefix
+ * @return an iterator of (prefix, count)
*/
- def getSuffix(item: Set[Int], sequence: List[Set[Int]]): List[Set[Int]] = {
- val itemsetSeq = sequence
- val index = itemsetSeq.indexWhere(item.subsetOf(_))
- if (index == -1) {
- List()
- } else {
- itemsetSeq.drop(index + 1)
+ private def genFreqPatterns(
+ prefix: ReversedPrefix,
+ postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = {
+ if (maxPatternLength == prefix.length || postfixes.length < minCount) {
+ return Iterator.empty
+ }
+ // find frequent items
+ val counts = mutable.Map.empty[Int, Long].withDefaultValue(0)
+ postfixes.foreach { postfix =>
+ postfix.genPrefixItems.foreach { case (x, _) =>
+ counts(x) += 1L
+ }
+ }
+ val freqItems = counts.toSeq.filter { case (_, count) =>
+ count >= minCount
+ }.sorted
+ // project and recursively call genFreqPatterns
+ freqItems.toIterator.flatMap { case (item, count) =>
+ val newPrefix = prefix :+ item
+ Iterator.single((newPrefix, count)) ++ {
+ val projected = postfixes.map(_.project(item)).filter(_.nonEmpty)
+ genFreqPatterns(newPrefix, projected)
+ }
}
}
+}
- def project(
- database: Iterable[List[Set[Int]]],
- prefix: Set[Int]): Iterable[List[Set[Int]]] = {
- database
- .map(getSuffix(prefix, _))
- .filter(_.nonEmpty)
- }
+private object LocalPrefixSpan {
/**
- * Generates frequent items by filtering the input data using minimal count level.
- * @param minCount the minimum count for an item to be frequent
- * @param database database of sequences
- * @return freq item to count map
+ * Represents a prefix stored as a list in reversed order.
+ * @param items items in the prefix in reversed order
+ * @param length length of the prefix, not counting delimiters
*/
- private def getFreqItemAndCounts(
- minCount: Long,
- database: Iterable[List[Set[Int]]]): Map[Set[Int], Long] = {
- // TODO: use PrimitiveKeyOpenHashMap
- val counts = mutable.Map[Set[Int], Long]().withDefaultValue(0L)
- database.foreach { sequence =>
- sequence.flatMap(nonemptySubsets(_)).distinct.foreach { item =>
- counts(item) += 1L
+ class ReversedPrefix private (val items: List[Int], val length: Int) extends Serializable {
+ /**
+ * Expands the prefix by one item.
+ */
+ def :+(item: Int): ReversedPrefix = {
+ require(item != 0)
+ if (item < 0) {
+ new ReversedPrefix(-item :: items, length + 1)
+ } else {
+ new ReversedPrefix(item :: 0 :: items, length + 1)
}
}
- counts
- .filter { case (_, count) => count >= minCount }
- .toMap
+
+ /**
+ * Converts this prefix to a sequence.
+ */
+ def toSequence: Array[Int] = (0 :: items).toArray.reverse
+ }
+
+ object ReversedPrefix {
+ /** An empty prefix. */
+ val empty: ReversedPrefix = new ReversedPrefix(List.empty, 0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 9eaf733fad..d5f0c926c6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -18,9 +18,10 @@
package org.apache.spark.mllib.fpm
import java.{lang => jl, util => ju}
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.mutable
import scala.collection.JavaConverters._
-import scala.collection.mutable.ArrayBuilder
import scala.reflect.ClassTag
import org.apache.spark.Logging
@@ -31,17 +32,20 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
/**
- *
* :: Experimental ::
*
- * A parallel PrefixSpan algorithm to mine sequential pattern.
- * The PrefixSpan algorithm is described in
- * [[http://doi.org/10.1109/ICDE.2001.914830]].
+ * A parallel PrefixSpan algorithm to mine frequent sequential patterns.
+ * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns
+ * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]).
*
* @param minSupport the minimal support level of the sequential pattern, any pattern appears
* more than (minSupport * size-of-the-dataset) times will be output
* @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
- * less than maxPatternLength will be output
+ * less than maxPatternLength will be output
+ * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal
+ * storage format) allowed in a projected database before local
+ * processing. If a projected database exceeds this size, another
+ * iteration of distributed prefix growth is run.
*
* @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
* (Wikipedia)]]
@@ -49,33 +53,28 @@ import org.apache.spark.storage.StorageLevel
@Experimental
class PrefixSpan private (
private var minSupport: Double,
- private var maxPatternLength: Int) extends Logging with Serializable {
+ private var maxPatternLength: Int,
+ private var maxLocalProjDBSize: Long) extends Logging with Serializable {
import PrefixSpan._
/**
- * The maximum number of items allowed in a projected database before local processing. If a
- * projected database exceeds this size, another iteration of distributed PrefixSpan is run.
- */
- // TODO: make configurable with a better default value
- private val maxLocalProjDBSize: Long = 32000000L
-
- /**
* Constructs a default instance with default parameters
- * {minSupport: `0.1`, maxPatternLength: `10`}.
+ * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}.
*/
- def this() = this(0.1, 10)
+ def this() = this(0.1, 10, 32000000L)
/**
* Get the minimal support (i.e. the frequency of occurrence before a pattern is considered
* frequent).
*/
- def getMinSupport: Double = this.minSupport
+ def getMinSupport: Double = minSupport
/**
* Sets the minimal support level (default: `0.1`).
*/
def setMinSupport(minSupport: Double): this.type = {
- require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].")
+ require(minSupport >= 0 && minSupport <= 1,
+ s"The minimum support value must be in [0, 1], but got $minSupport.")
this.minSupport = minSupport
this
}
@@ -83,45 +82,115 @@ class PrefixSpan private (
/**
* Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider.
*/
- def getMaxPatternLength: Double = this.maxPatternLength
+ def getMaxPatternLength: Double = maxPatternLength
/**
* Sets maximal pattern length (default: `10`).
*/
def setMaxPatternLength(maxPatternLength: Int): this.type = {
// TODO: support unbounded pattern length when maxPatternLength = 0
- require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.")
+ require(maxPatternLength >= 1,
+ s"The maximum pattern length value must be greater than 0, but got $maxPatternLength.")
this.maxPatternLength = maxPatternLength
this
}
/**
- * Find the complete set of sequential patterns in the input sequences of itemsets.
- * @param data ordered sequences of itemsets.
- * @return a [[PrefixSpanModel]] that contains the frequent sequences
+ * Gets the maximum number of items allowed in a projected database before local processing.
+ */
+ def getMaxLocalProjDBSize: Long = maxLocalProjDBSize
+
+ /**
+ * Sets the maximum number of items (including delimiters used in the internal storage format)
+ * allowed in a projected database before local processing (default: `32000000L`).
+ */
+ def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = {
+ require(maxLocalProjDBSize >= 0L,
+ s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize")
+ this.maxLocalProjDBSize = maxLocalProjDBSize
+ this
+ }
+
+ /**
+ * Finds the complete set of frequent sequential patterns in the input sequences of itemsets.
+ * @param data sequences of itemsets.
+ * @return a [[PrefixSpanModel]] that contains the frequent patterns
*/
def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = {
- val itemToInt = data.aggregate(Set[Item]())(
- seqOp = { (uniqItems, item) => uniqItems ++ item.flatten.toSet },
- combOp = { _ ++ _ }
- ).zipWithIndex.toMap
- val intToItem = Map() ++ (itemToInt.map { case (k, v) => (v, k) })
-
- val dataInternalRepr = data.map { seq =>
- seq.map(itemset => itemset.map(itemToInt)).reduce((a, b) => a ++ (DELIMITER +: b))
+ if (data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("Input data is not cached.")
}
- val results = run(dataInternalRepr)
- def toPublicRepr(pattern: Iterable[Int]): List[Array[Item]] = {
- pattern.span(_ != DELIMITER) match {
- case (x, xs) if xs.size > 1 => x.map(intToItem).toArray :: toPublicRepr(xs.tail)
- case (x, xs) => List(x.map(intToItem).toArray)
+ val totalCount = data.count()
+ logInfo(s"number of sequences: $totalCount")
+ val minCount = math.ceil(minSupport * totalCount).toLong
+ logInfo(s"minimum count for a frequent pattern: $minCount")
+
+ // Find frequent items.
+ val freqItemAndCounts = data.flatMap { itemsets =>
+ val uniqItems = mutable.Set.empty[Item]
+ itemsets.foreach { _.foreach { item =>
+ uniqItems += item
+ }}
+ uniqItems.toIterator.map((_, 1L))
+ }.reduceByKey(_ + _)
+ .filter { case (_, count) =>
+ count >= minCount
+ }.collect()
+ val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1)
+ logInfo(s"number of frequent items: ${freqItems.length}")
+
+ // Keep only frequent items from input sequences and convert them to internal storage.
+ val itemToInt = freqItems.zipWithIndex.toMap
+ val dataInternalRepr = data.flatMap { itemsets =>
+ val allItems = mutable.ArrayBuilder.make[Int]
+ var containsFreqItems = false
+ allItems += 0
+ itemsets.foreach { itemsets =>
+ val items = mutable.ArrayBuilder.make[Int]
+ itemsets.foreach { item =>
+ if (itemToInt.contains(item)) {
+ items += itemToInt(item) + 1 // using 1-indexing in internal format
+ }
+ }
+ val result = items.result()
+ if (result.nonEmpty) {
+ containsFreqItems = true
+ allItems ++= result.sorted
+ }
+ allItems += 0
+ }
+ if (containsFreqItems) {
+ Iterator.single(allItems.result())
+ } else {
+ Iterator.empty
+ }
+ }.persist(StorageLevel.MEMORY_AND_DISK)
+
+ val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize)
+
+ def toPublicRepr(pattern: Array[Int]): Array[Array[Item]] = {
+ val sequenceBuilder = mutable.ArrayBuilder.make[Array[Item]]
+ val itemsetBuilder = mutable.ArrayBuilder.make[Item]
+ val n = pattern.length
+ var i = 1
+ while (i < n) {
+ val x = pattern(i)
+ if (x == 0) {
+ sequenceBuilder += itemsetBuilder.result()
+ itemsetBuilder.clear()
+ } else {
+ itemsetBuilder += freqItems(x - 1) // using 1-indexing in internal format
+ }
+ i += 1
}
+ sequenceBuilder.result()
}
+
val freqSequences = results.map { case (seq: Array[Int], count: Long) =>
- new FreqSequence[Item](toPublicRepr(seq).toArray, count)
+ new FreqSequence(toPublicRepr(seq), count)
}
- new PrefixSpanModel[Item](freqSequences)
+ new PrefixSpanModel(freqSequences)
}
/**
@@ -131,7 +200,7 @@ class PrefixSpan private (
* @tparam Item item type
* @tparam Itemset itemset type, which is an Iterable of Items
* @tparam Sequence sequence type, which is an Iterable of Itemsets
- * @return a [[PrefixSpanModel]] that contains the frequent sequences
+ * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns
*/
def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]](
data: JavaRDD[Sequence]): PrefixSpanModel[Item] = {
@@ -139,200 +208,320 @@ class PrefixSpan private (
run(data.rdd.map(_.asScala.map(_.asScala.toArray).toArray))
}
+}
+
+@Experimental
+object PrefixSpan extends Logging {
+
/**
- * Find the complete set of sequential patterns in the input sequences. This method utilizes
- * the internal representation of itemsets as Array[Int] where each itemset is represented by
- * a contiguous sequence of non-negative integers and delimiters represented by [[DELIMITER]].
- * @param data ordered sequences of itemsets. Items are represented by non-negative integers.
- * Each itemset has one or more items and is delimited by [[DELIMITER]].
- * @return a set of sequential pattern pairs,
- * the key of pair is pattern (a list of elements),
- * the value of pair is the pattern's count.
+ * Find the complete set of frequent sequential patterns in the input sequences.
+ * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int],
+ * where each itemset is represented by a contiguous sequence of distinct and ordered
+ * positive integers. We use 0 as the delimiter at itemset boundaries, including the
+ * first and the last position.
+ * @return an RDD of (frequent sequential pattern, count) pairs,
+ * @see [[Postfix]]
*/
- private[fpm] def run(data: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
+ private[fpm] def genFreqPatterns(
+ data: RDD[Array[Int]],
+ minCount: Long,
+ maxPatternLength: Int,
+ maxLocalProjDBSize: Long): RDD[(Array[Int], Long)] = {
val sc = data.sparkContext
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
- // Use List[Set[Item]] for internal computation
- val sequences = data.map { seq => splitSequence(seq.toList) }
-
- // Convert min support to a min number of transactions for this dataset
- val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
-
- // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold
- val freqItemCounts = sequences
- .flatMap(seq => seq.flatMap(nonemptySubsets(_)).distinct.map(item => (item, 1L)))
- .reduceByKey(_ + _)
- .filter { case (item, count) => (count >= minCount) }
- .collect()
- .toMap
-
- // Pairs of (length 1 prefix, suffix consisting of frequent items)
- val itemSuffixPairs = {
- val freqItemSets = freqItemCounts.keys.toSet
- val freqItems = freqItemSets.flatten
- sequences.flatMap { seq =>
- val filteredSeq = seq.map(item => freqItems.intersect(item)).filter(_.nonEmpty)
- freqItemSets.flatMap { item =>
- val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq)
- candidateSuffix match {
- case suffix if !suffix.isEmpty => Some((List(item), suffix))
- case _ => None
+ val postfixes = data.map(items => new Postfix(items))
+
+ // Local frequent patterns (prefixes) and their counts.
+ val localFreqPatterns = mutable.ArrayBuffer.empty[(Array[Int], Long)]
+ // Prefixes whose projected databases are small.
+ val smallPrefixes = mutable.Map.empty[Int, Prefix]
+ val emptyPrefix = Prefix.empty
+ // Prefixes whose projected databases are large.
+ var largePrefixes = mutable.Map(emptyPrefix.id -> emptyPrefix)
+ while (largePrefixes.nonEmpty) {
+ val numLocalFreqPatterns = localFreqPatterns.length
+ logInfo(s"number of local frequent patterns: $numLocalFreqPatterns")
+ if (numLocalFreqPatterns > 1000000) {
+ logWarning(
+ s"""
+ | Collected $numLocalFreqPatterns local frequent patterns. You may want to consider:
+ | 1. increase minSupport,
+ | 2. decrease maxPatternLength,
+ | 3. increase maxLocalProjDBSize.
+ """.stripMargin)
+ }
+ logInfo(s"number of small prefixes: ${smallPrefixes.size}")
+ logInfo(s"number of large prefixes: ${largePrefixes.size}")
+ val largePrefixArray = largePrefixes.values.toArray
+ val freqPrefixes = postfixes.flatMap { postfix =>
+ largePrefixArray.flatMap { prefix =>
+ postfix.project(prefix).genPrefixItems.map { case (item, postfixSize) =>
+ ((prefix.id, item), (1L, postfixSize))
+ }
+ }
+ }.reduceByKey { case ((c0, s0), (c1, s1)) =>
+ (c0 + c1, s0 + s1)
+ }.filter { case (_, (c, _)) => c >= minCount }
+ .collect()
+ val newLargePrefixes = mutable.Map.empty[Int, Prefix]
+ freqPrefixes.foreach { case ((id, item), (count, projDBSize)) =>
+ val newPrefix = largePrefixes(id) :+ item
+ localFreqPatterns += ((newPrefix.items :+ 0, count))
+ if (newPrefix.length < maxPatternLength) {
+ if (projDBSize > maxLocalProjDBSize) {
+ newLargePrefixes += newPrefix.id -> newPrefix
+ } else {
+ smallPrefixes += newPrefix.id -> newPrefix
}
}
}
+ largePrefixes = newLargePrefixes
}
- // Accumulator for the computed results to be returned, initialized to the frequent items (i.e.
- // frequent length-one prefixes)
- var resultsAccumulator = freqItemCounts.map { case (item, count) => (List(item), count) }.toList
-
- // Remaining work to be locally and distributively processed respectfully
- var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs)
-
- // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have
- // projected database sizes <= `maxLocalProjDBSize`) or `maxPatternLength` is reached
- var patternLength = 1
- while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) {
- val (nextPatternAndCounts, nextPrefixSuffixPairs) =
- extendPrefixes(minCount, pairsForDistributed)
- pairsForDistributed.unpersist()
- val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs)
- pairsForDistributed = largerPairsPart
- pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK)
- pairsForLocal ++= smallerPairsPart
- resultsAccumulator ++= nextPatternAndCounts.collect()
- patternLength += 1 // pattern length grows one per iteration
+ // Switch to local processing.
+ val bcSmallPrefixes = sc.broadcast(smallPrefixes)
+ val distributedFreqPattern = postfixes.flatMap { postfix =>
+ bcSmallPrefixes.value.values.map { prefix =>
+ (prefix.id, postfix.project(prefix).compressed)
+ }.filter(_._2.nonEmpty)
+ }.groupByKey().flatMap { case (id, projPostfixes) =>
+ val prefix = bcSmallPrefixes.value(id)
+ val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length)
+ // TODO: We collect projected postfixes into memory. We should also compare the performance
+ // TODO: of keeping them on shuffle files.
+ localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) =>
+ (prefix.items ++ pattern, count)
+ }
}
- // Process the small projected databases locally
- val remainingResults = getPatternsInLocal(
- minCount, sc.parallelize(pairsForLocal, 1).groupByKey())
-
- (sc.parallelize(resultsAccumulator, 1) ++ remainingResults)
- .map { case (pattern, count) => (flattenSequence(pattern.reverse).toArray, count) }
+ // Union local frequent patterns and distributed ones.
+ val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ freqPatterns
}
-
/**
- * Partitions the prefix-suffix pairs by projected database size.
- * @param prefixSuffixPairs prefix (length n) and suffix pairs,
- * @return prefix-suffix pairs partitioned by whether their projected database size is <= or
- * greater than [[maxLocalProjDBSize]]
+ * Represents a prefix.
+ * @param items items in this prefix, using the internal format
+ * @param length length of this prefix, not counting 0
*/
- private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])])
- : (List[(List[Set[Int]], List[Set[Int]])], RDD[(List[Set[Int]], List[Set[Int]])]) = {
- val prefixToSuffixSize = prefixSuffixPairs
- .aggregateByKey(0)(
- seqOp = { case (count, suffix) => count + suffix.length },
- combOp = { _ + _ })
- val smallPrefixes = prefixToSuffixSize
- .filter(_._2 <= maxLocalProjDBSize)
- .keys
- .collect()
- .toSet
- val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) }
- val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) }
- (small.collect().toList, large)
+ private[fpm] class Prefix private (val items: Array[Int], val length: Int) extends Serializable {
+
+ /** A unique id for this prefix. */
+ val id: Int = Prefix.nextId
+
+ /** Expands this prefix by the input item. */
+ def :+(item: Int): Prefix = {
+ require(item != 0)
+ if (item < 0) {
+ new Prefix(items :+ -item, length + 1)
+ } else {
+ new Prefix(items ++ Array(0, item), length + 1)
+ }
+ }
}
- /**
- * Extends all prefixes by one itemset from their suffix and computes the resulting frequent
- * prefixes and remaining work.
- * @param minCount minimum count
- * @param prefixSuffixPairs prefix (length N) and suffix pairs,
- * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended
- * prefix, corresponding suffix) pairs.
- */
- private def extendPrefixes(
- minCount: Long,
- prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])])
- : (RDD[(List[Set[Int]], Long)], RDD[(List[Set[Int]], List[Set[Int]])]) = {
-
- // (length N prefix, itemset from suffix) pairs and their corresponding number of occurrences
- // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport`
- val prefixItemPairAndCounts = prefixSuffixPairs
- .flatMap { case (prefix, suffix) =>
- suffix.flatMap(nonemptySubsets(_)).distinct.map(y => ((prefix, y), 1L)) }
- .reduceByKey(_ + _)
- .filter { case (item, count) => (count >= minCount) }
-
- // Map from prefix to set of possible next items from suffix
- val prefixToNextItems = prefixItemPairAndCounts
- .keys
- .groupByKey()
- .mapValues(_.toSet)
- .collect()
- .toMap
-
- // Frequent patterns with length N+1 and their corresponding counts
- val extendedPrefixAndCounts = prefixItemPairAndCounts
- .map { case ((prefix, item), count) => (item :: prefix, count) }
-
- // Remaining work, all prefixes will have length N+1
- val extendedPrefixAndSuffix = prefixSuffixPairs
- .filter(x => prefixToNextItems.contains(x._1))
- .flatMap { case (prefix, suffix) =>
- val frequentNextItemSets = prefixToNextItems(prefix)
- val frequentNextItems = frequentNextItemSets.flatten
- val filteredSuffix = suffix
- .map(item => frequentNextItems.intersect(item))
- .filter(_.nonEmpty)
- frequentNextItemSets.flatMap { item =>
- LocalPrefixSpan.getSuffix(item, filteredSuffix) match {
- case suffix if !suffix.isEmpty => Some(item :: prefix, suffix)
- case _ => None
- }
- }
- }
+ private[fpm] object Prefix {
+ /** Internal counter to generate unique IDs. */
+ private val counter: AtomicInteger = new AtomicInteger(-1)
- (extendedPrefixAndCounts, extendedPrefixAndSuffix)
+ /** Gets the next unique ID. */
+ private def nextId: Int = counter.incrementAndGet()
+
+ /** An empty [[Prefix]] instance. */
+ val empty: Prefix = new Prefix(Array.empty, 0)
}
/**
- * Calculate the patterns in local.
- * @param minCount the absolute minimum count
- * @param data prefixes and projected sequences data data
- * @return patterns
+ * An internal representation of a postfix from some projection.
+ * We use one int array to store the items, which might also contains other items from the
+ * original sequence.
+ * Items are represented by positive integers, and items in each itemset must be distinct and
+ * ordered.
+ * we use 0 as the delimiter between itemsets.
+ * For example, a sequence `<(12)(31)1>` is represented by `[0, 1, 2, 0, 1, 3, 0, 1, 0]`.
+ * The postfix of this sequence w.r.t. to prefix `<1>` is `<(_2)(13)1>`.
+ * We may reuse the original items array `[0, 1, 2, 0, 1, 3, 0, 1, 0]` to represent the postfix,
+ * and mark the start index of the postfix, which is `2` in this example.
+ * So the active items in this postfix are `[2, 0, 1, 3, 0, 1, 0]`.
+ * We also remember the start indices of partial projections, the ones that split an itemset.
+ * For example, another possible partial projection w.r.t. `<1>` is `<(_3)1>`.
+ * We remember the start indices of partial projections, which is `[2, 5]` in this example.
+ * This data structure makes it easier to do projections.
+ *
+ * @param items a sequence stored as `Array[Int]` containing this postfix
+ * @param start the start index of this postfix in items
+ * @param partialStarts start indices of possible partial projections, strictly increasing
*/
- private def getPatternsInLocal(
- minCount: Long,
- data: RDD[(List[Set[Int]], Iterable[List[Set[Int]]])]): RDD[(List[Set[Int]], Long)] = {
- data.flatMap {
- case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB)
+ private[fpm] class Postfix(
+ val items: Array[Int],
+ val start: Int = 0,
+ val partialStarts: Array[Int] = Array.empty) extends Serializable {
+
+ require(items.last == 0, s"The last item in a postfix must be zero, but got ${items.last}.")
+ if (partialStarts.nonEmpty) {
+ require(partialStarts.head >= start,
+ "The first partial start cannot be smaller than the start index," +
+ s"but got partialStarts.head = ${partialStarts.head} < start = $start.")
}
- }
-}
+ /**
+ * Start index of the first full itemset contained in this postfix.
+ */
+ private[this] def fullStart: Int = {
+ var i = start
+ while (items(i) != 0) {
+ i += 1
+ }
+ i
+ }
+
+ /**
+ * Generates length-1 prefix items of this postfix with the corresponding postfix sizes.
+ * There are two types of prefix items:
+ * a) The item can be assembled to the last itemset of the prefix. For example,
+ * the postfix of `<(12)(123)>1` w.r.t. `<1>` is `<(_2)(123)1>`. The prefix items of this
+ * postfix can be assembled to `<1>` is `_2` and `_3`, resulting new prefixes `<(12)>` and
+ * `<(13)>`. We flip the sign in the output to indicate that this is a partial prefix item.
+ * b) The item can be appended to the prefix. Taking the same example above, the prefix items
+ * can be appended to `<1>` is `1`, `2`, and `3`, resulting new prefixes `<11>`, `<12>`,
+ * and `<13>`.
+ * @return an iterator of (prefix item, corresponding postfix size). If the item is negative, it
+ * indicates a partial prefix item, which should be assembled to the last itemset of the
+ * current prefix. Otherwise, the item should be appended to the current prefix.
+ */
+ def genPrefixItems: Iterator[(Int, Long)] = {
+ val n1 = items.length - 1
+ // For each unique item (subject to sign) in this sequence, we output exact one split.
+ // TODO: use PrimitiveKeyOpenHashMap
+ val prefixes = mutable.Map.empty[Int, Long]
+ // a) items that can be assembled to the last itemset of the prefix
+ partialStarts.foreach { start =>
+ var i = start
+ var x = -items(i)
+ while (x != 0) {
+ if (!prefixes.contains(x)) {
+ prefixes(x) = n1 - i
+ }
+ i += 1
+ x = -items(i)
+ }
+ }
+ // b) items that can be appended to the prefix
+ var i = fullStart
+ while (i < n1) {
+ val x = items(i)
+ if (x != 0 && !prefixes.contains(x)) {
+ prefixes(x) = n1 - i
+ }
+ i += 1
+ }
+ prefixes.toIterator
+ }
-object PrefixSpan {
- private[fpm] val DELIMITER = -1
+ /** Tests whether this postfix is non-empty. */
+ def nonEmpty: Boolean = items.length > start + 1
- /** Splits an array of itemsets delimited by [[DELIMITER]]. */
- private[fpm] def splitSequence(sequence: List[Int]): List[Set[Int]] = {
- sequence.span(_ != DELIMITER) match {
- case (x, xs) if xs.length > 1 => x.toSet :: splitSequence(xs.tail)
- case (x, xs) => List(x.toSet)
+ /**
+ * Projects this postfix with respect to the input prefix item.
+ * @param prefix prefix item. If prefix is positive, we match items in any full itemset; if it
+ * is negative, we do partial projections.
+ * @return the projected postfix
+ */
+ def project(prefix: Int): Postfix = {
+ require(prefix != 0)
+ val n1 = items.length - 1
+ var matched = false
+ var newStart = n1
+ val newPartialStarts = mutable.ArrayBuilder.make[Int]
+ if (prefix < 0) {
+ // Search for partial projections.
+ val target = -prefix
+ partialStarts.foreach { start =>
+ var i = start
+ var x = items(i)
+ while (x != target && x != 0) {
+ i += 1
+ x = items(i)
+ }
+ if (x == target) {
+ i += 1
+ if (!matched) {
+ newStart = i
+ matched = true
+ }
+ if (items(i) != 0) {
+ newPartialStarts += i
+ }
+ }
+ }
+ } else {
+ // Search for items in full itemsets.
+ // Though the items are ordered in each itemsets, they should be small in practice.
+ // So a sequential scan is sufficient here, compared to bisection search.
+ val target = prefix
+ var i = fullStart
+ while (i < n1) {
+ val x = items(i)
+ if (x == target) {
+ if (!matched) {
+ newStart = i
+ matched = true
+ }
+ if (items(i + 1) != 0) {
+ newPartialStarts += i + 1
+ }
+ }
+ i += 1
+ }
+ }
+ new Postfix(items, newStart, newPartialStarts.result())
}
- }
- /** Flattens a sequence of itemsets into an Array, inserting[[DELIMITER]] between itemsets. */
- private[fpm] def flattenSequence(sequence: List[Set[Int]]): List[Int] = {
- val builder = ArrayBuilder.make[Int]()
- for (itemSet <- sequence) {
- builder += DELIMITER
- builder ++= itemSet.toSeq.sorted
+ /**
+ * Projects this postfix with respect to the input prefix.
+ */
+ private def project(prefix: Array[Int]): Postfix = {
+ var partial = true
+ var cur = this
+ var i = 0
+ val np = prefix.length
+ while (i < np && cur.nonEmpty) {
+ val x = prefix(i)
+ if (x == 0) {
+ partial = false
+ } else {
+ if (partial) {
+ cur = cur.project(-x)
+ } else {
+ cur = cur.project(x)
+ partial = true
+ }
+ }
+ i += 1
+ }
+ cur
}
- builder.result().toList.drop(1) // drop trailing delimiter
- }
- /** Returns an iterator over all non-empty subsets of `itemSet` */
- private[fpm] def nonemptySubsets(itemSet: Set[Int]): Iterator[Set[Int]] = {
- // TODO: improve complexity by using partial prefixes, considering one item at a time
- itemSet.subsets.filter(_ != Set.empty[Int])
+ /**
+ * Projects this postfix with respect to the input prefix.
+ */
+ def project(prefix: Prefix): Postfix = project(prefix.items)
+
+ /**
+ * Returns the same sequence with compressed storage if possible.
+ */
+ def compressed: Postfix = {
+ if (start > 0) {
+ new Postfix(items.slice(start, items.length), 0, partialStarts.map(_ - start))
+ } else {
+ this
+ }
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index 0ae48d62cc..a83e543859 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("PrefixSpan internal (integer seq, -1 delim) run, singleton itemsets") {
+ test("PrefixSpan internal (integer seq, 0 delim) run, singleton itemsets") {
/*
library("arulesSequences")
@@ -35,83 +35,81 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
val sequences = Array(
- Array(1, -1, 3, -1, 4, -1, 5),
- Array(2, -1, 3, -1, 1),
- Array(2, -1, 4, -1, 1),
- Array(3, -1, 1, -1, 3, -1, 4, -1, 5),
- Array(3, -1, 4, -1, 4, -1, 3),
- Array(6, -1, 5, -1, 3))
+ Array(0, 1, 0, 3, 0, 4, 0, 5, 0),
+ Array(0, 2, 0, 3, 0, 1, 0),
+ Array(0, 2, 0, 4, 0, 1, 0),
+ Array(0, 3, 0, 1, 0, 3, 0, 4, 0, 5, 0),
+ Array(0, 3, 0, 4, 0, 4, 0, 3, 0),
+ Array(0, 6, 0, 5, 0, 3, 0))
val rdd = sc.parallelize(sequences, 2).cache()
- val prefixspan = new PrefixSpan()
- .setMinSupport(0.33)
- .setMaxPatternLength(50)
- val result1 = prefixspan.run(rdd)
+ val result1 = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L)
val expectedValue1 = Array(
- (Array(1), 4L),
- (Array(1, -1, 3), 2L),
- (Array(1, -1, 3, -1, 4), 2L),
- (Array(1, -1, 3, -1, 4, -1, 5), 2L),
- (Array(1, -1, 3, -1, 5), 2L),
- (Array(1, -1, 4), 2L),
- (Array(1, -1, 4, -1, 5), 2L),
- (Array(1, -1, 5), 2L),
- (Array(2), 2L),
- (Array(2, -1, 1), 2L),
- (Array(3), 5L),
- (Array(3, -1, 1), 2L),
- (Array(3, -1, 3), 2L),
- (Array(3, -1, 4), 3L),
- (Array(3, -1, 4, -1, 5), 2L),
- (Array(3, -1, 5), 2L),
- (Array(4), 4L),
- (Array(4, -1, 5), 2L),
- (Array(5), 3L)
+ (Array(0, 1, 0), 4L),
+ (Array(0, 1, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 4, 0, 5, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 5, 0), 2L),
+ (Array(0, 1, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 4, 0, 5, 0), 2L),
+ (Array(0, 1, 0, 5, 0), 2L),
+ (Array(0, 2, 0), 2L),
+ (Array(0, 2, 0, 1, 0), 2L),
+ (Array(0, 3, 0), 5L),
+ (Array(0, 3, 0, 1, 0), 2L),
+ (Array(0, 3, 0, 3, 0), 2L),
+ (Array(0, 3, 0, 4, 0), 3L),
+ (Array(0, 3, 0, 4, 0, 5, 0), 2L),
+ (Array(0, 3, 0, 5, 0), 2L),
+ (Array(0, 4, 0), 4L),
+ (Array(0, 4, 0, 5, 0), 2L),
+ (Array(0, 5, 0), 3L)
)
compareInternalResults(expectedValue1, result1.collect())
- prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
- val result2 = prefixspan.run(rdd)
+ val result2 = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 3, maxPatternLength = 50, maxLocalProjDBSize = 32L)
val expectedValue2 = Array(
- (Array(1), 4L),
- (Array(3), 5L),
- (Array(3, -1, 4), 3L),
- (Array(4), 4L),
- (Array(5), 3L)
+ (Array(0, 1, 0), 4L),
+ (Array(0, 3, 0), 5L),
+ (Array(0, 3, 0, 4, 0), 3L),
+ (Array(0, 4, 0), 4L),
+ (Array(0, 5, 0), 3L)
)
compareInternalResults(expectedValue2, result2.collect())
- prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
- val result3 = prefixspan.run(rdd)
+ val result3 = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 2, maxPatternLength = 2, maxLocalProjDBSize = 32L)
val expectedValue3 = Array(
- (Array(1), 4L),
- (Array(1, -1, 3), 2L),
- (Array(1, -1, 4), 2L),
- (Array(1, -1, 5), 2L),
- (Array(2, -1, 1), 2L),
- (Array(2), 2L),
- (Array(3), 5L),
- (Array(3, -1, 1), 2L),
- (Array(3, -1, 3), 2L),
- (Array(3, -1, 4), 3L),
- (Array(3, -1, 5), 2L),
- (Array(4), 4L),
- (Array(4, -1, 5), 2L),
- (Array(5), 3L)
+ (Array(0, 1, 0), 4L),
+ (Array(0, 1, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 5, 0), 2L),
+ (Array(0, 2, 0, 1, 0), 2L),
+ (Array(0, 2, 0), 2L),
+ (Array(0, 3, 0), 5L),
+ (Array(0, 3, 0, 1, 0), 2L),
+ (Array(0, 3, 0, 3, 0), 2L),
+ (Array(0, 3, 0, 4, 0), 3L),
+ (Array(0, 3, 0, 5, 0), 2L),
+ (Array(0, 4, 0), 4L),
+ (Array(0, 4, 0, 5, 0), 2L),
+ (Array(0, 5, 0), 3L)
)
compareInternalResults(expectedValue3, result3.collect())
}
test("PrefixSpan internal (integer seq, -1 delim) run, variable-size itemsets") {
val sequences = Array(
- Array(1, -1, 1, 2, 3, -1, 1, 3, -1, 4, -1, 3, 6),
- Array(1, 4, -1, 3, -1, 2, 3, -1, 1, 5),
- Array(5, 6, -1, 1, 2, -1, 4, 6, -1, 3, -1, 2),
- Array(5, -1, 7, -1, 1, 6, -1, 3, -1, 2, -1, 3))
+ Array(0, 1, 0, 1, 2, 3, 0, 1, 3, 0, 4, 0, 3, 6, 0),
+ Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0),
+ Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0),
+ Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0))
val rdd = sc.parallelize(sequences, 2).cache()
- val prefixspan = new PrefixSpan().setMinSupport(0.5).setMaxPatternLength(5)
- val result = prefixspan.run(rdd)
+ val result = PrefixSpan.genFreqPatterns(
+ rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L)
/*
To verify results, create file "prefixSpanSeqs" with content
@@ -200,63 +198,87 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
53 <{1},{2},{1}> 0.50
*/
val expectedValue = Array(
- (Array(1), 4L),
- (Array(2), 4L),
- (Array(3), 4L),
- (Array(4), 3L),
- (Array(5), 3L),
- (Array(6), 3L),
- (Array(1, -1, 6), 2L),
- (Array(2, -1, 6), 2L),
- (Array(5, -1, 6), 2L),
- (Array(1, 2, -1, 6), 2L),
- (Array(1, -1, 4), 2L),
- (Array(2, -1, 4), 2L),
- (Array(1, 2, -1, 4), 2L),
- (Array(1, -1, 3), 4L),
- (Array(2, -1, 3), 3L),
- (Array(2, 3), 2L),
- (Array(3, -1, 3), 3L),
- (Array(4, -1, 3), 3L),
- (Array(5, -1, 3), 2L),
- (Array(6, -1, 3), 2L),
- (Array(5, -1, 6, -1, 3), 2L),
- (Array(6, -1, 2, -1, 3), 2L),
- (Array(5, -1, 2, -1, 3), 2L),
- (Array(5, -1, 1, -1, 3), 2L),
- (Array(2, -1, 4, -1, 3), 2L),
- (Array(1, -1, 4, -1, 3), 2L),
- (Array(1, 2, -1, 4, -1, 3), 2L),
- (Array(1, -1, 3, -1, 3), 3L),
- (Array(1, 2, -1, 3), 2L),
- (Array(1, -1, 2, -1, 3), 2L),
- (Array(1, -1, 2, 3), 2L),
- (Array(1, -1, 2), 4L),
- (Array(1, 2), 2L),
- (Array(3, -1, 2), 3L),
- (Array(4, -1, 2), 2L),
- (Array(5, -1, 2), 2L),
- (Array(6, -1, 2), 2L),
- (Array(5, -1, 6, -1, 2), 2L),
- (Array(6, -1, 3, -1, 2), 2L),
- (Array(5, -1, 3, -1, 2), 2L),
- (Array(5, -1, 1, -1, 2), 2L),
- (Array(4, -1, 3, -1, 2), 2L),
- (Array(1, -1, 3, -1, 2), 3L),
- (Array(5, -1, 6, -1, 3, -1, 2), 2L),
- (Array(5, -1, 1, -1, 3, -1, 2), 2L),
- (Array(1, -1, 1), 2L),
- (Array(2, -1, 1), 2L),
- (Array(3, -1, 1), 2L),
- (Array(5, -1, 1), 2L),
- (Array(2, 3, -1, 1), 2L),
- (Array(1, -1, 3, -1, 1), 2L),
- (Array(1, -1, 2, 3, -1, 1), 2L),
- (Array(1, -1, 2, -1, 1), 2L))
+ (Array(0, 1, 0), 4L),
+ (Array(0, 2, 0), 4L),
+ (Array(0, 3, 0), 4L),
+ (Array(0, 4, 0), 3L),
+ (Array(0, 5, 0), 3L),
+ (Array(0, 6, 0), 3L),
+ (Array(0, 1, 0, 6, 0), 2L),
+ (Array(0, 2, 0, 6, 0), 2L),
+ (Array(0, 5, 0, 6, 0), 2L),
+ (Array(0, 1, 2, 0, 6, 0), 2L),
+ (Array(0, 1, 0, 4, 0), 2L),
+ (Array(0, 2, 0, 4, 0), 2L),
+ (Array(0, 1, 2, 0, 4, 0), 2L),
+ (Array(0, 1, 0, 3, 0), 4L),
+ (Array(0, 2, 0, 3, 0), 3L),
+ (Array(0, 2, 3, 0), 2L),
+ (Array(0, 3, 0, 3, 0), 3L),
+ (Array(0, 4, 0, 3, 0), 3L),
+ (Array(0, 5, 0, 3, 0), 2L),
+ (Array(0, 6, 0, 3, 0), 2L),
+ (Array(0, 5, 0, 6, 0, 3, 0), 2L),
+ (Array(0, 6, 0, 2, 0, 3, 0), 2L),
+ (Array(0, 5, 0, 2, 0, 3, 0), 2L),
+ (Array(0, 5, 0, 1, 0, 3, 0), 2L),
+ (Array(0, 2, 0, 4, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 4, 0, 3, 0), 2L),
+ (Array(0, 1, 2, 0, 4, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 3, 0), 3L),
+ (Array(0, 1, 2, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 2, 0, 3, 0), 2L),
+ (Array(0, 1, 0, 2, 3, 0), 2L),
+ (Array(0, 1, 0, 2, 0), 4L),
+ (Array(0, 1, 2, 0), 2L),
+ (Array(0, 3, 0, 2, 0), 3L),
+ (Array(0, 4, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 2, 0), 2L),
+ (Array(0, 6, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 6, 0, 2, 0), 2L),
+ (Array(0, 6, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 1, 0, 2, 0), 2L),
+ (Array(0, 4, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 2, 0), 3L),
+ (Array(0, 5, 0, 6, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 5, 0, 1, 0, 3, 0, 2, 0), 2L),
+ (Array(0, 1, 0, 1, 0), 2L),
+ (Array(0, 2, 0, 1, 0), 2L),
+ (Array(0, 3, 0, 1, 0), 2L),
+ (Array(0, 5, 0, 1, 0), 2L),
+ (Array(0, 2, 3, 0, 1, 0), 2L),
+ (Array(0, 1, 0, 3, 0, 1, 0), 2L),
+ (Array(0, 1, 0, 2, 3, 0, 1, 0), 2L),
+ (Array(0, 1, 0, 2, 0, 1, 0), 2L))
compareInternalResults(expectedValue, result.collect())
}
+ test("PrefixSpan projections with multiple partial starts") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(1, 2, 3)))
+ val rdd = sc.parallelize(sequences, 2)
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(1.0)
+ .setMaxPatternLength(2)
+ val model = prefixSpan.run(rdd)
+ val expected = Array(
+ (Array(Array(1)), 1L),
+ (Array(Array(1, 2)), 1L),
+ (Array(Array(1), Array(1)), 1L),
+ (Array(Array(1), Array(2)), 1L),
+ (Array(Array(1), Array(3)), 1L),
+ (Array(Array(1, 3)), 1L),
+ (Array(Array(2)), 1L),
+ (Array(Array(2, 3)), 1L),
+ (Array(Array(2), Array(1)), 1L),
+ (Array(Array(2), Array(2)), 1L),
+ (Array(Array(2), Array(3)), 1L),
+ (Array(Array(3)), 1L))
+ compareResults(expected, model.freqSequences.collect())
+ }
+
test("PrefixSpan Integer type, variable-size itemsets") {
val sequences = Seq(
Array(Array(1, 2), Array(3)),
@@ -265,7 +287,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
Array(Array(6)))
val rdd = sc.parallelize(sequences, 2).cache()
- val prefixspan = new PrefixSpan()
+ val prefixSpan = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
@@ -296,7 +318,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
5 <{1,2}> 0.75
*/
- val model = prefixspan.run(rdd)
+ val model = prefixSpan.run(rdd)
val expected = Array(
(Array(Array(1)), 3L),
(Array(Array(2)), 3L),
@@ -304,7 +326,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
(Array(Array(1), Array(3)), 2L),
(Array(Array(1, 2)), 3L)
)
- compareResults(expected, model.freqSequences.collect().map(x => (x.sequence, x.freq)))
+ compareResults(expected, model.freqSequences.collect())
}
test("PrefixSpan String type, variable-size itemsets") {
@@ -318,11 +340,11 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
Array(Array(6))).map(seq => seq.map(itemSet => itemSet.map(intToString)))
val rdd = sc.parallelize(sequences, 2).cache()
- val prefixspan = new PrefixSpan()
+ val prefixSpan = new PrefixSpan()
.setMinSupport(0.5)
.setMaxPatternLength(5)
- val model = prefixspan.run(rdd)
+ val model = prefixSpan.run(rdd)
val expected = Array(
(Array(Array(1)), 3L),
(Array(Array(2)), 3L),
@@ -332,17 +354,17 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
).map { case (pattern, count) =>
(pattern.map(itemSet => itemSet.map(intToString)), count)
}
- compareResults(expected, model.freqSequences.collect().map(x => (x.sequence, x.freq)))
+ compareResults(expected, model.freqSequences.collect())
}
private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
- actualValue: Array[(Array[Array[Item]], Long)]): Unit = {
+ actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {
val expectedSet = expectedValue.map { case (pattern: Array[Array[Item]], count: Long) =>
(pattern.map(itemSet => itemSet.toSet).toSeq, count)
}.toSet
- val actualSet = actualValue.map { case (pattern: Array[Array[Item]], count: Long) =>
- (pattern.map(itemSet => itemSet.toSet).toSeq, count)
+ val actualSet = actualValue.map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
}.toSet
assert(expectedSet === actualSet)
}
@@ -354,11 +376,4 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet
assert(expectedSet === actualSet)
}
-
- private def insertDelimiter(sequence: Array[Int]): Array[Int] = {
- sequence.zip(Seq.fill(sequence.length)(PrefixSpan.DELIMITER)).map { case (a, b) =>
- List(a, b)
- }.flatten
- }
-
}