aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-05 13:31:59 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-05 13:31:59 -0800
commit13a3b636d9425c5713cd1381203ee1b60f71b8c8 (patch)
tree2b973d5901dd6e4115d11c4a90d3940e13129252 /mllib/src/test/scala/org/apache
parent047a31bb1042867b20132b347b1e08feab4562eb (diff)
downloadspark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.gz
spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.bz2
spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.zip
[SPARK-6724][MLLIB] Support model save/load for FPGrowthModel
Support model save/load for FPGrowthModel Author: Yanbo Liang <ybliang8@gmail.com> Closes #9267 from yanboliang/spark-6724.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala68
1 files changed, 68 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 4a9bfdb348..b9e997c207 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
assert(model1.freqItemsets.count() === 65)
}
+
+ test("model save/load with String type") {
+ val transactions = Seq(
+ "r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p")
+ .map(_.split(" "))
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("model save/load with Int type") {
+ val transactions = Seq(
+ "1 2 3",
+ "1 2 3 4",
+ "5 4 3 2 1",
+ "6 5 4 3 2 1",
+ "2 4",
+ "1 3",
+ "1 7")
+ .map(_.split(" ").map(_.toInt).toArray)
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}