aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-01-05 13:31:59 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-05 13:31:59 -0800
commit13a3b636d9425c5713cd1381203ee1b60f71b8c8 (patch)
tree2b973d5901dd6e4115d11c4a90d3940e13129252 /mllib
parent047a31bb1042867b20132b347b1e08feab4562eb (diff)
downloadspark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.gz
spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.tar.bz2
spark-13a3b636d9425c5713cd1381203ee1b60f71b8c8.zip
[SPARK-6724][MLLIB] Support model save/load for FPGrowthModel
Support model save/load for FPGrowthModel Author: Yanbo Liang <ybliang8@gmail.com> Closes #9267 from yanboliang/spark-6724.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala100
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java40
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala68
3 files changed, 205 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
index 70ef1ed30c..5273ed4d76 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/fpm/FPGrowth.scala
@@ -17,19 +17,29 @@
package org.apache.spark.mllib.fpm
-import java.{util => ju}
import java.lang.{Iterable => JavaIterable}
+import java.{util => ju}
-import scala.collection.mutable
import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe._
+
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods.{compact, render}
import org.apache.spark.{HashPartitioner, Logging, Partitioner, SparkException}
import org.apache.spark.annotation.Since
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
import org.apache.spark.mllib.fpm.FPGrowth._
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext
+import org.apache.spark.sql.catalyst.ScalaReflection
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.types._
import org.apache.spark.storage.StorageLevel
/**
@@ -39,7 +49,8 @@ import org.apache.spark.storage.StorageLevel
*/
@Since("1.3.0")
class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
- @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]]) extends Serializable {
+ @Since("1.3.0") val freqItemsets: RDD[FreqItemset[Item]])
+ extends Saveable with Serializable {
/**
* Generates association rules for the [[Item]]s in [[freqItemsets]].
* @param confidence minimal confidence of the rules produced
@@ -49,6 +60,89 @@ class FPGrowthModel[Item: ClassTag] @Since("1.3.0") (
val associationRules = new AssociationRules(confidence)
associationRules.run(freqItemsets)
}
+
+ /**
+ * Save this model to the given path.
+ * It only works for Item datatypes supported by DataFrames.
+ *
+ * This saves:
+ * - human-readable (JSON) model metadata to path/metadata/
+ * - Parquet formatted data to path/data/
+ *
+ * The model may be loaded using [[FPGrowthModel.load]].
+ *
+ * @param sc Spark context used to save model data.
+ * @param path Path specifying the directory in which to save this model.
+ * If the directory already exists, this method throws an exception.
+ */
+ @Since("2.0.0")
+ override def save(sc: SparkContext, path: String): Unit = {
+ FPGrowthModel.SaveLoadV1_0.save(this, path)
+ }
+
+ override protected val formatVersion: String = "1.0"
+}
+
+@Since("2.0.0")
+object FPGrowthModel extends Loader[FPGrowthModel[_]] {
+
+ @Since("2.0.0")
+ override def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
+ FPGrowthModel.SaveLoadV1_0.load(sc, path)
+ }
+
+ private[fpm] object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private val thisClassName = "org.apache.spark.mllib.fpm.FPGrowthModel"
+
+ def save(model: FPGrowthModel[_], path: String): Unit = {
+ val sc = model.freqItemsets.sparkContext
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val metadata = compact(render(
+ ("class" -> thisClassName) ~ ("version" -> thisFormatVersion)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ // Get the type of item class
+ val sample = model.freqItemsets.first().items(0)
+ val className = sample.getClass.getCanonicalName
+ val classSymbol = runtimeMirror(getClass.getClassLoader).staticClass(className)
+ val tpe = classSymbol.selfType
+
+ val itemType = ScalaReflection.schemaFor(tpe).dataType
+ val fields = Array(StructField("items", ArrayType(itemType)),
+ StructField("freq", LongType))
+ val schema = StructType(fields)
+ val rowDataRDD = model.freqItemsets.map { x =>
+ Row(x.items, x.freq)
+ }
+ sqlContext.createDataFrame(rowDataRDD, schema).write.parquet(Loader.dataPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): FPGrowthModel[_] = {
+ implicit val formats = DefaultFormats
+ val sqlContext = SQLContext.getOrCreate(sc)
+
+ val (className, formatVersion, metadata) = Loader.loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+
+ val freqItemsets = sqlContext.read.parquet(Loader.dataPath(path))
+ val sample = freqItemsets.select("items").head().get(0)
+ loadImpl(freqItemsets, sample)
+ }
+
+ def loadImpl[Item : ClassTag](freqItemsets: DataFrame, sample: Item): FPGrowthModel[Item] = {
+ val freqItemsetsRDD = freqItemsets.select("items", "freq").map { x =>
+ val items = x.getAs[Seq[Item]](0).toArray
+ val freq = x.getLong(1)
+ new FreqItemset(items, freq)
+ }
+ new FPGrowthModel(freqItemsetsRDD)
+ }
+ }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
index 154f75d75e..eeeabfe359 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/fpm/JavaFPGrowthSuite.java
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.fpm;
+import java.io.File;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
@@ -28,6 +29,7 @@ import static org.junit.Assert.*;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.util.Utils;
public class JavaFPGrowthSuite implements Serializable {
private transient JavaSparkContext sc;
@@ -69,4 +71,42 @@ public class JavaFPGrowthSuite implements Serializable {
long freq = itemset.freq();
}
}
+
+ @Test
+ public void runFPGrowthSaveLoad() {
+
+ @SuppressWarnings("unchecked")
+ JavaRDD<List<String>> rdd = sc.parallelize(Arrays.asList(
+ Arrays.asList("r z h k p".split(" ")),
+ Arrays.asList("z y x w v u t s".split(" ")),
+ Arrays.asList("s x o n r".split(" ")),
+ Arrays.asList("x z y m t s q e".split(" ")),
+ Arrays.asList("z".split(" ")),
+ Arrays.asList("x z y r q t p".split(" "))), 2);
+
+ FPGrowthModel<String> model = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd);
+
+ File tempDir = Utils.createTempDir(
+ System.getProperty("java.io.tmpdir"), "JavaFPGrowthSuite");
+ String outputPath = tempDir.getPath();
+
+ try {
+ model.save(sc.sc(), outputPath);
+ FPGrowthModel newModel = FPGrowthModel.load(sc.sc(), outputPath);
+ List<FPGrowth.FreqItemset<String>> freqItemsets = newModel.freqItemsets().toJavaRDD()
+ .collect();
+ assertEquals(18, freqItemsets.size());
+
+ for (FPGrowth.FreqItemset<String> itemset: freqItemsets) {
+ // Test return types.
+ List<String> items = itemset.javaItems();
+ long freq = itemset.freq();
+ }
+ } finally {
+ Utils.deleteRecursively(tempDir);
+ }
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
index 4a9bfdb348..b9e997c207 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/fpm/FPGrowthSuite.scala
@@ -18,6 +18,7 @@ package org.apache.spark.mllib.fpm
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.util.Utils
class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -274,4 +275,71 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext {
*/
assert(model1.freqItemsets.count() === 65)
}
+
+ test("model save/load with String type") {
+ val transactions = Seq(
+ "r z h k p",
+ "z y x w v u t s",
+ "s x o n r",
+ "x z y m t s q e",
+ "z",
+ "x z y r q t p")
+ .map(_.split(" "))
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+ test("model save/load with Int type") {
+ val transactions = Seq(
+ "1 2 3",
+ "1 2 3 4",
+ "5 4 3 2 1",
+ "6 5 4 3 2 1",
+ "2 4",
+ "1 3",
+ "1 7")
+ .map(_.split(" ").map(_.toInt).toArray)
+ val rdd = sc.parallelize(transactions, 2).cache()
+
+ val model3 = new FPGrowth()
+ .setMinSupport(0.5)
+ .setNumPartitions(2)
+ .run(rdd)
+ val freqItemsets3 = model3.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ try {
+ model3.save(sc, path)
+ val newModel = FPGrowthModel.load(sc, path)
+ val newFreqItemsets = newModel.freqItemsets.collect().map { itemset =>
+ (itemset.items.toSet, itemset.freq)
+ }
+ assert(freqItemsets3.toSet === newFreqItemsets.toSet)
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}