diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-08-06 10:09:58 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-06 10:10:06 -0700 |
commit | e24b976506dd8563e4fe9cc295c756a1ce979e0d (patch) | |
tree | 46266809d58eefc242c31b630e16185230fb480a /python | |
parent | 70b9ed11d08014b96da9d5747c0cebb4927c0459 (diff) | |
download | spark-e24b976506dd8563e4fe9cc295c756a1ce979e0d.tar.gz spark-e24b976506dd8563e4fe9cc295c756a1ce979e0d.tar.bz2 spark-e24b976506dd8563e4fe9cc295c756a1ce979e0d.zip |
[SPARK-9533] [PYSPARK] [ML] Add missing methods in Word2Vec ML
After https://github.com/apache/spark/pull/7263 it is pretty straightforward to Python wrappers.
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #7930 from MechCoder/spark-9533 and squashes the following commits:
1bea394 [MechCoder] make getVectors a lazy val
5522756 [MechCoder] [SPARK-9533] [PySpark] [ML] Add missing methods in Word2Vec ML
(cherry picked from commit 076ec056818a65216eaf51aa5b3bd8f697c34748)
Signed-off-by: Joseph K. Bradley <joseph@databricks.com>
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/feature.py | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3f04c41ac5..cb4dfa2129 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -15,11 +15,16 @@ # limitations under the License. # +import sys +if sys.version > '3': + basestring = str + from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer from pyspark.mllib.common import inherit_doc +from pyspark.mllib.linalg import _convert_to_vector __all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder', 'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel', @@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has >>> sent = ("a b " * 100 + "a c " * 10).split(" ") >>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"]) >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc) + >>> model.getVectors().show() + +----+--------------------+ + |word| vector| + +----+--------------------+ + | a|[-0.3511952459812...| + | b|[0.29077222943305...| + | c|[0.02315592765808...| + +----+--------------------+ + ... + >>> model.findSynonyms("a", 2).show() + +----+-------------------+ + |word| similarity| + +----+-------------------+ + | b|0.29255685145799626| + | c|-0.5414068302988307| + +----+-------------------+ + ... >>> model.transform(doc).head().model DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ @@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel): Model fitted by Word2Vec. """ + def getVectors(self): + """ + Returns the vector representation of the words as a dataframe + with two fields, word and vector. + """ + return self._call_java("getVectors") + + def findSynonyms(self, word, num): + """ + Find "num" number of words closest in similarity to "word". + word can be a string or vector representation. + Returns a dataframe with two fields word and similarity (which + gives the cosine similarity). + """ + if not isinstance(word, basestring): + word = _convert_to_vector(word) + return self._call_java("findSynonyms", word, num) + @inherit_doc class PCA(JavaEstimator, HasInputCol, HasOutputCol): |