aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 7c1ec3026d..7b875e4b71 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc
-__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
@inherit_doc
@@ -168,3 +168,24 @@ class PipelineModel(Transformer):
for t in self.transformers:
dataset = t.transform(dataset, paramMap)
return dataset
+
+
+class Evaluator(object):
+ """
+ Base class for evaluators that compute metrics from predictions.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def evaluate(self, dataset, params={}):
+ """
+ Evaluates the output.
+
+ :param dataset: a dataset that contains labels/observations and
+ predictions
+ :param params: an optional param map that overrides embedded
+ params
+ :return: metric
+ """
+ raise NotImplementedError()