diff options
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 076d55c180..910d4b07d1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -103,6 +103,20 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul FPGrowthSuite.allParamSettings, checkModelData) } + test("FPGrowth prediction should not contain duplicates") { + // This should generate rule 1 -> 3, 2 -> 3 + val dataset = spark.createDataFrame(Seq( + Array("1", "3"), + Array("2", "3") + ).map(Tuple1(_))).toDF("features") + val model = new FPGrowth().fit(dataset) + + val prediction = model.transform( + spark.createDataFrame(Seq(Tuple1(Array("1", "2")))).toDF("features") + ).first().getAs[Seq[String]]("prediction") + + assert(prediction === Seq("3")) + } } object FPGrowthSuite { |