diff options
Diffstat (limited to 'python/pyspark/mllib/feature.py')
-rw-r--r-- | python/pyspark/mllib/feature.py | 18 |
1 files changed, 17 insertions, 1 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index 4bfe3014ef..3cda1205e1 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -337,6 +337,12 @@ class Word2VecModel(JavaVectorTransformer): words, similarity = self.call("findSynonyms", word, num) return zip(words, similarity) + def getVectors(self): + """ + Returns a map of words to their vector representations. + """ + return self.call("getVectors") + class Word2Vec(object): """ @@ -379,6 +385,7 @@ class Word2Vec(object): self.numPartitions = 1 self.numIterations = 1 self.seed = random.randint(0, sys.maxint) + self.minCount = 5 def setVectorSize(self, vectorSize): """ @@ -417,6 +424,14 @@ class Word2Vec(object): self.seed = seed return self + def setMinCount(self, minCount): + """ + Sets minCount, the minimum number of times a token must appear + to be included in the word2vec model's vocabulary (default: 5). + """ + self.minCount = minCount + return self + def fit(self, data): """ Computes the vector representation of each word in vocabulary. @@ -428,7 +443,8 @@ class Word2Vec(object): raise TypeError("data should be an RDD of list of string") jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize), float(self.learningRate), int(self.numPartitions), - int(self.numIterations), long(self.seed)) + int(self.numIterations), long(self.seed), + int(self.minCount)) return Word2VecModel(jmodel) |