aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-18 13:16:31 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-18 13:16:31 -0800
commit2acdf10b1f3bb1242dba64efa798c672fde9f0d2 (patch)
treefd533147dd84fc1cd87f9c5ecf7d1fb59d8133d0 /mllib
parent045a4f045821dcf60442f0600c2df1b79bddb536 (diff)
downloadspark-2acdf10b1f3bb1242dba64efa798c672fde9f0d2.tar.gz
spark-2acdf10b1f3bb1242dba64efa798c672fde9f0d2.tar.bz2
spark-2acdf10b1f3bb1242dba64efa798c672fde9f0d2.zip
[SPARK-6789][ML] Add Readable, Writable support for spark.ml ALS, ALSModel
Also modifies DefaultParamsWriter.saveMetadata to take optional extra metadata. CC: mengxr yanboliang Author: Joseph K. Bradley <joseph@databricks.com> Closes #9786 from jkbradley/als-io.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala75
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala78
3 files changed, 150 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 535f266b9a..d92514d2e2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -27,13 +27,16 @@ import scala.util.hashing.byteswap64
import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.json4s.{DefaultFormats, JValue}
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
import org.apache.spark.{Logging, Partitioner}
-import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.annotation.{Since, DeveloperApi, Experimental}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
-import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.CholeskyDecomposition
import org.apache.spark.mllib.optimization.NNLS
import org.apache.spark.rdd.RDD
@@ -182,7 +185,7 @@ class ALSModel private[ml] (
val rank: Int,
@transient val userFactors: DataFrame,
@transient val itemFactors: DataFrame)
- extends Model[ALSModel] with ALSModelParams {
+ extends Model[ALSModel] with ALSModelParams with Writable {
/** @group setParam */
def setUserCol(value: String): this.type = set(userCol, value)
@@ -220,8 +223,60 @@ class ALSModel private[ml] (
val copied = new ALSModel(uid, rank, userFactors, itemFactors)
copyValues(copied, extra).setParent(parent)
}
+
+ @Since("1.6.0")
+ override def write: Writer = new ALSModel.ALSModelWriter(this)
}
+@Since("1.6.0")
+object ALSModel extends Readable[ALSModel] {
+
+ @Since("1.6.0")
+ override def read: Reader[ALSModel] = new ALSModelReader
+
+ @Since("1.6.0")
+ override def load(path: String): ALSModel = read.load(path)
+
+ private[recommendation] class ALSModelWriter(instance: ALSModel) extends Writer {
+
+ override protected def saveImpl(path: String): Unit = {
+ val extraMetadata = render("rank" -> instance.rank)
+ DefaultParamsWriter.saveMetadata(instance, path, sc, Some(extraMetadata))
+ val userPath = new Path(path, "userFactors").toString
+ instance.userFactors.write.format("parquet").save(userPath)
+ val itemPath = new Path(path, "itemFactors").toString
+ instance.itemFactors.write.format("parquet").save(itemPath)
+ }
+ }
+
+ private[recommendation] class ALSModelReader extends Reader[ALSModel] {
+
+ /** Checked against metadata when loading model */
+ private val className = "org.apache.spark.ml.recommendation.ALSModel"
+
+ override def load(path: String): ALSModel = {
+ val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+ implicit val format = DefaultFormats
+ val rank: Int = metadata.extraMetadata match {
+ case Some(m: JValue) =>
+ (m \ "rank").extract[Int]
+ case None =>
+ throw new RuntimeException(s"ALSModel loader could not read rank from JSON metadata:" +
+ s" ${metadata.metadataStr}")
+ }
+
+ val userPath = new Path(path, "userFactors").toString
+ val userFactors = sqlContext.read.format("parquet").load(userPath)
+ val itemPath = new Path(path, "itemFactors").toString
+ val itemFactors = sqlContext.read.format("parquet").load(itemPath)
+
+ val model = new ALSModel(metadata.uid, rank, userFactors, itemFactors)
+
+ DefaultParamsReader.getAndSetParams(model, metadata)
+ model
+ }
+ }
+}
/**
* :: Experimental ::
@@ -254,7 +309,7 @@ class ALSModel private[ml] (
* preferences rather than explicit ratings given to items.
*/
@Experimental
-class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
+class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams with Writable {
import org.apache.spark.ml.recommendation.ALS.Rating
@@ -336,8 +391,12 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
}
override def copy(extra: ParamMap): ALS = defaultCopy(extra)
+
+ @Since("1.6.0")
+ override def write: Writer = new DefaultParamsWriter(this)
}
+
/**
* :: DeveloperApi ::
* An implementation of ALS that supports generic ID types, specialized for Int and Long. This is
@@ -347,7 +406,7 @@ class ALS(override val uid: String) extends Estimator[ALSModel] with ALSParams {
* than 2 billion.
*/
@DeveloperApi
-object ALS extends Logging {
+object ALS extends Readable[ALS] with Logging {
/**
* :: DeveloperApi ::
@@ -356,6 +415,12 @@ object ALS extends Logging {
@DeveloperApi
case class Rating[@specialized(Int, Long) ID](user: ID, item: ID, rating: Float)
+ @Since("1.6.0")
+ override def read: Reader[ALS] = new DefaultParamsReader[ALS]
+
+ @Since("1.6.0")
+ override def load(path: String): ALS = read.load(path)
+
/** Trait for least squares solvers applied to the normal equation. */
private[recommendation] trait LeastSquaresNESolver extends Serializable {
/** Solves a least squares problem with regularization (possibly with other constraints). */
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index dddb72af5b..d8ce907af5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -194,7 +194,11 @@ private[ml] object DefaultParamsWriter {
* - uid
* - paramMap: These must be encodable using [[org.apache.spark.ml.param.Param.jsonEncode()]].
*/
- def saveMetadata(instance: Params, path: String, sc: SparkContext): Unit = {
+ def saveMetadata(
+ instance: Params,
+ path: String,
+ sc: SparkContext,
+ extraMetadata: Option[JValue] = None): Unit = {
val uid = instance.uid
val cls = instance.getClass.getName
val params = instance.extractParamMap().toSeq.asInstanceOf[Seq[ParamPair[Any]]]
@@ -205,7 +209,8 @@ private[ml] object DefaultParamsWriter {
("timestamp" -> System.currentTimeMillis()) ~
("sparkVersion" -> sc.version) ~
("uid" -> uid) ~
- ("paramMap" -> jsonParams)
+ ("paramMap" -> jsonParams) ~
+ ("extraMetadata" -> extraMetadata)
val metadataPath = new Path(path, "metadata").toString
val metadataJson = compact(render(metadata))
sc.parallelize(Seq(metadataJson), 1).saveAsTextFile(metadataPath)
@@ -236,6 +241,7 @@ private[ml] object DefaultParamsReader {
/**
* All info from metadata file.
* @param params paramMap, as a [[JValue]]
+ * @param extraMetadata Extra metadata saved by [[DefaultParamsWriter.saveMetadata()]]
* @param metadataStr Full metadata file String (for debugging)
*/
case class Metadata(
@@ -244,6 +250,7 @@ private[ml] object DefaultParamsReader {
timestamp: Long,
sparkVersion: String,
params: JValue,
+ extraMetadata: Option[JValue],
metadataStr: String)
/**
@@ -262,12 +269,13 @@ private[ml] object DefaultParamsReader {
val timestamp = (metadata \ "timestamp").extract[Long]
val sparkVersion = (metadata \ "sparkVersion").extract[String]
val params = metadata \ "paramMap"
+ val extraMetadata = (metadata \ "extraMetadata").extract[Option[JValue]]
if (expectedClassName.nonEmpty) {
require(className == expectedClassName, s"Error loading metadata: Expected class name" +
s" $expectedClassName but found class name $className")
}
- Metadata(className, uid, timestamp, sparkVersion, params, metadataStr)
+ Metadata(className, uid, timestamp, sparkVersion, params, extraMetadata, metadataStr)
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
index eadc80e0e6..2c3fb84160 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ml.recommendation
-import java.io.File
import java.util.Random
import scala.collection.mutable
@@ -26,28 +25,26 @@ import scala.language.existentials
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.apache.spark.util.Utils
import org.apache.spark.{Logging, SparkException, SparkFunSuite}
import org.apache.spark.ml.recommendation.ALS._
-import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.{DataFrame, Row}
-class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
- private var tempDir: File = _
+class ALSSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest with Logging {
override def beforeAll(): Unit = {
super.beforeAll()
- tempDir = Utils.createTempDir()
sc.setCheckpointDir(tempDir.getAbsolutePath)
}
override def afterAll(): Unit = {
- Utils.deleteRecursively(tempDir)
super.afterAll()
}
@@ -186,7 +183,7 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
assert(compressed.dstPtrs.toSeq === Seq(0, 2, 3, 4, 5))
var decompressed = ArrayBuffer.empty[(Int, Int, Int, Float)]
var i = 0
- while (i < compressed.srcIds.size) {
+ while (i < compressed.srcIds.length) {
var j = compressed.dstPtrs(i)
while (j < compressed.dstPtrs(i + 1)) {
val dstEncodedIndex = compressed.dstEncodedIndices(j)
@@ -483,4 +480,67 @@ class ALSSuite extends SparkFunSuite with MLlibTestSparkContext with Logging {
ALS.train(ratings, rank = 1, maxIter = 50, numUserBlocks = 2, numItemBlocks = 2,
implicitPrefs = true, seed = 0)
}
+
+ test("read/write") {
+ import ALSSuite._
+ val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1)
+ val als = new ALS()
+ allEstimatorParamSettings.foreach { case (p, v) =>
+ als.set(als.getParam(p), v)
+ }
+ val sqlContext = this.sqlContext
+ import sqlContext.implicits._
+ val model = als.fit(ratings.toDF())
+
+ // Test Estimator save/load
+ val als2 = testDefaultReadWrite(als)
+ allEstimatorParamSettings.foreach { case (p, v) =>
+ val param = als.getParam(p)
+ assert(als.get(param).get === als2.get(param).get)
+ }
+
+ // Test Model save/load
+ val model2 = testDefaultReadWrite(model)
+ allModelParamSettings.foreach { case (p, v) =>
+ val param = model.getParam(p)
+ assert(model.get(param).get === model2.get(param).get)
+ }
+ assert(model.rank === model2.rank)
+ def getFactors(df: DataFrame): Set[(Int, Array[Float])] = {
+ df.select("id", "features").collect().map { case r =>
+ (r.getInt(0), r.getAs[Array[Float]](1))
+ }.toSet
+ }
+ assert(getFactors(model.userFactors) === getFactors(model2.userFactors))
+ assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors))
+ }
+}
+
+object ALSSuite {
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allModelParamSettings: Map[String, Any] = Map(
+ "predictionCol" -> "myPredictionCol"
+ )
+
+ /**
+ * Mapping from all Params to valid settings which differ from the defaults.
+ * This is useful for tests which need to exercise all Params, such as save/load.
+ * This excludes input columns to simplify some tests.
+ */
+ val allEstimatorParamSettings: Map[String, Any] = allModelParamSettings ++ Map(
+ "maxIter" -> 1,
+ "rank" -> 1,
+ "regParam" -> 0.01,
+ "numUserBlocks" -> 2,
+ "numItemBlocks" -> 2,
+ "implicitPrefs" -> true,
+ "alpha" -> 0.9,
+ "nonnegative" -> true,
+ "checkpointInterval" -> 20
+ )
}