aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorzero323 <zero323@users.noreply.github.com>2017-03-14 07:34:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2017-03-14 07:34:44 -0700
commitd4a637cd46b6dd5cc71ea17a55c4a26186e592c7 (patch)
tree376c8eddf64281e0bb049ea4bc1a462a78a161bd /mllib
parent5e96a57b2f383d4b33735681b41cd3ec06570671 (diff)
downloadspark-d4a637cd46b6dd5cc71ea17a55c4a26186e592c7.tar.gz
spark-d4a637cd46b6dd5cc71ea17a55c4a26186e592c7.tar.bz2
spark-d4a637cd46b6dd5cc71ea17a55c4a26186e592c7.zip
[SPARK-19940][ML][MINOR] FPGrowthModel.transform should skip duplicated items
## What changes were proposed in this pull request? This commit moved `distinct` in its intended place to avoid duplicated predictions and adds unit test covering the issue. ## How was this patch tested? Unit tests. Author: zero323 <zero323@users.noreply.github.com> Closes #17283 from zero323/SPARK-19940.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala14
2 files changed, 16 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
index 417968d9b8..fa39dd954a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/fpm/FPGrowth.scala
@@ -245,10 +245,10 @@ class FPGrowthModel private[ml] (
rule._2.filter(item => !itemset.contains(item))
} else {
Seq.empty
- })
+ }).distinct
} else {
Seq.empty
- }.distinct }, dt)
+ }}, dt)
dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 076d55c180..910d4b07d1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
FPGrowthSuite.allParamSettings, checkModelData)
}
+ test("FPGrowth prediction should not contain duplicates") {
+ // This should generate rule 1 -> 3, 2 -> 3
+ val dataset = spark.createDataFrame(Seq(
+ Array("1", "3"),
+ Array("2", "3")
+ ).map(Tuple1(_))).toDF("features")
+ val model = new FPGrowth().fit(dataset)
+
+ val prediction = model.transform(
+ spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
+ ).first().getAs[Seq[String]]("prediction")
+
+ assert(prediction === Seq("3"))
+ }
}
object FPGrowthSuite {