diff options
author | zhangjiajin <zhangjiajin@huawei.com> | 2015-07-10 21:11:46 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-10 21:11:46 -0700 |
commit | 7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d (patch) | |
tree | 6cf93413bdf4408f4940e74aae62d5e785eb9b38 /mllib | |
parent | 9c5075775d741eacbeeb2df77ea30611356b6e1a (diff) | |
download | spark-7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d.tar.gz spark-7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d.tar.bz2 spark-7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d.zip |
[SPARK-6487] [MLLIB] Add sequential pattern mining algorithm PrefixSpan to Spark MLlib
Add parallel PrefixSpan algorithm and test file.
Support non-temporal sequences.
Author: zhangjiajin <zhangjiajin@huawei.com>
Author: zhang jiajin <zhangjiajin@huawei.com>
Closes #7258 from zhangjiajin/master and squashes the following commits:
ca9c4c8 [zhangjiajin] Modified the code according to the review comments.
574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization.
ba5df34 [zhangjiajin] Fix a Scala style error.
4c60fb3 [zhangjiajin] Fix some Scala style errors.
1dd33ad [zhangjiajin] Modified the code according to the review comments.
89bc368 [zhangjiajin] Fixed a Scala style error.
a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala
951fd42 [zhang jiajin] Delete Prefixspan.scala
575995f [zhangjiajin] Modified the code according to the review comments.
91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file.
Diffstat (limited to 'mllib')
3 files changed, 390 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala new file mode 100644 index 0000000000..39c48b084e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental + +/** + * + * :: Experimental :: + * + * Calculate all patterns of a projected database in local. + */ +@Experimental +private[fpm] object LocalPrefixSpan extends Logging with Serializable { + + /** + * Calculate all patterns of a projected database. + * @param minCount minimum count + * @param maxPatternLength maximum pattern length + * @param prefix prefix + * @param projectedDatabase the projected dabase + * @return a set of sequential pattern pairs, + * the key of pair is sequential pattern (a list of items), + * the value of pair is the pattern's count. + */ + def run( + minCount: Long, + maxPatternLength: Int, + prefix: Array[Int], + projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = { + val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase) + val frequentPatternAndCounts = frequentPrefixAndCounts + .map(x => (prefix ++ Array(x._1), x._2)) + val prefixProjectedDatabases = getPatternAndProjectedDatabase( + prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase) + + val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength + if (continueProcess) { + val nextPatterns = prefixProjectedDatabases + .map(x => run(minCount, maxPatternLength, x._1, x._2)) + .reduce(_ ++ _) + frequentPatternAndCounts ++ nextPatterns + } else { + frequentPatternAndCounts + } + } + + /** + * calculate suffix sequence following a prefix in a sequence + * @param prefix prefix + * @param sequence sequence + * @return suffix sequence + */ + def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = { + val index = sequence.indexOf(prefix) + if (index == -1) { + Array() + } else { + sequence.drop(index + 1) + } + } + + /** + * Generates frequent items by filtering the input data using minimal count level. + * @param minCount the absolute minimum count + * @param sequences sequences data + * @return array of item and count pair + */ + private def getFreqItemAndCounts( + minCount: Long, + sequences: Array[Array[Int]]): Array[(Int, Long)] = { + sequences.flatMap(_.distinct) + .groupBy(x => x) + .mapValues(_.length.toLong) + .filter(_._2 >= minCount) + .toArray + } + + /** + * Get the frequent prefixes' projected database. + * @param prePrefix the frequent prefixes' prefix + * @param frequentPrefixes frequent prefixes + * @param sequences sequences data + * @return prefixes and projected database + */ + private def getPatternAndProjectedDatabase( + prePrefix: Array[Int], + frequentPrefixes: Array[Int], + sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = { + val filteredProjectedDatabase = sequences + .map(x => x.filter(frequentPrefixes.contains(_))) + frequentPrefixes.map { x => + val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty) + (prePrefix ++ Array(x), sub) + }.filter(x => x._2.nonEmpty) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala new file mode 100644 index 0000000000..9d8c60ef0f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -0,0 +1,157 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.fpm + +import org.apache.spark.Logging +import org.apache.spark.annotation.Experimental +import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel + +/** + * + * :: Experimental :: + * + * A parallel PrefixSpan algorithm to mine sequential pattern. + * The PrefixSpan algorithm is described in + * [[http://doi.org/10.1109/ICDE.2001.914830]]. + * + * @param minSupport the minimal support level of the sequential pattern, any pattern appears + * more than (minSupport * size-of-the-dataset) times will be output + * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears + * less than maxPatternLength will be output + * + * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining + * (Wikipedia)]] + */ +@Experimental +class PrefixSpan private ( + private var minSupport: Double, + private var maxPatternLength: Int) extends Logging with Serializable { + + /** + * Constructs a default instance with default parameters + * {minSupport: `0.1`, maxPatternLength: `10`}. + */ + def this() = this(0.1, 10) + + /** + * Sets the minimal support level (default: `0.1`). + */ + def setMinSupport(minSupport: Double): this.type = { + require(minSupport >= 0 && minSupport <= 1, + "The minimum support value must be between 0 and 1, including 0 and 1.") + this.minSupport = minSupport + this + } + + /** + * Sets maximal pattern length (default: `10`). + */ + def setMaxPatternLength(maxPatternLength: Int): this.type = { + require(maxPatternLength >= 1, + "The maximum pattern length value must be greater than 0.") + this.maxPatternLength = maxPatternLength + this + } + + /** + * Find the complete set of sequential patterns in the input sequences. + * @param sequences input data set, contains a set of sequences, + * a sequence is an ordered list of elements. + * @return a set of sequential pattern pairs, + * the key of pair is pattern (a list of elements), + * the value of pair is the pattern's count. + */ + def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = { + if (sequences.getStorageLevel == StorageLevel.NONE) { + logWarning("Input data is not cached.") + } + val minCount = getMinCount(sequences) + val lengthOnePatternsAndCounts = + getFreqItemAndCounts(minCount, sequences).collect() + val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase( + lengthOnePatternsAndCounts.map(_._1), sequences) + val groupedProjectedDatabase = prefixAndProjectedDatabase + .map(x => (x._1.toSeq, x._2)) + .groupByKey() + .map(x => (x._1.toArray, x._2.toArray)) + val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase) + val lengthOnePatternsAndCountsRdd = + sequences.sparkContext.parallelize( + lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2))) + val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns + allPatterns + } + + /** + * Get the minimum count (sequences count * minSupport). + * @param sequences input data set, contains a set of sequences, + * @return minimum count, + */ + private def getMinCount(sequences: RDD[Array[Int]]): Long = { + if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong + } + + /** + * Generates frequent items by filtering the input data using minimal count level. + * @param minCount the absolute minimum count + * @param sequences original sequences data + * @return array of item and count pair + */ + private def getFreqItemAndCounts( + minCount: Long, + sequences: RDD[Array[Int]]): RDD[(Int, Long)] = { + sequences.flatMap(_.distinct.map((_, 1L))) + .reduceByKey(_ + _) + .filter(_._2 >= minCount) + } + + /** + * Get the frequent prefixes' projected database. + * @param frequentPrefixes frequent prefixes + * @param sequences sequences data + * @return prefixes and projected database + */ + private def getPrefixAndProjectedDatabase( + frequentPrefixes: Array[Int], + sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = { + val filteredSequences = sequences.map { p => + p.filter (frequentPrefixes.contains(_) ) + } + filteredSequences.flatMap { x => + frequentPrefixes.map { y => + val sub = LocalPrefixSpan.getSuffix(y, x) + (Array(y), sub) + }.filter(_._2.nonEmpty) + } + } + + /** + * calculate the patterns in local. + * @param minCount the absolute minimum count + * @param data patterns and projected sequences data data + * @return patterns + */ + private def getPatternsInLocal( + minCount: Long, + data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = { + data.flatMap { x => + LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala new file mode 100644 index 0000000000..413436d3db --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.mllib.fpm + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.rdd.RDD + +class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("PrefixSpan using Integer type") { + + /* + library("arulesSequences") + prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE")) + freqItemSeq = cspade( + prefixSpanSeqs, + parameter = list(support = + 2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 )) + resSeq = as(freqItemSeq, "data.frame") + resSeq + */ + + val sequences = Array( + Array(1, 3, 4, 5), + Array(2, 3, 1), + Array(2, 4, 1), + Array(3, 1, 3, 4, 5), + Array(3, 4, 4, 3), + Array(6, 5, 3)) + + val rdd = sc.parallelize(sequences, 2).cache() + + def compareResult( + expectedValue: Array[(Array[Int], Long)], + actualValue: Array[(Array[Int], Long)]): Boolean = { + val sortedExpectedValue = expectedValue.sortWith{ (x, y) => + x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 + } + val sortedActualValue = actualValue.sortWith{ (x, y) => + x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2 + } + sortedExpectedValue.zip(sortedActualValue) + .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2) + .reduce(_&&_) + } + + val prefixspan = new PrefixSpan() + .setMinSupport(0.33) + .setMaxPatternLength(50) + val result1 = prefixspan.run(rdd) + val expectedValue1 = Array( + (Array(1), 4L), + (Array(1, 3), 2L), + (Array(1, 3, 4), 2L), + (Array(1, 3, 4, 5), 2L), + (Array(1, 3, 5), 2L), + (Array(1, 4), 2L), + (Array(1, 4, 5), 2L), + (Array(1, 5), 2L), + (Array(2), 2L), + (Array(2, 1), 2L), + (Array(3), 5L), + (Array(3, 1), 2L), + (Array(3, 3), 2L), + (Array(3, 4), 3L), + (Array(3, 4, 5), 2L), + (Array(3, 5), 2L), + (Array(4), 4L), + (Array(4, 5), 2L), + (Array(5), 3L) + ) + assert(compareResult(expectedValue1, result1.collect())) + + prefixspan.setMinSupport(0.5).setMaxPatternLength(50) + val result2 = prefixspan.run(rdd) + val expectedValue2 = Array( + (Array(1), 4L), + (Array(3), 5L), + (Array(3, 4), 3L), + (Array(4), 4L), + (Array(5), 3L) + ) + assert(compareResult(expectedValue2, result2.collect())) + + prefixspan.setMinSupport(0.33).setMaxPatternLength(2) + val result3 = prefixspan.run(rdd) + val expectedValue3 = Array( + (Array(1), 4L), + (Array(1, 3), 2L), + (Array(1, 4), 2L), + (Array(1, 5), 2L), + (Array(2, 1), 2L), + (Array(2), 2L), + (Array(3), 5L), + (Array(3, 1), 2L), + (Array(3, 3), 2L), + (Array(3, 4), 3L), + (Array(3, 5), 2L), + (Array(4), 4L), + (Array(4, 5), 2L), + (Array(5), 3L) + ) + assert(compareResult(expectedValue3, result3.collect())) + } +} |