From b0adb9f543fbac16ea14c64eef6ba032a9919039 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 13 Apr 2016 13:18:02 -0700 Subject: [SPARK-10386][MLLIB] PrefixSpanModel supports save/load ```PrefixSpanModel``` supports ```save/load```. It's similar with #9267. cc jkbradley Author: Yanbo Liang Closes #10664 from yanboliang/spark-10386. --- .../org/apache/spark/mllib/fpm/PrefixSpan.scala | 96 +++++++++++++++++++++- .../spark/mllib/fpm/JavaPrefixSpanSuite.java | 37 +++++++++ .../apache/spark/mllib/fpm/PrefixSpanSuite.scala | 31 +++++++ 3 files changed, 163 insertions(+), 1 deletion(-) (limited to 'mllib/src') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala index 4455681e50..4344ab1bad 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala @@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.JavaConverters._ import scala.collection.mutable import scala.reflect.ClassTag +import scala.reflect.runtime.universe._ +import org.json4s.DefaultFormats +import org.json4s.JsonDSL._ +import org.json4s.jackson.JsonMethods.{compact, render} + +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.api.java.JavaRDD import org.apache.spark.api.java.JavaSparkContext.fakeClassTag import org.apache.spark.internal.Logging +import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{DataFrame, Row, SQLContext} +import org.apache.spark.sql.catalyst.ScalaReflection +import org.apache.spark.sql.types._ import org.apache.spark.storage.StorageLevel /** @@ -566,4 +576,88 @@ object PrefixSpan extends Logging { @Since("1.5.0") class PrefixSpanModel[Item] @Since("1.5.0") ( @Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]]) - extends Serializable + extends Saveable with Serializable { + + /** + * Save this model to the given path. + * It only works for Item datatypes supported by DataFrames. + * + * This saves: + * - human-readable (JSON) model metadata to path/metadata/ + * - Parquet formatted data to path/data/ + * + * The model may be loaded using [[PrefixSpanModel.load]]. + * + * @param sc Spark context used to save model data. + * @param path Path specifying the directory in which to save this model. + * If the directory already exists, this method throws an exception. + */ + @Since("2.0.0") + override def save(sc: SparkContext, path: String): Unit = { + PrefixSpanModel.SaveLoadV1_0.save(this, path) + } + + override protected val formatVersion: String = "1.0" +} + +@Since("2.0.0") +object PrefixSpanModel extends Loader[PrefixSpanModel[_]] { + + @Since("2.0.0") + override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + PrefixSpanModel.SaveLoadV1_0.load(sc, path) + } + + private[fpm] object SaveLoadV1_0 { + + private val thisFormatVersion = "1.0" + + private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel" + + def save(model: PrefixSpanModel[_], path: String): Unit = { + val sc = model.freqSequences.sparkContext + val sqlContext = SQLContext.getOrCreate(sc) + + val metadata = compact(render( + ("class" -> thisClassName) ~ ("version" -> thisFormatVersion))) + sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path)) + + // Get the type of item class + val sample = model.freqSequences.first().sequence(0)(0) + val className = sample.getClass.getCanonicalName + val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className) + val tpe = classSymbol.selfType + + val itemType = ScalaReflection.schemaFor(tpe).dataType + val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))), + StructField("freq", LongType)) + val schema = StructType(fields) + val rowDataRDD = model.freqSequences.map { x => + Row(x.sequence, x.freq) + } + sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path)) + } + + def load(sc: SparkContext, path: String): PrefixSpanModel[_] = { + implicit val formats = DefaultFormats + val sqlContext = SQLContext.getOrCreate(sc) + + val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path) + assert(className == thisClassName) + assert(formatVersion == thisFormatVersion) + + val freqSequences = sqlContext.read.parquet(Loader.dataPath(path)) + val sample = freqSequences.select("sequence").head().get(0) + loadImpl(freqSequences, sample) + } + + def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = { + val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x => + val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray + val freq = x.getLong(1) + new PrefixSpan.FreqSequence(sequence, freq) + } + new PrefixSpanModel(freqSequencesRDD) + } + } +} diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java index 34daf5fbde..8a67793abc 100644 --- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java +++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java @@ -17,6 +17,7 @@ package org.apache.spark.mllib.fpm; +import java.io.File; import java.util.Arrays; import java.util.List; @@ -28,6 +29,7 @@ import org.junit.Test; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence; +import org.apache.spark.util.Utils; public class JavaPrefixSpanSuite { private transient JavaSparkContext sc; @@ -64,4 +66,39 @@ public class JavaPrefixSpanSuite { long freq = freqSeq.freq(); } } + + @Test + public void runPrefixSpanSaveLoad() { + JavaRDD>> sequences = sc.parallelize(Arrays.asList( + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)), + Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)), + Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)), + Arrays.asList(Arrays.asList(6)) + ), 2); + PrefixSpan prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5); + PrefixSpanModel model = prefixSpan.run(sequences); + + File tempDir = Utils.createTempDir( + System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite"); + String outputPath = tempDir.getPath(); + + try { + model.save(sc.sc(), outputPath); + PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath); + JavaRDD> freqSeqs = newModel.freqSequences().toJavaRDD(); + List> localFreqSeqs = freqSeqs.collect(); + Assert.assertEquals(5, localFreqSeqs.size()); + // Check that each frequent sequence could be materialized. + for (PrefixSpan.FreqSequence freqSeq: localFreqSeqs) { + List> seq = freqSeq.javaSequence(); + long freq = freqSeq.freq(); + } + } finally { + Utils.deleteRecursively(tempDir); + } + + + } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala index a83e543859..6d8c7b47d8 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.util.Utils class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext { compareResults(expected, model.freqSequences.collect()) } + test("model save/load") { + val sequences = Seq( + Array(Array(1, 2), Array(3)), + Array(Array(1), Array(3, 2), Array(1, 2)), + Array(Array(1, 2), Array(5)), + Array(Array(6))) + val rdd = sc.parallelize(sequences, 2).cache() + + val prefixSpan = new PrefixSpan() + .setMinSupport(0.5) + .setMaxPatternLength(5) + val model = prefixSpan.run(rdd) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + try { + model.save(sc, path) + val newModel = PrefixSpanModel.load(sc, path) + val originalSet = model.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + val newSet = newModel.freqSequences.collect().map { x => + (x.sequence.map(_.toSet).toSeq, x.freq) + }.toSet + assert(originalSet === newSet) + } finally { + Utils.deleteRecursively(tempDir) + } + } + private def compareResults[Item]( expectedValue: Array[(Array[Array[Item]], Long)], actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = { -- cgit v1.2.3