aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-20 07:55:51 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-20 07:55:51 -0700
commit98a46f9dffec294386f6c39acafa7f11adb87a8f (patch)
treea53b1011e52f440e97780de1bf1adc43624cc276 /python
parent589b12f8e62ec5d10713ce057756ebc791e7ddc6 (diff)
downloadspark-98a46f9dffec294386f6c39acafa7f11adb87a8f.tar.gz
spark-98a46f9dffec294386f6c39acafa7f11adb87a8f.tar.bz2
spark-98a46f9dffec294386f6c39acafa7f11adb87a8f.zip
[SPARK-6094] [MLLIB] Add MultilabelMetrics in PySpark/MLlib
Add MultilabelMetrics in PySpark/MLlib Author: Yanbo Liang <ybliang8@gmail.com> Closes #6276 from yanboliang/spark-6094 and squashes the following commits: b8e3343 [Yanbo Liang] Add MultilabelMetrics in PySpark/MLlib
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/evaluation.py117
1 files changed, 117 insertions, 0 deletions
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index a5e5ddc8fe..aab5e5f4b7 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -343,6 +343,123 @@ class RankingMetrics(JavaModelWrapper):
return self.call("ndcgAt", int(k))
+class MultilabelMetrics(JavaModelWrapper):
+ """
+ Evaluator for multilabel classification.
+
+ >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
+ ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
+ ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])
+ >>> metrics = MultilabelMetrics(predictionAndLabels)
+ >>> metrics.precision(0.0)
+ 1.0
+ >>> metrics.recall(1.0)
+ 0.66...
+ >>> metrics.f1Measure(2.0)
+ 0.5
+ >>> metrics.precision()
+ 0.66...
+ >>> metrics.recall()
+ 0.64...
+ >>> metrics.f1Measure()
+ 0.63...
+ >>> metrics.microPrecision
+ 0.72...
+ >>> metrics.microRecall
+ 0.66...
+ >>> metrics.microF1Measure
+ 0.69...
+ >>> metrics.hammingLoss
+ 0.33...
+ >>> metrics.subsetAccuracy
+ 0.28...
+ >>> metrics.accuracy
+ 0.54...
+ """
+
+ def __init__(self, predictionAndLabels):
+ sc = predictionAndLabels.ctx
+ sql_ctx = SQLContext(sc)
+ df = sql_ctx.createDataFrame(predictionAndLabels,
+ schema=sql_ctx._inferSchema(predictionAndLabels))
+ java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics
+ java_model = java_class(df._jdf)
+ super(MultilabelMetrics, self).__init__(java_model)
+
+ def precision(self, label=None):
+ """
+ Returns precision or precision for a given label (category) if specified.
+ """
+ if label is None:
+ return self.call("precision")
+ else:
+ return self.call("precision", float(label))
+
+ def recall(self, label=None):
+ """
+ Returns recall or recall for a given label (category) if specified.
+ """
+ if label is None:
+ return self.call("recall")
+ else:
+ return self.call("recall", float(label))
+
+ def f1Measure(self, label=None):
+ """
+ Returns f1Measure or f1Measure for a given label (category) if specified.
+ """
+ if label is None:
+ return self.call("f1Measure")
+ else:
+ return self.call("f1Measure", float(label))
+
+ @property
+ def microPrecision(self):
+ """
+ Returns micro-averaged label-based precision.
+ (equals to micro-averaged document-based precision)
+ """
+ return self.call("microPrecision")
+
+ @property
+ def microRecall(self):
+ """
+ Returns micro-averaged label-based recall.
+ (equals to micro-averaged document-based recall)
+ """
+ return self.call("microRecall")
+
+ @property
+ def microF1Measure(self):
+ """
+ Returns micro-averaged label-based f1-measure.
+ (equals to micro-averaged document-based f1-measure)
+ """
+ return self.call("microF1Measure")
+
+ @property
+ def hammingLoss(self):
+ """
+ Returns Hamming-loss.
+ """
+ return self.call("hammingLoss")
+
+ @property
+ def subsetAccuracy(self):
+ """
+ Returns subset accuracy.
+ (for equal sets of labels)
+ """
+ return self.call("subsetAccuracy")
+
+ @property
+ def accuracy(self):
+ """
+ Returns accuracy.
+ """
+ return self.call("accuracy")
+
+
def _test():
import doctest
from pyspark import SparkContext