diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-01-05 13:31:59 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-05 13:31:59 -0800 |
commit | 13a3b636d9425c5713cd1381203ee1b60f71b8c8 (patch) | |
tree | 2b973d5901dd6e4115d11c4a90d3940e13129252 /mllib/src/test/scala/org/apache | |
parent | 047a31bb1042867b20132b347b1e08feab4562eb (diff) | |
download | spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.gz spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.bz2 spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.zip |
[SPARK-6724][MLLIB] Support model save/load for FPGrowthModel
Support model save/load for FPGrowthModel
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #9267 from yanboliang/spark-6724.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala index 4a9bfdb348..b9e997c207 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext { */ assert(model1.freqItemsets.count() === 65) } + + test("model save/load with String type") { + val transactions = Seq( + "r z h k p", + "z y x w v u t s", + "s x o n r", + "x z y m t s q e", + "z", + "x z y r q t p") + .map(_.split(" ")) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + + test("model save/load with Int type") { + val transactions = Seq( + "1 2 3", + "1 2 3 4", + "5 4 3 2 1", + "6 5 4 3 2 1", + "2 4", + "1 3", + "1 7") + .map(_.split(" ").map(_.toInt).toArray) + val rdd = sc.parallelize(transactions, 2).cache() + + val model3 = new FPGrowth() + .setMinSupport(0.5) + .setNumPartitions(2) + .run(rdd) + val freqItemsets3 = model3.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model3.save(sc, path) + val newModel = FPGrowthModel.load(sc, path) + val newFreqItemsets = newModel.freqItemsets.collect().map { itemset => + (itemset.items.toSet, itemset.freq) + } + assert(freqItemsets3.toSet === newFreqItemsets.toSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } } |