diff options
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r-- | python/pyspark/ml/pipeline.py | 23 |
1 files changed, 22 insertions, 1 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 7c1ec3026d..7b875e4b71 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel'] +__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator'] @inherit_doc @@ -168,3 +168,24 @@ class PipelineModel(Transformer): for t in self.transformers: dataset = t.transform(dataset, paramMap) return dataset + + +class Evaluator(object): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def evaluate(self, dataset, params={}): + """ + Evaluates the output. + + :param dataset: a dataset that contains labels/observations and + predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + raise NotImplementedError() |