aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYuhao Yang <yuhao.yang@intel.com>2017-04-04 17:51:45 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-04-04 17:51:45 -0700
commitb28bbffbadf7ebc4349666e8f17111f6fca18c9a (patch)
treefd1eea16ecd6599fb1ba984f4b0232024780041d /mllib/src
parenta59759e6c059617b2fc8102cbf41acc5d409b34a (diff)
downloadspark-b28bbffbadf7ebc4349666e8f17111f6fca18c9a.tar.gz
spark-b28bbffbadf7ebc4349666e8f17111f6fca18c9a.tar.bz2
spark-b28bbffbadf7ebc4349666e8f17111f6fca18c9a.zip
[SPARK-20003][ML] FPGrowthModel setMinConfidence should affect rules generation and transform
## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-20003 I was doing some test and found the issue. ml.fpm.FPGrowthModel `setMinConfidence` should always affect rules generation and transform. Currently associationRules in FPGrowthModel is a lazy val and `setMinConfidence` in FPGrowthModel has no impact once associationRules got computed . I try to cache the associationRules to avoid re-computation if `minConfidence` is not changed, but this makes FPGrowthModel somehow stateful. Let me know if there's any concern. ## How was this patch tested? new unit test and I strength the unit test for model save/load to ensure the cache mechanism. Author: Yuhao Yang <yuhao.yang@intel.com> Closes #17336 from hhbyyh/fpmodelminconf.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala21
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala56
2 files changed, 56 insertions, 21 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 65cc806195..d604c1ac00 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -218,13 +218,28 @@ class FPGrowthModel private[ml] (
def setPredictionCol(value: String): this.type = set(predictionCol, value)
/**
- * Get association rules fitted by AssociationRules using the minConfidence. Returns a dataframe
+ * Cache minConfidence and associationRules to avoid redundant computation for association rules
+ * during transform. The associationRules will only be re-computed when minConfidence changed.
+ */
+ @transient private var _cachedMinConf: Double = Double.NaN
+
+ @transient private var _cachedRules: DataFrame = _
+
+ /**
+ * Get association rules fitted using the minConfidence. Returns a dataframe
* with three fields, "antecedent", "consequent" and "confidence", where "antecedent" and
* "consequent" are Array[T] and "confidence" is Double.
*/
@Since("2.2.0")
- @transient lazy val associationRules: DataFrame = {
- AssociationRules.getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
+ @transient def associationRules: DataFrame = {
+ if ($(minConfidence) == _cachedMinConf) {
+ _cachedRules
+ } else {
+ _cachedRules = AssociationRules
+ .getAssociationRulesFromFP(freqItemsets, "items", "freq", $(minConfidence))
+ _cachedMinConf = $(minConfidence)
+ _cachedRules
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 4603a618d2..6bec057511 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ml.fpm
import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions._
@@ -85,38 +85,58 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
assert(prediction.select("prediction").where("id=3").first().getSeq[String](0).isEmpty)
}
+ test("FPGrowth prediction should not contain duplicates") {
+ // This should generate rule 1 -> 3, 2 -> 3
+ val dataset = spark.createDataFrame(Seq(
+ Array("1", "3"),
+ Array("2", "3")
+ ).map(Tuple1(_))).toDF("items")
+ val model = new FPGrowth().fit(dataset)
+
+ val prediction = model.transform(
+ spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
+ ).first().getAs[Seq[String]]("prediction")
+
+ assert(prediction === Seq("3"))
+ }
+
+ test("FPGrowthModel setMinConfidence should affect rules generation and transform") {
+ val model = new FPGrowth().setMinSupport(0.1).setMinConfidence(0.1).fit(dataset)
+ val oldRulesNum = model.associationRules.count()
+ val oldPredict = model.transform(dataset)
+
+ model.setMinConfidence(0.8765)
+ assert(oldRulesNum > model.associationRules.count())
+ assert(!model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
+
+ // association rules should stay the same for same minConfidence
+ model.setMinConfidence(0.1)
+ assert(oldRulesNum === model.associationRules.count())
+ assert(model.transform(dataset).collect().toSet.equals(oldPredict.collect().toSet))
+ }
+
test("FPGrowth parameter check") {
val fpGrowth = new FPGrowth().setMinSupport(0.4567)
val model = fpGrowth.fit(dataset)
.setMinConfidence(0.5678)
assert(fpGrowth.getMinSupport === 0.4567)
assert(model.getMinConfidence === 0.5678)
+ MLTestingUtils.checkCopy(model)
}
test("read/write") {
def checkModelData(model: FPGrowthModel, model2: FPGrowthModel): Unit = {
- assert(model.freqItemsets.sort("items").collect() ===
- model2.freqItemsets.sort("items").collect())
+ assert(model.freqItemsets.collect().toSet.equals(
+ model2.freqItemsets.collect().toSet))
+ assert(model.associationRules.collect().toSet.equals(
+ model2.associationRules.collect().toSet))
+ assert(model.setMinConfidence(0.9).associationRules.collect().toSet.equals(
+ model2.setMinConfidence(0.9).associationRules.collect().toSet))
}
val fPGrowth = new FPGrowth()
testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings,
FPGrowthSuite.allParamSettings, checkModelData)
}
-
- test("FPGrowth prediction should not contain duplicates") {
- // This should generate rule 1 -> 3, 2 -> 3
- val dataset = spark.createDataFrame(Seq(
- Array("1", "3"),
- Array("2", "3")
- ).map(Tuple1(_))).toDF("items")
- val model = new FPGrowth().fit(dataset)
-
- val prediction = model.transform(
- spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("items")
- ).first().getAs[Seq[String]]("prediction")
-
- assert(prediction === Seq("3"))
- }
}
object FPGrowthSuite {