aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/classification.py
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-05-25 17:15:01 -0700
committerReynold Xin <rxin@apache.org>2014-05-25 17:15:01 -0700
commitd33d3c61ae9e4551aed0217e525a109e678298f2 (patch)
tree109ffeeaf31ae267bbe791051fd39f490af04aa4 /python/pyspark/mllib/classification.py
parent14f0358b2a0a9b92526bdad6d501ab753459eaa0 (diff)
downloadspark-d33d3c61ae9e4551aed0217e525a109e678298f2.tar.gz
spark-d33d3c61ae9e4551aed0217e525a109e678298f2.tar.bz2
spark-d33d3c61ae9e4551aed0217e525a109e678298f2.zip
Fix PEP8 violations in Python mllib.
Author: Reynold Xin <rxin@apache.org> Closes #871 from rxin/mllib-pep8 and squashes the following commits: 848416f [Reynold Xin] Fixed a typo in the previous cleanup (c -> sc). a8db4cd [Reynold Xin] Fix PEP8 violations in Python mllib.
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r--python/pyspark/mllib/classification.py26
1 files changed, 14 insertions, 12 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 6772e4337e..1c0c536c4f 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -29,6 +29,7 @@ from pyspark.mllib.linalg import SparseVector
from pyspark.mllib.regression import LabeledPoint, LinearModel
from math import exp, log
+
class LogisticRegressionModel(LinearModel):
"""A linear binary classification model derived from logistic regression.
@@ -68,14 +69,14 @@ class LogisticRegressionModel(LinearModel):
class LogisticRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0,
- miniBatchFraction=1.0, initialWeights=None):
+ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, initialWeights=None):
"""Train a logistic regression model on the given data."""
sc = data.context
- return _regression_train_wrapper(sc, lambda d, i:
- sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(d._jrdd,
- iterations, step, miniBatchFraction, i),
- LogisticRegressionModel, data, initialWeights)
+ train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainLogisticRegressionModelWithSGD(
+ d._jrdd, iterations, step, miniBatchFraction, i)
+ return _regression_train_wrapper(sc, train_func, LogisticRegressionModel, data,
+ initialWeights)
+
class SVMModel(LinearModel):
"""A support vector machine.
@@ -106,16 +107,17 @@ class SVMModel(LinearModel):
margin = _dot(x, self._coeff) + self._intercept
return 1 if margin >= 0 else 0
+
class SVMWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=1.0,
miniBatchFraction=1.0, initialWeights=None):
"""Train a support vector machine on the given data."""
sc = data.context
- return _regression_train_wrapper(sc, lambda d, i:
- sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(d._jrdd,
- iterations, step, regParam, miniBatchFraction, i),
- SVMModel, data, initialWeights)
+ train_func = lambda d, i: sc._jvm.PythonMLLibAPI().trainSVMModelWithSGD(
+ d._jrdd, iterations, step, regParam, miniBatchFraction, i)
+ return _regression_train_wrapper(sc, train_func, SVMModel, data, initialWeights)
+
class NaiveBayesModel(object):
"""
@@ -156,6 +158,7 @@ class NaiveBayesModel(object):
"""Return the most likely class for a data vector x"""
return self.labels[numpy.argmax(self.pi + _dot(x, self.theta.transpose()))]
+
class NaiveBayes(object):
@classmethod
def train(cls, data, lambda_=1.0):
@@ -186,8 +189,7 @@ def _test():
import doctest
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
- (failure_count, test_count) = doctest.testmod(globs=globs,
- optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)