diff options
author | Yuhao Yang <hhbyyh@gmail.com> | 2016-01-11 14:48:35 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-11 14:48:35 -0800 |
commit | 4f8eefa36bb90812aac61ac7a762c9452de666bf (patch) | |
tree | 289fb7093b3c83d92ead9906837b9e471e10aa9d /mllib | |
parent | ee4ee02b86be8756a6d895a2e23e80862134a6d3 (diff) | |
download | spark-4f8eefa36bb90812aac61ac7a762c9452de666bf.tar.gz spark-4f8eefa36bb90812aac61ac7a762c9452de666bf.tar.bz2 spark-4f8eefa36bb90812aac61ac7a762c9452de666bf.zip |
[SPARK-12685][MLLIB] word2vec trainWordsCount gets overflow
jira: https://issues.apache.org/jira/browse/SPARK-12685
the log of `word2vec` reports
trainWordsCount = -785727483
during computation over a large dataset.
Update the priority as it will affect the computation process.
`alpha = learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))`
Author: Yuhao Yang <hhbyyh@gmail.com>
Closes #10627 from hhbyyh/w2voverflow.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 8 |
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index a7e1b76df6..dc5d070890 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -151,7 +151,7 @@ class Word2Vec extends Serializable with Logging { /** context words from [-window, window] */ private var window = 5 - private var trainWordsCount = 0 + private var trainWordsCount = 0L private var vocabSize = 0 @transient private var vocab: Array[VocabWord] = null @transient private var vocabHash = mutable.HashMap.empty[String, Int] @@ -159,13 +159,13 @@ class Word2Vec extends Serializable with Logging { private def learnVocab(words: RDD[String]): Unit = { vocab = words.map(w => (w, 1)) .reduceByKey(_ + _) + .filter(_._2 >= minCount) .map(x => VocabWord( x._1, x._2, new Array[Int](MAX_CODE_LENGTH), new Array[Int](MAX_CODE_LENGTH), 0)) - .filter(_.cn >= minCount) .collect() .sortWith((a, b) => a.cn > b.cn) @@ -179,7 +179,7 @@ class Word2Vec extends Serializable with Logging { trainWordsCount += vocab(a).cn a += 1 } - logInfo("trainWordsCount = " + trainWordsCount) + logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount") } private def createExpTable(): Array[Float] = { @@ -332,7 +332,7 @@ class Word2Vec extends Serializable with Logging { val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8)) val syn0Modify = new Array[Int](vocabSize) val syn1Modify = new Array[Int](vocabSize) - val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) { + val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) { case ((syn0, syn1, lastWordCount, wordCount), sentence) => var lwc = lastWordCount var wc = wordCount |