aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala16
1 files changed, 13 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 15ca2547d5..e391567347 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -111,9 +111,11 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regParam: Double,
regType: String,
- intercept: Boolean): JList[Object] = {
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val lrAlg = new LinearRegressionWithSGD()
lrAlg.setIntercept(intercept)
+ .setValidateData(validateData)
lrAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -135,8 +137,12 @@ private[python] class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): JList[Object] = {
+ initialWeights: Vector,
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val lassoAlg = new LassoWithSGD()
+ lassoAlg.setIntercept(intercept)
+ .setValidateData(validateData)
lassoAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
@@ -157,8 +163,12 @@ private[python] class PythonMLLibAPI extends Serializable {
stepSize: Double,
regParam: Double,
miniBatchFraction: Double,
- initialWeights: Vector): JList[Object] = {
+ initialWeights: Vector,
+ intercept: Boolean,
+ validateData: Boolean): JList[Object] = {
val ridgeAlg = new RidgeRegressionWithSGD()
+ ridgeAlg.setIntercept(intercept)
+ .setValidateData(validateData)
ridgeAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)