aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-02-08 16:26:20 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-08 16:26:20 -0800
commit5c299c58fb9a5434a40be82150d4725bba805adf (patch)
tree8ea7856b545cd902fb30a92e14a9bf631b757936 /mllib
parent804949d519e2caa293a409d84b4e6190c1105444 (diff)
downloadspark-5c299c58fb9a5434a40be82150d4725bba805adf.tar.gz
spark-5c299c58fb9a5434a40be82150d4725bba805adf.tar.bz2
spark-5c299c58fb9a5434a40be82150d4725bba805adf.zip
[SPARK-5598][MLLIB] model save/load for ALS
following #4233. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #4422 from mengxr/SPARK-5598 and squashes the following commits: a059394 [Xiangrui Meng] SaveLoad not extending Loader 14b7ea6 [Xiangrui Meng] address comments f487cb2 [Xiangrui Meng] add unit tests 62fc43c [Xiangrui Meng] implement save/load for MFM
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala82
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala19
3 files changed, 100 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
index 4bb28d1b1e..caacab9430 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.recommendation
import org.apache.spark.Logging
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.ml.recommendation.{ALS => NewALS}
import org.apache.spark.rdd.RDD
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index ed2f8b41bc..9ff06ac362 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -17,13 +17,17 @@
package org.apache.spark.mllib.recommendation
+import java.io.IOException
import java.lang.{Integer => JavaInteger}
+import org.apache.hadoop.fs.Path
import org.jblas.DoubleMatrix
-import org.apache.spark.Logging
+import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.api.java.{JavaPairRDD, JavaRDD}
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.storage.StorageLevel
/**
@@ -41,7 +45,8 @@ import org.apache.spark.storage.StorageLevel
class MatrixFactorizationModel(
val rank: Int,
val userFeatures: RDD[(Int, Array[Double])],
- val productFeatures: RDD[(Int, Array[Double])]) extends Serializable with Logging {
+ val productFeatures: RDD[(Int, Array[Double])])
+ extends Saveable with Serializable with Logging {
require(rank > 0)
validateFeatures("User", userFeatures)
@@ -125,6 +130,12 @@ class MatrixFactorizationModel(
recommend(productFeatures.lookup(product).head, userFeatures, num)
.map(t => Rating(t._1, product, t._2))
+ protected override val formatVersion: String = "1.0"
+
+ override def save(sc: SparkContext, path: String): Unit = {
+ MatrixFactorizationModel.SaveLoadV1_0.save(this, path)
+ }
+
private def recommend(
recommendToFeatures: Array[Double],
recommendableFeatures: RDD[(Int, Array[Double])],
@@ -136,3 +147,70 @@ class MatrixFactorizationModel(
scored.top(num)(Ordering.by(_._2))
}
}
+
+object MatrixFactorizationModel extends Loader[MatrixFactorizationModel] {
+
+ import org.apache.spark.mllib.util.Loader._
+
+ override def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+ val (loadedClassName, formatVersion, metadata) = loadMetadata(sc, path)
+ val classNameV1_0 = SaveLoadV1_0.thisClassName
+ (loadedClassName, formatVersion) match {
+ case (className, "1.0") if className == classNameV1_0 =>
+ SaveLoadV1_0.load(sc, path)
+ case _ =>
+ throw new IOException("MatrixFactorizationModel.load did not recognize model with" +
+ s"(class: $loadedClassName, version: $formatVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+
+ private[recommendation]
+ object SaveLoadV1_0 {
+
+ private val thisFormatVersion = "1.0"
+
+ private[recommendation]
+ val thisClassName = "org.apache.spark.mllib.recommendation.MatrixFactorizationModel"
+
+ /**
+ * Saves a [[MatrixFactorizationModel]], where user features are saved under `data/users` and
+ * product features are saved under `data/products`.
+ */
+ def save(model: MatrixFactorizationModel, path: String): Unit = {
+ val sc = model.userFeatures.sparkContext
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits.createDataFrame
+ val metadata = (thisClassName, thisFormatVersion, model.rank)
+ val metadataRDD = sc.parallelize(Seq(metadata), 1).toDataFrame("class", "version", "rank")
+ metadataRDD.toJSON.saveAsTextFile(metadataPath(path))
+ model.userFeatures.toDataFrame("id", "features").saveAsParquetFile(userPath(path))
+ model.productFeatures.toDataFrame("id", "features").saveAsParquetFile(productPath(path))
+ }
+
+ def load(sc: SparkContext, path: String): MatrixFactorizationModel = {
+ val sqlContext = new SQLContext(sc)
+ val (className, formatVersion, metadata) = loadMetadata(sc, path)
+ assert(className == thisClassName)
+ assert(formatVersion == thisFormatVersion)
+ val rank = metadata.select("rank").first().getInt(0)
+ val userFeatures = sqlContext.parquetFile(userPath(path))
+ .map { case Row(id: Int, features: Seq[Double]) =>
+ (id, features.toArray)
+ }
+ val productFeatures = sqlContext.parquetFile(productPath(path))
+ .map { case Row(id: Int, features: Seq[Double]) =>
+ (id, features.toArray)
+ }
+ new MatrixFactorizationModel(rank, userFeatures, productFeatures)
+ }
+
+ private def userPath(path: String): String = {
+ new Path(dataPath(path), "user").toUri.toString
+ }
+
+ private def productPath(path: String): String = {
+ new Path(dataPath(path), "product").toUri.toString
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
index b9caecc904..9801e87576 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala
@@ -22,6 +22,7 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext {
@@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext
new MatrixFactorizationModel(rank, userFeatures, prodFeatures1)
}
}
+
+ test("save/load") {
+ val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures)
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+ def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = {
+ features.mapValues(_.toSeq).collect().toSet
+ }
+ try {
+ model.save(sc, path)
+ val newModel = MatrixFactorizationModel.load(sc, path)
+ assert(newModel.rank === rank)
+ assert(collect(newModel.userFeatures) === collect(userFeatures))
+ assert(collect(newModel.productFeatures) === collect(prodFeatures))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
}