aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-08 16:21:28 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-08 16:21:28 -0700
commit381cb161ba4e3a30f2da3c4ef4ee19869d51f101 (patch)
tree75aca18f8b40042db17df575be395be4639963b0 /python
parent4ffc27caaf46ffac56c3c0b3e928f1aff227a184 (diff)
downloadspark-381cb161ba4e3a30f2da3c4ef4ee19869d51f101.tar.gz
spark-381cb161ba4e3a30f2da3c4ef4ee19869d51f101.tar.bz2
spark-381cb161ba4e3a30f2da3c4ef4ee19869d51f101.zip
[SPARK-8068] [MLLIB] Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib
Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib Author: Yanbo Liang <ybliang8@gmail.com> Closes #7286 from yanboliang/spark-8068 and squashes the following commits: 6109fe1 [Yanbo Liang] Add confusionMatrix method at class MulticlassMetrics in pyspark/mllib
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/evaluation.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index c5cf3a4e7f..f21403707e 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -152,6 +152,10 @@ class MulticlassMetrics(JavaModelWrapper):
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
+ >>> metrics.confusionMatrix().toArray()
+ array([[ 2., 1., 1.],
+ [ 1., 3., 0.],
+ [ 0., 0., 1.]])
>>> metrics.falsePositiveRate(0.0)
0.2...
>>> metrics.precision(1.0)
@@ -186,6 +190,13 @@ class MulticlassMetrics(JavaModelWrapper):
java_model = java_class(df._jdf)
super(MulticlassMetrics, self).__init__(java_model)
+ def confusionMatrix(self):
+ """
+ Returns confusion matrix: predicted classes are in columns,
+ they are ordered by class label ascending, as in "labels".
+ """
+ return self.call("confusionMatrix")
+
def truePositiveRate(self, label):
"""
Returns true positive rate for a given label (category).