From bf7e81a51cd81706570615cd67362c86602dec88 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Sun, 10 May 2015 00:57:14 -0700 Subject: [SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib https://issues.apache.org/jira/browse/SPARK-6091 Author: Yanbo Liang Closes #6011 from yanboliang/spark-6091 and squashes the following commits: bb3e4ba [Yanbo Liang] trigger jenkins 53c045d [Yanbo Liang] keep compatibility for python 2.6 972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib --- python/pyspark/mllib/evaluation.py | 129 +++++++++++++++++++++++++++++++++++++ 1 file changed, 129 insertions(+) (limited to 'python') diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 3e11df09da..36914597de 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -141,6 +141,135 @@ class RegressionMetrics(JavaModelWrapper): return self.call("r2") +class MulticlassMetrics(JavaModelWrapper): + """ + Evaluator for multiclass classification. + + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) + >>> metrics = MulticlassMetrics(predictionAndLabels) + >>> metrics.falsePositiveRate(0.0) + 0.2... + >>> metrics.precision(1.0) + 0.75... + >>> metrics.recall(2.0) + 1.0... + >>> metrics.fMeasure(0.0, 2.0) + 0.52... + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.66... + >>> metrics.weightedFalsePositiveRate + 0.19... + >>> metrics.weightedPrecision + 0.68... + >>> metrics.weightedRecall + 0.66... + >>> metrics.weightedFMeasure() + 0.66... + >>> metrics.weightedFMeasure(2.0) + 0.65... + """ + + def __init__(self, predictionAndLabels): + """ + :param predictionAndLabels an RDD of (prediction, label) pairs. + """ + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ + StructField("prediction", DoubleType(), nullable=False), + StructField("label", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics + java_model = java_class(df._jdf) + super(MulticlassMetrics, self).__init__(java_model) + + def truePositiveRate(self, label): + """ + Returns true positive rate for a given label (category). + """ + return self.call("truePositiveRate", label) + + def falsePositiveRate(self, label): + """ + Returns false positive rate for a given label (category). + """ + return self.call("falsePositiveRate", label) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def fMeasure(self, label=None, beta=None): + """ + Returns f-measure or f-measure for a given label (category) if specified. + """ + if beta is None: + if label is None: + return self.call("fMeasure") + else: + return self.call("fMeasure", label) + else: + if label is None: + raise Exception("If the beta parameter is specified, label can not be none") + else: + return self.call("fMeasure", label, beta) + + @property + def weightedTruePositiveRate(self): + """ + Returns weighted true positive rate. + (equals to precision, recall and f-measure) + """ + return self.call("weightedTruePositiveRate") + + @property + def weightedFalsePositiveRate(self): + """ + Returns weighted false positive rate. + """ + return self.call("weightedFalsePositiveRate") + + @property + def weightedRecall(self): + """ + Returns weighted averaged recall. + (equals to precision, recall and f-measure) + """ + return self.call("weightedRecall") + + @property + def weightedPrecision(self): + """ + Returns weighted averaged precision. + """ + return self.call("weightedPrecision") + + def weightedFMeasure(self, beta=None): + """ + Returns weighted averaged f-measure. + """ + if beta is None: + return self.call("weightedFMeasure") + else: + return self.call("weightedFMeasure", beta) + + def _test(): import doctest from pyspark import SparkContext -- cgit v1.2.3