aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2015-07-02 15:55:16 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-02 15:55:16 -0700
commit488bad319a70975733e83c83490240a70beb0c90 (patch)
tree52eaf67312340ffbe8ce20b9589c4d82dd268e15 /python
parentfc7aebd94a3c09657fc4dbded0997ed068304e0a (diff)
downloadspark-488bad319a70975733e83c83490240a70beb0c90.tar.gz
spark-488bad319a70975733e83c83490240a70beb0c90.tar.bz2
spark-488bad319a70975733e83c83490240a70beb0c90.zip
[SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #6821 from yu-iskw/SPARK-7104 and squashes the following commits: 975136b [Yu ISHIKAWA] Organize import 0ef58b6 [Yu ISHIKAWA] Use rmtree, instead of removedirs cb21653 [Yu ISHIKAWA] Add an explicit type for `Word2VecModelWrapper.save` 1d468ef [Yu ISHIKAWA] [SPARK-7104][MLlib] Support model save/load in Python's Word2Vec
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/feature.py21
1 files changed, 20 insertions, 1 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index b5138773fd..f921e3ad1a 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -36,6 +36,7 @@ from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import (
Vector, Vectors, DenseVector, SparseVector, _convert_to_vector)
from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.util import JavaLoader, JavaSaveable
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel',
@@ -416,7 +417,7 @@ class IDF(object):
return IDFModel(jmodel)
-class Word2VecModel(JavaVectorTransformer):
+class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader):
"""
class for Word2Vec model
"""
@@ -455,6 +456,12 @@ class Word2VecModel(JavaVectorTransformer):
"""
return self.call("getVectors")
+ @classmethod
+ def load(cls, sc, path):
+ jmodel = sc._jvm.org.apache.spark.mllib.feature \
+ .Word2VecModel.load(sc._jsc.sc(), path)
+ return Word2VecModel(jmodel)
+
@ignore_unicode_prefix
class Word2Vec(object):
@@ -488,6 +495,18 @@ class Word2Vec(object):
>>> syms = model.findSynonyms(vec, 2)
>>> [s[0] for s in syms]
[u'b', u'c']
+
+ >>> import os, tempfile
+ >>> path = tempfile.mkdtemp()
+ >>> model.save(sc, path)
+ >>> sameModel = Word2VecModel.load(sc, path)
+ >>> model.transform("a") == sameModel.transform("a")
+ True
+ >>> from shutil import rmtree
+ >>> try:
+ ... rmtree(path)
+ ... except OSError:
+ ... pass
"""
def __init__(self):
"""