diff options
author | Holden Karau <holden@us.ibm.com> | 2016-02-20 09:07:19 +0000 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-02-20 09:07:19 +0000 |
commit | 9ca79c1ece5ad139719e4eea9f7d1b59aed01b20 (patch) | |
tree | 633e6e65cb71feaede920fc405e1da263807ad2d /python/pyspark/ml/regression.py | |
parent | dfb2ae2f141960c10200a870ed21583e6af5c536 (diff) | |
download | spark-9ca79c1ece5ad139719e4eea9f7d1b59aed01b20.tar.gz spark-9ca79c1ece5ad139719e4eea9f7d1b59aed01b20.tar.bz2 spark-9ca79c1ece5ad139719e4eea9f7d1b59aed01b20.zip |
[SPARK-13302][PYSPARK][TESTS] Move the temp file creation and cleanup outside of the doctests
Some of the new doctests in ml/clustering.py have a lot of setup code, move the setup code to the general test init to keep the doctest more example-style looking.
In part this is a follow up to https://github.com/apache/spark/pull/10999
Note that the same pattern is followed in regression & recommendation - might as well clean up all three at the same time.
Author: Holden Karau <holden@us.ibm.com>
Closes #11197 from holdenk/SPARK-13302-cleanup-doctests-in-ml-clustering.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r-- | python/pyspark/ml/regression.py | 25 |
1 files changed, 14 insertions, 11 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 20dc6c2db9..de4a751a54 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> lr_path = path + "/lr" + >>> lr_path = temp_path + "/lr" >>> lr.save(lr_path) >>> lr2 = LinearRegression.load(lr_path) >>> lr2.getMaxIter() 5 - >>> model_path = path + "/lr_model" + >>> model_path = temp_path + "/lr_model" >>> model.save(model_path) >>> model2 = LinearRegressionModel.load(model_path) >>> model.coefficients[0] == model2.coefficients[0] True >>> model.intercept == model2.intercept True - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.4.0 """ @@ -850,7 +843,17 @@ if __name__ == "__main__": sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) |