aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorKai Jiang <jiangkai@gmail.com>2016-02-24 23:22:14 -0800
committerXiangrui Meng <meng@databricks.com>2016-02-24 23:22:14 -0800
commit4d2864b2d7cac027cf6bb78fc3cac9bc37534b07 (patch)
tree2780785bc19ac493fdbaf719bf7283b974aae38b /python
parent13ce10e95401b21fa40ca0bb27ebf9a0bfffe70f (diff)
downloadspark-4d2864b2d7cac027cf6bb78fc3cac9bc37534b07.tar.gz
spark-4d2864b2d7cac027cf6bb78fc3cac9bc37534b07.tar.bz2
spark-4d2864b2d7cac027cf6bb78fc3cac9bc37534b07.zip
[SPARK-7106][MLLIB][PYSPARK] Support model save/load in Python's FPGrowth
## What changes were proposed in this pull request? Python API supports mode save/load in FPGrowth JIRA: [https://issues.apache.org/jira/browse/SPARK-7106](https://issues.apache.org/jira/browse/SPARK-7106) ## How was the this patch tested? The patch is tested with Python doctest. Author: Kai Jiang <jiangkai@gmail.com> Closes #11321 from vectorijk/spark-7106.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/fpm.py35
1 files changed, 31 insertions, 4 deletions
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index 7a2d77a4da..5c9706cb8c 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -21,14 +21,15 @@ from collections import namedtuple
from pyspark import SparkContext, since
from pyspark.rdd import ignore_unicode_prefix
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
+from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc
__all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel']
@inherit_doc
@ignore_unicode_prefix
-class FPGrowthModel(JavaModelWrapper):
+class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader):
"""
.. note:: Experimental
@@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper):
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect())
[FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
+ >>> model_path = temp_path + "/fpm"
+ >>> model.save(sc, model_path)
+ >>> sameModel = FPGrowthModel.load(sc, model_path)
+ >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect())
+ True
.. versionadded:: 1.4.0
"""
@@ -51,6 +57,16 @@ class FPGrowthModel(JavaModelWrapper):
"""
return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
+ @classmethod
+ @since("2.0.0")
+ def load(cls, sc, path):
+ """
+ Load a model from the given path.
+ """
+ model = cls._load_java(sc, path)
+ wrapper = sc._jvm.FPGrowthModelWrapper(model)
+ return FPGrowthModel(wrapper)
+
class FPGrowth(object):
"""
@@ -170,8 +186,19 @@ def _test():
import pyspark.mllib.fpm
globs = pyspark.mllib.fpm.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- globs['sc'].stop()
+ import tempfile
+
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)