diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-05-10 00:57:14 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-10 00:57:14 -0700 |
commit | bf7e81a51cd81706570615cd67362c86602dec88 (patch) | |
tree | dc3d55d57d58606fe4c10f8bb2ec0be428ec24b6 /python/pyspark/mllib/evaluation.py | |
parent | b13162b364aeff35e3bdeea9c9a31e5ce66f8c9a (diff) | |
download | spark-bf7e81a51cd81706570615cd67362c86602dec88.tar.gz spark-bf7e81a51cd81706570615cd67362c86602dec88.tar.bz2 spark-bf7e81a51cd81706570615cd67362c86602dec88.zip |
[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib
https://issues.apache.org/jira/browse/SPARK-6091
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #6011 from yanboliang/spark-6091 and squashes the following commits:
bb3e4ba [Yanbo Liang] trigger jenkins
53c045d [Yanbo Liang] keep compatibility for python 2.6
972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
Diffstat (limited to 'python/pyspark/mllib/evaluation.py')
-rw-r--r-- | python/pyspark/mllib/evaluation.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index 3e11df09da..36914597de 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -141,6 +141,135 @@ class RegressionMetrics(JavaModelWrapper): return self.call("r2") +class MulticlassMetrics(JavaModelWrapper): + """ + Evaluator for multiclass classification. + + >>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0), + ... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)]) + >>> metrics = MulticlassMetrics(predictionAndLabels) + >>> metrics.falsePositiveRate(0.0) + 0.2... + >>> metrics.precision(1.0) + 0.75... + >>> metrics.recall(2.0) + 1.0... + >>> metrics.fMeasure(0.0, 2.0) + 0.52... + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.66... + >>> metrics.weightedFalsePositiveRate + 0.19... + >>> metrics.weightedPrecision + 0.68... + >>> metrics.weightedRecall + 0.66... + >>> metrics.weightedFMeasure() + 0.66... + >>> metrics.weightedFMeasure(2.0) + 0.65... + """ + + def __init__(self, predictionAndLabels): + """ + :param predictionAndLabels an RDD of (prediction, label) pairs. + """ + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([ + StructField("prediction", DoubleType(), nullable=False), + StructField("label", DoubleType(), nullable=False)])) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MulticlassMetrics + java_model = java_class(df._jdf) + super(MulticlassMetrics, self).__init__(java_model) + + def truePositiveRate(self, label): + """ + Returns true positive rate for a given label (category). + """ + return self.call("truePositiveRate", label) + + def falsePositiveRate(self, label): + """ + Returns false positive rate for a given label (category). + """ + return self.call("falsePositiveRate", label) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def fMeasure(self, label=None, beta=None): + """ + Returns f-measure or f-measure for a given label (category) if specified. + """ + if beta is None: + if label is None: + return self.call("fMeasure") + else: + return self.call("fMeasure", label) + else: + if label is None: + raise Exception("If the beta parameter is specified, label can not be none") + else: + return self.call("fMeasure", label, beta) + + @property + def weightedTruePositiveRate(self): + """ + Returns weighted true positive rate. + (equals to precision, recall and f-measure) + """ + return self.call("weightedTruePositiveRate") + + @property + def weightedFalsePositiveRate(self): + """ + Returns weighted false positive rate. + """ + return self.call("weightedFalsePositiveRate") + + @property + def weightedRecall(self): + """ + Returns weighted averaged recall. + (equals to precision, recall and f-measure) + """ + return self.call("weightedRecall") + + @property + def weightedPrecision(self): + """ + Returns weighted averaged precision. + """ + return self.call("weightedPrecision") + + def weightedFMeasure(self, beta=None): + """ + Returns weighted averaged f-measure. + """ + if beta is None: + return self.call("weightedFMeasure") + else: + return self.call("weightedFMeasure", beta) + + def _test(): import doctest from pyspark import SparkContext |