aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/evaluation.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/evaluation.py')
-rw-r--r--python/pyspark/mllib/evaluation.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index aab5e5f4b7..c5cf3a4e7f 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -27,6 +27,8 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
Evaluator for binary classification.
+ :param scoreAndLabels: an RDD of (score, label) pairs
+
>>> scoreAndLabels = sc.parallelize([
... (0.1, 0.0), (0.1, 1.0), (0.4, 0.0), (0.6, 0.0), (0.6, 1.0), (0.6, 1.0), (0.8, 1.0)], 2)
>>> metrics = BinaryClassificationMetrics(scoreAndLabels)
@@ -38,9 +40,6 @@ class BinaryClassificationMetrics(JavaModelWrapper):
"""
def __init__(self, scoreAndLabels):
- """
- :param scoreAndLabels: an RDD of (score, label) pairs
- """
sc = scoreAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(scoreAndLabels, schema=StructType([
@@ -76,6 +75,9 @@ class RegressionMetrics(JavaModelWrapper):
"""
Evaluator for regression.
+ :param predictionAndObservations: an RDD of (prediction,
+ observation) pairs.
+
>>> predictionAndObservations = sc.parallelize([
... (2.5, 3.0), (0.0, -0.5), (2.0, 2.0), (8.0, 7.0)])
>>> metrics = RegressionMetrics(predictionAndObservations)
@@ -92,9 +94,6 @@ class RegressionMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndObservations):
- """
- :param predictionAndObservations: an RDD of (prediction, observation) pairs.
- """
sc = predictionAndObservations.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndObservations, schema=StructType([
@@ -148,6 +147,8 @@ class MulticlassMetrics(JavaModelWrapper):
"""
Evaluator for multiclass classification.
+ :param predictionAndLabels an RDD of (prediction, label) pairs.
+
>>> predictionAndLabels = sc.parallelize([(0.0, 0.0), (0.0, 1.0), (0.0, 0.0),
... (1.0, 0.0), (1.0, 1.0), (1.0, 1.0), (1.0, 1.0), (2.0, 2.0), (2.0, 0.0)])
>>> metrics = MulticlassMetrics(predictionAndLabels)
@@ -176,9 +177,6 @@ class MulticlassMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels an RDD of (prediction, label) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels, schema=StructType([
@@ -277,6 +275,9 @@ class RankingMetrics(JavaModelWrapper):
"""
Evaluator for ranking algorithms.
+ :param predictionAndLabels: an RDD of (predicted ranking,
+ ground truth set) pairs.
+
>>> predictionAndLabels = sc.parallelize([
... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
@@ -298,9 +299,6 @@ class RankingMetrics(JavaModelWrapper):
"""
def __init__(self, predictionAndLabels):
- """
- :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
- """
sc = predictionAndLabels.ctx
sql_ctx = SQLContext(sc)
df = sql_ctx.createDataFrame(predictionAndLabels,
@@ -347,6 +345,10 @@ class MultilabelMetrics(JavaModelWrapper):
"""
Evaluator for multilabel classification.
+ :param predictionAndLabels: an RDD of (predictions, labels) pairs,
+ both are non-null Arrays, each with
+ unique elements.
+
>>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]),
... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]),
... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])])