aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 13:18:02 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:18:02 -0700
commitb0adb9f543fbac16ea14c64eef6ba032a9919039 (patch)
tree9da12072bdf70d2b1c70e15b13c2f02e1a098d96 /mllib/src/test
parentdbbe149070052af5cda04f7b110d65de73766ded (diff)
downloadspark-b0adb9f543fbac16ea14c64eef6ba032a9919039.tar.gz
spark-b0adb9f543fbac16ea14c64eef6ba032a9919039.tar.bz2
spark-b0adb9f543fbac16ea14c64eef6ba032a9919039.zip
[SPARK-10386][MLLIB] PrefixSpanModel supports save/load
```PrefixSpanModel``` supports ```save/load```. It's similar with #9267. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #10664 from yanboliang/spark-10386.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java37
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala31
2 files changed, 68 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 34daf5fbde..8a67793abc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.fpm;
+import java.io.File;
import java.util.Arrays;
import java.util.List;
@@ -28,6 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc;
@@ -64,4 +66,39 @@ public class JavaPrefixSpanSuite {
long freq = freqSeq.freq();
}
}
+
+ @Test
+ public void runPrefixSpanSaveLoad() {
+ JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
+ Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
+ Arrays.asList(Arrays.asList(6))
+ ), 2);
+ PrefixSpan prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5);
+ PrefixSpanModel<Integer> model = prefixSpan.run(sequences);
+
+ File tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite");
+ String outputPath = tempDir.getPath();
+
+ try {
+ model.save(sc.sc(), outputPath);
+ PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+ JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
+ List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
+ Assert.assertEquals(5, localFreqSeqs.size());
+ // Check that each frequent sequence could be materialized.
+ for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ List<List<Integer>> seq = freqSeq.javaSequence();
+ long freq = freqSeq.freq();
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+
+
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index a83e543859..6d8c7b47d8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
compareResults(expected, model.freqSequences.collect())
}
+ test("model save/load") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(3)),
+ Array(Array(1), Array(3, 2), Array(1, 2)),
+ Array(Array(1, 2), Array(5)),
+ Array(Array(6)))
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ val model = prefixSpan.run(rdd)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model.save(sc, path)
+ val newModel = PrefixSpanModel.load(sc, path)
+ val originalSet = model.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ val newSet = newModel.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ assert(originalSet === newSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {