aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/evaluation.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py
index 595593a7f2..06e8093522 100644
--- a/python/pyspark/ml/evaluation.py
+++ b/python/pyspark/ml/evaluation.py
@@ -214,6 +214,72 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
kwargs = self.setParams._input_kwargs
return self._set(**kwargs)
+
+@inherit_doc
+class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol):
+ """
+ Evaluator for Multiclass Classification, which expects two input
+ columns: prediction and label.
+ >>> scoreAndLabels = [(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
+ ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]
+ >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["prediction", "label"])
+ ...
+ >>> evaluator = MulticlassClassificationEvaluator(predictionCol="prediction")
+ >>> evaluator.evaluate(dataset)
+ 0.66...
+ >>> evaluator.evaluate(dataset, {evaluator.metricName: "precision"})
+ 0.66...
+ >>> evaluator.evaluate(dataset, {evaluator.metricName: "recall"})
+ 0.66...
+ """
+ # a placeholder to make it appear in the generated doc
+ metricName = Param(Params._dummy(), "metricName",
+ "metric name in evaluation "
+ "(f1|precision|recall|weightedPrecision|weightedRecall)")
+
+ @keyword_only
+ def __init__(self, predictionCol="prediction", labelCol="label",
+ metricName="f1"):
+ """
+ __init__(self, predictionCol="prediction", labelCol="label", \
+ metricName="f1")
+ """
+ super(MulticlassClassificationEvaluator, self).__init__()
+ self._java_obj = self._new_java_obj(
+ "org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator", self.uid)
+ # param for metric name in evaluation (f1|precision|recall|weightedPrecision|weightedRecall)
+ self.metricName = Param(self, "metricName",
+ "metric name in evaluation"
+ " (f1|precision|recall|weightedPrecision|weightedRecall)")
+ self._setDefault(predictionCol="prediction", labelCol="label",
+ metricName="f1")
+ kwargs = self.__init__._input_kwargs
+ self._set(**kwargs)
+
+ def setMetricName(self, value):
+ """
+ Sets the value of :py:attr:`metricName`.
+ """
+ self._paramMap[self.metricName] = value
+ return self
+
+ def getMetricName(self):
+ """
+ Gets the value of metricName or its default value.
+ """
+ return self.getOrDefault(self.metricName)
+
+ @keyword_only
+ def setParams(self, predictionCol="prediction", labelCol="label",
+ metricName="f1"):
+ """
+ setParams(self, predictionCol="prediction", labelCol="label", \
+ metricName="f1")
+ Sets params for multiclass classification evaluator.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext