aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-05 11:50:09 -0800
committerXiangrui Meng <meng@databricks.com>2015-03-05 11:50:09 -0800
commit0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285 (patch)
tree2b13352131bb3dbd88e4214c6c7728d26898d25e /python
parentc9cfba0cebe3eb546e3e96f3e5b9b89a74c5b7de (diff)
downloadspark-0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285.tar.gz
spark-0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285.tar.bz2
spark-0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285.zip
[SPARK-6090][MLLIB] add a basic BinaryClassificationMetrics to PySpark/MLlib
A simple wrapper around the Scala implementation. `DataFrame` is used for serialization/deserialization. Methods that return `RDD`s are not supported in this PR. davies If we recognize Scala's `Product`s in Py4J, we can easily add wrappers for Scala methods that returns `RDD[(Double, Double)]`. Is it easy to register serializer for `Product` in PySpark? Author: Xiangrui Meng <meng@databricks.com> Closes #4863 from mengxr/SPARK-6090 and squashes the following commits: 009a3a3 [Xiangrui Meng] provide schema dcddab5 [Xiangrui Meng] add a basic BinaryClassificationMetrics to PySpark/MLlib
Diffstat (limited to 'python')
-rw-r--r--python/docs/pyspark.mllib.rst7
-rw-r--r--python/pyspark/mllib/evaluation.py83
-rwxr-xr-xpython/run-tests1
3 files changed, 91 insertions, 0 deletions
diff --git a/python/docs/pyspark.mllib.rst b/python/docs/pyspark.mllib.rst
index b706c5e376..15101470af 100644
--- a/python/docs/pyspark.mllib.rst
+++ b/python/docs/pyspark.mllib.rst
@@ -16,6 +16,13 @@ pyspark.mllib.clustering module
:members:
:undoc-members:
+pyspark.mllib.evaluation module
+-------------------------------
+
+.. automodule:: pyspark.mllib.evaluation
+ :members:
+ :undoc-members:
+
pyspark.mllib.feature module
-------------------------------
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
new file mode 100644
index 0000000000..16cb49cc0c
--- /dev/null
+++ b/python/pyspark/mllib/evaluation.py
@@ -0,0 +1,83 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from pyspark.mllib.common import JavaModelWrapper
+from pyspark.sql import SQLContext
+from pyspark.sql.types import StructField, StructType, DoubleType
+
+
+class BinaryClassificationMetrics(JavaModelWrapper):
+ """
+ Evaluator for binary classification.
+
+ >>> scoreAndLabels = sc.parallelize([
+ ... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
+ >>> metrics = BinaryClassificationMetrics(scoreAndLabels)
+ >>> metrics.areaUnderROC()
+ 0.70...
+ >>> metrics.areaUnderPR()
+ 0.83...
+ >>> metrics.unpersist()
+ """
+
+ def __init__(self, scoreAndLabels):
+ """
+ :param scoreAndLabels: an RDD of (score, label) pairs
+ """
+ sc = scoreAndLabels.ctx
+ sql_ctx = SQLContext(sc)
+ df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
+ StructField("score", DoubleType(), nullable=False),
+ StructField("label", DoubleType(), nullable=False)]))
+ java_class = sc._jvm.org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+ java_model = java_class(df._jdf)
+ super(BinaryClassificationMetrics, self).__init__(java_model)
+
+ def areaUnderROC(self):
+ """
+ Computes the area under the receiver operating characteristic
+ (ROC) curve.
+ """
+ return self.call("areaUnderROC")
+
+ def areaUnderPR(self):
+ """
+ Computes the area under the precision-recall curve.
+ """
+ return self.call("areaUnderPR")
+
+ def unpersist(self):
+ """
+ Unpersists intermediate RDDs used in the computation.
+ """
+ self.call("unpersist")
+
+
+def _test():
+ import doctest
+ from pyspark import SparkContext
+ import pyspark.mllib.evaluation
+ globs = pyspark.mllib.evaluation.__dict__.copy()
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ globs['sc'].stop()
+ if failure_count:
+ exit(-1)
+
+
+if __name__ == "__main__":
+ _test()
diff --git a/python/run-tests b/python/run-tests
index a2c2f37a54..b7630c356c 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -75,6 +75,7 @@ function run_mllib_tests() {
echo "Run mllib tests ..."
run_test "pyspark/mllib/classification.py"
run_test "pyspark/mllib/clustering.py"
+ run_test "pyspark/mllib/evaluation.py"
run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/linalg.py"
run_test "pyspark/mllib/rand.py"