aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/feature.py
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-04-03 09:49:50 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-03 09:49:50 -0700
commit512a2f191a6b53699373b6588f316b4437050425 (patch)
tree302ae3dad48e4d261de5e87f10c16e0eac1036dd /python/pyspark/mllib/feature.py
parentb52c7f9fc87a1b9a039724e1dac8b30554f75196 (diff)
downloadspark-512a2f191a6b53699373b6588f316b4437050425.tar.gz
spark-512a2f191a6b53699373b6588f316b4437050425.tar.bz2
spark-512a2f191a6b53699373b6588f316b4437050425.zip
[SPARK-6615][MLLIB] Python API for Word2Vec
This is the sub-task of SPARK-6254. Wrap missing method for `Word2Vec` and `Word2VecModel`. Author: lewuathe <lewuathe@me.com> Closes #5296 from Lewuathe/SPARK-6615 and squashes the following commits: f14c304 [lewuathe] Reorder tests 1d326b9 [lewuathe] Merge master e2bedfb [lewuathe] Modify test cases afb866d [lewuathe] [SPARK-6615] Python API for Word2Vec
Diffstat (limited to 'python/pyspark/mllib/feature.py')
-rw-r--r--python/pyspark/mllib/feature.py18
1 files changed, 17 insertions, 1 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 4bfe3014ef..3cda1205e1 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -337,6 +337,12 @@ class Word2VecModel(JavaVectorTransformer):
words, similarity = self.call("findSynonyms", word, num)
return zip(words, similarity)
+ def getVectors(self):
+ """
+ Returns a map of words to their vector representations.
+ """
+ return self.call("getVectors")
+
class Word2Vec(object):
"""
@@ -379,6 +385,7 @@ class Word2Vec(object):
self.numPartitions = 1
self.numIterations = 1
self.seed = random.randint(0, sys.maxint)
+ self.minCount = 5
def setVectorSize(self, vectorSize):
"""
@@ -417,6 +424,14 @@ class Word2Vec(object):
self.seed = seed
return self
+ def setMinCount(self, minCount):
+ """
+ Sets minCount, the minimum number of times a token must appear
+ to be included in the word2vec model's vocabulary (default: 5).
+ """
+ self.minCount = minCount
+ return self
+
def fit(self, data):
"""
Computes the vector representation of each word in vocabulary.
@@ -428,7 +443,8 @@ class Word2Vec(object):
raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
- int(self.numIterations), long(self.seed))
+ int(self.numIterations), long(self.seed),
+ int(self.minCount))
return Word2VecModel(jmodel)