aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/evaluation.py
blob: 16cb49cc0cfffccbd4022edfd279d78de636da40 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

from pyspark.mllib.common import JavaModelWrapper
from pyspark.sql import SQLContext
from pyspark.sql.types import StructField, StructType, DoubleType


class BinaryClassificationMetrics(JavaModelWrapper):
    """
    Evaluator for binary classification.

    >>> scoreAndLabels = sc.parallelize([
    ...     (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
    >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
    >>> metrics.areaUnderROC()
    0.70...
    >>> metrics.areaUnderPR()
    0.83...
    >>> metrics.unpersist()
    """

    def __init__(self, scoreAndLabels):
        """
        :param scoreAndLabels: an RDD of (score, label) pairs
        """
        sc = scoreAndLabels.ctx
        sql_ctx = SQLContext(sc)
        df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
            StructField("score", DoubleType(), nullable=False),
            StructField("label", DoubleType(), nullable=False)]))
        java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
        java_model = java_class(df._jdf)
        super(BinaryClassificationMetrics, self).__init__(java_model)

    def areaUnderROC(self):
        """
        Computes the area under the receiver operating characteristic
        (ROC) curve.
        """
        return self.call("areaUnderROC")

    def areaUnderPR(self):
        """
        Computes the area under the precision-recall curve.
        """
        return self.call("areaUnderPR")

    def unpersist(self):
        """
        Unpersists intermediate RDDs used in the computation.
        """
        self.call("unpersist")


def _test():
    import doctest
    from pyspark import SparkContext
    import pyspark.mllib.evaluation
    globs = pyspark.mllib.evaluation.__dict__.copy()
    globs['sc'] = SparkContext('local[4]', 'PythonTest')
    (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
    globs['sc'].stop()
    if failure_count:
        exit(-1)


if __name__ == "__main__":
    _test()