From 4d2864b2d7cac027cf6bb78fc3cac9bc37534b07 Mon Sep 17 00:00:00 2001 From: Kai Jiang Date: Wed, 24 Feb 2016 23:22:14 -0800 Subject: [SPARK-7106][MLLIB][PYSPARK] Support model save/load in Python's FPGrowth ## What changes were proposed in this pull request? Python API supports mode save/load in FPGrowth JIRA: [https://issues.apache.org/jira/browse/SPARK-7106](https://issues.apache.org/jira/browse/SPARK-7106) ## How was the this patch tested? The patch is tested with Python doctest. Author: Kai Jiang Closes #11321 from vectorijk/spark-7106. --- python/pyspark/mllib/fpm.py | 35 +++++++++++++++++++++++++++++++---- 1 file changed, 31 insertions(+), 4 deletions(-) (limited to 'python') diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py index 7a2d77a4da..5c9706cb8c 100644 --- a/python/pyspark/mllib/fpm.py +++ b/python/pyspark/mllib/fpm.py @@ -21,14 +21,15 @@ from collections import namedtuple from pyspark import SparkContext, since from pyspark.rdd import ignore_unicode_prefix -from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc +from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc +from pyspark.mllib.util import JavaSaveable, JavaLoader, inherit_doc __all__ = ['FPGrowth', 'FPGrowthModel', 'PrefixSpan', 'PrefixSpanModel'] @inherit_doc @ignore_unicode_prefix -class FPGrowthModel(JavaModelWrapper): +class FPGrowthModel(JavaModelWrapper, JavaSaveable, JavaLoader): """ .. note:: Experimental @@ -40,6 +41,11 @@ class FPGrowthModel(JavaModelWrapper): >>> model = FPGrowth.train(rdd, 0.6, 2) >>> sorted(model.freqItemsets().collect()) [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ... + >>> model_path = temp_path + "/fpm" + >>> model.save(sc, model_path) + >>> sameModel = FPGrowthModel.load(sc, model_path) + >>> sorted(model.freqItemsets().collect()) == sorted(sameModel.freqItemsets().collect()) + True .. versionadded:: 1.4.0 """ @@ -51,6 +57,16 @@ class FPGrowthModel(JavaModelWrapper): """ return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1]))) + @classmethod + @since("2.0.0") + def load(cls, sc, path): + """ + Load a model from the given path. + """ + model = cls._load_java(sc, path) + wrapper = sc._jvm.FPGrowthModelWrapper(model) + return FPGrowthModel(wrapper) + class FPGrowth(object): """ @@ -170,8 +186,19 @@ def _test(): import pyspark.mllib.fpm globs = pyspark.mllib.fpm.__dict__.copy() globs['sc'] = SparkContext('local[4]', 'PythonTest') - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - globs['sc'].stop() + import tempfile + + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) -- cgit v1.2.3