aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-07-08 16:27:11 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-08 16:27:11 -0700
commit8c32b2e870c7c250a63e838718df833edf6dea07 (patch)
tree238d1b23e3bcfbd976ebdbf3be29136a3613a531 /mllib
parent381cb161ba4e3a30f2da3c4ef4ee19869d51f101 (diff)
downloadspark-8c32b2e870c7c250a63e838718df833edf6dea07.tar.gz
spark-8c32b2e870c7c250a63e838718df833edf6dea07.tar.bz2
spark-8c32b2e870c7c250a63e838718df833edf6dea07.zip
[SPARK-8877] [MLLIB] Public API for association rule generation
Adds FPGrowth.generateAssociationRules to public API for generating association rules after mining frequent itemsets. Author: Feynman Liang <fliang@databricks.com> Closes #7271 from feynmanliang/SPARK-8877 and squashes the following commits: 83b8baf [Feynman Liang] Add API Doc 867abff [Feynman Liang] Add FPGrowth.generateAssociationRules and change access modifiers for AssociationRules
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala42
3 files changed, 55 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
index 4a0f842f33..7e2bbfe31c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala
@@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD
* association rules which have a single item as the consequent.
*/
@Experimental
-class AssociationRules private (
+class AssociationRules private[fpm] (
private var minConfidence: Double) extends Logging with Serializable {
/**
@@ -45,6 +45,7 @@ class AssociationRules private (
* Sets the minimal confidence (default: `0.8`).
*/
def setMinConfidence(minConfidence: Double): this.type = {
+ require(minConfidence >= 0.0 && minConfidence <= 1.0)
this.minConfidence = minConfidence
this
}
@@ -91,7 +92,7 @@ object AssociationRules {
* @tparam Item item type
*/
@Experimental
- class Rule[Item] private[mllib] (
+ class Rule[Item] private[fpm] (
val antecedent: Array[Item],
val consequent: Array[Item],
freqUnion: Double,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 0da59e812d..9cb9a00dbd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -40,7 +40,16 @@ import org.apache.spark.storage.StorageLevel
* @tparam Item item type
*/
@Experimental
-class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable
+class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
+ /**
+ * Generates association rules for the [[Item]]s in [[freqItemsets]].
+ * @param confidence minimal confidence of the rules produced
+ */
+ def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = {
+ val associationRules = new AssociationRules(confidence)
+ associationRules.run(freqItemsets)
+ }
+}
/**
* :: Experimental ::
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index ddc296a428..4a9bfdb348 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -132,6 +132,48 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model1.freqItemsets.count() === 625)
}
+ test("FP-Growth String type association rule generation") {
+ val transactions = Seq(
+ "r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p")
+ .map(_.split(" "))
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ /* Verify results using the `R` code:
+ transactions = as(sapply(
+ list("r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p"),
+ FUN=function(x) strsplit(x," ",fixed=TRUE)),
+ "transactions")
+ ars = apriori(transactions,
+ parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2))
+ arsDF = as(ars, "data.frame")
+ arsDF$support = arsDF$support * length(transactions)
+ names(arsDF)[names(arsDF) == "support"] = "freq"
+ > nrow(arsDF)
+ [1] 23
+ > sum(arsDF$confidence == 1)
+ [1] 23
+ */
+ val rules = (new FPGrowth())
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ .generateAssociationRules(0.9)
+ .collect()
+
+ assert(rules.size === 23)
+ assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23)
+ }
+
test("FP-Growth using Int type") {
val transactions = Seq(
"1 2 3",