aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-18 10:11:13 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-18 10:11:13 -0800
commit8fbf72b7903b5bbec8d949151aa4693b4af26ff5 (patch)
tree47878ffe9262b30cc626c173c341caadb88390ce /python
parentcedc3b5aa43a16e2da62f12a36317f00aa1002cc (diff)
downloadspark-8fbf72b7903b5bbec8d949151aa4693b4af26ff5.tar.gz
spark-8fbf72b7903b5bbec8d949151aa4693b4af26ff5.tar.bz2
spark-8fbf72b7903b5bbec8d949151aa4693b4af26ff5.zip
[SPARK-4435] [MLlib] [PySpark] improve classification
This PR add setThrehold() and clearThreshold() for LogisticRegressionModel and SVMModel, also support RDD of vector in LogisticRegressionModel.predict(), SVNModel.predict() and NaiveBayes.predict() Author: Davies Liu <davies@databricks.com> Closes #3305 from davies/setThreshold and squashes the following commits: d0b835f [Davies Liu] Merge branch 'master' of github.com:apache/spark into setThreshold e4acd76 [Davies Liu] address comments 2231a5f [Davies Liu] bugfix 7bd9009 [Davies Liu] address comments 0b0a8a7 [Davies Liu] address comments c1e5573 [Davies Liu] improve classification
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py135
1 files changed, 106 insertions, 29 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index b654813fb4..ee0729b1eb 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -20,6 +20,7 @@ from math import exp
import numpy
from numpy import array
+from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
@@ -29,39 +30,88 @@ __all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
-class LogisticRegressionModel(LinearModel):
+class LinearBinaryClassificationModel(LinearModel):
+ """
+ Represents a linear binary classification model that predicts to whether an
+ example is positive (1.0) or negative (0.0).
+ """
+ def __init__(self, weights, intercept):
+ super(LinearBinaryClassificationModel, self).__init__(weights, intercept)
+ self._threshold = None
+
+ def setThreshold(self, value):
+ """
+ :: Experimental ::
+
+ Sets the threshold that separates positive predictions from negative
+ predictions. An example with prediction score greater than or equal
+ to this threshold is identified as an positive, and negative otherwise.
+ """
+ self._threshold = value
+
+ def clearThreshold(self):
+ """
+ :: Experimental ::
+
+ Clears the threshold so that `predict` will output raw prediction scores.
+ """
+ self._threshold = None
+
+ def predict(self, test):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ raise NotImplementedError
+
+
+class LogisticRegressionModel(LinearBinaryClassificationModel):
"""A linear binary classification model derived from logistic regression.
>>> data = [
- ... LabeledPoint(0.0, [0.0]),
- ... LabeledPoint(1.0, [1.0]),
- ... LabeledPoint(1.0, [2.0]),
- ... LabeledPoint(1.0, [3.0])
+ ... LabeledPoint(0.0, [0.0, 1.0]),
+ ... LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
- >>> lrm.predict(array([1.0])) > 0
- True
- >>> lrm.predict(array([0.0])) <= 0
- True
+ >>> lrm.predict([1.0, 0.0])
+ 1
+ >>> lrm.predict([0.0, 1.0])
+ 0
+ >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
+ [1, 0]
+ >>> lrm.clearThreshold()
+ >>> lrm.predict([0.0, 1.0])
+ 0.123...
+
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
- ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data))
- >>> lrm.predict(array([0.0, 1.0])) > 0
- True
- >>> lrm.predict(array([0.0, 0.0])) <= 0
- True
- >>> lrm.predict(SparseVector(2, {1: 1.0})) > 0
- True
- >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0
- True
+ >>> lrm.predict(array([0.0, 1.0]))
+ 1
+ >>> lrm.predict(array([1.0, 0.0]))
+ 0
+ >>> lrm.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> lrm.predict(SparseVector(2, {0: 1.0}))
+ 0
"""
+ def __init__(self, weights, intercept):
+ super(LogisticRegressionModel, self).__init__(weights, intercept)
+ self._threshold = 0.5
def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
+
x = _convert_to_vector(x)
margin = self.weights.dot(x) + self._intercept
if margin > 0:
@@ -69,7 +119,10 @@ class LogisticRegressionModel(LinearModel):
else:
exp_margin = exp(margin)
prob = exp_margin / (1 + exp_margin)
- return 1 if prob > 0.5 else 0
+ if self._threshold is None:
+ return prob
+ else:
+ return 1 if prob > self._threshold else 0
class LogisticRegressionWithSGD(object):
@@ -111,7 +164,7 @@ class LogisticRegressionWithSGD(object):
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
-class SVMModel(LinearModel):
+class SVMModel(LinearBinaryClassificationModel):
"""A support vector machine.
@@ -122,8 +175,14 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, [3.0])
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(data))
- >>> svm.predict(array([1.0])) > 0
- True
+ >>> svm.predict([1.0])
+ 1
+ >>> svm.predict(sc.parallelize([[1.0]])).collect()
+ [1]
+ >>> svm.clearThreshold()
+ >>> svm.predict(array([1.0]))
+ 1.25...
+
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: -1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
@@ -131,16 +190,29 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(sparse_data))
- >>> svm.predict(SparseVector(2, {1: 1.0})) > 0
- True
- >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0
- True
+ >>> svm.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> svm.predict(SparseVector(2, {0: -1.0}))
+ 0
"""
+ def __init__(self, weights, intercept):
+ super(SVMModel, self).__init__(weights, intercept)
+ self._threshold = 0.0
def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
+
x = _convert_to_vector(x)
margin = self.weights.dot(x) + self.intercept
- return 1 if margin >= 0 else 0
+ if self._threshold is None:
+ return margin
+ else:
+ return 1 if margin > self._threshold else 0
class SVMWithSGD(object):
@@ -201,6 +273,8 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(array([1.0, 0.0]))
1.0
+ >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect()
+ [1.0]
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
... LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
@@ -219,7 +293,9 @@ class NaiveBayesModel(object):
self.theta = theta
def predict(self, x):
- """Return the most likely class for a data vector x"""
+ """Return the most likely class for a data vector or an RDD of vectors"""
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
@@ -250,7 +326,8 @@ class NaiveBayes(object):
def _test():
import doctest
from pyspark import SparkContext
- globs = globals().copy()
+ import pyspark.mllib.classification
+ globs = pyspark.mllib.classification.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()