aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorzhangjiajin <zhangjiajin@huawei.com>2015-07-10 21:11:46 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-10 21:11:46 -0700
commit7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d (patch)
tree6cf93413bdf4408f4940e74aae62d5e785eb9b38 /mllib
parent9c5075775d741eacbeeb2df77ea30611356b6e1a (diff)
downloadspark-7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d.tar.gz
spark-7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d.tar.bz2
spark-7f6be1f24d4f2fcb3d3ec181b5abf241709a8b6d.zip
[SPARK-6487] [MLLIB] Add sequential pattern mining algorithm PrefixSpan to Spark MLlib
Add parallel PrefixSpan algorithm and test file. Support non-temporal sequences. Author: zhangjiajin <zhangjiajin@huawei.com> Author: zhang jiajin <zhangjiajin@huawei.com> Closes #7258 from zhangjiajin/master and squashes the following commits: ca9c4c8 [zhangjiajin] Modified the code according to the review comments. 574e56c [zhangjiajin] Add new object LocalPrefixSpan, and do some optimization. ba5df34 [zhangjiajin] Fix a Scala style error. 4c60fb3 [zhangjiajin] Fix some Scala style errors. 1dd33ad [zhangjiajin] Modified the code according to the review comments. 89bc368 [zhangjiajin] Fixed a Scala style error. a2eb14c [zhang jiajin] Delete PrefixspanSuite.scala 951fd42 [zhang jiajin] Delete Prefixspan.scala 575995f [zhangjiajin] Modified the code according to the review comments. 91fd7e6 [zhangjiajin] Add new algorithm PrefixSpan and test file.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala113
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala157
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala120
3 files changed, 390 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
new file mode 100644
index 0000000000..39c48b084e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/LocalPrefixSpan.scala
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+
+/**
+ *
+ * :: Experimental ::
+ *
+ * Calculate all patterns of a projected database in local.
+ */
+@Experimental
+private[fpm] object LocalPrefixSpan extends Logging with Serializable {
+
+ /**
+ * Calculate all patterns of a projected database.
+ * @param minCount minimum count
+ * @param maxPatternLength maximum pattern length
+ * @param prefix prefix
+ * @param projectedDatabase the projected dabase
+ * @return a set of sequential pattern pairs,
+ * the key of pair is sequential pattern (a list of items),
+ * the value of pair is the pattern's count.
+ */
+ def run(
+ minCount: Long,
+ maxPatternLength: Int,
+ prefix: Array[Int],
+ projectedDatabase: Array[Array[Int]]): Array[(Array[Int], Long)] = {
+ val frequentPrefixAndCounts = getFreqItemAndCounts(minCount, projectedDatabase)
+ val frequentPatternAndCounts = frequentPrefixAndCounts
+ .map(x => (prefix ++ Array(x._1), x._2))
+ val prefixProjectedDatabases = getPatternAndProjectedDatabase(
+ prefix, frequentPrefixAndCounts.map(_._1), projectedDatabase)
+
+ val continueProcess = prefixProjectedDatabases.nonEmpty && prefix.length + 1 < maxPatternLength
+ if (continueProcess) {
+ val nextPatterns = prefixProjectedDatabases
+ .map(x => run(minCount, maxPatternLength, x._1, x._2))
+ .reduce(_ ++ _)
+ frequentPatternAndCounts ++ nextPatterns
+ } else {
+ frequentPatternAndCounts
+ }
+ }
+
+ /**
+ * calculate suffix sequence following a prefix in a sequence
+ * @param prefix prefix
+ * @param sequence sequence
+ * @return suffix sequence
+ */
+ def getSuffix(prefix: Int, sequence: Array[Int]): Array[Int] = {
+ val index = sequence.indexOf(prefix)
+ if (index == -1) {
+ Array()
+ } else {
+ sequence.drop(index + 1)
+ }
+ }
+
+ /**
+ * Generates frequent items by filtering the input data using minimal count level.
+ * @param minCount the absolute minimum count
+ * @param sequences sequences data
+ * @return array of item and count pair
+ */
+ private def getFreqItemAndCounts(
+ minCount: Long,
+ sequences: Array[Array[Int]]): Array[(Int, Long)] = {
+ sequences.flatMap(_.distinct)
+ .groupBy(x => x)
+ .mapValues(_.length.toLong)
+ .filter(_._2 >= minCount)
+ .toArray
+ }
+
+ /**
+ * Get the frequent prefixes' projected database.
+ * @param prePrefix the frequent prefixes' prefix
+ * @param frequentPrefixes frequent prefixes
+ * @param sequences sequences data
+ * @return prefixes and projected database
+ */
+ private def getPatternAndProjectedDatabase(
+ prePrefix: Array[Int],
+ frequentPrefixes: Array[Int],
+ sequences: Array[Array[Int]]): Array[(Array[Int], Array[Array[Int]])] = {
+ val filteredProjectedDatabase = sequences
+ .map(x => x.filter(frequentPrefixes.contains(_)))
+ frequentPrefixes.map { x =>
+ val sub = filteredProjectedDatabase.map(y => getSuffix(x, y)).filter(_.nonEmpty)
+ (prePrefix ++ Array(x), sub)
+ }.filter(x => x._2.nonEmpty)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
new file mode 100644
index 0000000000..9d8c60ef0f
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ *
+ * :: Experimental ::
+ *
+ * A parallel PrefixSpan algorithm to mine sequential pattern.
+ * The PrefixSpan algorithm is described in
+ * [[http://doi.org/10.1109/ICDE.2001.914830]].
+ *
+ * @param minSupport the minimal support level of the sequential pattern, any pattern appears
+ * more than (minSupport * size-of-the-dataset) times will be output
+ * @param maxPatternLength the maximal length of the sequential pattern, any pattern appears
+ * less than maxPatternLength will be output
+ *
+ * @see [[https://en.wikipedia.org/wiki/Sequential_Pattern_Mining Sequential Pattern Mining
+ * (Wikipedia)]]
+ */
+@Experimental
+class PrefixSpan private (
+ private var minSupport: Double,
+ private var maxPatternLength: Int) extends Logging with Serializable {
+
+ /**
+ * Constructs a default instance with default parameters
+ * {minSupport: `0.1`, maxPatternLength: `10`}.
+ */
+ def this() = this(0.1, 10)
+
+ /**
+ * Sets the minimal support level (default: `0.1`).
+ */
+ def setMinSupport(minSupport: Double): this.type = {
+ require(minSupport >= 0 && minSupport <= 1,
+ "The minimum support value must be between 0 and 1, including 0 and 1.")
+ this.minSupport = minSupport
+ this
+ }
+
+ /**
+ * Sets maximal pattern length (default: `10`).
+ */
+ def setMaxPatternLength(maxPatternLength: Int): this.type = {
+ require(maxPatternLength >= 1,
+ "The maximum pattern length value must be greater than 0.")
+ this.maxPatternLength = maxPatternLength
+ this
+ }
+
+ /**
+ * Find the complete set of sequential patterns in the input sequences.
+ * @param sequences input data set, contains a set of sequences,
+ * a sequence is an ordered list of elements.
+ * @return a set of sequential pattern pairs,
+ * the key of pair is pattern (a list of elements),
+ * the value of pair is the pattern's count.
+ */
+ def run(sequences: RDD[Array[Int]]): RDD[(Array[Int], Long)] = {
+ if (sequences.getStorageLevel == StorageLevel.NONE) {
+ logWarning("Input data is not cached.")
+ }
+ val minCount = getMinCount(sequences)
+ val lengthOnePatternsAndCounts =
+ getFreqItemAndCounts(minCount, sequences).collect()
+ val prefixAndProjectedDatabase = getPrefixAndProjectedDatabase(
+ lengthOnePatternsAndCounts.map(_._1), sequences)
+ val groupedProjectedDatabase = prefixAndProjectedDatabase
+ .map(x => (x._1.toSeq, x._2))
+ .groupByKey()
+ .map(x => (x._1.toArray, x._2.toArray))
+ val nextPatterns = getPatternsInLocal(minCount, groupedProjectedDatabase)
+ val lengthOnePatternsAndCountsRdd =
+ sequences.sparkContext.parallelize(
+ lengthOnePatternsAndCounts.map(x => (Array(x._1), x._2)))
+ val allPatterns = lengthOnePatternsAndCountsRdd ++ nextPatterns
+ allPatterns
+ }
+
+ /**
+ * Get the minimum count (sequences count * minSupport).
+ * @param sequences input data set, contains a set of sequences,
+ * @return minimum count,
+ */
+ private def getMinCount(sequences: RDD[Array[Int]]): Long = {
+ if (minSupport == 0) 0L else math.ceil(sequences.count() * minSupport).toLong
+ }
+
+ /**
+ * Generates frequent items by filtering the input data using minimal count level.
+ * @param minCount the absolute minimum count
+ * @param sequences original sequences data
+ * @return array of item and count pair
+ */
+ private def getFreqItemAndCounts(
+ minCount: Long,
+ sequences: RDD[Array[Int]]): RDD[(Int, Long)] = {
+ sequences.flatMap(_.distinct.map((_, 1L)))
+ .reduceByKey(_ + _)
+ .filter(_._2 >= minCount)
+ }
+
+ /**
+ * Get the frequent prefixes' projected database.
+ * @param frequentPrefixes frequent prefixes
+ * @param sequences sequences data
+ * @return prefixes and projected database
+ */
+ private def getPrefixAndProjectedDatabase(
+ frequentPrefixes: Array[Int],
+ sequences: RDD[Array[Int]]): RDD[(Array[Int], Array[Int])] = {
+ val filteredSequences = sequences.map { p =>
+ p.filter (frequentPrefixes.contains(_) )
+ }
+ filteredSequences.flatMap { x =>
+ frequentPrefixes.map { y =>
+ val sub = LocalPrefixSpan.getSuffix(y, x)
+ (Array(y), sub)
+ }.filter(_._2.nonEmpty)
+ }
+ }
+
+ /**
+ * calculate the patterns in local.
+ * @param minCount the absolute minimum count
+ * @param data patterns and projected sequences data data
+ * @return patterns
+ */
+ private def getPatternsInLocal(
+ minCount: Long,
+ data: RDD[(Array[Int], Array[Array[Int]])]): RDD[(Array[Int], Long)] = {
+ data.flatMap { x =>
+ LocalPrefixSpan.run(minCount, maxPatternLength, x._1, x._2)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
new file mode 100644
index 0000000000..413436d3db
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -0,0 +1,120 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.rdd.RDD
+
+class PrefixspanSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("PrefixSpan using Integer type") {
+
+ /*
+ library("arulesSequences")
+ prefixSpanSeqs = read_baskets("prefixSpanSeqs", info = c("sequenceID","eventID","SIZE"))
+ freqItemSeq = cspade(
+ prefixSpanSeqs,
+ parameter = list(support =
+ 2 / length(unique(transactionInfo(prefixSpanSeqs)$sequenceID)), maxlen = 2 ))
+ resSeq = as(freqItemSeq, "data.frame")
+ resSeq
+ */
+
+ val sequences = Array(
+ Array(1, 3, 4, 5),
+ Array(2, 3, 1),
+ Array(2, 4, 1),
+ Array(3, 1, 3, 4, 5),
+ Array(3, 4, 4, 3),
+ Array(6, 5, 3))
+
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ def compareResult(
+ expectedValue: Array[(Array[Int], Long)],
+ actualValue: Array[(Array[Int], Long)]): Boolean = {
+ val sortedExpectedValue = expectedValue.sortWith{ (x, y) =>
+ x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
+ }
+ val sortedActualValue = actualValue.sortWith{ (x, y) =>
+ x._1.mkString(",") + ":" + x._2 < y._1.mkString(",") + ":" + y._2
+ }
+ sortedExpectedValue.zip(sortedActualValue)
+ .map(x => x._1._1.mkString(",") == x._2._1.mkString(",") && x._1._2 == x._2._2)
+ .reduce(_&&_)
+ }
+
+ val prefixspan = new PrefixSpan()
+ .setMinSupport(0.33)
+ .setMaxPatternLength(50)
+ val result1 = prefixspan.run(rdd)
+ val expectedValue1 = Array(
+ (Array(1), 4L),
+ (Array(1, 3), 2L),
+ (Array(1, 3, 4), 2L),
+ (Array(1, 3, 4, 5), 2L),
+ (Array(1, 3, 5), 2L),
+ (Array(1, 4), 2L),
+ (Array(1, 4, 5), 2L),
+ (Array(1, 5), 2L),
+ (Array(2), 2L),
+ (Array(2, 1), 2L),
+ (Array(3), 5L),
+ (Array(3, 1), 2L),
+ (Array(3, 3), 2L),
+ (Array(3, 4), 3L),
+ (Array(3, 4, 5), 2L),
+ (Array(3, 5), 2L),
+ (Array(4), 4L),
+ (Array(4, 5), 2L),
+ (Array(5), 3L)
+ )
+ assert(compareResult(expectedValue1, result1.collect()))
+
+ prefixspan.setMinSupport(0.5).setMaxPatternLength(50)
+ val result2 = prefixspan.run(rdd)
+ val expectedValue2 = Array(
+ (Array(1), 4L),
+ (Array(3), 5L),
+ (Array(3, 4), 3L),
+ (Array(4), 4L),
+ (Array(5), 3L)
+ )
+ assert(compareResult(expectedValue2, result2.collect()))
+
+ prefixspan.setMinSupport(0.33).setMaxPatternLength(2)
+ val result3 = prefixspan.run(rdd)
+ val expectedValue3 = Array(
+ (Array(1), 4L),
+ (Array(1, 3), 2L),
+ (Array(1, 4), 2L),
+ (Array(1, 5), 2L),
+ (Array(2, 1), 2L),
+ (Array(2), 2L),
+ (Array(3), 5L),
+ (Array(3, 1), 2L),
+ (Array(3, 3), 2L),
+ (Array(3, 4), 3L),
+ (Array(3, 5), 2L),
+ (Array(4), 4L),
+ (Array(4, 5), 2L),
+ (Array(5), 3L)
+ )
+ assert(compareResult(expectedValue3, result3.collect()))
+ }
+}