diff options
author | Joseph J.C. Tang <jinntrance@gmail.com> | 2015-01-30 10:07:26 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-01-30 10:07:26 -0800 |
commit | 54d95758fcbe29a9af0f59673ac0b8a8c72b778e (patch) | |
tree | b5f049abcc78cb82226dac701087e26902e272b4 | |
parent | 254eaa4d350dafe19f1715e80eb816856a126c21 (diff) | |
download | spark-54d95758fcbe29a9af0f59673ac0b8a8c72b778e.tar.gz spark-54d95758fcbe29a9af0f59673ac0b8a8c72b778e.tar.bz2 spark-54d95758fcbe29a9af0f59673ac0b8a8c72b778e.zip |
[MLLIB] SPARK-4846: throw a RuntimeException and give users hints to increase the minCount
When the vocabSize\*vectorSize is larger than Int.MaxValue/8, we try to throw a RuntimeException. Because under this circumstance it would definitely throw an OOM when allocating memory to serialize the arrays syn0Global&syn1Global. syn0Global&syn1Global are float arrays. Serializing them should need a byte array of more than 8 times of syn0Global's size.
Also if we catch an OOM even if vocabSize\*vectorSize is less than Int.MaxValue/8, we should give users hints to increase the minCount or decrease the vectorSize.
Author: Joseph J.C. Tang <jinntrance@gmail.com>
Closes #4247 from jinntrance/w2v-fix and squashes the following commits:
b5eb71f [Joseph J.C. Tang] throw a RuntimeException and give users hints regarding the vectorSize&minCount
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala | 7 |
1 files changed, 7 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index d25a7cd5b4..a3e40200bc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -290,6 +290,13 @@ class Word2Vec extends Serializable with Logging { val newSentences = sentences.repartition(numPartitions).cache() val initRandom = new XORShiftRandom(seed) + + if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) { + throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" + + " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " + + "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.") + } + val syn0Global = Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize) val syn1Global = new Array[Float](vocabSize * vectorSize) |