aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-03-31 16:01:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-31 16:01:08 -0700
commit0e00f12d33d28d064c166262b14e012a1aeaa7b0 (patch)
treebc69dd88ed7ee75ec3ff6bf0a744c00f8bcc86af /mllib/src/test
parent2036bc5993022da550f0cb1c0485ae92ec3e6fb0 (diff)
downloadspark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.tar.gz
spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.tar.bz2
spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.zip
[SPARK-5692] [MLlib] Word2Vec save/load
Word2Vec model now supports saving and loading. a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and "numWords" b] The data stored in Parquet file format consists of an Array of rows with each row consisting of 2 columns, first being the word: String and the second, an Array of Floats. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5291 from MechCoder/spark-5692 and squashes the following commits: 1142f3a [MechCoder] Add numWords to metaData bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala26
1 files changed, 26 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 52278690db..98a98a7599 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -21,6 +21,9 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
// TODO: add more tests
@@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
+
+ test("model load / save") {
+
+ val word2VecMap = Map(
+ ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+ ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+ ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+ ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+ )
+ val model = new Word2VecModel(word2VecMap)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ model.save(sc, path)
+ val sameModel = Word2VecModel.load(sc, path)
+ assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ }
}