aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-07-24 14:58:07 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-24 14:58:07 -0700
commita400ab516fa93185aa683a596f9d7c6c1a02f330 (patch)
tree0dfc4adc09cd782fedb5a0d30d09e061f55bee61
parent64135cbb3363e3b74dad3c0498cb9959c047d381 (diff)
downloadspark-a400ab516fa93185aa683a596f9d7c6c1a02f330.tar.gz
spark-a400ab516fa93185aa683a596f9d7c6c1a02f330.tar.bz2
spark-a400ab516fa93185aa683a596f9d7c6c1a02f330.zip
[SPARK-7045] [MLLIB] Avoid intermediate representation when creating model
Word2Vec used to convert from an Array[Float] representation to a Map[String, Array[Float]] and then back to an Array[Float] through Word2VecModel. This prevents this conversion while still supporting the older method of supplying a Map. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5748 from MechCoder/spark-7045 and squashes the following commits: e308913 [MechCoder] move docs 5703116 [MechCoder] minor fa04313 [MechCoder] style fixes b1d61c4 [MechCoder] better errors and tests 3b32c8c [MechCoder] [SPARK-7045] Avoid intermediate representation when creating model
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala85
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala6
2 files changed, 55 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index f087d06d2a..cbbd2b0c8d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging {
}
newSentences.unpersist()
- val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
- var i = 0
- while (i < vocabSize) {
- val word = bcVocab.value(i).word
- val vector = new Array[Float](vectorSize)
- Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize)
- word2VecMap += word -> vector
- i += 1
- }
-
- new Word2VecModel(word2VecMap.toMap)
+ val wordArray = vocab.map(_.word)
+ new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global)
}
/**
@@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging {
/**
* :: Experimental ::
* Word2Vec model
+ * @param wordIndex maps each word to an index, which can retrieve the corresponding
+ * vector from wordVectors
+ * @param wordVectors array of length numWords * vectorSize, vector corresponding
+ * to the word mapped with index i can be retrieved by the slice
+ * (i * vectorSize, i * vectorSize + vectorSize)
*/
@Experimental
-class Word2VecModel private[spark] (
- model: Map[String, Array[Float]]) extends Serializable with Saveable {
-
- // wordList: Ordered list of words obtained from model.
- private val wordList: Array[String] = model.keys.toArray
-
- // wordIndex: Maps each word to an index, which can retrieve the corresponding
- // vector from wordVectors (see below).
- private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap
+class Word2VecModel private[mllib] (
+ private val wordIndex: Map[String, Int],
+ private val wordVectors: Array[Float]) extends Serializable with Saveable {
- // vectorSize: Dimension of each word's vector.
- private val vectorSize = model.head._2.size
private val numWords = wordIndex.size
+ // vectorSize: Dimension of each word's vector.
+ private val vectorSize = wordVectors.length / numWords
+
+ // wordList: Ordered list of words obtained from wordIndex.
+ private val wordList: Array[String] = {
+ val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip
+ wl.toArray
+ }
- // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word
- // mapped with index i can be retrieved by the slice
- // (ind * vectorSize, ind * vectorSize + vectorSize)
// wordVecNorms: Array of length numWords, each value being the Euclidean norm
// of the wordVector.
- private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = {
- val wordVectors = new Array[Float](vectorSize * numWords)
+ private val wordVecNorms: Array[Double] = {
val wordVecNorms = new Array[Double](numWords)
var i = 0
while (i < numWords) {
- val vec = model.get(wordList(i)).get
- Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize)
+ val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize)
wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1)
i += 1
}
- (wordVectors, wordVecNorms)
+ wordVecNorms
+ }
+
+ def this(model: Map[String, Array[Float]]) = {
+ this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model))
}
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
@@ -484,8 +479,9 @@ class Word2VecModel private[spark] (
* @return vector representation of word
*/
def transform(word: String): Vector = {
- model.get(word) match {
- case Some(vec) =>
+ wordIndex.get(word) match {
+ case Some(ind) =>
+ val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize)
Vectors.dense(vec.map(_.toDouble))
case None =>
throw new IllegalStateException(s"$word not in vocabulary")
@@ -511,7 +507,7 @@ class Word2VecModel private[spark] (
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
-
+ // TODO: optimize top-k
val fVector = vector.toArray.map(_.toFloat)
val cosineVec = Array.fill[Float](numWords)(0)
val alpha: Float = 1
@@ -521,13 +517,13 @@ class Word2VecModel private[spark] (
"T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1)
// Need not divide with the norm of the given vector since it is constant.
- val updatedCosines = new Array[Double](numWords)
+ val cosVec = cosineVec.map(_.toDouble)
var ind = 0
while (ind < numWords) {
- updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind)
+ cosVec(ind) /= wordVecNorms(ind)
ind += 1
}
- wordList.zip(updatedCosines)
+ wordList.zip(cosVec)
.toSeq
.sortBy(- _._2)
.take(num + 1)
@@ -548,6 +544,23 @@ class Word2VecModel private[spark] (
@Experimental
object Word2VecModel extends Loader[Word2VecModel] {
+ private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = {
+ model.keys.zipWithIndex.toMap
+ }
+
+ private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = {
+ require(model.nonEmpty, "Word2VecMap should be non-empty")
+ val (vectorSize, numWords) = (model.head._2.size, model.size)
+ val wordList = model.keys.toArray
+ val wordVectors = new Array[Float](vectorSize * numWords)
+ var i = 0
+ while (i < numWords) {
+ Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize)
+ i += 1
+ }
+ wordVectors
+ }
+
private object SaveLoadV1_0 {
val formatVersionV1_0 = "1.0"
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index b681836920..4cc8d1129b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -37,6 +37,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(syms.length == 2)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
+
+ // Test that model built using Word2Vec, i.e wordVectors and wordIndec
+ // and a Word2VecMap give the same values.
+ val word2VecMap = model.getVectors
+ val newModel = new Word2VecModel(word2VecMap)
+ assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq))
}
test("Word2VecModel") {