aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-08-06 10:09:58 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-06 10:09:58 -0700
commit076ec056818a65216eaf51aa5b3bd8f697c34748 (patch)
tree681bd1b6621e4de217d71d5b102edc6382a0dd5f /python
parentc5c6aded641048a3e66ac79d9e84d34e4b1abae7 (diff)
downloadspark-076ec056818a65216eaf51aa5b3bd8f697c34748.tar.gz
spark-076ec056818a65216eaf51aa5b3bd8f697c34748.tar.bz2
spark-076ec056818a65216eaf51aa5b3bd8f697c34748.zip
[SPARK-9533] [PYSPARK] [ML] Add missing methods in Word2Vec ML
After https://github.com/apache/spark/pull/7263 it is pretty straightforward to Python wrappers. Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #7930 from MechCoder/spark-9533 and squashes the following commits: 1bea394 [MechCoder] make getVectors a lazy val 5522756 [MechCoder] [SPARK-9533] [PySpark] [ML] Add missing methods in Word2Vec ML
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py40
1 files changed, 40 insertions, 0 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 3f04c41ac5..cb4dfa2129 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -15,11 +15,16 @@
# limitations under the License.
#
+import sys
+if sys.version > '3':
+ basestring = str
+
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer
from pyspark.mllib.common import inherit_doc
+from pyspark.mllib.linalg import _convert_to_vector
__all__ = ['Binarizer', 'HashingTF', 'IDF', 'IDFModel', 'NGram', 'Normalizer', 'OneHotEncoder',
'PolynomialExpansion', 'RegexTokenizer', 'StandardScaler', 'StandardScalerModel',
@@ -954,6 +959,23 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")
>>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
>>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc)
+ >>> model.getVectors().show()
+ +----+--------------------+
+ |word| vector|
+ +----+--------------------+
+ | a|[-0.3511952459812...|
+ | b|[0.29077222943305...|
+ | c|[0.02315592765808...|
+ +----+--------------------+
+ ...
+ >>> model.findSynonyms("a", 2).show()
+ +----+-------------------+
+ |word| similarity|
+ +----+-------------------+
+ | b|0.29255685145799626|
+ | c|-0.5414068302988307|
+ +----+-------------------+
+ ...
>>> model.transform(doc).head().model
DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276])
"""
@@ -1047,6 +1069,24 @@ class Word2VecModel(JavaModel):
Model fitted by Word2Vec.
"""
+ def getVectors(self):
+ """
+ Returns the vector representation of the words as a dataframe
+ with two fields, word and vector.
+ """
+ return self._call_java("getVectors")
+
+ def findSynonyms(self, word, num):
+ """
+ Find "num" number of words closest in similarity to "word".
+ word can be a string or vector representation.
+ Returns a dataframe with two fields word and similarity (which
+ gives the cosine similarity).
+ """
+ if not isinstance(word, basestring):
+ word = _convert_to_vector(word)
+ return self._call_java("findSynonyms", word, num)
+
@inherit_doc
class PCA(JavaEstimator, HasInputCol, HasOutputCol):