aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala16
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala19
2 files changed, 31 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index be12d45286..b693f3c8e4 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -604,13 +604,21 @@ object Word2VecModel extends Loader[Word2VecModel] {
val vectorSize = model.values.head.size
val numWords = model.size
- val metadata = compact(render
- (("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
- ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
+ val metadata = compact(render(
+ ("class" -> classNameV1_0) ~ ("version" -> formatVersionV1_0) ~
+ ("vectorSize" -> vectorSize) ~ ("numWords" -> numWords)))
sc.parallelize(Seq(metadata), 1).saveAsTextFile(Loader.metadataPath(path))
+ // We want to partition the model in partitions of size 32MB
+ val partitionSize = (1L << 25)
+ // We calculate the approximate size of the model
+ // We only calculate the array size, not considering
+ // the string size, the formula is:
+ // floatSize * numWords * vectorSize
+ val approxSize = 4L * numWords * vectorSize
+ val nPartitions = ((approxSize / partitionSize) + 1).toInt
val dataArray = model.toSeq.map { case (w, v) => Data(w, v) }
- sc.parallelize(dataArray.toSeq, 1).toDF().write.parquet(Loader.dataPath(path))
+ sc.parallelize(dataArray.toSeq, nPartitions).toDF().write.parquet(Loader.dataPath(path))
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
index a864eec460..37d01e2876 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/Word2VecSuite.scala
@@ -92,4 +92,23 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+
+ test("big model load / save") {
+ // create a model bigger than 32MB since 9000 * 1000 * 4 > 2^25
+ val word2VecMap = Map((0 to 9000).map(i => s"$i" -> Array.fill(1000)(0.1f)): _*)
+ val model = new Word2VecModel(word2VecMap)
+
+ val tempDir = Utils.createTempDir()
+ val path = tempDir.toURI.toString
+
+ try {
+ model.save(sc, path)
+ val sameModel = Word2VecModel.load(sc, path)
+ assert(sameModel.getVectors.mapValues(_.toSeq) === model.getVectors.mapValues(_.toSeq))
+ } finally {
+ Utils.deleteRecursively(tempDir)
+ }
+ }
+
+
}