aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorBryan Cutler <bjcutler@us.ibm.com>2015-10-08 22:21:07 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-08 22:21:07 -0700
commit5410747a84e9be1cea44159dfc2216d5e0728ab4 (patch)
tree98b3f9b8ed5e2fddbc7213a8115402abfd21472e /python
parent67fbecbf32fced87d3accd2618fef2af9f44fae2 (diff)
downloadspark-5410747a84e9be1cea44159dfc2216d5e0728ab4.tar.gz
spark-5410747a84e9be1cea44159dfc2216d5e0728ab4.tar.bz2
spark-5410747a84e9be1cea44159dfc2216d5e0728ab4.zip
[SPARK-10959] [PYSPARK] StreamingLogisticRegressionWithSGD does not train with given regParam and convergenceTol parameters
These params were being passed into the StreamingLogisticRegressionWithSGD constructor, but not transferred to the call for model training. Same with StreamingLinearRegressionWithSGD. I added the params as named arguments to the call and also fixed the intercept parameter, which was being passed as regularization value. Author: Bryan Cutler <bjcutler@us.ibm.com> Closes #9002 from BryanCutler/StreamingSGD-convergenceTol-bug-10959.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/classification.py3
-rw-r--r--python/pyspark/mllib/regression.py2
2 files changed, 3 insertions, 2 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index cb4ee83678..b77754500b 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -639,7 +639,8 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
if not rdd.isEmpty():
self._model = LogisticRegressionWithSGD.train(
rdd, self.numIterations, self.stepSize,
- self.miniBatchFraction, self._model.weights)
+ self.miniBatchFraction, self._model.weights,
+ regParam=self.regParam, convergenceTol=self.convergenceTol)
dstream.foreachRDD(update)
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 256b7537fe..961b5e80b0 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -679,7 +679,7 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
self._model = LinearRegressionWithSGD.train(
rdd, self.numIterations, self.stepSize,
self.miniBatchFraction, self._model.weights,
- self._model.intercept)
+ intercept=self._model.intercept, convergenceTol=self.convergenceTol)
dstream.foreachRDD(update)