diff options
-rw-r--r-- | python/pyspark/ml/clustering.py | 25 | ||||
-rw-r--r-- | python/pyspark/ml/recommendation.py | 25 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 25 |
3 files changed, 42 insertions, 33 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 91278d570a..611b919049 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol True >>> rows[2].prediction == rows[3].prediction True - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> kmeans_path = path + "/kmeans" + >>> kmeans_path = temp_path + "/kmeans" >>> kmeans.save(kmeans_path) >>> kmeans2 = KMeans.load(kmeans_path) >>> kmeans2.getK() 2 - >>> model_path = path + "/kmeans_model" + >>> model_path = temp_path + "/kmeans_model" >>> model.save(model_path) >>> model2 = KMeansModel.load(model_path) >>> model.clusterCenters()[0] == model2.clusterCenters()[0] array([ True, True], dtype=bool) >>> model.clusterCenters()[1] == model2.clusterCenters()[1] array([ True, True], dtype=bool) - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.5.0 """ @@ -310,7 +303,17 @@ if __name__ == "__main__": sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index ef9448855e..2b605e5c50 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -82,14 +82,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha Row(user=1, item=0, prediction=2.6258413791656494) >>> predictions[2] Row(user=2, item=0, prediction=-1.5018409490585327) - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> als_path = path + "/als" + >>> als_path = temp_path + "/als" >>> als.save(als_path) >>> als2 = ALS.load(als_path) >>> als.getMaxIter() 5 - >>> model_path = path + "/als_model" + >>> model_path = temp_path + "/als_model" >>> model.save(model_path) >>> model2 = ALSModel.load(model_path) >>> model.rank == model2.rank @@ -98,11 +96,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha True >>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect()) True - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.4.0 """ @@ -340,7 +333,17 @@ if __name__ == "__main__": sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 20dc6c2db9..de4a751a54 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction Traceback (most recent call last): ... TypeError: Method setParams forces keyword arguments. - >>> import os, tempfile - >>> path = tempfile.mkdtemp() - >>> lr_path = path + "/lr" + >>> lr_path = temp_path + "/lr" >>> lr.save(lr_path) >>> lr2 = LinearRegression.load(lr_path) >>> lr2.getMaxIter() 5 - >>> model_path = path + "/lr_model" + >>> model_path = temp_path + "/lr_model" >>> model.save(model_path) >>> model2 = LinearRegressionModel.load(model_path) >>> model.coefficients[0] == model2.coefficients[0] True >>> model.intercept == model2.intercept True - >>> from shutil import rmtree - >>> try: - ... rmtree(path) - ... except OSError: - ... pass .. versionadded:: 1.4.0 """ @@ -850,7 +843,17 @@ if __name__ == "__main__": sqlContext = SQLContext(sc) globs['sc'] = sc globs['sqlContext'] = sqlContext - (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) - sc.stop() + import tempfile + temp_path = tempfile.mkdtemp() + globs['temp_path'] = temp_path + try: + (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + finally: + from shutil import rmtree + try: + rmtree(temp_path) + except OSError: + pass if failure_count: exit(-1) |