aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-04-03 09:49:50 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-03 09:49:50 -0700
commit512a2f191a6b53699373b6588f316b4437050425 (patch)
tree302ae3dad48e4d261de5e87f10c16e0eac1036dd /mllib
parentb52c7f9fc87a1b9a039724e1dac8b30554f75196 (diff)
downloadspark-512a2f191a6b53699373b6588f316b4437050425.tar.gz
spark-512a2f191a6b53699373b6588f316b4437050425.tar.bz2
spark-512a2f191a6b53699373b6588f316b4437050425.zip
[SPARK-6615][MLLIB] Python API for Word2Vec
This is the sub-task of SPARK-6254. Wrap missing method for `Word2Vec` and `Word2VecModel`. Author: lewuathe <lewuathe@me.com> Closes #5296 from Lewuathe/SPARK-6615 and squashes the following commits: f14c304 [lewuathe] Reorder tests 1d326b9 [lewuathe] Merge master e2bedfb [lewuathe] Modify test cases afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala8
1 files changed, 7 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 5995d6df97..6c386cacfb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -476,13 +476,15 @@ private[python] class PythonMLLibAPI extends Serializable {
learningRate: Double,
numPartitions: Int,
numIterations: Int,
- seed: Long): Word2VecModelWrapper = {
+ seed: Long,
+ minCount: Int): Word2VecModelWrapper = {
val word2vec = new Word2Vec()
.setVectorSize(vectorSize)
.setLearningRate(learningRate)
.setNumPartitions(numPartitions)
.setNumIterations(numIterations)
.setSeed(seed)
+ .setMinCount(minCount)
try {
val model = word2vec.fit(dataJRDD.rdd.persist(StorageLevel.MEMORY_AND_DISK_SER))
new Word2VecModelWrapper(model)
@@ -516,6 +518,10 @@ private[python] class PythonMLLibAPI extends Serializable {
val words = result.map(_._1)
List(words, similarity).map(_.asInstanceOf[Object]).asJava
}
+
+ def getVectors: JMap[String, JList[Float]] = {
+ model.getVectors.map({case (k, v) => (k, v.toList.asJava)}).asJava
+ }
}
/**