aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-03-31 16:01:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-31 16:01:08 -0700
commit0e00f12d33d28d064c166262b14e012a1aeaa7b0 (patch)
treebc69dd88ed7ee75ec3ff6bf0a744c00f8bcc86af
parent2036bc5993022da550f0cb1c0485ae92ec3e6fb0 (diff)
downloadspark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.tar.gz
spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.tar.bz2
spark-0e00f12d33d28d064c166262b14e012a1aeaa7b0.zip
[SPARK-5692] [MLlib] Word2Vec save/load
Word2Vec model now supports saving and loading. a] The Metadata stored in JSON format consists of "version", "classname", "vectorSize" and "numWords" b] The data stored in Parquet file format consists of an Array of rows with each row consisting of 2 columns, first being the word: String and the second, an Array of Floats. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5291 from MechCoder/spark-5692 and squashes the following commits: 1142f3a [MechCoder] Add numWords to metaData bfe4c39 [MechCoder] [SPARK-5692] Word2Vec save/load
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala87
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala26
2 files changed, 110 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 59a79e5c6a..9ee7e4a66b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -25,14 +25,21 @@ import scala.collection.mutable.ArrayBuilder
import com.github.fommil.netlib.BLAS.{getInstance => blas}
+import org.json4s.DefaultFormats
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
import org.apache.spark.Logging
+import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.{Loader, Saveable}
import org.apache.spark.rdd._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandom
+import org.apache.spark.sql.{SQLContext, Row}
/**
* Entry in vocabulary
@@ -422,7 +429,7 @@ class Word2Vec extends Serializable with Logging {
*/
@Experimental
class Word2VecModel private[mllib] (
- private val model: Map[String, Array[Float]]) extends Serializable {
+ private val model: Map[String, Array[Float]]) extends Serializable with Saveable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
@@ -432,7 +439,13 @@ class Word2VecModel private[mllib] (
if (norm1 == 0 || norm2 == 0) return 0.0
blas.sdot(n, v1, 1, v2,1) / norm1 / norm2
}
-
+
+ override protected def formatVersion = "1.0"
+
+ def save(sc: SparkContext, path: String): Unit = {
+ Word2VecModel.SaveLoadV1_0.save(sc, path, model)
+ }
+
/**
* Transforms a word to its vector representation
* @param word a word
@@ -475,7 +488,7 @@ class Word2VecModel private[mllib] (
.tail
.toArray
}
-
+
/**
* Returns a map of words to their vector representations.
*/
@@ -483,3 +496,71 @@ class Word2VecModel private[mllib] (
model
}
}
+
+@Experimental
+object Word2VecModel extends Loader[Word2VecModel] {
+
+ private object SaveLoadV1_0 {
+
+ val formatVersionV1_0 = "1.0"
+
+ val classNameV1_0 = "org.apache.spark.mllib.feature.Word2VecModel"
+
+ case class Data(word: String, vector: Array[Float])
+
+ def load(sc: SparkContext, path: String): Word2VecModel = {
+ val dataPath = Loader.dataPath(path)
+ val sqlContext = new SQLContext(sc)
+ val dataFrame = sqlContext.parquetFile(dataPath)
+
+ val dataArray = dataFrame.select("word", "vector").collect()
+
+ // Check schema explicitly since erasure makes it hard to use match-case for checking.
+ Loader.checkSchema[Data](dataFrame.schema)
+
+ val word2VecMap = dataArray.map(i => (i.getString(0), i.getSeq[Float](1).toArray)).toMap
+ new Word2VecModel(word2VecMap)
+ }
+
+ def save(sc: SparkContext, path: String, model: Map[String, Array[Float]]) = {
+
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val vectorSize = model.values.head.size
+ val numWords = model.size
+ val metadata = compact(render
+ (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
+ ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
+ sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+
+ val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
+ sc.parallelize(dataArray.toSeq, 1).toDF().saveAsParquetFile(Loader.dataPath(path))
+ }
+ }
+
+ override def load(sc: SparkContext, path: String): Word2VecModel = {
+
+ val (loadedClassName, loadedVersion, metadata) = Loader.loadMetadata(sc, path)
+ implicit val formats = DefaultFormats
+ val expectedVectorSize = (metadata \ "vectorSize").extract[Int]
+ val expectedNumWords = (metadata \ "numWords").extract[Int]
+ val classNameV1_0 = SaveLoadV1_0.classNameV1_0
+ (loadedClassName, loadedVersion) match {
+ case (classNameV1_0, "1.0") =>
+ val model = SaveLoadV1_0.load(sc, path)
+ val vectorSize = model.getVectors.values.head.size
+ val numWords = model.getVectors.size
+ require(expectedVectorSize == vectorSize,
+ s"Word2VecModel requires each word to be mapped to a vector of size " +
+ s"$expectedVectorSize, got vector of size $vectorSize")
+ require(expectedNumWords == numWords,
+ s"Word2VecModel requires $expectedNumWords words, but got $numWords")
+ model
+ case _ => throw new Exception(
+ s"Word2VecModel.load did not recognize model with (className, format version):" +
+ s"($loadedClassName, $loadedVersion). Supported:\n" +
+ s" ($classNameV1_0, 1.0)")
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index 52278690db..98a98a7599 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -21,6 +21,9 @@ import org.scalatest.FunSuite
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.util.Utils
+
class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
// TODO: add more tests
@@ -51,4 +54,27 @@ class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
assert(syms(0)._1 == "taiwan")
assert(syms(1)._1 == "japan")
}
+
+ test("model load / save") {
+
+ val word2VecMap = Map(
+ ("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
+ ("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
+ ("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
+ ("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
+ )
+ val model = new Word2VecModel(word2VecMap)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ model.save(sc, path)
+ val sameModel = Word2VecModel.load(sc, path)
+ assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ }
}