aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/mllib/evaluation.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index c5cf3a4e7f..f21403707e 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper):
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
+ >>> metrics.confusionMatrix().toArray()
+ array([[ 2., 1., 1.],
+ [ 1., 3., 0.],
+ [ 0., 0., 1.]])
>>> metrics.falsePositiveRate(0.0)
0.2...
>>> metrics.precision(1.0)
@@ -186,6 +190,13 @@ class MulticlassMetrics(JavaModelWrapper):
java_model = java_class(df._jdf)
super(MulticlassMetrics, self).__init__(java_model)
+ def confusionMatrix(self):
+ """
+ Returns confusion matrix: predicted classes are in columns,
+ they are ordered by class label ascending, as in "labels".
+ """
+ return self.call("confusionMatrix")
+
def truePositiveRate(self, label):
"""
Returns true positive rate for a given label (category).