aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph J.C. Tang <jinntrance@gmail.com>2015-01-30 10:07:26 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-30 10:07:26 -0800
commit54d95758fcbe29a9af0f59673ac0b8a8c72b778e (patch)
treeb5f049abcc78cb82226dac701087e26902e272b4 /mllib
parent254eaa4d350dafe19f1715e80eb816856a126c21 (diff)
downloadspark-54d95758fcbe29a9af0f59673ac0b8a8c72b778e.tar.gz
spark-54d95758fcbe29a9af0f59673ac0b8a8c72b778e.tar.bz2
spark-54d95758fcbe29a9af0f59673ac0b8a8c72b778e.zip
[MLLIB] SPARK-4846: throw a RuntimeException and give users hints to increase the minCount
When the vocabSize\*vectorSize is larger than Int.MaxValue/8, we try to throw a RuntimeException. Because under this circumstance it would definitely throw an OOM when allocating memory to serialize the arrays syn0Global&syn1Global. syn0Global&syn1Global are float arrays. Serializing them should need a byte array of more than 8 times of syn0Global's size. Also if we catch an OOM even if vocabSize\*vectorSize is less than Int.MaxValue/8, we should give users hints to increase the minCount or decrease the vectorSize. Author: Joseph J.C. Tang <jinntrance@gmail.com> Closes #4247 from jinntrance/w2v-fix and squashes the following commits: b5eb71f [Joseph J.C. Tang] throw a RuntimeException and give users hints regarding the vectorSize&minCount
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala7
1 files changed, 7 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index d25a7cd5b4..a3e40200bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -290,6 +290,13 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
+
+ if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) {
+ throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" +
+ " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " +
+ "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.")
+ }
+
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)