diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala | 31 |
1 files changed, 31 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index a83e543859..6d8c7b47d8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("model save/load") { + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val newModel = PrefixSpanModel.load(sc, path) + val originalSet = model.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + val newSet = newModel.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + assert(originalSet === newSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { |