aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-10-14 04:17:03 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-10-14 04:17:03 -0700
commit1db8feab8c564053c05e8bdc1a7f5026fd637d4f (patch)
tree639dc612dbe8fa518e62c81e6d6da2317159e88a /python
parent2fb12b0a33deeeadfac451095f64dea6c967caac (diff)
downloadspark-1db8feab8c564053c05e8bdc1a7f5026fd637d4f.tar.gz
spark-1db8feab8c564053c05e8bdc1a7f5026fd637d4f.tar.bz2
spark-1db8feab8c564053c05e8bdc1a7f5026fd637d4f.zip
[SPARK-15402][ML][PYSPARK] PySpark ml.evaluation should support save/load
## What changes were proposed in this pull request? Since ```ml.evaluation``` has supported save/load at Scala side, supporting it at Python side is very straightforward and easy. ## How was this patch tested? Add python doctest. Author: Yanbo Liang <ybliang8@gmail.com> Closes #13194 from yanboliang/spark-15402.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/evaluation.py45
1 files changed, 36 insertions, 9 deletions
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 1fe8772da7..7aa16fa5b9 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -22,6 +22,7 @@ from pyspark.ml.wrapper import JavaParams
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol
from pyspark.ml.common import inherit_doc
+from pyspark.ml.util import JavaMLReadable, JavaMLWritable
__all__ = ['Evaluator', 'BinaryClassificationEvaluator', 'RegressionEvaluator',
'MulticlassClassificationEvaluator']
@@ -103,7 +104,8 @@ class JavaEvaluator(JavaParams, Evaluator):
@inherit_doc
-class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol):
+class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol,
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -121,6 +123,11 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
0.70...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"})
0.83...
+ >>> bce_path = temp_path + "/bce"
+ >>> evaluator.save(bce_path)
+ >>> evaluator2 = BinaryClassificationEvaluator.load(bce_path)
+ >>> str(evaluator2.getRawPredictionCol())
+ 'raw'
.. versionadded:: 1.4.0
"""
@@ -172,7 +179,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction
@inherit_doc
-class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -190,6 +198,11 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
0.993...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "mae"})
2.649...
+ >>> re_path = temp_path + "/re"
+ >>> evaluator.save(re_path)
+ >>> evaluator2 = RegressionEvaluator.load(re_path)
+ >>> str(evaluator2.getPredictionCol())
+ 'raw'
.. versionadded:: 1.4.0
"""
@@ -244,7 +257,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
@inherit_doc
-class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol,
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -260,6 +274,11 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
0.66...
>>> evaluator.evaluate(dataset, {evaluator.metricName: "accuracy"})
0.66...
+ >>> mce_path = temp_path + "/mce"
+ >>> evaluator.save(mce_path)
+ >>> evaluator2 = MulticlassClassificationEvaluator.load(mce_path)
+ >>> str(evaluator2.getPredictionCol())
+ 'prediction'
.. versionadded:: 1.5.0
"""
@@ -311,19 +330,27 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio
if __name__ == "__main__":
import doctest
+ import tempfile
+ import pyspark.ml.evaluation
from pyspark.sql import SparkSession
- globs = globals().copy()
+ globs = pyspark.ml.evaluation.__dict__.copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
spark = SparkSession.builder\
.master("local[2]")\
.appName("ml.evaluation tests")\
.getOrCreate()
- sc = spark.sparkContext
- globs['sc'] = sc
globs['spark'] = spark
- (failure_count, test_count) = doctest.testmod(
- globs=globs, optionflags=doctest.ELLIPSIS)
- spark.stop()
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ spark.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)