From 8c32b2e870c7c250a63e838718df833edf6dea07 Mon Sep 17 00:00:00 2001 From: Feynman Liang Date: Wed, 8 Jul 2015 16:27:11 -0700 Subject: [SPARK-8877] [MLLIB] Public API for association rule generation Adds FPGrowth.generateAssociationRules to public API for generating association rules after mining frequent itemsets. Author: Feynman Liang Closes #7271 from feynmanliang/SPARK-8877 and squashes the following commits: 83b8baf [Feynman Liang] Add API Doc 867abff [Feynman Liang] Add FPGrowth.generateAssociationRules and change access modifiers for AssociationRules --- .../apache/spark/mllib/fpm/AssociationRules.scala | 5 +-- .../org/apache/spark/mllib/fpm/FPGrowth.scala | 11 +++++- .../org/apache/spark/mllib/fpm/FPGrowthSuite.scala | 42 ++++++++++++++++++++++ 3 files changed, 55 insertions(+), 3 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala index 4a0f842f33..7e2bbfe31c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/AssociationRules.scala @@ -33,7 +33,7 @@ import org.apache.spark.rdd.RDD * association rules which have a single item as the consequent. */ @Experimental -class AssociationRules private ( +class AssociationRules private[fpm] ( private var minConfidence: Double) extends Logging with Serializable { /** @@ -45,6 +45,7 @@ class AssociationRules private ( * Sets the minimal confidence (default: `0.8`). */ def setMinConfidence(minConfidence: Double): this.type = { + require(minConfidence >= 0.0 && minConfidence <= 1.0) this.minConfidence = minConfidence this } @@ -91,7 +92,7 @@ object AssociationRules { * @tparam Item item type */ @Experimental - class Rule[Item] private[mllib] ( + class Rule[Item] private[fpm] ( val antecedent: Array[Item], val consequent: Array[Item], freqUnion: Double, diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala index 0da59e812d..9cb9a00dbd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala @@ -40,7 +40,16 @@ import org.apache.spark.storage.StorageLevel * @tparam Item item type */ @Experimental -class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable +class FPGrowthModel[Item: ClassTag](val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable { + /** + * Generates association rules for the [[Item]]s in [[freqItemsets]]. + * @param confidence minimal confidence of the rules produced + */ + def generateAssociationRules(confidence: Double): RDD[AssociationRules.Rule[Item]] = { + val associationRules = new AssociationRules(confidence) + associationRules.run(freqItemsets) + } +} /** * :: Experimental :: diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index ddc296a428..4a9bfdb348 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -132,6 +132,48 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { assert(model1.freqItemsets.count() === 625) } + test("FP-Growth String type association rule generation") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + /* Verify results using the `R` code: + transactions = as(sapply( + list("r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p"), + FUN=function(x) strsplit(x," ",fixed=TRUE)), + "transactions") + ars = apriori(transactions, + parameter = list(support = 0.0, confidence = 0.5, target="rules", minlen=2)) + arsDF = as(ars, "data.frame") + arsDF$support = arsDF$support * length(transactions) + names(arsDF)[names(arsDF) == "support"] = "freq" + > nrow(arsDF) + [1] 23 + > sum(arsDF$confidence == 1) + [1] 23 + */ + val rules = (new FPGrowth()) + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + .generateAssociationRules(0.9) + .collect() + + assert(rules.size === 23) + assert(rules.count(rule => math.abs(rule.confidence - 1.0D) < 1e-6) == 23) + } + test("FP-Growth using Int type") { val transactions = Seq( "1 2 3", -- cgit v1.2.3