aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-01-11 14:48:35 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-11 14:48:35 -0800
commit4f8eefa36bb90812aac61ac7a762c9452de666bf (patch)
tree289fb7093b3c83d92ead9906837b9e471e10aa9d /mllib
parentee4ee02b86be8756a6d895a2e23e80862134a6d3 (diff)
downloadspark-4f8eefa36bb90812aac61ac7a762c9452de666bf.tar.gz
spark-4f8eefa36bb90812aac61ac7a762c9452de666bf.tar.bz2
spark-4f8eefa36bb90812aac61ac7a762c9452de666bf.zip
[SPARK-12685][MLLIB] word2vec trainWordsCount gets overflow
jira: https://issues.apache.org/jira/browse/SPARK-12685 the log of `word2vec` reports trainWordsCount = -785727483 during computation over a large dataset. Update the priority as it will affect the computation process. `alpha = learningRate * (1 - numPartitions * wordCount.toDouble / (trainWordsCount + 1))` Author: Yuhao Yang <hhbyyh@gmail.com> Closes #10627 from hhbyyh/w2voverflow.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala8
1 files changed, 4 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index a7e1b76df6..dc5d070890 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -151,7 +151,7 @@ class Word2Vec extends Serializable with Logging {
/** context words from [-window, window] */
private var window = 5
- private var trainWordsCount = 0
+ private var trainWordsCount = 0L
private var vocabSize = 0
@transient private var vocab: Array[VocabWord] = null
@transient private var vocabHash = mutable.HashMap.empty[String, Int]
@@ -159,13 +159,13 @@ class Word2Vec extends Serializable with Logging {
private def learnVocab(words: RDD[String]): Unit = {
vocab = words.map(w => (w, 1))
.reduceByKey(_ + _)
+ .filter(_._2 >= minCount)
.map(x => VocabWord(
x._1,
x._2,
new Array[Int](MAX_CODE_LENGTH),
new Array[Int](MAX_CODE_LENGTH),
0))
- .filter(_.cn >= minCount)
.collect()
.sortWith((a, b) => a.cn > b.cn)
@@ -179,7 +179,7 @@ class Word2Vec extends Serializable with Logging {
trainWordsCount += vocab(a).cn
a += 1
}
- logInfo("trainWordsCount = " + trainWordsCount)
+ logInfo(s"vocabSize = $vocabSize, trainWordsCount = $trainWordsCount")
}
private def createExpTable(): Array[Float] = {
@@ -332,7 +332,7 @@ class Word2Vec extends Serializable with Logging {
val random = new XORShiftRandom(seed ^ ((idx + 1) << 16) ^ ((-k - 1) << 8))
val syn0Modify = new Array[Int](vocabSize)
val syn1Modify = new Array[Int](vocabSize)
- val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0, 0)) {
+ val model = iter.foldLeft((bcSyn0Global.value, bcSyn1Global.value, 0L, 0L)) {
case ((syn0, syn1, lastWordCount, wordCount), sentence) =>
var lwc = lastWordCount
var wc = wordCount