aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorRam Sriharsha <rsriharsha@hw11853.local>2015-07-30 23:02:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-30 23:02:11 -0700
commit4e5919bfb47a58bcbda90ae01c1bed2128ded983 (patch)
treeed723467f08be003bd57a97c6909ee48d1b8029c /python
parent83670fc9e6fc9c7a6ae68dfdd3f9335ea72f4ab0 (diff)
downloadspark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.tar.gz
spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.tar.bz2
spark-4e5919bfb47a58bcbda90ae01c1bed2128ded983.zip
[SPARK-7690] [ML] Multiclass classification Evaluator
Multiclass Classification Evaluator for ML Pipelines. F1 score, precision, recall, weighted precision and weighted recall are supported as available metrics. Author: Ram Sriharsha <rsriharsha@hw11853.local> Closes #7475 from harsha2010/SPARK-7690 and squashes the following commits: 9bf4ec7 [Ram Sriharsha] fix indentation 3f09a85 [Ram Sriharsha] cleanup doc 16115ae [Ram Sriharsha] code review fixes 032d2a3 [Ram Sriharsha] fix test eec9865 [Ram Sriharsha] Fix Python Indentation 1dbeffd [Ram Sriharsha] Merge branch 'master' into SPARK-7690 68cea85 [Ram Sriharsha] Merge branch 'master' into SPARK-7690 54c03de [Ram Sriharsha] [SPARK-7690][ml][WIP] Multiclass Evaluator for ML Pipeline
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/evaluation.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 595593a7f2..06e8093522 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -214,6 +214,72 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
+
+@inherit_doc
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+ """
+ Evaluator for Multiclass Classification, which expects two input
+ columns: prediction and label.
+ >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
+ ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]
+ >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"])
+ ...
+ >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
+ >>> evaluator.evaluate(dataset)
+ 0.66...
+ >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"})
+ 0.66...
+ >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"})
+ 0.66...
+ """
+ # a placeholder to make it appear in the generated doc
+ metricName = Param(Params._dummy(), "metricName",
+ "metric name in evaluation "
+ "(f1|precision|recall|weightedPrecision|weightedRecall)")
+
+ @keyword_only
+ def __init__(self, predictionCol="prediction", labelCol="label",
+ metricName="f1"):
+ """
+ __init__(self, predictionCol="prediction", labelCol="label", \
+ metricName="f1")
+ """
+ super(MulticlassClassificationEvaluator, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
+ # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall)
+ self.metricName = Param(self, "metricName",
+ "metric name in evaluation"
+ " (f1|precision|recall|weightedPrecision|weightedRecall)")
+ self._setDefault(predictionCol="prediction", labelCol="label",
+ metricName="f1")
+ kwargs = self.__init__._input_kwargs
+ self._set(**kwargs)
+
+ def setMetricName(self, value):
+ """
+ Sets the value of :py:attr:`metricName`.
+ """
+ self._paramMap[self.metricName] = value
+ return self
+
+ def getMetricName(self):
+ """
+ Gets the value of metricName or its default value.
+ """
+ return self.getOrDefault(self.metricName)
+
+ @keyword_only
+ def setParams(self, predictionCol="prediction", labelCol="label",
+ metricName="f1"):
+ """
+ setParams(self, predictionCol="prediction", labelCol="label", \
+ metricName="f1")
+ Sets params for multiclass classification evaluator.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext