aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-08-05 16:22:41 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-05 16:22:41 -0700
commitcc491f69cd239ae7572f1f5f55a2452f7f417dc1 (patch)
tree3d7dfa9e75f9f621c9b837c99b979e61d617521c /mllib
parent41e0a21b22ccd2788dc079790788e505b0d4e37d (diff)
downloadspark-cc491f69cd239ae7572f1f5f55a2452f7f417dc1.tar.gz
spark-cc491f69cd239ae7572f1f5f55a2452f7f417dc1.tar.bz2
spark-cc491f69cd239ae7572f1f5f55a2452f7f417dc1.zip
[SPARK-2864][MLLIB] fix random seed in word2vec; move model to local
It also moves the model to local in order to map `RDD[String]` to `RDD[Vector]`. Ishiihara Author: Xiangrui Meng <meng@databricks.com> Closes #1790 from mengxr/word2vec-fix and squashes the following commits: a87146c [Xiangrui Meng] add setters and make a default constructor e5c923b [Xiangrui Meng] fix random seed in word2vec; move model to local
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala188
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala15
2 files changed, 106 insertions, 97 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 87c81e7b0b..3bf44ad7c4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -19,16 +19,17 @@ package org.apache.spark.mllib.feature
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
import com.github.fommil.netlib.BLAS.{getInstance => blas}
-import org.apache.spark.{HashPartitioner, Logging}
+
+import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.rdd.RDDFunctions._
import org.apache.spark.rdd._
-import org.apache.spark.storage.StorageLevel
+import org.apache.spark.util.Utils
+import org.apache.spark.util.random.XORShiftRandom
/**
* Entry in vocabulary
@@ -58,29 +59,63 @@ private case class VocabWord(
* Efficient Estimation of Word Representations in Vector Space
* and
* Distributed Representations of Words and Phrases and their Compositionality.
- * @param size vector dimension
- * @param startingAlpha initial learning rate
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
- * @param numIterations number of iterations to run, should be smaller than or equal to parallelism
*/
@Experimental
-class Word2Vec(
- val size: Int,
- val startingAlpha: Double,
- val parallelism: Int,
- val numIterations: Int) extends Serializable with Logging {
+class Word2Vec extends Serializable with Logging {
+
+ private var vectorSize = 100
+ private var startingAlpha = 0.025
+ private var numPartitions = 1
+ private var numIterations = 1
+ private var seed = Utils.random.nextLong()
+
+ /**
+ * Sets vector size (default: 100).
+ */
+ def setVectorSize(vectorSize: Int): this.type = {
+ this.vectorSize = vectorSize
+ this
+ }
+
+ /**
+ * Sets initial learning rate (default: 0.025).
+ */
+ def setLearningRate(learningRate: Double): this.type = {
+ this.startingAlpha = learningRate
+ this
+ }
/**
- * Word2Vec with a single thread.
+ * Sets number of partitions (default: 1). Use a small number for accuracy.
*/
- def this(size: Int, startingAlpha: Int) = this(size, startingAlpha, 1, 1)
+ def setNumPartitions(numPartitions: Int): this.type = {
+ require(numPartitions > 0, s"numPartitions must be greater than 0 but got $numPartitions")
+ this.numPartitions = numPartitions
+ this
+ }
+
+ /**
+ * Sets number of iterations (default: 1), which should be smaller than or equal to number of
+ * partitions.
+ */
+ def setNumIterations(numIterations: Int): this.type = {
+ this.numIterations = numIterations
+ this
+ }
+
+ /**
+ * Sets random seed (default: a random long integer).
+ */
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
private val EXP_TABLE_SIZE = 1000
private val MAX_EXP = 6
private val MAX_CODE_LENGTH = 40
private val MAX_SENTENCE_LENGTH = 1000
- private val layer1Size = size
- private val modelPartitionNum = 100
+ private val layer1Size = vectorSize
/** context words from [-window, window] */
private val window = 5
@@ -94,12 +129,12 @@ class Word2Vec(
private var vocabHash = mutable.HashMap.empty[String, Int]
private var alpha = startingAlpha
- private def learnVocab(words:RDD[String]): Unit = {
+ private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
.map(x => VocabWord(
- x._1,
- x._2,
+ x._1,
+ x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
@@ -245,23 +280,24 @@ class Word2Vec(
}
}
- val newSentences = sentences.repartition(parallelism).cache()
+ val newSentences = sentences.repartition(numPartitions).cache()
+ val initRandom = new XORShiftRandom(seed)
var syn0Global =
- Array.fill[Float](vocabSize * layer1Size)((Random.nextFloat() - 0.5f) / layer1Size)
+ Array.fill[Float](vocabSize * layer1Size)((initRandom.nextFloat() - 0.5f) / layer1Size)
var syn1Global = new Array[Float](vocabSize * layer1Size)
-
- for(iter <- 1 to numIterations) {
- val (aggSyn0, aggSyn1, _, _) =
- // TODO: broadcast temp instead of serializing it directly
- // or initialize the model in each executor
- newSentences.treeAggregate((syn0Global, syn1Global, 0, 0))(
- seqOp = (c, v) => (c, v) match {
+
+ for (k <- 1 to numIterations) {
+ val partial = newSentences.mapPartitionsWithIndex { case (idx, iter) =>
+ val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
+ val model = iter.foldLeft((syn0Global, syn1Global, 0, 0)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
- var wc = wordCount
+ var wc = wordCount
if (wordCount - lastWordCount > 10000) {
lwc = wordCount
- alpha = startingAlpha * (1 - parallelism * wordCount.toDouble / (trainWordsCount + 1))
+ // TODO: discount by iteration?
+ alpha =
+ startingAlpha * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))
if (alpha < startingAlpha * 0.0001) alpha = startingAlpha * 0.0001
logInfo("wordCount = " + wordCount + ", alpha = " + alpha)
}
@@ -269,8 +305,7 @@ class Word2Vec(
var pos = 0
while (pos < sentence.size) {
val word = sentence(pos)
- // TODO: fix random seed
- val b = Random.nextInt(window)
+ val b = random.nextInt(window)
// Train Skip-gram
var a = b
while (a < window * 2 + 1 - b) {
@@ -280,7 +315,7 @@ class Word2Vec(
val lastWord = sentence(c)
val l1 = lastWord * layer1Size
val neu1e = new Array[Float](layer1Size)
- // Hierarchical softmax
+ // Hierarchical softmax
var d = 0
while (d < bcVocab.value(word).codeLen) {
val l2 = bcVocab.value(word).point(d) * layer1Size
@@ -303,44 +338,44 @@ class Word2Vec(
pos += 1
}
(syn0, syn1, lwc, wc)
- },
- combOp = (c1, c2) => (c1, c2) match {
- case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
- val n = syn0_1.length
- val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
- val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
- blas.sscal(n, weight1, syn0_1, 1)
- blas.sscal(n, weight1, syn1_1, 1)
- blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
- blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
- (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
- })
+ }
+ Iterator(model)
+ }
+ val (aggSyn0, aggSyn1, _, _) =
+ partial.treeReduce { case ((syn0_1, syn1_1, lwc_1, wc_1), (syn0_2, syn1_2, lwc_2, wc_2)) =>
+ val n = syn0_1.length
+ val weight1 = 1.0f * wc_1 / (wc_1 + wc_2)
+ val weight2 = 1.0f * wc_2 / (wc_1 + wc_2)
+ blas.sscal(n, weight1, syn0_1, 1)
+ blas.sscal(n, weight1, syn1_1, 1)
+ blas.saxpy(n, weight2, syn0_2, 1, syn0_1, 1)
+ blas.saxpy(n, weight2, syn1_2, 1, syn1_1, 1)
+ (syn0_1, syn1_1, lwc_1 + lwc_2, wc_1 + wc_2)
+ }
syn0Global = aggSyn0
syn1Global = aggSyn1
}
newSentences.unpersist()
- val wordMap = new Array[(String, Array[Float])](vocabSize)
+ val word2VecMap = mutable.HashMap.empty[String, Array[Float]]
var i = 0
while (i < vocabSize) {
val word = bcVocab.value(i).word
val vector = new Array[Float](layer1Size)
Array.copy(syn0Global, i * layer1Size, vector, 0, layer1Size)
- wordMap(i) = (word, vector)
+ word2VecMap += word -> vector
i += 1
}
- val modelRDD = sc.parallelize(wordMap, modelPartitionNum)
- .partitionBy(new HashPartitioner(modelPartitionNum))
- .persist(StorageLevel.MEMORY_AND_DISK)
-
- new Word2VecModel(modelRDD)
+
+ new Word2VecModel(word2VecMap.toMap)
}
}
/**
* Word2Vec model
-*/
-class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Serializable {
+ */
+class Word2VecModel private[mllib] (
+ private val model: Map[String, Array[Float]]) extends Serializable {
private def cosineSimilarity(v1: Array[Float], v2: Array[Float]): Double = {
require(v1.length == v2.length, "Vectors should have the same length")
@@ -357,11 +392,12 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
* @return vector representation of word
*/
def transform(word: String): Vector = {
- val result = model.lookup(word)
- if (result.isEmpty) {
- throw new IllegalStateException(s"$word not in vocabulary")
+ model.get(word) match {
+ case Some(vec) =>
+ Vectors.dense(vec.map(_.toDouble))
+ case None =>
+ throw new IllegalStateException(s"$word not in vocabulary")
}
- else Vectors.dense(result(0).map(_.toDouble))
}
/**
@@ -392,33 +428,13 @@ class Word2VecModel(private val model: RDD[(String, Array[Float])]) extends Seri
*/
def findSynonyms(vector: Vector, num: Int): Array[(String, Double)] = {
require(num > 0, "Number of similar words should > 0")
- val topK = model.map { case(w, vec) =>
- (cosineSimilarity(vector.toArray.map(_.toFloat), vec), w) }
- .sortByKey(ascending = false)
- .take(num + 1)
- .map(_.swap)
- .tail
-
- topK
- }
-}
-
-object Word2Vec{
- /**
- * Train Word2Vec model
- * @param input RDD of words
- * @param size vector dimension
- * @param startingAlpha initial learning rate
- * @param parallelism number of partitions to run Word2Vec (using a small number for accuracy)
- * @param numIterations number of iterations, should be smaller than or equal to parallelism
- * @return Word2Vec model
- */
- def train[S <: Iterable[String]](
- input: RDD[S],
- size: Int,
- startingAlpha: Double,
- parallelism: Int = 1,
- numIterations:Int = 1): Word2VecModel = {
- new Word2Vec(size,startingAlpha, parallelism, numIterations).fit[S](input)
+ // TODO: optimize top-k
+ val fVector = vector.toArray.map(_.toFloat)
+ model.mapValues(vec => cosineSimilarity(fVector, vec))
+ .toSeq
+ .sortBy(- _._2)
+ .take(num + 1)
+ .tail
+ .toArray
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index b5db39b68a..e34335d89e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -30,29 +30,22 @@ class Word2VecSuite extends FunSuite with LocalSparkContext {
val localDoc = Seq(sentence, sentence)
val doc = sc.parallelize(localDoc)
.map(line => line.split(" ").toSeq)
- val size = 10
- val startingAlpha = 0.025
- val window = 2
- val minCount = 2
- val num = 2
-
- val model = Word2Vec.train(doc, size, startingAlpha)
+ val model = new Word2Vec().setVectorSize(10).setSeed(42L).fit(doc)
val syms = model.findSynonyms("a", 2)
- assert(syms.length == num)
+ assert(syms.length == 2)
assert(syms(0)._1 == "b")
assert(syms(1)._1 == "c")
}
-
test("Word2VecModel") {
val num = 2
- val localModel = Seq(
+ val word2VecMap = Map(
("china", Array(0.50f, 0.50f, 0.50f, 0.50f)),
("japan", Array(0.40f, 0.50f, 0.50f, 0.50f)),
("taiwan", Array(0.60f, 0.50f, 0.50f, 0.50f)),
("korea", Array(0.45f, 0.60f, 0.60f, 0.60f))
)
- val model = new Word2VecModel(sc.parallelize(localModel, 2))
+ val model = new Word2VecModel(word2VecMap)
val syms = model.findSynonyms("china", num)
assert(syms.length == num)
assert(syms(0)._1 == "taiwan")