aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-07-07 11:34:30 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-07 11:34:30 -0700
commit3336c7b148ad543d1f9b64ca2b559ea04930f5be (patch)
tree5834336b9b0f9db2b26e236fa3fa2713e001140a /mllib
parent70beb808e13f6371968ac87f7cf625ed110375e6 (diff)
downloadspark-3336c7b148ad543d1f9b64ca2b559ea04930f5be.tar.gz
spark-3336c7b148ad543d1f9b64ca2b559ea04930f5be.tar.bz2
spark-3336c7b148ad543d1f9b64ca2b559ea04930f5be.zip
[SPARK-8559] [MLLIB] Support Association Rule Generation
Distributed generation of single-consequent association rules from a RDD of frequent itemsets. Tests referenced against `R`'s implementation of A Priori in [arules](http://cran.r-project.org/web/packages/arules/index.html). Author: Feynman Liang <fliang@databricks.com> Closes #7005 from feynmanliang/fp-association-rules-distributed and squashes the following commits: 466ced0 [Feynman Liang] Refactor AR generation impl 73c1cff [Feynman Liang] Make rule attributes public, remove numTransactions from FreqItemset 80f63ff [Feynman Liang] Change default confidence and optimize imports 04cf5b5 [Feynman Liang] Code review with @mengxr, add R to tests 0cc1a6a [Feynman Liang] Java compatibility test f3c14b5 [Feynman Liang] Fix MiMa test 764375e [Feynman Liang] Fix tests 1187307 [Feynman Liang] Almost working tests b20779b [Feynman Liang] Working implementation 5395c4e [Feynman Liang] Fix imports 2d34405 [Feynman Liang] Partial implementation of distributed ar 83ace4b [Feynman Liang] Local rule generation without pruning complete 69c2c87 [Feynman Liang] Working local implementation, now to parallelize../.. 4e1ec9a [Feynman Liang] Pull FreqItemsets out, refactor type param, tests 69ccedc [Feynman Liang] First implementation of association rule generation
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala108
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala2
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java58
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java5
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala89
5 files changed, 258 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
new file mode 100644
index 0000000000..4a0f842f33
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.mllib.fpm.AssociationRules.Rule
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
+import org.apache.spark.rdd.RDD
+
+/**
+ * :: Experimental ::
+ *
+ * Generates association rules from a [[RDD[FreqItemset[Item]]]. This method only generates
+ * association rules which have a single item as the consequent.
+ */
+@Experimental
+class AssociationRules private (
+ private var minConfidence: Double) extends Logging with Serializable {
+
+ /**
+ * Constructs a default instance with default parameters {minConfidence = 0.8}.
+ */
+ def this() = this(0.8)
+
+ /**
+ * Sets the minimal confidence (default: `0.8`).
+ */
+ def setMinConfidence(minConfidence: Double): this.type = {
+ this.minConfidence = minConfidence
+ this
+ }
+
+ /**
+ * Computes the association rules with confidence above [[minConfidence]].
+ * @param freqItemsets frequent itemset model obtained from [[FPGrowth]]
+ * @return a [[Set[Rule[Item]]] containing the assocation rules.
+ */
+ def run[Item: ClassTag](freqItemsets: RDD[FreqItemset[Item]]): RDD[Rule[Item]] = {
+ // For candidate rule X => Y, generate (X, (Y, freq(X union Y)))
+ val candidates = freqItemsets.flatMap { itemset =>
+ val items = itemset.items
+ items.flatMap { item =>
+ items.partition(_ == item) match {
+ case (consequent, antecedent) if !antecedent.isEmpty =>
+ Some((antecedent.toSeq, (consequent.toSeq, itemset.freq)))
+ case _ => None
+ }
+ }
+ }
+
+ // Join to get (X, ((Y, freq(X union Y)), freq(X))), generate rules, and filter by confidence
+ candidates.join(freqItemsets.map(x => (x.items.toSeq, x.freq)))
+ .map { case (antecendent, ((consequent, freqUnion), freqAntecedent)) =>
+ new Rule(antecendent.toArray, consequent.toArray, freqUnion, freqAntecedent)
+ }.filter(_.confidence >= minConfidence)
+ }
+
+ def run[Item](freqItemsets: JavaRDD[FreqItemset[Item]]): JavaRDD[Rule[Item]] = {
+ val tag = fakeClassTag[Item]
+ run(freqItemsets.rdd)(tag)
+ }
+}
+
+object AssociationRules {
+
+ /**
+ * :: Experimental ::
+ *
+ * An association rule between sets of items.
+ * @param antecedent hypotheses of the rule
+ * @param consequent conclusion of the rule
+ * @tparam Item item type
+ */
+ @Experimental
+ class Rule[Item] private[mllib] (
+ val antecedent: Array[Item],
+ val consequent: Array[Item],
+ freqUnion: Double,
+ freqAntecedent: Double) extends Serializable {
+
+ def confidence: Double = freqUnion.toDouble / freqAntecedent
+
+ require(antecedent.toSet.intersect(consequent.toSet).isEmpty, {
+ val sharedItems = antecedent.toSet.intersect(consequent.toSet)
+ s"A valid association rule must have disjoint antecedent and " +
+ s"consequent but ${sharedItems} is present in both."
+ })
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index efa8459d3c..0da59e812d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -28,7 +28,7 @@ import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
-import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset
+import org.apache.spark.mllib.fpm.FPGrowth._
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
new file mode 100644
index 0000000000..b3815ae603
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaAssociationRulesSuite.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm;
+
+import java.io.Serializable;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import com.google.common.collect.Lists;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
+
+
+public class JavaAssociationRulesSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaFPGrowth");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void runAssociationRules() {
+
+ @SuppressWarnings("unchecked")
+ JavaRDD<FPGrowth.FreqItemset<String>> freqItemsets = sc.parallelize(Lists.newArrayList(
+ new FreqItemset<String>(new String[] {"a"}, 15L),
+ new FreqItemset<String>(new String[] {"b"}, 35L),
+ new FreqItemset<String>(new String[] {"a", "b"}, 18L)
+ ));
+
+ JavaRDD<AssociationRules.Rule<String>> results = (new AssociationRules()).run(freqItemsets);
+ }
+}
+
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index bd0edf2b9e..9ce2c52dca 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -29,7 +29,6 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.mllib.fpm.FPGrowth.FreqItemset;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -62,10 +61,10 @@ public class JavaFPGrowthSuite implements Serializable {
.setNumPartitions(2)
.run(rdd);
- List<FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
+ List<FPGrowth.FreqItemset<String>> freqItemsets = model.freqItemsets().toJavaRDD().collect();
assertEquals(18, freqItemsets.size());
- for (FreqItemset<String> itemset: freqItemsets) {
+ for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
// Test return types.
List<String> items = itemset.javaItems();
long freq = itemset.freq();
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
new file mode 100644
index 0000000000..77a2773c36
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/AssociationRulesSuite.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.mllib.fpm
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class AssociationRulesSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("association rules using String type") {
+ val freqItemsets = sc.parallelize(Seq(
+ (Set("s"), 3L), (Set("z"), 5L), (Set("x"), 4L), (Set("t"), 3L), (Set("y"), 3L),
+ (Set("r"), 3L),
+ (Set("x", "z"), 3L), (Set("t", "y"), 3L), (Set("t", "x"), 3L), (Set("s", "x"), 3L),
+ (Set("y", "x"), 3L), (Set("y", "z"), 3L), (Set("t", "z"), 3L),
+ (Set("y", "x", "z"), 3L), (Set("t", "x", "z"), 3L), (Set("t", "y", "z"), 3L),
+ (Set("t", "y", "x"), 3L),
+ (Set("t", "y", "x", "z"), 3L)
+ ).map {
+ case (items, freq) => new FPGrowth.FreqItemset(items.toArray, freq)
+ })
+
+ val ar = new AssociationRules()
+
+ val results1 = ar
+ .setMinConfidence(0.9)
+ .run(freqItemsets)
+ .collect()
+
+ /* Verify results using the `R` code:
+ transactions = as(sapply(
+ list("r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p"),
+ FUN=function(x) strsplit(x," ",fixed=TRUE)),
+ "transactions")
+ ars = apriori(transactions,
+ parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
+ arsDF = as(ars, "data.frame")
+ arsDF$support = arsDF$support * length(transactions)
+ names(arsDF)[names(arsDF) == "support"] = "freq"
+ > nrow(arsDF)
+ [1] 23
+ > sum(arsDF$confidence == 1)
+ [1] 23
+ */
+ assert(results1.size === 23)
+ assert(results1.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+
+ val results2 = ar
+ .setMinConfidence(0)
+ .run(freqItemsets)
+ .collect()
+
+ /* Verify results using the `R` code:
+ ars = apriori(transactions,
+ parameter = list(support = 0.5, confidence = 0.5, target="rules", minlen=2))
+ arsDF = as(ars, "data.frame")
+ arsDF$support = arsDF$support * length(transactions)
+ names(arsDF)[names(arsDF) == "support"] = "freq"
+ nrow(arsDF)
+ sum(arsDF$confidence == 1)
+ > nrow(arsDF)
+ [1] 30
+ > sum(arsDF$confidence == 1)
+ [1] 23
+ */
+ assert(results2.size === 30)
+ assert(results2.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+ }
+}
+