aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/pipeline.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-05 11:45:37 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-05 11:45:37 -0700
commitee374e89cd1f08730fed9d50b742627d5b19d241 (patch)
tree9912c353fe5e563bbf7ced6dc0e0c20f20272d5d /python/pyspark/ml/pipeline.py
parent18340d7be55a6834918956555bf820c96769aa52 (diff)
downloadspark-ee374e89cd1f08730fed9d50b742627d5b19d241.tar.gz
spark-ee374e89cd1f08730fed9d50b742627d5b19d241.tar.bz2
spark-ee374e89cd1f08730fed9d50b742627d5b19d241.zip
[SPARK-7333] [MLLIB] Add BinaryClassificationEvaluator to PySpark
This PR adds `BinaryClassificationEvaluator` to Python ML Pipelines API, which is a simple wrapper of the Scala implementation. oefirouz Author: Xiangrui Meng <meng@databricks.com> Closes #5885 from mengxr/SPARK-7333 and squashes the following commits: 25d7451 [Xiangrui Meng] fix tests in python 3 babdde7 [Xiangrui Meng] fix doc cb51e6a [Xiangrui Meng] add BinaryClassificationEvaluator in PySpark
Diffstat (limited to 'python/pyspark/ml/pipeline.py')
-rw-r--r--python/pyspark/ml/pipeline.py23
1 files changed, 22 insertions, 1 deletions
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 7c1ec3026d..7b875e4b71 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -22,7 +22,7 @@ from pyspark.ml.util import keyword_only
from pyspark.mllib.common import inherit_doc
-__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel']
+__all__ = ['Estimator', 'Transformer', 'Pipeline', 'PipelineModel', 'Evaluator']
@inherit_doc
@@ -168,3 +168,24 @@ class PipelineModel(Transformer):
for t in self.transformers:
dataset = t.transform(dataset, paramMap)
return dataset
+
+
+class Evaluator(object):
+ """
+ Base class for evaluators that compute metrics from predictions.
+ """
+
+ __metaclass__ = ABCMeta
+
+ @abstractmethod
+ def evaluate(self, dataset, params={}):
+ """
+ Evaluates the output.
+
+ :param dataset: a dataset that contains labels/observations and
+ predictions
+ :param params: an optional param map that overrides embedded
+ params
+ :return: metric
+ """
+ raise NotImplementedError()