aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBryan Cutler <bjcutler@us.ibm.com>2015-10-08 22:21:07 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-08 22:21:07 -0700
commit5410747a84e9be1cea44159dfc2216d5e0728ab4 (patch)
tree98b3f9b8ed5e2fddbc7213a8115402abfd21472e
parent67fbecbf32fced87d3accd2618fef2af9f44fae2 (diff)
downloadspark-5410747a84e9be1cea44159dfc2216d5e0728ab4.tar.gz
spark-5410747a84e9be1cea44159dfc2216d5e0728ab4.tar.bz2
spark-5410747a84e9be1cea44159dfc2216d5e0728ab4.zip
[SPARK-10959] [PYSPARK] StreamingLogisticRegressionWithSGD does not train with given regParam and convergenceTol parameters
These params were being passed into the StreamingLogisticRegressionWithSGD constructor, but not transferred to the call for model training. Same with StreamingLinearRegressionWithSGD. I added the params as named arguments to the call and also fixed the intercept parameter, which was being passed as regularization value. Author: Bryan Cutler <bjcutler@us.ibm.com> Closes #9002 from BryanCutler/StreamingSGD-convergenceTol-bug-10959.
-rw-r--r--python/pyspark/mllib/classification.py3
-rw-r--r--python/pyspark/mllib/regression.py2
2 files changed, 3 insertions, 2 deletions
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index cb4ee83678..b77754500b 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -639,7 +639,8 @@ class StreamingLogisticRegressionWithSGD(StreamingLinearAlgorithm):
if not rdd.isEmpty():
self._model = LogisticRegressionWithSGD.train(
rdd, self.numIterations, self.stepSize,
- self.miniBatchFraction, self._model.weights)
+ self.miniBatchFraction, self._model.weights,
+ regParam=self.regParam, convergenceTol=self.convergenceTol)
dstream.foreachRDD(update)
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 256b7537fe..961b5e80b0 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -679,7 +679,7 @@ class StreamingLinearRegressionWithSGD(StreamingLinearAlgorithm):
self._model = LinearRegressionWithSGD.train(
rdd, self.numIterations, self.stepSize,
self.miniBatchFraction, self._model.weights,
- self._model.intercept)
+ intercept=self._model.intercept, convergenceTol=self.convergenceTol)
dstream.foreachRDD(update)