aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala14
1 files changed, 14 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
index 076d55c180..910d4b07d1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala
@@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
FPGrowthSuite.allParamSettings, checkModelData)
}
+ test("FPGrowth prediction should not contain duplicates") {
+ // This should generate rule 1 -> 3, 2 -> 3
+ val dataset = spark.createDataFrame(Seq(
+ Array("1", "3"),
+ Array("2", "3")
+ ).map(Tuple1(_))).toDF("features")
+ val model = new FPGrowth().fit(dataset)
+
+ val prediction = model.transform(
+ spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features")
+ ).first().getAs[Seq[String]]("prediction")
+
+ assert(prediction === Seq("3"))
+ }
}
object FPGrowthSuite {