diff options
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 85 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala | 6 |
2 files changed, 55 insertions, 36 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index f087d06d2a..cbbd2b0c8d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -403,17 +403,8 @@ class Word2Vec extends Serializable with Logging { } newSentences.unpersist() - val word2VecMap = mutable.HashMap.empty[String, Array[Float]] - var i = 0 - while (i < vocabSize) { - val word = bcVocab.value(i).word - val vector = new Array[Float](vectorSize) - Array.copy(syn0Global, i * vectorSize, vector, 0, vectorSize) - word2VecMap += word -> vector - i += 1 - } - - new Word2VecModel(word2VecMap.toMap) + val wordArray = vocab.map(_.word) + new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) } /** @@ -429,38 +420,42 @@ class Word2Vec extends Serializable with Logging { /** * :: Experimental :: * Word2Vec model + * @param wordIndex maps each word to an index, which can retrieve the corresponding + * vector from wordVectors + * @param wordVectors array of length numWords * vectorSize, vector corresponding + * to the word mapped with index i can be retrieved by the slice + * (i * vectorSize, i * vectorSize + vectorSize) */ @Experimental -class Word2VecModel private[spark] ( - model: Map[String, Array[Float]]) extends Serializable with Saveable { - - // wordList: Ordered list of words obtained from model. - private val wordList: Array[String] = model.keys.toArray - - // wordIndex: Maps each word to an index, which can retrieve the corresponding - // vector from wordVectors (see below). - private val wordIndex: Map[String, Int] = wordList.zip(0 until model.size).toMap +class Word2VecModel private[mllib] ( + private val wordIndex: Map[String, Int], + private val wordVectors: Array[Float]) extends Serializable with Saveable { - // vectorSize: Dimension of each word's vector. - private val vectorSize = model.head._2.size private val numWords = wordIndex.size + // vectorSize: Dimension of each word's vector. + private val vectorSize = wordVectors.length / numWords + + // wordList: Ordered list of words obtained from wordIndex. + private val wordList: Array[String] = { + val (wl, _) = wordIndex.toSeq.sortBy(_._2).unzip + wl.toArray + } - // wordVectors: Array of length numWords * vectorSize, vector corresponding to the word - // mapped with index i can be retrieved by the slice - // (ind * vectorSize, ind * vectorSize + vectorSize) // wordVecNorms: Array of length numWords, each value being the Euclidean norm // of the wordVector. - private val (wordVectors: Array[Float], wordVecNorms: Array[Double]) = { - val wordVectors = new Array[Float](vectorSize * numWords) + private val wordVecNorms: Array[Double] = { val wordVecNorms = new Array[Double](numWords) var i = 0 while (i < numWords) { - val vec = model.get(wordList(i)).get - Array.copy(vec, 0, wordVectors, i * vectorSize, vectorSize) + val vec = wordVectors.slice(i * vectorSize, i * vectorSize + vectorSize) wordVecNorms(i) = blas.snrm2(vectorSize, vec, 1) i += 1 } - (wordVectors, wordVecNorms) + wordVecNorms + } + + def this(model: Map[String, Array[Float]]) = { + this(Word2VecModel.buildWordIndex(model), Word2VecModel.buildWordVectors(model)) } private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { @@ -484,8 +479,9 @@ class Word2VecModel private[spark] ( * @return vector representation of word */ def transform(word: String): Vector = { - model.get(word) match { - case Some(vec) => + wordIndex.get(word) match { + case Some(ind) => + val vec = wordVectors.slice(ind * vectorSize, ind * vectorSize + vectorSize) Vectors.dense(vec.map(_.toDouble)) case None => throw new IllegalStateException(s"$word not in vocabulary") @@ -511,7 +507,7 @@ class Word2VecModel private[spark] ( */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - + // TODO: optimize top-k val fVector = vector.toArray.map(_.toFloat) val cosineVec = Array.fill[Float](numWords)(0) val alpha: Float = 1 @@ -521,13 +517,13 @@ class Word2VecModel private[spark] ( "T", vectorSize, numWords, alpha, wordVectors, vectorSize, fVector, 1, beta, cosineVec, 1) // Need not divide with the norm of the given vector since it is constant. - val updatedCosines = new Array[Double](numWords) + val cosVec = cosineVec.map(_.toDouble) var ind = 0 while (ind < numWords) { - updatedCosines(ind) = cosineVec(ind) / wordVecNorms(ind) + cosVec(ind) /= wordVecNorms(ind) ind += 1 } - wordList.zip(updatedCosines) + wordList.zip(cosVec) .toSeq .sortBy(- _._2) .take(num + 1) @@ -548,6 +544,23 @@ class Word2VecModel private[spark] ( @Experimental object Word2VecModel extends Loader[Word2VecModel] { + private def buildWordIndex(model: Map[String, Array[Float]]): Map[String, Int] = { + model.keys.zipWithIndex.toMap + } + + private def buildWordVectors(model: Map[String, Array[Float]]): Array[Float] = { + require(model.nonEmpty, "Word2VecMap should be non-empty") + val (vectorSize, numWords) = (model.head._2.size, model.size) + val wordList = model.keys.toArray + val wordVectors = new Array[Float](vectorSize * numWords) + var i = 0 + while (i < numWords) { + Array.copy(model(wordList(i)), 0, wordVectors, i * vectorSize, vectorSize) + i += 1 + } + wordVectors + } + private object SaveLoadV1_0 { val formatVersionV1_0 = "1.0" diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala index b681836920..4cc8d1129b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala @@ -37,6 +37,12 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext { assert(syms.length == 2) assert(syms(0)._1 == "b") assert(syms(1)._1 == "c") + + // Test that model built using Word2Vec, i.e wordVectors and wordIndec + // and a Word2VecMap give the same values. + val word2VecMap = model.getVectors + val newModel = new Word2VecModel(word2VecMap) + assert(newModel.getVectors.mapValues(_.toSeq) === word2VecMap.mapValues(_.toSeq)) } test("Word2VecModel") { |