aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala31
1 files changed, 31 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index a83e543859..6d8c7b47d8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
compareResults(expected, model.freqSequences.collect())
}
+ test("model save/load") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(3)),
+ Array(Array(1), Array(3, 2), Array(1, 2)),
+ Array(Array(1, 2), Array(5)),
+ Array(Array(6)))
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ val model = prefixSpan.run(rdd)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model.save(sc, path)
+ val newModel = PrefixSpanModel.load(sc, path)
+ val originalSet = model.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ val newSet = newModel.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ assert(originalSet === newSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {