aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/clustering.py')
-rw-r--r--python/pyspark/ml/clustering.py25
1 files changed, 14 insertions, 11 deletions
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 91278d570a..611b919049 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -70,25 +70,18 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol
True
>>> rows[2].prediction == rows[3].prediction
True
- >>> import os, tempfile
- >>> path = tempfile.mkdtemp()
- >>> kmeans_path = path + "/kmeans"
+ >>> kmeans_path = temp_path + "/kmeans"
>>> kmeans.save(kmeans_path)
>>> kmeans2 = KMeans.load(kmeans_path)
>>> kmeans2.getK()
2
- >>> model_path = path + "/kmeans_model"
+ >>> model_path = temp_path + "/kmeans_model"
>>> model.save(model_path)
>>> model2 = KMeansModel.load(model_path)
>>> model.clusterCenters()[0] == model2.clusterCenters()[0]
array([ True, True], dtype=bool)
>>> model.clusterCenters()[1] == model2.clusterCenters()[1]
array([ True, True], dtype=bool)
- >>> from shutil import rmtree
- >>> try:
- ... rmtree(path)
- ... except OSError:
- ... pass
.. versionadded:: 1.5.0
"""
@@ -310,7 +303,17 @@ if __name__ == "__main__":
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ import tempfile
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)