aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala68
1 files changed, 68 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 4a9bfdb348..b9e997c207 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
assert(model1.freqItemsets.count() === 65)
}
+
+ test("model save/load with String type") {
+ val transactions = Seq(
+ "r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p")
+ .map(_.split(" "))
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("model save/load with Int type") {
+ val transactions = Seq(
+ "1 2 3",
+ "1 2 3 4",
+ "5 4 3 2 1",
+ "6 5 4 3 2 1",
+ "2 4",
+ "1 3",
+ "1 7")
+ .map(_.split(" ").map(_.toInt).toArray)
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}