aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorJacky Li <jacky.likun@huawei.com>2015-02-01 20:07:25 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-01 20:07:25 -0800
commit859f7249a614c86fc1691cc3116463f85f33f153 (patch)
tree7f16495e4023248f5620b5454f582070b4bdf68f /mllib/src
parentd85cd4eb1479f8d37dab360530dc2c71216b4a8d (diff)
downloadspark-859f7249a614c86fc1691cc3116463f85f33f153.tar.gz
spark-859f7249a614c86fc1691cc3116463f85f33f153.tar.bz2
spark-859f7249a614c86fc1691cc3116463f85f33f153.zip
[SPARK-4001][MLlib] adding parallel FP-Growth algorithm for frequent pattern mining in MLlib
Apriori is the classic algorithm for frequent item set mining in a transactional data set. It will be useful if Apriori algorithm is added to MLLib in Spark. This PR add an implementation for it. There is a point I am not sure wether it is most efficient. In order to filter out the eligible frequent item set, currently I am using a cartesian operation on two RDDs to calculate the degree of support of each item set, not sure wether it is better to use broadcast variable to achieve the same. I will add an example to use this algorithm if requires Author: Jacky Li <jacky.likun@huawei.com> Author: Jacky Li <jackylk@users.noreply.github.com> Author: Xiangrui Meng <meng@databricks.com> Closes #2847 from jackylk/apriori and squashes the following commits: bee3093 [Jacky Li] Merge pull request #1 from mengxr/SPARK-4001 7e69725 [Xiangrui Meng] simplify FPTree and update FPGrowth ec21f7d [Jacky Li] fix scalastyle 93f3280 [Jacky Li] create FPTree class d110ab2 [Jacky Li] change test case to use MLlibTestSparkContext a6c5081 [Jacky Li] Add Parallel FPGrowth algorithm eb3e4ca [Jacky Li] add FPGrowth 03df2b6 [Jacky Li] refactory according to comments 7b77ad7 [Jacky Li] fix scalastyle check f68a0bd [Jacky Li] add 2 apriori implemenation and fp-growth implementation 889b33f [Jacky Li] modify per scalastyle check da2cba7 [Jacky Li] adding apriori algorithm for frequent item set mining in Spark
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala162
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala134
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala73
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala115
4 files changed, 484 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
new file mode 100644
index 0000000000..9591c7966e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -0,0 +1,162 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import java.{util => ju}
+
+import scala.collection.mutable
+
+import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
+
+/**
+ * This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
+ * Parallel FPGrowth (PFP) partitions computation in such a way that each machine executes an
+ * independent group of mining tasks. More detail of this algorithm can be found at
+ * [[http://dx.doi.org/10.1145/1454008.1454027, PFP]], and the original FP-growth paper can be
+ * found at [[http://dx.doi.org/10.1145/335191.335372, FP-growth]]
+ *
+ * @param minSupport the minimal support level of the frequent pattern, any pattern appears
+ * more than (minSupport * size-of-the-dataset) times will be output
+ * @param numPartitions number of partitions used by parallel FP-growth
+ */
+class FPGrowth private (
+ private var minSupport: Double,
+ private var numPartitions: Int) extends Logging with Serializable {
+
+ /**
+ * Constructs a FPGrowth instance with default parameters:
+ * {minSupport: 0.3, numPartitions: auto}
+ */
+ def this() = this(0.3, -1)
+
+ /**
+ * Sets the minimal support level (default: 0.3).
+ */
+ def setMinSupport(minSupport: Double): this.type = {
+ this.minSupport = minSupport
+ this
+ }
+
+ /**
+ * Sets the number of partitions used by parallel FP-growth (default: same as input data).
+ */
+ def setNumPartitions(numPartitions: Int): this.type = {
+ this.numPartitions = numPartitions
+ this
+ }
+
+ /**
+ * Computes an FP-Growth model that contains frequent itemsets.
+ * @param data input data set, each element contains a transaction
+ * @return an [[FPGrowthModel]]
+ */
+ def run(data: RDD[Array[String]]): FPGrowthModel = {
+ if (data.getStorageLevel == StorageLevel.NONE) {
+ logWarning("Input data is not cached.")
+ }
+ val count = data.count()
+ val minCount = math.ceil(minSupport * count).toLong
+ val numParts = if (numPartitions > 0) numPartitions else data.partitions.length
+ val partitioner = new HashPartitioner(numParts)
+ val freqItems = genFreqItems(data, minCount, partitioner)
+ val freqItemsets = genFreqItemsets(data, minCount, freqItems, partitioner)
+ new FPGrowthModel(freqItemsets)
+ }
+
+ /**
+ * Generates frequent items by filtering the input data using minimal support level.
+ * @param minCount minimum count for frequent itemsets
+ * @param partitioner partitioner used to distribute items
+ * @return array of frequent pattern ordered by their frequencies
+ */
+ private def genFreqItems(
+ data: RDD[Array[String]],
+ minCount: Long,
+ partitioner: Partitioner): Array[String] = {
+ data.flatMap { t =>
+ val uniq = t.toSet
+ if (t.length != uniq.size) {
+ throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
+ }
+ t
+ }.map(v => (v, 1L))
+ .reduceByKey(partitioner, _ + _)
+ .filter(_._2 >= minCount)
+ .collect()
+ .sortBy(-_._2)
+ .map(_._1)
+ }
+
+ /**
+ * Generate frequent itemsets by building FP-Trees, the extraction is done on each partition.
+ * @param data transactions
+ * @param minCount minimum count for frequent itemsets
+ * @param freqItems frequent items
+ * @param partitioner partitioner used to distribute transactions
+ * @return an RDD of (frequent itemset, count)
+ */
+ private def genFreqItemsets(
+ data: RDD[Array[String]],
+ minCount: Long,
+ freqItems: Array[String],
+ partitioner: Partitioner): RDD[(Array[String], Long)] = {
+ val itemToRank = freqItems.zipWithIndex.toMap
+ data.flatMap { transaction =>
+ genCondTransactions(transaction, itemToRank, partitioner)
+ }.aggregateByKey(new FPTree[Int], partitioner.numPartitions)(
+ (tree, transaction) => tree.add(transaction, 1L),
+ (tree1, tree2) => tree1.merge(tree2))
+ .flatMap { case (part, tree) =>
+ tree.extract(minCount, x => partitioner.getPartition(x) == part)
+ }.map { case (ranks, count) =>
+ (ranks.map(i => freqItems(i)).toArray, count)
+ }
+ }
+
+ /**
+ * Generates conditional transactions.
+ * @param transaction a transaction
+ * @param itemToRank map from item to their rank
+ * @param partitioner partitioner used to distribute transactions
+ * @return a map of (target partition, conditional transaction)
+ */
+ private def genCondTransactions(
+ transaction: Array[String],
+ itemToRank: Map[String, Int],
+ partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
+ val output = mutable.Map.empty[Int, Array[Int]]
+ // Filter the basket by frequent items pattern and sort their ranks.
+ val filtered = transaction.flatMap(itemToRank.get)
+ ju.Arrays.sort(filtered)
+ val n = filtered.length
+ var i = n - 1
+ while (i >= 0) {
+ val item = filtered(i)
+ val part = partitioner.getPartition(item)
+ if (!output.contains(part)) {
+ output(part) = filtered.slice(0, i + 1)
+ }
+ i -= 1
+ }
+ output
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala
new file mode 100644
index 0000000000..1d2d777c00
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPTree.scala
@@ -0,0 +1,134 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import scala.collection.mutable
+import scala.collection.mutable.ListBuffer
+
+/**
+ * FP-Tree data structure used in FP-Growth.
+ * @tparam T item type
+ */
+private[fpm] class FPTree[T] extends Serializable {
+
+ import FPTree._
+
+ val root: Node[T] = new Node(null)
+
+ private val summaries: mutable.Map[T, Summary[T]] = mutable.Map.empty
+
+ /** Adds a transaction with count. */
+ def add(t: Iterable[T], count: Long = 1L): this.type = {
+ require(count > 0)
+ var curr = root
+ curr.count += count
+ t.foreach { item =>
+ val summary = summaries.getOrElseUpdate(item, new Summary)
+ summary.count += count
+ val child = curr.children.getOrElseUpdate(item, {
+ val newNode = new Node(curr)
+ newNode.item = item
+ summary.nodes += newNode
+ newNode
+ })
+ child.count += count
+ curr = child
+ }
+ this
+ }
+
+ /** Merges another FP-Tree. */
+ def merge(other: FPTree[T]): this.type = {
+ other.transactions.foreach { case (t, c) =>
+ add(t, c)
+ }
+ this
+ }
+
+ /** Gets a subtree with the suffix. */
+ private def project(suffix: T): FPTree[T] = {
+ val tree = new FPTree[T]
+ if (summaries.contains(suffix)) {
+ val summary = summaries(suffix)
+ summary.nodes.foreach { node =>
+ var t = List.empty[T]
+ var curr = node.parent
+ while (!curr.isRoot) {
+ t = curr.item :: t
+ curr = curr.parent
+ }
+ tree.add(t, node.count)
+ }
+ }
+ tree
+ }
+
+ /** Returns all transactions in an iterator. */
+ def transactions: Iterator[(List[T], Long)] = getTransactions(root)
+
+ /** Returns all transactions under this node. */
+ private def getTransactions(node: Node[T]): Iterator[(List[T], Long)] = {
+ var count = node.count
+ node.children.iterator.flatMap { case (item, child) =>
+ getTransactions(child).map { case (t, c) =>
+ count -= c
+ (item :: t, c)
+ }
+ } ++ {
+ if (count > 0) {
+ Iterator.single((Nil, count))
+ } else {
+ Iterator.empty
+ }
+ }
+ }
+
+ /** Extracts all patterns with valid suffix and minimum count. */
+ def extract(
+ minCount: Long,
+ validateSuffix: T => Boolean = _ => true): Iterator[(List[T], Long)] = {
+ summaries.iterator.flatMap { case (item, summary) =>
+ if (validateSuffix(item) && summary.count >= minCount) {
+ Iterator.single((item :: Nil, summary.count)) ++
+ project(item).extract(minCount).map { case (t, c) =>
+ (item :: t, c)
+ }
+ } else {
+ Iterator.empty
+ }
+ }
+ }
+}
+
+private[fpm] object FPTree {
+
+ /** Representing a node in an FP-Tree. */
+ class Node[T](val parent: Node[T]) extends Serializable {
+ var item: T = _
+ var count: Long = 0L
+ val children: mutable.Map[T, Node[T]] = mutable.Map.empty
+
+ def isRoot: Boolean = parent == null
+ }
+
+ /** Summary of a item in an FP-Tree. */
+ private class Summary[T] extends Serializable {
+ var count: Long = 0L
+ val nodes: ListBuffer[Node[T]] = ListBuffer.empty
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
new file mode 100644
index 0000000000..71ef60da6d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
+
+ test("FP-Growth") {
+ val transactions = Seq(
+ "r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p")
+ .map(_.split(" "))
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val fpg = new FPGrowth()
+
+ val model6 = fpg
+ .setMinSupport(0.9)
+ .setNumPartitions(1)
+ .run(rdd)
+ assert(model6.freqItemsets.count() === 0)
+
+ val model3 = fpg
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
+ (items.toSet, count)
+ }
+ val expected = Set(
+ (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
+ (Set("r"), 3L),
+ (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
+ (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
+ (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
+ (Set("t", "y", "x"), 3L),
+ (Set("t", "y", "x", "z"), 3L))
+ assert(freqItemsets3.toSet === expected)
+
+ val model2 = fpg
+ .setMinSupport(0.3)
+ .setNumPartitions(4)
+ .run(rdd)
+ assert(model2.freqItemsets.count() === 54)
+
+ val model1 = fpg
+ .setMinSupport(0.1)
+ .setNumPartitions(8)
+ .run(rdd)
+ assert(model1.freqItemsets.count() === 625)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
new file mode 100644
index 0000000000..04017f67c3
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPTreeSuite.scala
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm
+
+import scala.language.existentials
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class FPTreeSuite extends FunSuite with MLlibTestSparkContext {
+
+ test("add transaction") {
+ val tree = new FPTree[String]
+ .add(Seq("a", "b", "c"))
+ .add(Seq("a", "b", "y"))
+ .add(Seq("b"))
+
+ assert(tree.root.children.size == 2)
+ assert(tree.root.children.contains("a"))
+ assert(tree.root.children("a").item.equals("a"))
+ assert(tree.root.children("a").count == 2)
+ assert(tree.root.children.contains("b"))
+ assert(tree.root.children("b").item.equals("b"))
+ assert(tree.root.children("b").count == 1)
+ var child = tree.root.children("a")
+ assert(child.children.size == 1)
+ assert(child.children.contains("b"))
+ assert(child.children("b").item.equals("b"))
+ assert(child.children("b").count == 2)
+ child = child.children("b")
+ assert(child.children.size == 2)
+ assert(child.children.contains("c"))
+ assert(child.children.contains("y"))
+ assert(child.children("c").item.equals("c"))
+ assert(child.children("y").item.equals("y"))
+ assert(child.children("c").count == 1)
+ assert(child.children("y").count == 1)
+ }
+
+ test("merge tree") {
+ val tree1 = new FPTree[String]
+ .add(Seq("a", "b", "c"))
+ .add(Seq("a", "b", "y"))
+ .add(Seq("b"))
+
+ val tree2 = new FPTree[String]
+ .add(Seq("a", "b"))
+ .add(Seq("a", "b", "c"))
+ .add(Seq("a", "b", "c", "d"))
+ .add(Seq("a", "x"))
+ .add(Seq("a", "x", "y"))
+ .add(Seq("c", "n"))
+ .add(Seq("c", "m"))
+
+ val tree3 = tree1.merge(tree2)
+
+ assert(tree3.root.children.size == 3)
+ assert(tree3.root.children("a").count == 7)
+ assert(tree3.root.children("b").count == 1)
+ assert(tree3.root.children("c").count == 2)
+ val child1 = tree3.root.children("a")
+ assert(child1.children.size == 2)
+ assert(child1.children("b").count == 5)
+ assert(child1.children("x").count == 2)
+ val child2 = child1.children("b")
+ assert(child2.children.size == 2)
+ assert(child2.children("y").count == 1)
+ assert(child2.children("c").count == 3)
+ val child3 = child2.children("c")
+ assert(child3.children.size == 1)
+ assert(child3.children("d").count == 1)
+ val child4 = child1.children("x")
+ assert(child4.children.size == 1)
+ assert(child4.children("y").count == 1)
+ val child5 = tree3.root.children("c")
+ assert(child5.children.size == 2)
+ assert(child5.children("n").count == 1)
+ assert(child5.children("m").count == 1)
+ }
+
+ test("extract freq itemsets") {
+ val tree = new FPTree[String]
+ .add(Seq("a", "b", "c"))
+ .add(Seq("a", "b", "y"))
+ .add(Seq("a", "b"))
+ .add(Seq("a"))
+ .add(Seq("b"))
+ .add(Seq("b", "n"))
+
+ val freqItemsets = tree.extract(3L).map { case (items, count) =>
+ (items.toSet, count)
+ }.toSet
+ val expected = Set(
+ (Set("a"), 4L),
+ (Set("b"), 5L),
+ (Set("a", "b"), 3L))
+ assert(freqItemsets === expected)
+ }
+}