From cc491f69cd239ae7572f1f5f55a2452f7f417dc1 Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Tue, 5 Aug 2014 16:22:41 -0700 Subject: [SPARK-2864][MLLIB] fix random seed in word2vec; move model to local It also moves the model to local in order to map `RDD[String]` to `RDD[Vector]`. Ishiihara Author: Xiangrui Meng Closes #1790 from mengxr/word2vec-fix and squashes the following commits: a87146c [Xiangrui Meng] add setters and make a default constructor e5c923b [Xiangrui Meng] fix random seed in word2vec; move model to local --- .../org/apache/spark/mllib/feature/Word2Vec.scala | 188 +++++++++++---------- 1 file changed, 102 insertions(+), 86 deletions(-) (limited to 'mllib/src/main') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 87c81e7b0b..3bf44ad7c4 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import scala.util.Random import com.github.fommil.netlib.BLAS.{getInstance => blas} -import org.apache.spark.{HashPartitioner, Logging} + +import org.apache.spark.Logging import org.apache.spark.SparkContext._ import org.apache.spark.annotation.Experimental import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.rdd.RDDFunctions._ import org.apache.spark.rdd._ -import org.apache.spark.storage.StorageLevel +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom /** * Entry in vocabulary @@ -58,29 +59,63 @@ private case class VocabWord( * Efficient Estimation of Word Representations in Vector Space * and * Distributed Representations of Words and Phrases and their Compositionality. - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations to run, should be smaller than or equal to parallelism */ @Experimental -class Word2Vec( - val size: Int, - val startingAlpha: Double, - val parallelism: Int, - val numIterations: Int) extends Serializable with Logging { +class Word2Vec extends Serializable with Logging { + + private var vectorSize = 100 + private var startingAlpha = 0.025 + private var numPartitions = 1 + private var numIterations = 1 + private var seed = Utils.random.nextLong() + + /** + * Sets vector size (default: 100). + */ + def setVectorSize(vectorSize: Int): this.type = { + this.vectorSize = vectorSize + this + } + + /** + * Sets initial learning rate (default: 0.025). + */ + def setLearningRate(learningRate: Double): this.type = { + this.startingAlpha = learningRate + this + } /** - * Word2Vec with a single thread. + * Sets number of partitions (default: 1). Use a small number for accuracy. */ - def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1) + def setNumPartitions(numPartitions: Int): this.type = { + require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions") + this.numPartitions = numPartitions + this + } + + /** + * Sets number of iterations (default: 1), which should be smaller than or equal to number of + * partitions. + */ + def setNumIterations(numIterations: Int): this.type = { + this.numIterations = numIterations + this + } + + /** + * Sets random seed (default: a random long integer). + */ + def setSeed(seed: Long): this.type = { + this.seed = seed + this + } private val EXP_TABLE_SIZE = 1000 private val MAX_EXP = 6 private val MAX_CODE_LENGTH = 40 private val MAX_SENTENCE_LENGTH = 1000 - private val layer1Size = size - private val modelPartitionNum = 100 + private val layer1Size = vectorSize /** context words from [-window, window] */ private val window = 5 @@ -94,12 +129,12 @@ class Word2Vec( private var vocabHash = mutable.HashMap.empty[String, Int] private var alpha = startingAlpha - private def learnVocab(words:RDD[String]): Unit = { + private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) .map(x => VocabWord( - x._1, - x._2, + x._1, + x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) @@ -245,23 +280,24 @@ class Word2Vec( } } - val newSentences = sentences.repartition(parallelism).cache() + val newSentences = sentences.repartition(numPartitions).cache() + val initRandom = new XORShiftRandom(seed) var syn0Global = - Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size) + Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size) var syn1Global = new Array[Float](vocabSize * layer1Size) - - for(iter <- 1 to numIterations) { - val (aggSyn0, aggSyn1, _, _) = - // TODO: broadcast temp instead of serializing it directly - // or initialize the model in each executor - newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))( - seqOp = (c, v) => (c, v) match { + + for (k <- 1 to numIterations) { + val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) => + val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) + val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount - var wc = wordCount + var wc = wordCount if (wordCount - lastWordCount > 10000) { lwc = wordCount - alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1)) + // TODO: discount by iteration? + alpha = + startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1)) if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001 logInfo("wordCount = " + wordCount + ", alpha = " + alpha) } @@ -269,8 +305,7 @@ class Word2Vec( var pos = 0 while (pos < sentence.size) { val word = sentence(pos) - // TODO: fix random seed - val b = Random.nextInt(window) + val b = random.nextInt(window) // Train Skip-gram var a = b while (a < window * 2 + 1 - b) { @@ -280,7 +315,7 @@ class Word2Vec( val lastWord = sentence(c) val l1 = lastWord * layer1Size val neu1e = new Array[Float](layer1Size) - // Hierarchical softmax + // Hierarchical softmax var d = 0 while (d < bcVocab.value(word).codeLen) { val l2 = bcVocab.value(word).point(d) * layer1Size @@ -303,44 +338,44 @@ class Word2Vec( pos += 1 } (syn0, syn1, lwc, wc) - }, - combOp = (c1, c2) => (c1, c2) match { - case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => - val n = syn0_1.length - val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) - val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) - blas.sscal(n, weight1, syn0_1, 1) - blas.sscal(n, weight1, syn1_1, 1) - blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) - blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) - (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) - }) + } + Iterator(model) + } + val (aggSyn0, aggSyn1, _, _) = + partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) => + val n = syn0_1.length + val weight1 = 1.0f * wc_1 / (wc_1 + wc_2) + val weight2 = 1.0f * wc_2 / (wc_1 + wc_2) + blas.sscal(n, weight1, syn0_1, 1) + blas.sscal(n, weight1, syn1_1, 1) + blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1) + blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1) + (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2) + } syn0Global = aggSyn0 syn1Global = aggSyn1 } newSentences.unpersist() - val wordMap = new Array[(String, Array[Float])](vocabSize) + val word2VecMap = mutable.HashMap.empty[String, Array[Float]] var i = 0 while (i < vocabSize) { val word = bcVocab.value(i).word val vector = new Array[Float](layer1Size) Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size) - wordMap(i) = (word, vector) + word2VecMap += word -> vector i += 1 } - val modelRDD = sc.parallelize(wordMap, modelPartitionNum) - .partitionBy(new HashPartitioner(modelPartitionNum)) - .persist(StorageLevel.MEMORY_AND_DISK) - - new Word2VecModel(modelRDD) + + new Word2VecModel(word2VecMap.toMap) } } /** * Word2Vec model -*/ -class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable { + */ +class Word2VecModel private[mllib] ( + private val model: Map[String, Array[Float]]) extends Serializable { private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = { require(v1.length == v2.length, "Vectors should have the same length") @@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri * @return vector representation of word */ def transform(word: String): Vector = { - val result = model.lookup(word) - if (result.isEmpty) { - throw new IllegalStateException(s"$word not in vocabulary") + model.get(word) match { + case Some(vec) => + Vectors.dense(vec.map(_.toDouble)) + case None => + throw new IllegalStateException(s"$word not in vocabulary") } - else Vectors.dense(result(0).map(_.toDouble)) } /** @@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri */ def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = { require(num > 0, "Number of similar words should > 0") - val topK = model.map { case(w, vec) => - (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) } - .sortByKey(ascending = false) - .take(num + 1) - .map(_.swap) - .tail - - topK - } -} - -object Word2Vec{ - /** - * Train Word2Vec model - * @param input RDD of words - * @param size vector dimension - * @param startingAlpha initial learning rate - * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy) - * @param numIterations number of iterations, should be smaller than or equal to parallelism - * @return Word2Vec model - */ - def train[S <: Iterable[String]]( - input: RDD[S], - size: Int, - startingAlpha: Double, - parallelism: Int = 1, - numIterations:Int = 1): Word2VecModel = { - new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input) + // TODO: optimize top-k + val fVector = vector.toArray.map(_.toFloat) + model.mapValues(vec => cosineSimilarity(fVector, vec)) + .toSeq + .sortBy(- _._2) + .take(num + 1) + .tail + .toArray } } -- cgit v1.2.3