diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-05 11:45:37 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-05 11:45:37 -0700 |
commit | ee374e89cd1f08730fed9d50b742627d5b19d241 (patch) | |
tree | 9912c353fe5e563bbf7ced6dc0e0c20f20272d5d /python/pyspark/ml/evaluation.py | |
parent | 18340d7be55a6834918956555bf820c96769aa52 (diff) | |
download | spark-ee374e89cd1f08730fed9d50b742627d5b19d241.tar.gz spark-ee374e89cd1f08730fed9d50b742627d5b19d241.tar.bz2 spark-ee374e89cd1f08730fed9d50b742627d5b19d241.zip |
[SPARK-7333] [MLLIB] Add BinaryClassificationEvaluator to PySpark
This PR adds `BinaryClassificationEvaluator` to Python ML Pipelines API, which is a simple wrapper of the Scala implementation. oefirouz
Author: Xiangrui Meng <meng@databricks.com>
Closes #5885 from mengxr/SPARK-7333 and squashes the following commits:
25d7451 [Xiangrui Meng] fix tests in python 3
babdde7 [Xiangrui Meng] fix doc
cb51e6a [Xiangrui Meng] add BinaryClassificationEvaluator in PySpark
Diffstat (limited to 'python/pyspark/ml/evaluation.py')
-rw-r--r-- | python/pyspark/ml/evaluation.py | 107 |
1 files changed, 107 insertions, 0 deletions
diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py new file mode 100644 index 0000000000..02020ebff9 --- /dev/null +++ b/python/pyspark/ml/evaluation.py @@ -0,0 +1,107 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark.ml.wrapper import JavaEvaluator +from pyspark.ml.param import Param, Params +from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol +from pyspark.ml.util import keyword_only +from pyspark.mllib.common import inherit_doc + +__all__ = ['BinaryClassificationEvaluator'] + + +@inherit_doc +class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPredictionCol): + """ + Evaluator for binary classification, which expects two input + columns: rawPrediction and label. + + >>> from pyspark.mllib.linalg import Vectors + >>> scoreAndLabels = map(lambda x: (Vectors.dense([1.0 - x[0], x[0]]), x[1]), + ... [(0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)]) + >>> dataset = sqlContext.createDataFrame(scoreAndLabels, ["raw", "label"]) + ... + >>> evaluator = BinaryClassificationEvaluator(rawPredictionCol="raw") + >>> evaluator.evaluate(dataset) + 0.70... + >>> evaluator.evaluate(dataset, {evaluator.metricName: "areaUnderPR"}) + 0.83... + """ + + _java_class = "org.apache.spark.ml.evaluation.BinaryClassificationEvaluator" + + # a placeholder to make it appear in the generated doc + metricName = Param(Params._dummy(), "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)") + + @keyword_only + def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC"): + """ + __init__(self, rawPredictionCol="rawPrediction", labelCol="label", \ + metricName="areaUnderROC") + """ + super(BinaryClassificationEvaluator, self).__init__() + #: param for metric name in evaluation (areaUnderROC|areaUnderPR) + self.metricName = Param(self, "metricName", + "metric name in evaluation (areaUnderROC|areaUnderPR)") + self._setDefault(rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC") + kwargs = self.__init__._input_kwargs + self._set(**kwargs) + + def setMetricName(self, value): + """ + Sets the value of :py:attr:`metricName`. + """ + self.paramMap[self.metricName] = value + return self + + def getMetricName(self): + """ + Gets the value of metricName or its default value. + """ + return self.getOrDefault(self.metricName) + + @keyword_only + def setParams(self, rawPredictionCol="rawPrediction", labelCol="label", + metricName="areaUnderROC"): + """ + setParams(self, rawPredictionCol="rawPrediction", labelCol="label", \ + metricName="areaUnderROC") + Sets params for binary classification evaluator. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + +if __name__ == "__main__": + import doctest + from pyspark.context import SparkContext + from pyspark.sql import SQLContext + globs = globals().copy() + # The small batch size here ensures that we see multiple batches, + # even in these small test examples: + sc = SparkContext("local[2]", "ml.evaluation tests") + sqlContext = SQLContext(sc) + globs['sc'] = sc + globs['sqlContext'] = sqlContext + (failure_count, test_count) = doctest.testmod( + globs=globs, optionflags=doctest.ELLIPSIS) + sc.stop() + if failure_count: + exit(-1) |