aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-02-20 09:07:19 +0000
committerSean Owen <sowen@cloudera.com>2016-02-20 09:07:19 +0000
commit9ca79c1ece5ad139719e4eea9f7d1b59aed01b20 (patch)
tree633e6e65cb71feaede920fc405e1da263807ad2d
parentdfb2ae2f141960c10200a870ed21583e6af5c536 (diff)
downloadspark-9ca79c1ece5ad139719e4eea9f7d1b59aed01b20.tar.gz
spark-9ca79c1ece5ad139719e4eea9f7d1b59aed01b20.tar.bz2
spark-9ca79c1ece5ad139719e4eea9f7d1b59aed01b20.zip
[SPARK-13302][PYSPARK][TESTS] Move the temp file creation and cleanup outside of the doctests
Some of the new doctests in ml/clustering.py have a lot of setup code, move the setup code to the general test init to keep the doctest more example-style looking. In part this is a follow up to https://github.com/apache/spark/pull/10999 Note that the same pattern is followed in regression & recommendation - might as well clean up all three at the same time. Author: Holden Karau <holden@us.ibm.com> Closes #11197 from holdenk/SPARK-13302-cleanup-doctests-in-ml-clustering.
-rw-r--r--python/pyspark/ml/clustering.py25
-rw-r--r--python/pyspark/ml/recommendation.py25
-rw-r--r--python/pyspark/ml/regression.py25
3 files changed, 42 insertions, 33 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 91278d570a..611b919049 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
- >>> import os, tempfile
- >>> path = tempfile.mkdtemp()
- >>> kmeans_path = path + "/kmeans"
+ >>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
>>> kmeans2.getK()
2
- >>> model_path = path + "/kmeans_model"
+ >>> model_path = temp_path + "/kmeans_model"
>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
array([ True, True], dtype=bool)
- >>> from shutil import rmtree
- >>> try:
- ... rmtree(path)
- ... except OSError:
- ... pass
.. versionadded:: 1.5.0
"""
@@ -310,7 +303,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ import tempfile
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index ef9448855e..2b605e5c50 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -82,14 +82,12 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
Row(user=1, item=0, prediction=2.6258413791656494)
>>> predictions[2]
Row(user=2, item=0, prediction=-1.5018409490585327)
- >>> import os, tempfile
- >>> path = tempfile.mkdtemp()
- >>> als_path = path + "/als"
+ >>> als_path = temp_path + "/als"
>>> als.save(als_path)
>>> als2 = ALS.load(als_path)
>>> als.getMaxIter()
5
- >>> model_path = path + "/als_model"
+ >>> model_path = temp_path + "/als_model"
>>> model.save(model_path)
>>> model2 = ALSModel.load(model_path)
>>> model.rank == model2.rank
@@ -98,11 +96,6 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
True
>>> sorted(model.itemFactors.collect()) == sorted(model2.itemFactors.collect())
True
- >>> from shutil import rmtree
- >>> try:
- ... rmtree(path)
- ... except OSError:
- ... pass
.. versionadded:: 1.4.0
"""
@@ -340,7 +333,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ import tempfile
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 20dc6c2db9..de4a751a54 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -68,25 +68,18 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
- >>> import os, tempfile
- >>> path = tempfile.mkdtemp()
- >>> lr_path = path + "/lr"
+ >>> lr_path = temp_path + "/lr"
>>> lr.save(lr_path)
>>> lr2 = LinearRegression.load(lr_path)
>>> lr2.getMaxIter()
5
- >>> model_path = path + "/lr_model"
+ >>> model_path = temp_path + "/lr_model"
>>> model.save(model_path)
>>> model2 = LinearRegressionModel.load(model_path)
>>> model.coefficients[0] == model2.coefficients[0]
True
>>> model.intercept == model2.intercept
True
- >>> from shutil import rmtree
- >>> try:
- ... rmtree(path)
- ... except OSError:
- ... pass
.. versionadded:: 1.4.0
"""
@@ -850,7 +843,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ import tempfile
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)