diff options
author | Ram Sriharsha <rsriharsha@hw11853.local> | 2015-07-30 23:02:11 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-07-30 23:02:11 -0700 |
commit | 4e5919bfb47a58bcbda90ae01c1bed2128ded983 (patch) | |
tree | ed723467f08be003bd57a97c6909ee48d1b8029c /python | |
parent | 83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0 (diff) | |
download | spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.tar.gz spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.tar.bz2 spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.zip |
[SPARK-7690] [ML] Multiclass classification Evaluator
Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics.
Author: Ram Sriharsha <rsriharsha@hw11853.local>
Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits:
9bf4ec7 [Ram Sriharsha] fix indentation
3f09a85 [Ram Sriharsha] cleanup doc
16115ae [Ram Sriharsha] code review fixes
032d2a3 [Ram Sriharsha] fix test
eec9865 [Ram Sriharsha] Fix Python Indentation
1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690
68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690
54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/ml/evaluation.py | 66 |
1 files changed, 66 insertions, 0 deletions
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 595593a7f2..06e8093522 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -214,6 +214,72 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + +@inherit_doc +class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): + """ + Evaluator for Multiclass Classification, which expects two input + columns: prediction and label. + >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)] + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"]) + ... + >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction") + >>> evaluator.evaluate(dataset) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"}) + 0.66... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"}) + 0.66... + """ + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation " + "(f1|precision|recall|weightedPrecision|weightedRecall)") + + @keyword_only + def __init__(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + __init__(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + """ + super(MulticlassClassificationEvaluator, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid) + # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall) + self.metricName = Param(self, "metricName", + "metric name in evaluation" + " (f1|precision|recall|weightedPrecision|weightedRecall)") + self._setDefault(predictionCol="prediction", labelCol="label", + metricName="f1") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self._paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, predictionCol="prediction", labelCol="label", + metricName="f1"): + """ + setParams(self, predictionCol="prediction", labelCol="label", \ + metricName="f1") + Sets params for multiclass classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + if __name__ == "__main__": import doctest from pyspark.context import SparkContext |