aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java40
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala68
2 files changed, 108 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 154f75d75e..eeeabfe359 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.fpm;
+import java.io.File;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -69,4 +71,42 @@ public class JavaFPGrowthSuite implements Serializable {
long freq = itemset.freq();
}
}
+
+ @Test
+ public void runFPGrowthSaveLoad() {
+
+ @SuppressWarnings("unchecked")
+ JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ Arrays.asList("r z h k p".split(" ")),
+ Arrays.asList("z y x w v u t s".split(" ")),
+ Arrays.asList("s x o n r".split(" ")),
+ Arrays.asList("x z y m t s q e".split(" ")),
+ Arrays.asList("z".split(" ")),
+ Arrays.asList("x z y r q t p".split(" "))), 2);
+
+ FPGrowthModel<String> model = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd);
+
+ File tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite");
+ String outputPath = tempDir.getPath();
+
+ try {
+ model.save(sc.sc(), outputPath);
+ FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath);
+ List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
+ .collect();
+ assertEquals(18, freqItemsets.size());
+
+ for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ // Test return types.
+ List<String> items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 4a9bfdb348..b9e997c207 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
assert(model1.freqItemsets.count() === 65)
}
+
+ test("model save/load with String type") {
+ val transactions = Seq(
+ "r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p")
+ .map(_.split(" "))
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("model save/load with Int type") {
+ val transactions = Seq(
+ "1 2 3",
+ "1 2 3 4",
+ "5 4 3 2 1",
+ "6 5 4 3 2 1",
+ "2 4",
+ "1 3",
+ "1 7")
+ .map(_.split(" ").map(_.toInt).toArray)
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}