aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-13 13:18:02 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-13 13:18:02 -0700
commitb0adb9f543fbac16ea14c64eef6ba032a9919039 (patch)
tree9da12072bdf70d2b1c70e15b13c2f02e1a098d96 /mllib/src
parentdbbe149070052af5cda04f7b110d65de73766ded (diff)
downloadspark-b0adb9f543fbac16ea14c64eef6ba032a9919039.tar.gz
spark-b0adb9f543fbac16ea14c64eef6ba032a9919039.tar.bz2
spark-b0adb9f543fbac16ea14c64eef6ba032a9919039.zip
[SPARK-10386][MLLIB] PrefixSpanModel supports save/load
```PrefixSpanModel``` supports ```save/load```. It's similar with #9267. cc jkbradley Author: Yanbo Liang <ybliang8@gmail.com> Closes #10664 from yanboliang/spark-10386.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala96
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java37
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala31
3 files changed, 163 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
index 4455681e50..4344ab1bad 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/PrefixSpan.scala
@@ -23,12 +23,22 @@ import java.util.concurrent.atomic.AtomicInteger
import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
+
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.internal.Logging
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
/**
@@ -566,4 +576,88 @@ object PrefixSpan extends Logging {
@Since("1.5.0")
class PrefixSpanModel[Item] @Since("1.5.0") (
@Since("1.5.0") val freqSequences: RDD[PrefixSpan.FreqSequence[Item]])
- extends Serializable
+ extends Saveable with Serializable {
+
+ /**
+ * Save this model to the given path.
+ * It only works for Item datatypes supported by DataFrames.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[PrefixSpanModel.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * If the directory already exists, this method throws an exception.
+ */
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ PrefixSpanModel.SaveLoadV1_0.save(this, path)
+ }
+
+ override protected val formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object PrefixSpanModel extends Loader[PrefixSpanModel[_]] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+ PrefixSpanModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[fpm] object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private val thisClassName = "org.apache.spark.mllib.fpm.PrefixSpanModel"
+
+ def save(model: PrefixSpanModel[_], path: String): Unit = {
+ val sc = model.freqSequences.sparkContext
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Get the type of item class
+ val sample = model.freqSequences.first().sequence(0)(0)
+ val className = sample.getClass.getCanonicalName
+ val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
+ val tpe = classSymbol.selfType
+
+ val itemType = ScalaReflection.schemaFor(tpe).dataType
+ val fields = Array(StructField("sequence", ArrayType(ArrayType(itemType))),
+ StructField("freq", LongType))
+ val schema = StructType(fields)
+ val rowDataRDD = model.freqSequences.map { x =>
+ Row(x.sequence, x.freq)
+ }
+ sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): PrefixSpanModel[_] = {
+ implicit val formats = DefaultFormats
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val freqSequences = sqlContext.read.parquet(Loader.dataPath(path))
+ val sample = freqSequences.select("sequence").head().get(0)
+ loadImpl(freqSequences, sample)
+ }
+
+ def loadImpl[Item: ClassTag](freqSequences: DataFrame, sample: Item): PrefixSpanModel[Item] = {
+ val freqSequencesRDD = freqSequences.select("sequence", "freq").rdd.map { x =>
+ val sequence = x.getAs[Seq[Seq[Item]]](0).map(_.toArray).toArray
+ val freq = x.getLong(1)
+ new PrefixSpan.FreqSequence(sequence, freq)
+ }
+ new PrefixSpanModel(freqSequencesRDD)
+ }
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
index 34daf5fbde..8a67793abc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaPrefixSpanSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.fpm;
+import java.io.File;
import java.util.Arrays;
import java.util.List;
@@ -28,6 +29,7 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.fpm.PrefixSpan.FreqSequence;
+import org.apache.spark.util.Utils;
public class JavaPrefixSpanSuite {
private transient JavaSparkContext sc;
@@ -64,4 +66,39 @@ public class JavaPrefixSpanSuite {
long freq = freqSeq.freq();
}
}
+
+ @Test
+ public void runPrefixSpanSaveLoad() {
+ JavaRDD<List<List<Integer>>> sequences = sc.parallelize(Arrays.asList(
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(3)),
+ Arrays.asList(Arrays.asList(1), Arrays.asList(3, 2), Arrays.asList(1, 2)),
+ Arrays.asList(Arrays.asList(1, 2), Arrays.asList(5)),
+ Arrays.asList(Arrays.asList(6))
+ ), 2);
+ PrefixSpan prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5);
+ PrefixSpanModel<Integer> model = prefixSpan.run(sequences);
+
+ File tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaPrefixSpanSuite");
+ String outputPath = tempDir.getPath();
+
+ try {
+ model.save(sc.sc(), outputPath);
+ PrefixSpanModel newModel = PrefixSpanModel.load(sc.sc(), outputPath);
+ JavaRDD<FreqSequence<Integer>> freqSeqs = newModel.freqSequences().toJavaRDD();
+ List<FreqSequence<Integer>> localFreqSeqs = freqSeqs.collect();
+ Assert.assertEquals(5, localFreqSeqs.size());
+ // Check that each frequent sequence could be materialized.
+ for (PrefixSpan.FreqSequence<Integer> freqSeq: localFreqSeqs) {
+ List<List<Integer>> seq = freqSeq.javaSequence();
+ long freq = freqSeq.freq();
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+
+
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
index a83e543859..6d8c7b47d8 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/PrefixSpanSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -357,6 +358,36 @@ class PrefixSpanSuite extends SparkFunSuite with MLlibTestSparkContext {
compareResults(expected, model.freqSequences.collect())
}
+ test("model save/load") {
+ val sequences = Seq(
+ Array(Array(1, 2), Array(3)),
+ Array(Array(1), Array(3, 2), Array(1, 2)),
+ Array(Array(1, 2), Array(5)),
+ Array(Array(6)))
+ val rdd = sc.parallelize(sequences, 2).cache()
+
+ val prefixSpan = new PrefixSpan()
+ .setMinSupport(0.5)
+ .setMaxPatternLength(5)
+ val model = prefixSpan.run(rdd)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model.save(sc, path)
+ val newModel = PrefixSpanModel.load(sc, path)
+ val originalSet = model.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ val newSet = newModel.freqSequences.collect().map { x =>
+ (x.sequence.map(_.toSet).toSeq, x.freq)
+ }.toSet
+ assert(originalSet === newSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
private def compareResults[Item](
expectedValue: Array[(Array[Array[Item]], Long)],
actualValue: Array[PrefixSpan.FreqSequence[Item]]): Unit = {