aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r--python/pyspark/mllib/classification.py65
1 files changed, 59 insertions, 6 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 70de332d34..03ff5a572e 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -15,6 +15,8 @@
# limitations under the License.
#
+import numpy
+
from numpy import array, dot, shape
from pyspark import SparkContext
from pyspark.mllib._common import \
@@ -29,8 +31,8 @@ class LogisticRegressionModel(LinearModel):
"""A linear binary classification model derived from logistic regression.
>>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
- >>> lrm = LogisticRegressionWithSGD.train(sc, sc.parallelize(data))
- >>> lrm.predict(array([1.0])) != None
+ >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
+ >>> lrm.predict(array([1.0])) > 0
True
"""
def predict(self, x):
@@ -41,9 +43,10 @@ class LogisticRegressionModel(LinearModel):
class LogisticRegressionWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0,
+ def train(cls, data, iterations=100, step=1.0,
mini_batch_fraction=1.0, initial_weights=None):
"""Train a logistic regression model on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd,
iterations, step, mini_batch_fraction, i),
@@ -53,8 +56,8 @@ class SVMModel(LinearModel):
"""A support vector machine.
>>> data = array([0.0, 0.0, 1.0, 1.0, 1.0, 2.0, 1.0, 3.0]).reshape(4,2)
- >>> svm = SVMWithSGD.train(sc, sc.parallelize(data))
- >>> svm.predict(array([1.0])) != None
+ >>> svm = SVMWithSGD.train(sc.parallelize(data))
+ >>> svm.predict(array([1.0])) > 0
True
"""
def predict(self, x):
@@ -64,14 +67,64 @@ class SVMModel(LinearModel):
class SVMWithSGD(object):
@classmethod
- def train(cls, sc, data, iterations=100, step=1.0, reg_param=1.0,
+ def train(cls, data, iterations=100, step=1.0, reg_param=1.0,
mini_batch_fraction=1.0, initial_weights=None):
"""Train a support vector machine on the given data."""
+ sc = data.context
return _regression_train_wrapper(sc, lambda d, i:
sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd,
iterations, step, reg_param, mini_batch_fraction, i),
SVMModel, data, initial_weights)
+class NaiveBayesModel(object):
+ """
+ Model for Naive Bayes classifiers.
+
+ Contains two parameters:
+ - pi: vector of logs of class priors (dimension C)
+ - theta: matrix of logs of class conditional probabilities (CxD)
+
+ >>> data = array([0.0, 0.0, 1.0, 0.0, 0.0, 2.0, 1.0, 1.0, 0.0]).reshape(3,3)
+ >>> model = NaiveBayes.train(sc.parallelize(data))
+ >>> model.predict(array([0.0, 1.0]))
+ 0
+ >>> model.predict(array([1.0, 0.0]))
+ 1
+ """
+
+ def __init__(self, pi, theta):
+ self.pi = pi
+ self.theta = theta
+
+ def predict(self, x):
+ """Return the most likely class for a data vector x"""
+ return numpy.argmax(self.pi + dot(x, self.theta))
+
+class NaiveBayes(object):
+ @classmethod
+ def train(cls, data, lambda_=1.0):
+ """
+ Train a Naive Bayes model given an RDD of (label, features) vectors.
+
+ This is the Multinomial NB (U{http://tinyurl.com/lsdw6p}) which can
+ handle all kinds of discrete data. For example, by converting
+ documents into TF-IDF vectors, it can be used for document
+ classification. By making every vector a 0-1 vector, it can also be
+ used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
+
+ @param data: RDD of NumPy vectors, one per element, where the first
+ coordinate is the label and the rest is the feature vector
+ (e.g. a count vector).
+ @param lambda_: The smoothing parameter
+ """
+ sc = data.context
+ dataBytes = _get_unmangled_double_vector_rdd(data)
+ ans = sc._jvm.PythonMLLibAPI().trainNaiveBayes(dataBytes._jrdd, lambda_)
+ return NaiveBayesModel(
+ _deserialize_double_vector(ans[0]),
+ _deserialize_double_matrix(ans[1]))
+
+
def _test():
import doctest
globs = globals().copy()