aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-05 13:31:59 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-05 13:31:59 -0800
commit13a3b636d9425c5713cd1381203ee1b60f71b8c8 (patch)
tree2b973d5901dd6e4115d11c4a90d3940e13129252 /mllib/src/test/java
parent047a31bb1042867b20132b347b1e08feab4562eb (diff)
downloadspark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.gz
spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.bz2
spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.zip
[SPARK-6724][MLLIB] Support model save/load for FPGrowthModel
Support model save/load for FPGrowthModel Author: Yanbo Liang <ybliang8@gmail.com> Closes #9267 from yanboliang/spark-6724.
Diffstat (limited to 'mllib/src/test/java')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java40
1 files changed, 40 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 154f75d75e..eeeabfe359 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.fpm;
+import java.io.File;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -69,4 +71,42 @@ public class JavaFPGrowthSuite implements Serializable {
long freq = itemset.freq();
}
}
+
+ @Test
+ public void runFPGrowthSaveLoad() {
+
+ @SuppressWarnings("unchecked")
+ JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ Arrays.asList("r z h k p".split(" ")),
+ Arrays.asList("z y x w v u t s".split(" ")),
+ Arrays.asList("s x o n r".split(" ")),
+ Arrays.asList("x z y m t s q e".split(" ")),
+ Arrays.asList("z".split(" ")),
+ Arrays.asList("x z y r q t p".split(" "))), 2);
+
+ FPGrowthModel<String> model = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd);
+
+ File tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite");
+ String outputPath = tempDir.getPath();
+
+ try {
+ model.save(sc.sc(), outputPath);
+ FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath);
+ List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
+ .collect();
+ assertEquals(18, freqItemsets.size());
+
+ for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ // Test return types.
+ List<String> items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ }
}