aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/classification.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/classification.py')
-rw-r--r--python/pyspark/mllib/classification.py17
1 files changed, 9 insertions, 8 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index eda0b60f8b..a70c664a71 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -86,7 +86,7 @@ class LogisticRegressionModel(LinearClassificationModel):
... LabeledPoint(0.0, [0.0, 1.0]),
... LabeledPoint(1.0, [1.0, 0.0]),
... ]
- >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
+ >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data), iterations=10)
>>> lrm.predict([1.0, 0.0])
1
>>> lrm.predict([0.0, 1.0])
@@ -95,7 +95,7 @@ class LogisticRegressionModel(LinearClassificationModel):
[1, 0]
>>> lrm.clearThreshold()
>>> lrm.predict([0.0, 1.0])
- 0.123...
+ 0.279...
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
@@ -103,7 +103,7 @@ class LogisticRegressionModel(LinearClassificationModel):
... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
- >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data))
+ >>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data), iterations=10)
>>> lrm.predict(array([0.0, 1.0]))
1
>>> lrm.predict(array([1.0, 0.0]))
@@ -129,7 +129,8 @@ class LogisticRegressionModel(LinearClassificationModel):
... LabeledPoint(1.0, [1.0, 0.0, 0.0]),
... LabeledPoint(2.0, [0.0, 0.0, 1.0])
... ]
- >>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3)
+ >>> data = sc.parallelize(multi_class_data)
+ >>> mcm = LogisticRegressionWithLBFGS.train(data, iterations=10, numClasses=3)
>>> mcm.predict([0.0, 0.5, 0.0])
0
>>> mcm.predict([0.8, 0.0, 0.0])
@@ -298,7 +299,7 @@ class LogisticRegressionWithLBFGS(object):
... LabeledPoint(0.0, [0.0, 1.0]),
... LabeledPoint(1.0, [1.0, 0.0]),
... ]
- >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
+ >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data), iterations=10)
>>> lrm.predict([1.0, 0.0])
1
>>> lrm.predict([0.0, 1.0])
@@ -330,14 +331,14 @@ class SVMModel(LinearClassificationModel):
... LabeledPoint(1.0, [2.0]),
... LabeledPoint(1.0, [3.0])
... ]
- >>> svm = SVMWithSGD.train(sc.parallelize(data))
+ >>> svm = SVMWithSGD.train(sc.parallelize(data), iterations=10)
>>> svm.predict([1.0])
1
>>> svm.predict(sc.parallelize([[1.0]])).collect()
[1]
>>> svm.clearThreshold()
>>> svm.predict(array([1.0]))
- 1.25...
+ 1.44...
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: -1.0})),
@@ -345,7 +346,7 @@ class SVMModel(LinearClassificationModel):
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
- >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data))
+ >>> svm = SVMWithSGD.train(sc.parallelize(sparse_data), iterations=10)
>>> svm.predict(SparseVector(2, {1: 1.0}))
1
>>> svm.predict(SparseVector(2, {0: -1.0}))