diff options
author | Yu ISHIKAWA <yuu.ishikawa@gmail.com> | 2015-07-02 15:55:16 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-02 15:55:16 -0700 |
commit | 488bad319a70975733e83c83490240a70beb0c90 (patch) | |
tree | 52eaf67312340ffbe8ce20b9589c4d82dd268e15 /python | |
parent | fc7aebd94a3c09657fc4dbded0997ed068304e0a (diff) | |
download | spark-488bad319a70975733e83c83490240a70beb0c90.tar.gz spark-488bad319a70975733e83c83490240a70beb0c90.tar.bz2 spark-488bad319a70975733e83c83490240a70beb0c90.zip |
[SPARK-7104] [MLLIB] Support model save/load in Python's Word2Vec
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>
Closes #6821 from yu-iskw/SPARK-7104 and squashes the following commits:
975136b [Yu ISHIKAWA] Organize import
0ef58b6 [Yu ISHIKAWA] Use rmtree, instead of removedirs
cb21653 [Yu ISHIKAWA] Add an explicit type for `Word2VecModelWrapper.save`
1d468ef [Yu ISHIKAWA] [SPARK-7104][MLlib] Support model save/load in Python's Word2Vec
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/feature.py | 21 |
1 files changed, 20 insertions, 1 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py index b5138773fd..f921e3ad1a 100644 --- a/python/pyspark/mllib/feature.py +++ b/python/pyspark/mllib/feature.py @@ -36,6 +36,7 @@ from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper from pyspark.mllib.linalg import ( Vector, Vectors, DenseVector, SparseVector, _convert_to_vector) from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.util import JavaLoader, JavaSaveable __all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler', 'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel', @@ -416,7 +417,7 @@ class IDF(object): return IDFModel(jmodel) -class Word2VecModel(JavaVectorTransformer): +class Word2VecModel(JavaVectorTransformer, JavaSaveable, JavaLoader): """ class for Word2Vec model """ @@ -455,6 +456,12 @@ class Word2VecModel(JavaVectorTransformer): """ return self.call("getVectors") + @classmethod + def load(cls, sc, path): + jmodel = sc._jvm.org.apache.spark.mllib.feature \ + .Word2VecModel.load(sc._jsc.sc(), path) + return Word2VecModel(jmodel) + @ignore_unicode_prefix class Word2Vec(object): @@ -488,6 +495,18 @@ class Word2Vec(object): >>> syms = model.findSynonyms(vec, 2) >>> [s[0] for s in syms] [u'b', u'c'] + + >>> import os, tempfile + >>> path = tempfile.mkdtemp() + >>> model.save(sc, path) + >>> sameModel = Word2VecModel.load(sc, path) + >>> model.transform("a") == sameModel.transform("a") + True + >>> from shutil import rmtree + >>> try: + ... rmtree(path) + ... except OSError: + ... pass """ def __init__(self): """ |