diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-02-08 16:26:20 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-08 16:26:20 -0800 |
commit | 5c299c58fb9a5434a40be82150d4725bba805adf (patch) | |
tree | 8ea7856b545cd902fb30a92e14a9bf631b757936 /mllib/src/test | |
parent | 804949d519e2caa293a409d84b4e6190c1105444 (diff) | |
download | spark-5c299c58fb9a5434a40be82150d4725bba805adf.tar.gz spark-5c299c58fb9a5434a40be82150d4725bba805adf.tar.bz2 spark-5c299c58fb9a5434a40be82150d4725bba805adf.zip |
[SPARK-5598][MLLIB] model save/load for ALS
following #4233. jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #4422 from mengxr/SPARK-5598 and squashes the following commits:
a059394 [Xiangrui Meng] SaveLoad not extending Loader
14b7ea6 [Xiangrui Meng] address comments
f487cb2 [Xiangrui Meng] add unit tests
62fc43c [Xiangrui Meng] implement save/load for MFM
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala index b9caecc904..9801e87576 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModelSuite.scala @@ -22,6 +22,7 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext { @@ -53,4 +54,22 @@ class MatrixFactorizationModelSuite extends FunSuite with MLlibTestSparkContext new MatrixFactorizationModel(rank, userFeatures, prodFeatures1) } } + + test("save/load") { + val model = new MatrixFactorizationModel(rank, userFeatures, prodFeatures) + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + def collect(features: RDD[(Int, Array[Double])]): Set[(Int, Seq[Double])] = { + features.mapValues(_.toSeq).collect().toSet + } + try { + model.save(sc, path) + val newModel = MatrixFactorizationModel.load(sc, path) + assert(newModel.rank === rank) + assert(collect(newModel.userFeatures) === collect(userFeatures)) + assert(collect(newModel.productFeatures) === collect(prodFeatures)) + } finally { + Utils.deleteRecursively(tempDir) + } + } } |