aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r--python/pyspark/ml/feature.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 3f04c41ac5..cb4dfa2129 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -15,11 +15,16 @@
# limitations under the License.
#
+import sys
+if sys.version > '3':
+ basestring = str
+
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.mllib.common import inherit_doc
+from pyspark.mllib.linalg import _convert_to_vector
__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
@@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")
>>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
>>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc)
+ >>> model.getVectors().show()
+ +----+--------------------+
+ |word| vector|
+ +----+--------------------+
+ | a|[-0.3511952459812...|
+ | b|[0.29077222943305...|
+ | c|[0.02315592765808...|
+ +----+--------------------+
+ ...
+ >>> model.findSynonyms("a", 2).show()
+ +----+-------------------+
+ |word| similarity|
+ +----+-------------------+
+ | b|0.29255685145799626|
+ | c|-0.5414068302988307|
+ +----+-------------------+
+ ...
>>> model.transform(doc).head().model
DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
"""
@@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel):
Model fitted by Word2Vec.
"""
+ def getVectors(self):
+ """
+ Returns the vector representation of the words as a dataframe
+ with two fields, word and vector.
+ """
+ return self._call_java("getVectors")
+
+ def findSynonyms(self, word, num):
+ """
+ Find "num" number of words closest in similarity to "word".
+ word can be a string or vector representation.
+ Returns a dataframe with two fields word and similarity (which
+ gives the cosine similarity).
+ """
+ if not isinstance(word, basestring):
+ word = _convert_to_vector(word)
+ return self._call_java("findSynonyms", word, num)
+
@inherit_doc
class PCA(JavaEstimator, HasInputCol, HasOutputCol):