From a02bcf20c4fc9e2e182630d197221729e996afc2 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 4 Aug 2015 22:28:49 -0700 Subject: [SPARK-9540] [MLLIB] optimize PrefixSpan implementation This is a major refactoring of the PrefixSpan implementation. It contains the following changes: 1. Expand prefix with one item at a time. The existing implementation generates all subsets for each itemset, which might have scalability issue when the itemset is large. 2. Use a new internal format. `<(12)(31)>` is represented by `[0, 1, 2, 0, 1, 3, 0]` internally. We use `0` because negative numbers are used to indicates partial prefix items, e.g., `_2` is represented by `-2`. 3. Remember the start indices of all partial projections in the projected postfix to help next projection. 4. Reuse the original sequence array for projected postfixes. 5. Use `Prefix` IDs in aggregation rather than its content. 6. Use `ArrayBuilder` for building primitive arrays. 7. Expose `maxLocalProjDBSize`. 8. Tests are not changed except using `0` instead of `-1` as the delimiter. `Postfix`'s API doc should be a good place to start. Closes #7594 feynmanliang zhangjiajin Author: Xiangrui Meng Closes #7937 from mengxr/SPARK-9540 and squashes the following commits: 2d0ec31 [Xiangrui Meng] address more comments 48f450c [Xiangrui Meng] address comments from Feynman; fixed a bug in project and added a test 65f90e8 [Xiangrui Meng] naming and documentation 8afc86a [Xiangrui Meng] refactor impl --- .../apache/spark/mllib/fpm/LocalPrefixSpan.scala | 132 ++--- .../org/apache/spark/mllib/fpm/PrefixSpan.scala | 587 ++++++++++++++------- .../apache/spark/mllib/fpm/PrefixSpanSuite.scala | 271 +++++----- 3 files changed, 599 insertions(+), 391 deletions(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala index ccebf951c8..3ea10779a1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -22,85 +22,89 @@ import scala.collection.mutable import org.apache.spark.Logging /** - * Calculate all patterns of a projected database in local. + * Calculate all patterns of a projected database in local mode. + * + * @param minCount minimal count for a frequent pattern + * @param maxPatternLength max pattern length for a frequent pattern */ -private[fpm] object LocalPrefixSpan extends Logging with Serializable { - import PrefixSpan._ +private[fpm] class LocalPrefixSpan( + val minCount: Long, + val maxPatternLength: Int) extends Logging with Serializable { + import PrefixSpan.Postfix + import LocalPrefixSpan.ReversedPrefix + /** - * Calculate all patterns of a projected database. - * @param minCount minimum count - * @param maxPatternLength maximum pattern length - * @param prefixes prefixes in reversed order - * @param database the projected database - * @return a set of sequential pattern pairs, - * the key of pair is sequential pattern (a list of items in reversed order), - * the value of pair is the pattern's count. + * Generates frequent patterns on the input array of postfixes. + * @param postfixes an array of postfixes + * @return an iterator of (frequent pattern, count) */ - def run( - minCount: Long, - maxPatternLength: Int, - prefixes: List[Set[Int]], - database: Iterable[List[Set[Int]]]): Iterator[(List[Set[Int]], Long)] = { - if (prefixes.length == maxPatternLength || database.isEmpty) { - return Iterator.empty - } - val freqItemSetsAndCounts = getFreqItemAndCounts(minCount, database) - val freqItems = freqItemSetsAndCounts.keys.flatten.toSet - val filteredDatabase = database.map { suffix => - suffix - .map(item => freqItems.intersect(item)) - .filter(_.nonEmpty) - } - freqItemSetsAndCounts.iterator.flatMap { case (item, count) => - val newPrefixes = item :: prefixes - val newProjected = project(filteredDatabase, item) - Iterator.single((newPrefixes, count)) ++ - run(minCount, maxPatternLength, newPrefixes, newProjected) + def run(postfixes: Array[Postfix]): Iterator[(Array[Int], Long)] = { + genFreqPatterns(ReversedPrefix.empty, postfixes).map { case (prefix, count) => + (prefix.toSequence, count) } } /** - * Calculate suffix sequence immediately after the first occurrence of an item. - * @param item itemset to get suffix after - * @param sequence sequence to extract suffix from - * @return suffix sequence + * Recursively generates frequent patterns. + * @param prefix current prefix + * @param postfixes projected postfixes w.r.t. the prefix + * @return an iterator of (prefix, count) */ - def getSuffix(item: Set[Int], sequence: List[Set[Int]]): List[Set[Int]] = { - val itemsetSeq = sequence - val index = itemsetSeq.indexWhere(item.subsetOf(_)) - if (index == -1) { - List() - } else { - itemsetSeq.drop(index + 1) + private def genFreqPatterns( + prefix: ReversedPrefix, + postfixes: Array[Postfix]): Iterator[(ReversedPrefix, Long)] = { + if (maxPatternLength == prefix.length || postfixes.length < minCount) { + return Iterator.empty + } + // find frequent items + val counts = mutable.Map.empty[Int, Long].withDefaultValue(0) + postfixes.foreach { postfix => + postfix.genPrefixItems.foreach { case (x, _) => + counts(x) += 1L + } + } + val freqItems = counts.toSeq.filter { case (_, count) => + count >= minCount + }.sorted + // project and recursively call genFreqPatterns + freqItems.toIterator.flatMap { case (item, count) => + val newPrefix = prefix :+ item + Iterator.single((newPrefix, count)) ++ { + val projected = postfixes.map(_.project(item)).filter(_.nonEmpty) + genFreqPatterns(newPrefix, projected) + } } } +} - def project( - database: Iterable[List[Set[Int]]], - prefix: Set[Int]): Iterable[List[Set[Int]]] = { - database - .map(getSuffix(prefix, _)) - .filter(_.nonEmpty) - } +private object LocalPrefixSpan { /** - * Generates frequent items by filtering the input data using minimal count level. - * @param minCount the minimum count for an item to be frequent - * @param database database of sequences - * @return freq item to count map + * Represents a prefix stored as a list in reversed order. + * @param items items in the prefix in reversed order + * @param length length of the prefix, not counting delimiters */ - private def getFreqItemAndCounts( - minCount: Long, - database: Iterable[List[Set[Int]]]): Map[Set[Int], Long] = { - // TODO: use PrimitiveKeyOpenHashMap - val counts = mutable.Map[Set[Int], Long]().withDefaultValue(0L) - database.foreach { sequence => - sequence.flatMap(nonemptySubsets(_)).distinct.foreach { item => - counts(item) += 1L + class ReversedPrefix private (val items: List[Int], val length: Int) extends Serializable { + /** + * Expands the prefix by one item. + */ + def :+(item: Int): ReversedPrefix = { + require(item != 0) + if (item < 0) { + new ReversedPrefix(-item :: items, length + 1) + } else { + new ReversedPrefix(item :: 0 :: items, length + 1) } } - counts - .filter { case (_, count) => count >= minCount } - .toMap + + /** + * Converts this prefix to a sequence. + */ + def toSequence: Array[Int] = (0 :: items).toArray.reverse + } + + object ReversedPrefix { + /** An empty prefix. */ + val empty: ReversedPrefix = new ReversedPrefix(List.empty, 0) } } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 9eaf733fad..d5f0c926c6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -18,9 +18,10 @@ package org.apache.spark.mllib.fpm import java.{lang => jl, util => ju} +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.mutable import scala.collection.JavaConverters._ -import scala.collection.mutable.ArrayBuilder import scala.reflect.ClassTag import org.apache.spark.Logging @@ -31,17 +32,20 @@ import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel /** - * * :: Experimental :: * - * A parallel PrefixSpan algorithm to mine sequential pattern. - * The PrefixSpan algorithm is described in - * [[http://doi.org/10.1109/ICDE.2001.914830]]. + * A parallel PrefixSpan algorithm to mine frequent sequential patterns. + * The PrefixSpan algorithm is described in J. Pei, et al., PrefixSpan: Mining Sequential Patterns + * Efficiently by Prefix-Projected Pattern Growth ([[http://doi.org/10.1109/ICDE.2001.914830]]). * * @param minSupport the minimal support level of the sequential pattern, any pattern appears * more than (minSupport * size-of-the-dataset) times will be output * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears - * less than maxPatternLength will be output + * less than maxPatternLength will be output + * @param maxLocalProjDBSize The maximum number of items (including delimiters used in the internal + * storage format) allowed in a projected database before local + * processing. If a projected database exceeds this size, another + * iteration of distributed prefix growth is run. * * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining * (Wikipedia)]] @@ -49,33 +53,28 @@ import org.apache.spark.storage.StorageLevel @Experimental class PrefixSpan private ( private var minSupport: Double, - private var maxPatternLength: Int) extends Logging with Serializable { + private var maxPatternLength: Int, + private var maxLocalProjDBSize: Long) extends Logging with Serializable { import PrefixSpan._ - /** - * The maximum number of items allowed in a projected database before local processing. If a - * projected database exceeds this size, another iteration of distributed PrefixSpan is run. - */ - // TODO: make configurable with a better default value - private val maxLocalProjDBSize: Long = 32000000L - /** * Constructs a default instance with default parameters - * {minSupport: `0.1`, maxPatternLength: `10`}. + * {minSupport: `0.1`, maxPatternLength: `10`, maxLocalProjDBSize: `32000000L`}. */ - def this() = this(0.1, 10) + def this() = this(0.1, 10, 32000000L) /** * Get the minimal support (i.e. the frequency of occurrence before a pattern is considered * frequent). */ - def getMinSupport: Double = this.minSupport + def getMinSupport: Double = minSupport /** * Sets the minimal support level (default: `0.1`). */ def setMinSupport(minSupport: Double): this.type = { - require(minSupport >= 0 && minSupport <= 1, "The minimum support value must be in [0, 1].") + require(minSupport >= 0 && minSupport <= 1, + s"The minimum support value must be in [0, 1], but got $minSupport.") this.minSupport = minSupport this } @@ -83,45 +82,115 @@ class PrefixSpan private ( /** * Gets the maximal pattern length (i.e. the length of the longest sequential pattern to consider. */ - def getMaxPatternLength: Double = this.maxPatternLength + def getMaxPatternLength: Double = maxPatternLength /** * Sets maximal pattern length (default: `10`). */ def setMaxPatternLength(maxPatternLength: Int): this.type = { // TODO: support unbounded pattern length when maxPatternLength = 0 - require(maxPatternLength >= 1, "The maximum pattern length value must be greater than 0.") + require(maxPatternLength >= 1, + s"The maximum pattern length value must be greater than 0, but got $maxPatternLength.") this.maxPatternLength = maxPatternLength this } /** - * Find the complete set of sequential patterns in the input sequences of itemsets. - * @param data ordered sequences of itemsets. - * @return a [[PrefixSpanModel]] that contains the frequent sequences + * Gets the maximum number of items allowed in a projected database before local processing. + */ + def getMaxLocalProjDBSize: Long = maxLocalProjDBSize + + /** + * Sets the maximum number of items (including delimiters used in the internal storage format) + * allowed in a projected database before local processing (default: `32000000L`). + */ + def setMaxLocalProjDBSize(maxLocalProjDBSize: Long): this.type = { + require(maxLocalProjDBSize >= 0L, + s"The maximum local projected database size must be nonnegative, but got $maxLocalProjDBSize") + this.maxLocalProjDBSize = maxLocalProjDBSize + this + } + + /** + * Finds the complete set of frequent sequential patterns in the input sequences of itemsets. + * @param data sequences of itemsets. + * @return a [[PrefixSpanModel]] that contains the frequent patterns */ def run[Item: ClassTag](data: RDD[Array[Array[Item]]]): PrefixSpanModel[Item] = { - val itemToInt = data.aggregate(Set[Item]())( - seqOp = { (uniqItems, item) => uniqItems ++ item.flatten.toSet }, - combOp = { _ ++ _ } - ).zipWithIndex.toMap - val intToItem = Map() ++ (itemToInt.map { case (k, v) => (v, k) }) - - val dataInternalRepr = data.map { seq => - seq.map(itemset => itemset.map(itemToInt)).reduce((a, b) => a ++ (DELIMITER +: b)) + if (data.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") } - val results = run(dataInternalRepr) - def toPublicRepr(pattern: Iterable[Int]): List[Array[Item]] = { - pattern.span(_ != DELIMITER) match { - case (x, xs) if xs.size > 1 => x.map(intToItem).toArray :: toPublicRepr(xs.tail) - case (x, xs) => List(x.map(intToItem).toArray) + val totalCount = data.count() + logInfo(s"number of sequences: $totalCount") + val minCount = math.ceil(minSupport * totalCount).toLong + logInfo(s"minimum count for a frequent pattern: $minCount") + + // Find frequent items. + val freqItemAndCounts = data.flatMap { itemsets => + val uniqItems = mutable.Set.empty[Item] + itemsets.foreach { _.foreach { item => + uniqItems += item + }} + uniqItems.toIterator.map((_, 1L)) + }.reduceByKey(_ + _) + .filter { case (_, count) => + count >= minCount + }.collect() + val freqItems = freqItemAndCounts.sortBy(-_._2).map(_._1) + logInfo(s"number of frequent items: ${freqItems.length}") + + // Keep only frequent items from input sequences and convert them to internal storage. + val itemToInt = freqItems.zipWithIndex.toMap + val dataInternalRepr = data.flatMap { itemsets => + val allItems = mutable.ArrayBuilder.make[Int] + var containsFreqItems = false + allItems += 0 + itemsets.foreach { itemsets => + val items = mutable.ArrayBuilder.make[Int] + itemsets.foreach { item => + if (itemToInt.contains(item)) { + items += itemToInt(item) + 1 // using 1-indexing in internal format + } + } + val result = items.result() + if (result.nonEmpty) { + containsFreqItems = true + allItems ++= result.sorted + } + allItems += 0 + } + if (containsFreqItems) { + Iterator.single(allItems.result()) + } else { + Iterator.empty + } + }.persist(StorageLevel.MEMORY_AND_DISK) + + val results = genFreqPatterns(dataInternalRepr, minCount, maxPatternLength, maxLocalProjDBSize) + + def toPublicRepr(pattern: Array[Int]): Array[Array[Item]] = { + val sequenceBuilder = mutable.ArrayBuilder.make[Array[Item]] + val itemsetBuilder = mutable.ArrayBuilder.make[Item] + val n = pattern.length + var i = 1 + while (i < n) { + val x = pattern(i) + if (x == 0) { + sequenceBuilder += itemsetBuilder.result() + itemsetBuilder.clear() + } else { + itemsetBuilder += freqItems(x - 1) // using 1-indexing in internal format + } + i += 1 } + sequenceBuilder.result() } + val freqSequences = results.map { case (seq: Array[Int], count: Long) => - new FreqSequence[Item](toPublicRepr(seq).toArray, count) + new FreqSequence(toPublicRepr(seq), count) } - new PrefixSpanModel[Item](freqSequences) + new PrefixSpanModel(freqSequences) } /** @@ -131,7 +200,7 @@ class PrefixSpan private ( * @tparam Item item type * @tparam Itemset itemset type, which is an Iterable of Items * @tparam Sequence sequence type, which is an Iterable of Itemsets - * @return a [[PrefixSpanModel]] that contains the frequent sequences + * @return a [[PrefixSpanModel]] that contains the frequent sequential patterns */ def run[Item, Itemset <: jl.Iterable[Item], Sequence <: jl.Iterable[Itemset]]( data: JavaRDD[Sequence]): PrefixSpanModel[Item] = { @@ -139,200 +208,320 @@ class PrefixSpan private ( run(data.rdd.map(_.asScala.map(_.asScala.toArray).toArray)) } +} + +@Experimental +object PrefixSpan extends Logging { + /** - * Find the complete set of sequential patterns in the input sequences. This method utilizes - * the internal representation of itemsets as Array[Int] where each itemset is represented by - * a contiguous sequence of non-negative integers and delimiters represented by [[DELIMITER]]. - * @param data ordered sequences of itemsets. Items are represented by non-negative integers. - * Each itemset has one or more items and is delimited by [[DELIMITER]]. - * @return a set of sequential pattern pairs, - * the key of pair is pattern (a list of elements), - * the value of pair is the pattern's count. + * Find the complete set of frequent sequential patterns in the input sequences. + * @param data ordered sequences of itemsets. We represent a sequence internally as Array[Int], + * where each itemset is represented by a contiguous sequence of distinct and ordered + * positive integers. We use 0 as the delimiter at itemset boundaries, including the + * first and the last position. + * @return an RDD of (frequent sequential pattern, count) pairs, + * @see [[Postfix]] */ - private[fpm] def run(data: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + private[fpm] def genFreqPatterns( + data: RDD[Array[Int]], + minCount: Long, + maxPatternLength: Int, + maxLocalProjDBSize: Long): RDD[(Array[Int], Long)] = { val sc = data.sparkContext if (data.getStorageLevel == StorageLevel.NONE) { logWarning("Input data is not cached.") } - // Use List[Set[Item]] for internal computation - val sequences = data.map { seq => splitSequence(seq.toList) } - - // Convert min support to a min number of transactions for this dataset - val minCount = if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong - - // (Frequent items -> number of occurrences, all items here satisfy the `minSupport` threshold - val freqItemCounts = sequences - .flatMap(seq => seq.flatMap(nonemptySubsets(_)).distinct.map(item => (item, 1L))) - .reduceByKey(_ + _) - .filter { case (item, count) => (count >= minCount) } - .collect() - .toMap - - // Pairs of (length 1 prefix, suffix consisting of frequent items) - val itemSuffixPairs = { - val freqItemSets = freqItemCounts.keys.toSet - val freqItems = freqItemSets.flatten - sequences.flatMap { seq => - val filteredSeq = seq.map(item => freqItems.intersect(item)).filter(_.nonEmpty) - freqItemSets.flatMap { item => - val candidateSuffix = LocalPrefixSpan.getSuffix(item, filteredSeq) - candidateSuffix match { - case suffix if !suffix.isEmpty => Some((List(item), suffix)) - case _ => None + val postfixes = data.map(items => new Postfix(items)) + + // Local frequent patterns (prefixes) and their counts. + val localFreqPatterns = mutable.ArrayBuffer.empty[(Array[Int], Long)] + // Prefixes whose projected databases are small. + val smallPrefixes = mutable.Map.empty[Int, Prefix] + val emptyPrefix = Prefix.empty + // Prefixes whose projected databases are large. + var largePrefixes = mutable.Map(emptyPrefix.id -> emptyPrefix) + while (largePrefixes.nonEmpty) { + val numLocalFreqPatterns = localFreqPatterns.length + logInfo(s"number of local frequent patterns: $numLocalFreqPatterns") + if (numLocalFreqPatterns > 1000000) { + logWarning( + s""" + | Collected $numLocalFreqPatterns local frequent patterns. You may want to consider: + | 1. increase minSupport, + | 2. decrease maxPatternLength, + | 3. increase maxLocalProjDBSize. + """.stripMargin) + } + logInfo(s"number of small prefixes: ${smallPrefixes.size}") + logInfo(s"number of large prefixes: ${largePrefixes.size}") + val largePrefixArray = largePrefixes.values.toArray + val freqPrefixes = postfixes.flatMap { postfix => + largePrefixArray.flatMap { prefix => + postfix.project(prefix).genPrefixItems.map { case (item, postfixSize) => + ((prefix.id, item), (1L, postfixSize)) + } + } + }.reduceByKey { case ((c0, s0), (c1, s1)) => + (c0 + c1, s0 + s1) + }.filter { case (_, (c, _)) => c >= minCount } + .collect() + val newLargePrefixes = mutable.Map.empty[Int, Prefix] + freqPrefixes.foreach { case ((id, item), (count, projDBSize)) => + val newPrefix = largePrefixes(id) :+ item + localFreqPatterns += ((newPrefix.items :+ 0, count)) + if (newPrefix.length < maxPatternLength) { + if (projDBSize > maxLocalProjDBSize) { + newLargePrefixes += newPrefix.id -> newPrefix + } else { + smallPrefixes += newPrefix.id -> newPrefix } } } + largePrefixes = newLargePrefixes } - // Accumulator for the computed results to be returned, initialized to the frequent items (i.e. - // frequent length-one prefixes) - var resultsAccumulator = freqItemCounts.map { case (item, count) => (List(item), count) }.toList - - // Remaining work to be locally and distributively processed respectfully - var (pairsForLocal, pairsForDistributed) = partitionByProjDBSize(itemSuffixPairs) - - // Continue processing until no pairs for distributed processing remain (i.e. all prefixes have - // projected database sizes <= `maxLocalProjDBSize`) or `maxPatternLength` is reached - var patternLength = 1 - while (pairsForDistributed.count() != 0 && patternLength < maxPatternLength) { - val (nextPatternAndCounts, nextPrefixSuffixPairs) = - extendPrefixes(minCount, pairsForDistributed) - pairsForDistributed.unpersist() - val (smallerPairsPart, largerPairsPart) = partitionByProjDBSize(nextPrefixSuffixPairs) - pairsForDistributed = largerPairsPart - pairsForDistributed.persist(StorageLevel.MEMORY_AND_DISK) - pairsForLocal ++= smallerPairsPart - resultsAccumulator ++= nextPatternAndCounts.collect() - patternLength += 1 // pattern length grows one per iteration + // Switch to local processing. + val bcSmallPrefixes = sc.broadcast(smallPrefixes) + val distributedFreqPattern = postfixes.flatMap { postfix => + bcSmallPrefixes.value.values.map { prefix => + (prefix.id, postfix.project(prefix).compressed) + }.filter(_._2.nonEmpty) + }.groupByKey().flatMap { case (id, projPostfixes) => + val prefix = bcSmallPrefixes.value(id) + val localPrefixSpan = new LocalPrefixSpan(minCount, maxPatternLength - prefix.length) + // TODO: We collect projected postfixes into memory. We should also compare the performance + // TODO: of keeping them on shuffle files. + localPrefixSpan.run(projPostfixes.toArray).map { case (pattern, count) => + (prefix.items ++ pattern, count) + } } - // Process the small projected databases locally - val remainingResults = getPatternsInLocal( - minCount, sc.parallelize(pairsForLocal, 1).groupByKey()) - - (sc.parallelize(resultsAccumulator, 1) ++ remainingResults) - .map { case (pattern, count) => (flattenSequence(pattern.reverse).toArray, count) } + // Union local frequent patterns and distributed ones. + val freqPatterns = (sc.parallelize(localFreqPatterns, 1) ++ distributedFreqPattern) + .persist(StorageLevel.MEMORY_AND_DISK) + freqPatterns } - /** - * Partitions the prefix-suffix pairs by projected database size. - * @param prefixSuffixPairs prefix (length n) and suffix pairs, - * @return prefix-suffix pairs partitioned by whether their projected database size is <= or - * greater than [[maxLocalProjDBSize]] + * Represents a prefix. + * @param items items in this prefix, using the internal format + * @param length length of this prefix, not counting 0 */ - private def partitionByProjDBSize(prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])]) - : (List[(List[Set[Int]], List[Set[Int]])], RDD[(List[Set[Int]], List[Set[Int]])]) = { - val prefixToSuffixSize = prefixSuffixPairs - .aggregateByKey(0)( - seqOp = { case (count, suffix) => count + suffix.length }, - combOp = { _ + _ }) - val smallPrefixes = prefixToSuffixSize - .filter(_._2 <= maxLocalProjDBSize) - .keys - .collect() - .toSet - val small = prefixSuffixPairs.filter { case (prefix, _) => smallPrefixes.contains(prefix) } - val large = prefixSuffixPairs.filter { case (prefix, _) => !smallPrefixes.contains(prefix) } - (small.collect().toList, large) + private[fpm] class Prefix private (val items: Array[Int], val length: Int) extends Serializable { + + /** A unique id for this prefix. */ + val id: Int = Prefix.nextId + + /** Expands this prefix by the input item. */ + def :+(item: Int): Prefix = { + require(item != 0) + if (item < 0) { + new Prefix(items :+ -item, length + 1) + } else { + new Prefix(items ++ Array(0, item), length + 1) + } + } } - /** - * Extends all prefixes by one itemset from their suffix and computes the resulting frequent - * prefixes and remaining work. - * @param minCount minimum count - * @param prefixSuffixPairs prefix (length N) and suffix pairs, - * @return (frequent length N+1 extended prefix, count) pairs and (frequent length N+1 extended - * prefix, corresponding suffix) pairs. - */ - private def extendPrefixes( - minCount: Long, - prefixSuffixPairs: RDD[(List[Set[Int]], List[Set[Int]])]) - : (RDD[(List[Set[Int]], Long)], RDD[(List[Set[Int]], List[Set[Int]])]) = { - - // (length N prefix, itemset from suffix) pairs and their corresponding number of occurrences - // Every (prefix :+ suffix) is guaranteed to have support exceeding `minSupport` - val prefixItemPairAndCounts = prefixSuffixPairs - .flatMap { case (prefix, suffix) => - suffix.flatMap(nonemptySubsets(_)).distinct.map(y => ((prefix, y), 1L)) } - .reduceByKey(_ + _) - .filter { case (item, count) => (count >= minCount) } - - // Map from prefix to set of possible next items from suffix - val prefixToNextItems = prefixItemPairAndCounts - .keys - .groupByKey() - .mapValues(_.toSet) - .collect() - .toMap - - // Frequent patterns with length N+1 and their corresponding counts - val extendedPrefixAndCounts = prefixItemPairAndCounts - .map { case ((prefix, item), count) => (item :: prefix, count) } - - // Remaining work, all prefixes will have length N+1 - val extendedPrefixAndSuffix = prefixSuffixPairs - .filter(x => prefixToNextItems.contains(x._1)) - .flatMap { case (prefix, suffix) => - val frequentNextItemSets = prefixToNextItems(prefix) - val frequentNextItems = frequentNextItemSets.flatten - val filteredSuffix = suffix - .map(item => frequentNextItems.intersect(item)) - .filter(_.nonEmpty) - frequentNextItemSets.flatMap { item => - LocalPrefixSpan.getSuffix(item, filteredSuffix) match { - case suffix if !suffix.isEmpty => Some(item :: prefix, suffix) - case _ => None - } - } - } + private[fpm] object Prefix { + /** Internal counter to generate unique IDs. */ + private val counter: AtomicInteger = new AtomicInteger(-1) - (extendedPrefixAndCounts, extendedPrefixAndSuffix) + /** Gets the next unique ID. */ + private def nextId: Int = counter.incrementAndGet() + + /** An empty [[Prefix]] instance. */ + val empty: Prefix = new Prefix(Array.empty, 0) } /** - * Calculate the patterns in local. - * @param minCount the absolute minimum count - * @param data prefixes and projected sequences data data - * @return patterns + * An internal representation of a postfix from some projection. + * We use one int array to store the items, which might also contains other items from the + * original sequence. + * Items are represented by positive integers, and items in each itemset must be distinct and + * ordered. + * we use 0 as the delimiter between itemsets. + * For example, a sequence `<(12)(31)1>` is represented by `[0, 1, 2, 0, 1, 3, 0, 1, 0]`. + * The postfix of this sequence w.r.t. to prefix `<1>` is `<(_2)(13)1>`. + * We may reuse the original items array `[0, 1, 2, 0, 1, 3, 0, 1, 0]` to represent the postfix, + * and mark the start index of the postfix, which is `2` in this example. + * So the active items in this postfix are `[2, 0, 1, 3, 0, 1, 0]`. + * We also remember the start indices of partial projections, the ones that split an itemset. + * For example, another possible partial projection w.r.t. `<1>` is `<(_3)1>`. + * We remember the start indices of partial projections, which is `[2, 5]` in this example. + * This data structure makes it easier to do projections. + * + * @param items a sequence stored as `Array[Int]` containing this postfix + * @param start the start index of this postfix in items + * @param partialStarts start indices of possible partial projections, strictly increasing */ - private def getPatternsInLocal( - minCount: Long, - data: RDD[(List[Set[Int]], Iterable[List[Set[Int]]])]): RDD[(List[Set[Int]], Long)] = { - data.flatMap { - case (prefix, projDB) => LocalPrefixSpan.run(minCount, maxPatternLength, prefix, projDB) + private[fpm] class Postfix( + val items: Array[Int], + val start: Int = 0, + val partialStarts: Array[Int] = Array.empty) extends Serializable { + + require(items.last == 0, s"The last item in a postfix must be zero, but got ${items.last}.") + if (partialStarts.nonEmpty) { + require(partialStarts.head >= start, + "The first partial start cannot be smaller than the start index," + + s"but got partialStarts.head = ${partialStarts.head} < start = $start.") } - } -} + /** + * Start index of the first full itemset contained in this postfix. + */ + private[this] def fullStart: Int = { + var i = start + while (items(i) != 0) { + i += 1 + } + i + } + + /** + * Generates length-1 prefix items of this postfix with the corresponding postfix sizes. + * There are two types of prefix items: + * a) The item can be assembled to the last itemset of the prefix. For example, + * the postfix of `<(12)(123)>1` w.r.t. `<1>` is `<(_2)(123)1>`. The prefix items of this + * postfix can be assembled to `<1>` is `_2` and `_3`, resulting new prefixes `<(12)>` and + * `<(13)>`. We flip the sign in the output to indicate that this is a partial prefix item. + * b) The item can be appended to the prefix. Taking the same example above, the prefix items + * can be appended to `<1>` is `1`, `2`, and `3`, resulting new prefixes `<11>`, `<12>`, + * and `<13>`. + * @return an iterator of (prefix item, corresponding postfix size). If the item is negative, it + * indicates a partial prefix item, which should be assembled to the last itemset of the + * current prefix. Otherwise, the item should be appended to the current prefix. + */ + def genPrefixItems: Iterator[(Int, Long)] = { + val n1 = items.length - 1 + // For each unique item (subject to sign) in this sequence, we output exact one split. + // TODO: use PrimitiveKeyOpenHashMap + val prefixes = mutable.Map.empty[Int, Long] + // a) items that can be assembled to the last itemset of the prefix + partialStarts.foreach { start => + var i = start + var x = -items(i) + while (x != 0) { + if (!prefixes.contains(x)) { + prefixes(x) = n1 - i + } + i += 1 + x = -items(i) + } + } + // b) items that can be appended to the prefix + var i = fullStart + while (i < n1) { + val x = items(i) + if (x != 0 && !prefixes.contains(x)) { + prefixes(x) = n1 - i + } + i += 1 + } + prefixes.toIterator + } -object PrefixSpan { - private[fpm] val DELIMITER = -1 + /** Tests whether this postfix is non-empty. */ + def nonEmpty: Boolean = items.length > start + 1 - /** Splits an array of itemsets delimited by [[DELIMITER]]. */ - private[fpm] def splitSequence(sequence: List[Int]): List[Set[Int]] = { - sequence.span(_ != DELIMITER) match { - case (x, xs) if xs.length > 1 => x.toSet :: splitSequence(xs.tail) - case (x, xs) => List(x.toSet) + /** + * Projects this postfix with respect to the input prefix item. + * @param prefix prefix item. If prefix is positive, we match items in any full itemset; if it + * is negative, we do partial projections. + * @return the projected postfix + */ + def project(prefix: Int): Postfix = { + require(prefix != 0) + val n1 = items.length - 1 + var matched = false + var newStart = n1 + val newPartialStarts = mutable.ArrayBuilder.make[Int] + if (prefix < 0) { + // Search for partial projections. + val target = -prefix + partialStarts.foreach { start => + var i = start + var x = items(i) + while (x != target && x != 0) { + i += 1 + x = items(i) + } + if (x == target) { + i += 1 + if (!matched) { + newStart = i + matched = true + } + if (items(i) != 0) { + newPartialStarts += i + } + } + } + } else { + // Search for items in full itemsets. + // Though the items are ordered in each itemsets, they should be small in practice. + // So a sequential scan is sufficient here, compared to bisection search. + val target = prefix + var i = fullStart + while (i < n1) { + val x = items(i) + if (x == target) { + if (!matched) { + newStart = i + matched = true + } + if (items(i + 1) != 0) { + newPartialStarts += i + 1 + } + } + i += 1 + } + } + new Postfix(items, newStart, newPartialStarts.result()) } - } - /** Flattens a sequence of itemsets into an Array, inserting[[DELIMITER]] between itemsets. */ - private[fpm] def flattenSequence(sequence: List[Set[Int]]): List[Int] = { - val builder = ArrayBuilder.make[Int]() - for (itemSet <- sequence) { - builder += DELIMITER - builder ++= itemSet.toSeq.sorted + /** + * Projects this postfix with respect to the input prefix. + */ + private def project(prefix: Array[Int]): Postfix = { + var partial = true + var cur = this + var i = 0 + val np = prefix.length + while (i < np && cur.nonEmpty) { + val x = prefix(i) + if (x == 0) { + partial = false + } else { + if (partial) { + cur = cur.project(-x) + } else { + cur = cur.project(x) + partial = true + } + } + i += 1 + } + cur } - builder.result().toList.drop(1) // drop trailing delimiter - } - /** Returns an iterator over all non-empty subsets of `itemSet` */ - private[fpm] def nonemptySubsets(itemSet: Set[Int]): Iterator[Set[Int]] = { - // TODO: improve complexity by using partial prefixes, considering one item at a time - itemSet.subsets.filter(_ != Set.empty[Int]) + /** + * Projects this postfix with respect to the input prefix. + */ + def project(prefix: Prefix): Postfix = project(prefix.items) + + /** + * Returns the same sequence with compressed storage if possible. + */ + def compressed: Postfix = { + if (start > 0) { + new Postfix(items.slice(start, items.length), 0, partialStarts.map(_ - start)) + } else { + this + } + } } /** diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index 0ae48d62cc..a83e543859 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -21,7 +21,7 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { - test("PrefixSpan internal (integer seq, -1 delim) run, singleton itemsets") { + test("PrefixSpan internal (integer seq, 0 delim) run, singleton itemsets") { /* library("arulesSequences") @@ -35,83 +35,81 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { */ val sequences = Array( - Array(1, -1, 3, -1, 4, -1, 5), - Array(2, -1, 3, -1, 1), - Array(2, -1, 4, -1, 1), - Array(3, -1, 1, -1, 3, -1, 4, -1, 5), - Array(3, -1, 4, -1, 4, -1, 3), - Array(6, -1, 5, -1, 3)) + Array(0, 1, 0, 3, 0, 4, 0, 5, 0), + Array(0, 2, 0, 3, 0, 1, 0), + Array(0, 2, 0, 4, 0, 1, 0), + Array(0, 3, 0, 1, 0, 3, 0, 4, 0, 5, 0), + Array(0, 3, 0, 4, 0, 4, 0, 3, 0), + Array(0, 6, 0, 5, 0, 3, 0)) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan() - .setMinSupport(0.33) - .setMaxPatternLength(50) - val result1 = prefixspan.run(rdd) + val result1 = PrefixSpan.genFreqPatterns( + rdd, minCount = 2L, maxPatternLength = 50, maxLocalProjDBSize = 16L) val expectedValue1 = Array( - (Array(1), 4L), - (Array(1, -1, 3), 2L), - (Array(1, -1, 3, -1, 4), 2L), - (Array(1, -1, 3, -1, 4, -1, 5), 2L), - (Array(1, -1, 3, -1, 5), 2L), - (Array(1, -1, 4), 2L), - (Array(1, -1, 4, -1, 5), 2L), - (Array(1, -1, 5), 2L), - (Array(2), 2L), - (Array(2, -1, 1), 2L), - (Array(3), 5L), - (Array(3, -1, 1), 2L), - (Array(3, -1, 3), 2L), - (Array(3, -1, 4), 3L), - (Array(3, -1, 4, -1, 5), 2L), - (Array(3, -1, 5), 2L), - (Array(4), 4L), - (Array(4, -1, 5), 2L), - (Array(5), 3L) + (Array(0, 1, 0), 4L), + (Array(0, 1, 0, 3, 0), 2L), + (Array(0, 1, 0, 3, 0, 4, 0), 2L), + (Array(0, 1, 0, 3, 0, 4, 0, 5, 0), 2L), + (Array(0, 1, 0, 3, 0, 5, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 1, 0, 4, 0, 5, 0), 2L), + (Array(0, 1, 0, 5, 0), 2L), + (Array(0, 2, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 3, 0, 3, 0), 2L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 3, 0, 4, 0, 5, 0), 2L), + (Array(0, 3, 0, 5, 0), 2L), + (Array(0, 4, 0), 4L), + (Array(0, 4, 0, 5, 0), 2L), + (Array(0, 5, 0), 3L) ) compareInternalResults(expectedValue1, result1.collect()) - prefixspan.setMinSupport(0.5).setMaxPatternLength(50) - val result2 = prefixspan.run(rdd) + val result2 = PrefixSpan.genFreqPatterns( + rdd, minCount = 3, maxPatternLength = 50, maxLocalProjDBSize = 32L) val expectedValue2 = Array( - (Array(1), 4L), - (Array(3), 5L), - (Array(3, -1, 4), 3L), - (Array(4), 4L), - (Array(5), 3L) + (Array(0, 1, 0), 4L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 4, 0), 4L), + (Array(0, 5, 0), 3L) ) compareInternalResults(expectedValue2, result2.collect()) - prefixspan.setMinSupport(0.33).setMaxPatternLength(2) - val result3 = prefixspan.run(rdd) + val result3 = PrefixSpan.genFreqPatterns( + rdd, minCount = 2, maxPatternLength = 2, maxLocalProjDBSize = 32L) val expectedValue3 = Array( - (Array(1), 4L), - (Array(1, -1, 3), 2L), - (Array(1, -1, 4), 2L), - (Array(1, -1, 5), 2L), - (Array(2, -1, 1), 2L), - (Array(2), 2L), - (Array(3), 5L), - (Array(3, -1, 1), 2L), - (Array(3, -1, 3), 2L), - (Array(3, -1, 4), 3L), - (Array(3, -1, 5), 2L), - (Array(4), 4L), - (Array(4, -1, 5), 2L), - (Array(5), 3L) + (Array(0, 1, 0), 4L), + (Array(0, 1, 0, 3, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 1, 0, 5, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 2, 0), 2L), + (Array(0, 3, 0), 5L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 3, 0, 3, 0), 2L), + (Array(0, 3, 0, 4, 0), 3L), + (Array(0, 3, 0, 5, 0), 2L), + (Array(0, 4, 0), 4L), + (Array(0, 4, 0, 5, 0), 2L), + (Array(0, 5, 0), 3L) ) compareInternalResults(expectedValue3, result3.collect()) } test("PrefixSpan internal (integer seq, -1 delim) run, variable-size itemsets") { val sequences = Array( - Array(1, -1, 1, 2, 3, -1, 1, 3, -1, 4, -1, 3, 6), - Array(1, 4, -1, 3, -1, 2, 3, -1, 1, 5), - Array(5, 6, -1, 1, 2, -1, 4, 6, -1, 3, -1, 2), - Array(5, -1, 7, -1, 1, 6, -1, 3, -1, 2, -1, 3)) + Array(0, 1, 0, 1, 2, 3, 0, 1, 3, 0, 4, 0, 3, 6, 0), + Array(0, 1, 4, 0, 3, 0, 2, 3, 0, 1, 5, 0), + Array(0, 5, 6, 0, 1, 2, 0, 4, 6, 0, 3, 0, 2, 0), + Array(0, 5, 0, 7, 0, 1, 6, 0, 3, 0, 2, 0, 3, 0)) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan().setMinSupport(0.5).setMaxPatternLength(5) - val result = prefixspan.run(rdd) + val result = PrefixSpan.genFreqPatterns( + rdd, minCount = 2, maxPatternLength = 5, maxLocalProjDBSize = 128L) /* To verify results, create file "prefixSpanSeqs" with content @@ -200,63 +198,87 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { 53 <{1},{2},{1}> 0.50 */ val expectedValue = Array( - (Array(1), 4L), - (Array(2), 4L), - (Array(3), 4L), - (Array(4), 3L), - (Array(5), 3L), - (Array(6), 3L), - (Array(1, -1, 6), 2L), - (Array(2, -1, 6), 2L), - (Array(5, -1, 6), 2L), - (Array(1, 2, -1, 6), 2L), - (Array(1, -1, 4), 2L), - (Array(2, -1, 4), 2L), - (Array(1, 2, -1, 4), 2L), - (Array(1, -1, 3), 4L), - (Array(2, -1, 3), 3L), - (Array(2, 3), 2L), - (Array(3, -1, 3), 3L), - (Array(4, -1, 3), 3L), - (Array(5, -1, 3), 2L), - (Array(6, -1, 3), 2L), - (Array(5, -1, 6, -1, 3), 2L), - (Array(6, -1, 2, -1, 3), 2L), - (Array(5, -1, 2, -1, 3), 2L), - (Array(5, -1, 1, -1, 3), 2L), - (Array(2, -1, 4, -1, 3), 2L), - (Array(1, -1, 4, -1, 3), 2L), - (Array(1, 2, -1, 4, -1, 3), 2L), - (Array(1, -1, 3, -1, 3), 3L), - (Array(1, 2, -1, 3), 2L), - (Array(1, -1, 2, -1, 3), 2L), - (Array(1, -1, 2, 3), 2L), - (Array(1, -1, 2), 4L), - (Array(1, 2), 2L), - (Array(3, -1, 2), 3L), - (Array(4, -1, 2), 2L), - (Array(5, -1, 2), 2L), - (Array(6, -1, 2), 2L), - (Array(5, -1, 6, -1, 2), 2L), - (Array(6, -1, 3, -1, 2), 2L), - (Array(5, -1, 3, -1, 2), 2L), - (Array(5, -1, 1, -1, 2), 2L), - (Array(4, -1, 3, -1, 2), 2L), - (Array(1, -1, 3, -1, 2), 3L), - (Array(5, -1, 6, -1, 3, -1, 2), 2L), - (Array(5, -1, 1, -1, 3, -1, 2), 2L), - (Array(1, -1, 1), 2L), - (Array(2, -1, 1), 2L), - (Array(3, -1, 1), 2L), - (Array(5, -1, 1), 2L), - (Array(2, 3, -1, 1), 2L), - (Array(1, -1, 3, -1, 1), 2L), - (Array(1, -1, 2, 3, -1, 1), 2L), - (Array(1, -1, 2, -1, 1), 2L)) + (Array(0, 1, 0), 4L), + (Array(0, 2, 0), 4L), + (Array(0, 3, 0), 4L), + (Array(0, 4, 0), 3L), + (Array(0, 5, 0), 3L), + (Array(0, 6, 0), 3L), + (Array(0, 1, 0, 6, 0), 2L), + (Array(0, 2, 0, 6, 0), 2L), + (Array(0, 5, 0, 6, 0), 2L), + (Array(0, 1, 2, 0, 6, 0), 2L), + (Array(0, 1, 0, 4, 0), 2L), + (Array(0, 2, 0, 4, 0), 2L), + (Array(0, 1, 2, 0, 4, 0), 2L), + (Array(0, 1, 0, 3, 0), 4L), + (Array(0, 2, 0, 3, 0), 3L), + (Array(0, 2, 3, 0), 2L), + (Array(0, 3, 0, 3, 0), 3L), + (Array(0, 4, 0, 3, 0), 3L), + (Array(0, 5, 0, 3, 0), 2L), + (Array(0, 6, 0, 3, 0), 2L), + (Array(0, 5, 0, 6, 0, 3, 0), 2L), + (Array(0, 6, 0, 2, 0, 3, 0), 2L), + (Array(0, 5, 0, 2, 0, 3, 0), 2L), + (Array(0, 5, 0, 1, 0, 3, 0), 2L), + (Array(0, 2, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 2, 0, 4, 0, 3, 0), 2L), + (Array(0, 1, 0, 3, 0, 3, 0), 3L), + (Array(0, 1, 2, 0, 3, 0), 2L), + (Array(0, 1, 0, 2, 0, 3, 0), 2L), + (Array(0, 1, 0, 2, 3, 0), 2L), + (Array(0, 1, 0, 2, 0), 4L), + (Array(0, 1, 2, 0), 2L), + (Array(0, 3, 0, 2, 0), 3L), + (Array(0, 4, 0, 2, 0), 2L), + (Array(0, 5, 0, 2, 0), 2L), + (Array(0, 6, 0, 2, 0), 2L), + (Array(0, 5, 0, 6, 0, 2, 0), 2L), + (Array(0, 6, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 1, 0, 2, 0), 2L), + (Array(0, 4, 0, 3, 0, 2, 0), 2L), + (Array(0, 1, 0, 3, 0, 2, 0), 3L), + (Array(0, 5, 0, 6, 0, 3, 0, 2, 0), 2L), + (Array(0, 5, 0, 1, 0, 3, 0, 2, 0), 2L), + (Array(0, 1, 0, 1, 0), 2L), + (Array(0, 2, 0, 1, 0), 2L), + (Array(0, 3, 0, 1, 0), 2L), + (Array(0, 5, 0, 1, 0), 2L), + (Array(0, 2, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 2, 3, 0, 1, 0), 2L), + (Array(0, 1, 0, 2, 0, 1, 0), 2L)) compareInternalResults(expectedValue, result.collect()) } + test("PrefixSpan projections with multiple partial starts") { + val sequences = Seq( + Array(Array(1, 2), Array(1, 2, 3))) + val rdd = sc.parallelize(sequences, 2) + val prefixSpan = new PrefixSpan() + .setMinSupport(1.0) + .setMaxPatternLength(2) + val model = prefixSpan.run(rdd) + val expected = Array( + (Array(Array(1)), 1L), + (Array(Array(1, 2)), 1L), + (Array(Array(1), Array(1)), 1L), + (Array(Array(1), Array(2)), 1L), + (Array(Array(1), Array(3)), 1L), + (Array(Array(1, 3)), 1L), + (Array(Array(2)), 1L), + (Array(Array(2, 3)), 1L), + (Array(Array(2), Array(1)), 1L), + (Array(Array(2), Array(2)), 1L), + (Array(Array(2), Array(3)), 1L), + (Array(Array(3)), 1L)) + compareResults(expected, model.freqSequences.collect()) + } + test("PrefixSpan Integer type, variable-size itemsets") { val sequences = Seq( Array(Array(1, 2), Array(3)), @@ -265,7 +287,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { Array(Array(6))) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan() + val prefixSpan = new PrefixSpan() .setMinSupport(0.5) .setMaxPatternLength(5) @@ -296,7 +318,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { 5 <{1,2}> 0.75 */ - val model = prefixspan.run(rdd) + val model = prefixSpan.run(rdd) val expected = Array( (Array(Array(1)), 3L), (Array(Array(2)), 3L), @@ -304,7 +326,7 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { (Array(Array(1), Array(3)), 2L), (Array(Array(1, 2)), 3L) ) - compareResults(expected, model.freqSequences.collect().map(x => (x.sequence, x.freq))) + compareResults(expected, model.freqSequences.collect()) } test("PrefixSpan String type, variable-size itemsets") { @@ -318,11 +340,11 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { Array(Array(6))).map(seq => seq.map(itemSet => itemSet.map(intToString))) val rdd = sc.parallelize(sequences, 2).cache() - val prefixspan = new PrefixSpan() + val prefixSpan = new PrefixSpan() .setMinSupport(0.5) .setMaxPatternLength(5) - val model = prefixspan.run(rdd) + val model = prefixSpan.run(rdd) val expected = Array( (Array(Array(1)), 3L), (Array(Array(2)), 3L), @@ -332,17 +354,17 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { ).map { case (pattern, count) => (pattern.map(itemSet => itemSet.map(intToString)), count) } - compareResults(expected, model.freqSequences.collect().map(x => (x.sequence, x.freq))) + compareResults(expected, model.freqSequences.collect()) } private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], - actualValue: Array[(Array[Array[Item]], Long)]): Unit = { + actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { val expectedSet = expectedValue.map { case (pattern: Array[Array[Item]], count: Long) => (pattern.map(itemSet => itemSet.toSet).toSeq, count) }.toSet - val actualSet = actualValue.map { case (pattern: Array[Array[Item]], count: Long) => - (pattern.map(itemSet => itemSet.toSet).toSeq, count) + val actualSet = actualValue.map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) }.toSet assert(expectedSet === actualSet) } @@ -354,11 +376,4 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { val actualSet = actualValue.map(x => (x._1.toSeq, x._2)).toSet assert(expectedSet === actualSet) } - - private def insertDelimiter(sequence: Array[Int]): Array[Int] = { - sequence.zip(Seq.fill(sequence.length)(PrefixSpan.DELIMITER)).map { case (a, b) => - List(a, b) - }.flatten - } - } -- cgit v1.2.3