aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJacky Li <jacky.likun@huawei.com>2015-02-03 17:02:42 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-03 17:02:42 -0800
commite380d2d46c92b319eafe30974ac7c1509081fca4 (patch)
tree687991ded2c9e13324585d8160397fc89ca75478 /mllib
parent068c0e2ee05ee8b133c2dc26b8fa094ab2712d45 (diff)
downloadspark-e380d2d46c92b319eafe30974ac7c1509081fca4.tar.gz
spark-e380d2d46c92b319eafe30974ac7c1509081fca4.tar.bz2
spark-e380d2d46c92b319eafe30974ac7c1509081fca4.zip
[SPARK-5520][MLlib] Make FP-Growth implementation take generic item types (WIP)
Make FPGrowth.run API take generic item types: `def run[Item: ClassTag, Basket <: Iterable[Item]](data: RDD[Basket]): FPGrowthModel[Item]` so that user can invoke it by run[String, Seq[String]], run[Int, Seq[Int]], run[Int, List[Int]], etc. Scala part is done, while java part is still in progress Author: Jacky Li <jacky.likun@huawei.com> Author: Jacky Li <jackylk@users.noreply.github.com> Author: Xiangrui Meng <meng@databricks.com> Closes #4340 from jackylk/SPARK-5520-WIP and squashes the following commits: f5acf84 [Jacky Li] Merge pull request #2 from mengxr/SPARK-5520 63073d0 [Xiangrui Meng] update to make generic FPGrowth Java-friendly 737d8bb [Jacky Li] fix scalastyle 793f85c [Jacky Li] add Java test case 7783351 [Jacky Li] add generic support in FPGrowth
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala50
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java84
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala51
3 files changed, 170 insertions, 15 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 9591c7966e..1433ee9a0d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -18,14 +18,31 @@
package org.apache.spark.mllib.fpm
import java.{util => ju}
+import java.lang.{Iterable => JavaIterable}
import scala.collection.mutable
+import scala.collection.JavaConverters._
+import scala.reflect.ClassTag
-import org.apache.spark.{SparkException, HashPartitioner, Logging, Partitioner}
+import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
+import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
-class FPGrowthModel(val freqItemsets: RDD[(Array[String], Long)]) extends Serializable
+/**
+ * Model trained by [[FPGrowth]], which holds frequent itemsets.
+ * @param freqItemsets frequent itemset, which is an RDD of (itemset, frequency) pairs
+ * @tparam Item item type
+ */
+class FPGrowthModel[Item: ClassTag](
+ val freqItemsets: RDD[(Array[Item], Long)]) extends Serializable {
+
+ /** Returns frequent itemsets as a [[org.apache.spark.api.java.JavaPairRDD]]. */
+ def javaFreqItemsets(): JavaPairRDD[Array[Item], java.lang.Long] = {
+ JavaPairRDD.fromRDD(freqItemsets).asInstanceOf[JavaPairRDD[Array[Item], java.lang.Long]]
+ }
+}
/**
* This class implements Parallel FP-growth algorithm to do frequent pattern matching on input data.
@@ -69,7 +86,7 @@ class FPGrowth private (
* @param data input data set, each element contains a transaction
* @return an [[FPGrowthModel]]
*/
- def run(data: RDD[Array[String]]): FPGrowthModel = {
+ def run[Item: ClassTag](data: RDD[Array[Item]]): FPGrowthModel[Item] = {
if (data.getStorageLevel == StorageLevel.NONE) {
logWarning("Input data is not cached.")
}
@@ -82,19 +99,24 @@ class FPGrowth private (
new FPGrowthModel(freqItemsets)
}
+ def run[Item, Basket <: JavaIterable[Item]](data: JavaRDD[Basket]): FPGrowthModel[Item] = {
+ implicit val tag = fakeClassTag[Item]
+ run(data.rdd.map(_.asScala.toArray))
+ }
+
/**
* Generates frequent items by filtering the input data using minimal support level.
* @param minCount minimum count for frequent itemsets
* @param partitioner partitioner used to distribute items
* @return array of frequent pattern ordered by their frequencies
*/
- private def genFreqItems(
- data: RDD[Array[String]],
+ private def genFreqItems[Item: ClassTag](
+ data: RDD[Array[Item]],
minCount: Long,
- partitioner: Partitioner): Array[String] = {
+ partitioner: Partitioner): Array[Item] = {
data.flatMap { t =>
val uniq = t.toSet
- if (t.length != uniq.size) {
+ if (t.size != uniq.size) {
throw new SparkException(s"Items in a transaction must be unique but got ${t.toSeq}.")
}
t
@@ -114,11 +136,11 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return an RDD of (frequent itemset, count)
*/
- private def genFreqItemsets(
- data: RDD[Array[String]],
+ private def genFreqItemsets[Item: ClassTag](
+ data: RDD[Array[Item]],
minCount: Long,
- freqItems: Array[String],
- partitioner: Partitioner): RDD[(Array[String], Long)] = {
+ freqItems: Array[Item],
+ partitioner: Partitioner): RDD[(Array[Item], Long)] = {
val itemToRank = freqItems.zipWithIndex.toMap
data.flatMap { transaction =>
genCondTransactions(transaction, itemToRank, partitioner)
@@ -139,9 +161,9 @@ class FPGrowth private (
* @param partitioner partitioner used to distribute transactions
* @return a map of (target partition, conditional transaction)
*/
- private def genCondTransactions(
- transaction: Array[String],
- itemToRank: Map[String, Int],
+ private def genCondTransactions[Item: ClassTag](
+ transaction: Array[Item],
+ itemToRank: Map[Item, Int],
partitioner: Partitioner): mutable.Map[Int, Array[Int]] = {
val output = mutable.Map.empty[Int, Array[Int]]
// Filter the basket by frequent items pattern and sort their ranks.
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
new file mode 100644
index 0000000000..851707c8a1
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.fpm;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import com.google.common.collect.Lists;
+import static org.junit.Assert.*;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaFPGrowthSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaFPGrowth");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runFPGrowth() {
+
+ @SuppressWarnings("unchecked")
+ JavaRDD<ArrayList<String>> rdd = sc.parallelize(Lists.newArrayList(
+ Lists.newArrayList("r z h k p".split(" ")),
+ Lists.newArrayList("z y x w v u t s".split(" ")),
+ Lists.newArrayList("s x o n r".split(" ")),
+ Lists.newArrayList("x z y m t s q e".split(" ")),
+ Lists.newArrayList("z".split(" ")),
+ Lists.newArrayList("x z y r q t p".split(" "))), 2);
+
+ FPGrowth fpg = new FPGrowth();
+
+ FPGrowthModel<String> model6 = fpg
+ .setMinSupport(0.9)
+ .setNumPartitions(1)
+ .run(rdd);
+ assertEquals(0, model6.javaFreqItemsets().count());
+
+ FPGrowthModel<String> model3 = fpg
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd);
+ assertEquals(18, model3.javaFreqItemsets().count());
+
+ FPGrowthModel<String> model2 = fpg
+ .setMinSupport(0.3)
+ .setNumPartitions(4)
+ .run(rdd);
+ assertEquals(54, model2.javaFreqItemsets().count());
+
+ FPGrowthModel<String> model1 = fpg
+ .setMinSupport(0.1)
+ .setNumPartitions(8)
+ .run(rdd);
+ assertEquals(625, model1.javaFreqItemsets().count());
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 71ef60da6d..68128284b8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -22,7 +22,8 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
- test("FP-Growth") {
+
+ test("FP-Growth using String type") {
val transactions = Seq(
"r z h k p",
"z y x w v u t s",
@@ -70,4 +71,52 @@ class FPGrowthSuite extends FunSuite with MLlibTestSparkContext {
.run(rdd)
assert(model1.freqItemsets.count() === 625)
}
+
+ test("FP-Growth using Int type") {
+ val transactions = Seq(
+ "1 2 3",
+ "1 2 3 4",
+ "5 4 3 2 1",
+ "6 5 4 3 2 1",
+ "2 4",
+ "1 3",
+ "1 7")
+ .map(_.split(" ").map(_.toInt).toArray)
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val fpg = new FPGrowth()
+
+ val model6 = fpg
+ .setMinSupport(0.9)
+ .setNumPartitions(1)
+ .run(rdd)
+ assert(model6.freqItemsets.count() === 0)
+
+ val model3 = fpg
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ assert(model3.freqItemsets.first()._1.getClass === Array(1).getClass,
+ "frequent itemsets should use primitive arrays")
+ val freqItemsets3 = model3.freqItemsets.collect().map { case (items, count) =>
+ (items.toSet, count)
+ }
+ val expected = Set(
+ (Set(1), 6L), (Set(2), 5L), (Set(3), 5L), (Set(4), 4L),
+ (Set(1, 2), 4L), (Set(1, 3), 5L), (Set(2, 3), 4L),
+ (Set(2, 4), 4L), (Set(1, 2, 3), 4L))
+ assert(freqItemsets3.toSet === expected)
+
+ val model2 = fpg
+ .setMinSupport(0.3)
+ .setNumPartitions(4)
+ .run(rdd)
+ assert(model2.freqItemsets.count() === 15)
+
+ val model1 = fpg
+ .setMinSupport(0.1)
+ .setNumPartitions(8)
+ .run(rdd)
+ assert(model1.freqItemsets.count() === 65)
+ }
}