diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-03-31 16:01:08 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-03-31 16:01:08 -0700 |
commit | 0e00f12d33d28d064c166262b14e012a1aeaa7b0 (patch) | |
tree | bc69dd88ed7ee75ec3ff6bf0a744c00f8bcc86af /mllib/src/test | |
parent | 2036bc5993022da550f0cb1c0485ae92ec3e6fb0 (diff) | |
download | spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.tar.gz spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.tar.bz2 spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.zip |
[SPARK-5692] [MLlib] Word2Vec save/load
Word2Vec model now supports saving and loading.
a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and "numWords"
b] The data stored in Parquet file format consists of an Array of rows with each row consisting of 2 columns, first being the word: String and the second, an Array of Floats.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #5291 from MechCoder/spark-5692 and squashes the following commits:
1142f3a [MechCoder] Add numWords to metaData
bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala | 26 |
1 files changed, 26 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index 52278690db..98a98a7599 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -21,6 +21,9 @@ import org.scalatest.FunSuite import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ +import org.apache.spark.util.Utils + class Word2VecSuite extends FunSuite with MLlibTestSparkContext { // TODO: add more tests @@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext { assert(syms(0)._1 == "taiwan") assert(syms(1)._1 == "japan") } + + test("model load / save") { + + val word2VecMap = Map( + ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)), + ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)), + ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)), + ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f)) + ) + val model = new Word2VecModel(word2VecMap) + + val tempDir = Utils.createTempDir() + val path = tempDir.toURI.toString + + try { + model.save(sc, path) + val sameModel = Word2VecModel.load(sc, path) + assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq)) + } finally { + Utils.deleteRecursively(tempDir) + } + + } } |