aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala7
1 files changed, 7 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index d25a7cd5b4..a3e40200bc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -290,6 +290,13 @@ class Word2Vec extends Serializable with Logging {
val newSentences = sentences.repartition(numPartitions).cache()
val initRandom = new XORShiftRandom(seed)
+
+ if (vocabSize.toLong * vectorSize * 8 >= Int.MaxValue) {
+ throw new RuntimeException("Please increase minCount or decrease vectorSize in Word2Vec" +
+ " to avoid an OOM. You are highly recommended to make your vocabSize*vectorSize, " +
+ "which is " + vocabSize + "*" + vectorSize + " for now, less than `Int.MaxValue/8`.")
+ }
+
val syn0Global =
Array.fill[Float](vocabSize * vectorSize)((initRandom.nextFloat() - 0.5f) / vectorSize)
val syn1Global = new Array[Float](vocabSize * vectorSize)