aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-13 13:54:16 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-13 13:54:16 -0800
commit32218307edc6de2b08d5f7a0db6d566081d27197 (patch)
tree08be5a64c506f6d8f403dfae3f49ab50e8ac6e01
parent4b0c1edfdf457cde0e39083c47961184059efded (diff)
downloadspark-32218307edc6de2b08d5f7a0db6d566081d27197.tar.gz
spark-32218307edc6de2b08d5f7a0db6d566081d27197.tar.bz2
spark-32218307edc6de2b08d5f7a0db6d566081d27197.zip
[SPARK-4372][MLLIB] Make LR and SVM's default parameters consistent in Scala and Python
The current default regParam is 1.0 and regType is claimed to be none in Python (but actually it is l2), while regParam = 0.0 and regType is L2 in Scala. We should make the default values consistent. This PR sets the default regType to L2 and regParam to 0.01. Note that the default regParam value in LIBLINEAR (and hence scikit-learn) is 1.0. However, we use average loss instead of total loss in our formulation. Hence regParam=1.0 is definitely too heavy. In LinearRegression, we set regParam=0.0 and regType=None, because we have separate classes for Lasso and Ridge, both of which use regParam=0.01 as the default. davies atalwalkar Author: Xiangrui Meng <meng@databricks.com> Closes #3232 from mengxr/SPARK-4372 and squashes the following commits: 9979837 [Xiangrui Meng] update Ridge/Lasso to use default regParam 0.01 cast input arguments d3ba096 [Xiangrui Meng] change 'none' back to None 1909a6e [Xiangrui Meng] change default regParam to 0.01 and regType to L2 in LR and SVM
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala2
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala28
-rw-r--r--python/pyspark/mllib/classification.py36
-rw-r--r--python/pyspark/mllib/regression.py36
10 files changed, 95 insertions, 79 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
index 1edd2432a0..a113653810 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassification.scala
@@ -55,7 +55,7 @@ object BinaryClassification {
stepSize: Double = 1.0,
algorithm: Algorithm = LR,
regType: RegType = L2,
- regParam: Double = 0.1) extends AbstractParams[Params]
+ regParam: Double = 0.01) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
index e1f9622350..6815b1c052 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegression.scala
@@ -47,7 +47,7 @@ object LinearRegression extends App {
numIterations: Int = 100,
stepSize: Double = 1.0,
regType: RegType = L2,
- regParam: Double = 0.1) extends AbstractParams[Params]
+ regParam: Double = 0.01) extends AbstractParams[Params]
val defaultParams = Params()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 70d7138e30..c8476a5370 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -28,22 +28,22 @@ import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
-import org.apache.spark.api.python.{PythonRDD, SerDeUtil}
+import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.feature._
-import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.linalg._
+import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
-import org.apache.spark.mllib.tree.DecisionTree
-import org.apache.spark.mllib.tree.impurity._
-import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.stat.test.ChiSqTestResult
+import org.apache.spark.mllib.tree.DecisionTree
+import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
+import org.apache.spark.mllib.tree.impurity._
+import org.apache.spark.mllib.tree.model.DecisionTreeModel
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
@@ -103,9 +103,11 @@ class PythonMLLibAPI extends Serializable {
lrAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
lrAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
- throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
+ } else if (regType == null) {
+ lrAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
+ throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
lrAlg,
@@ -180,9 +182,11 @@ class PythonMLLibAPI extends Serializable {
SVMAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
SVMAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
+ } else if (regType == null) {
+ SVMAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
SVMAlg,
@@ -213,9 +217,11 @@ class PythonMLLibAPI extends Serializable {
LogRegAlg.optimizer.setUpdater(new SquaredL2Updater)
} else if (regType == "l1") {
LogRegAlg.optimizer.setUpdater(new L1Updater)
- } else if (regType != "none") {
+ } else if (regType == null) {
+ LogRegAlg.optimizer.setUpdater(new SimpleUpdater)
+ } else {
throw new java.lang.IllegalArgumentException("Invalid value for 'regType' parameter."
- + " Can only be initialized using the following string values: [l1, l2, none].")
+ + " Can only be initialized using the following string values: ['l1', 'l2', None].")
}
trainRegressionModel(
LogRegAlg,
@@ -250,7 +256,7 @@ class PythonMLLibAPI extends Serializable {
.setInitializationMode(initializationMode)
// Disable the uncached input warning because 'data' is a deliberately uncached MappedRDD.
.disableUncachedWarning()
- return kMeansAlg.run(data.rdd)
+ kMeansAlg.run(data.rdd)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 84d3c7cebd..18b95f1edc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -71,9 +71,10 @@ class LogisticRegressionModel (
}
/**
- * Train a classification model for Logistic Regression using Stochastic Gradient Descent.
- * NOTE: Labels used in Logistic Regression should be {0, 1}
- *
+ * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By
+ * default L2 regularization is used, which can be changed via
+ * [[LogisticRegressionWithSGD.optimizer]].
+ * NOTE: Labels used in Logistic Regression should be {0, 1}.
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/
class LogisticRegressionWithSGD private (
@@ -93,9 +94,10 @@ class LogisticRegressionWithSGD private (
override protected val validators = List(DataValidators.binaryLabelValidator)
/**
- * Construct a LogisticRegression object with default parameters
+ * Construct a LogisticRegression object with default parameters: {stepSize: 1.0,
+ * numIterations: 100, regParm: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 0.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index 80f8a1b2f1..ab9515b2a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -72,7 +72,8 @@ class SVMModel (
}
/**
- * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent.
+ * Train a Support Vector Machine (SVM) using Stochastic Gradient Descent. By default L2
+ * regularization is used, which can be changed via [[SVMWithSGD.optimizer]].
* NOTE: Labels used in SVM should be {0, 1}.
*/
class SVMWithSGD private (
@@ -92,9 +93,10 @@ class SVMWithSGD private (
override protected val validators = List(DataValidators.binaryLabelValidator)
/**
- * Construct a SVM object with default parameters
+ * Construct a SVM object with default parameters: {stepSize: 1.0, numIterations: 100,
+ * regParm: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new SVMModel(weights, intercept)
@@ -185,6 +187,6 @@ object SVMWithSGD {
* @return a SVMModel which has the weights and offset from training.
*/
def train(input: RDD[LabeledPoint], numIterations: Int): SVMModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
index cb0d39e759..f9791c6571 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/Lasso.scala
@@ -67,9 +67,9 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 1.0, miniBatchFraction: 1.0}.
+ * regParam: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new LassoModel(weights, intercept)
@@ -161,6 +161,6 @@ object LassoWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int): LassoModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
index a826deb695..c8cad773f5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/RidgeRegression.scala
@@ -68,9 +68,9 @@ class RidgeRegressionWithSGD private (
/**
* Construct a RidgeRegression object with default parameters: {stepSize: 1.0, numIterations: 100,
- * regParam: 1.0, miniBatchFraction: 1.0}.
+ * regParam: 0.01, miniBatchFraction: 1.0}.
*/
- def this() = this(1.0, 100, 1.0, 1.0)
+ def this() = this(1.0, 100, 0.01, 1.0)
override protected def createModel(weights: Vector, intercept: Double) = {
new RidgeRegressionModel(weights, intercept)
@@ -143,7 +143,7 @@ object RidgeRegressionWithSGD {
numIterations: Int,
stepSize: Double,
regParam: Double): RidgeRegressionModel = {
- train(input, numIterations, stepSize, regParam, 1.0)
+ train(input, numIterations, stepSize, regParam, 0.01)
}
/**
@@ -158,6 +158,6 @@ object RidgeRegressionWithSGD {
def train(
input: RDD[LabeledPoint],
numIterations: Int): RidgeRegressionModel = {
- train(input, numIterations, 1.0, 1.0, 1.0)
+ train(input, numIterations, 1.0, 0.01, 1.0)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 6c1c784a19..4e81299440 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -80,13 +80,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer.setStepSize(10.0).setNumIterations(20)
+ lr.optimizer
+ .setStepSize(10.0)
+ .setRegParam(0.0)
+ .setNumIterations(20)
val model = lr.run(testRDD)
// Test the weights
- assert(model.weights(0) ~== -1.52 relTol 0.01)
- assert(model.intercept ~== 2.00 relTol 0.01)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -112,10 +115,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD)
// Test the weights
- assert(model.weights(0) ~== -1.52 relTol 0.01)
- assert(model.intercept ~== 2.00 relTol 0.01)
- assert(model.weights(0) ~== model.weights(0) relTol 0.01)
- assert(model.intercept ~== model.intercept relTol 0.01)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -141,13 +142,16 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
// Use half as many iterations as the previous test.
val lr = new LogisticRegressionWithSGD().setIntercept(true)
- lr.optimizer.setStepSize(10.0).setNumIterations(10)
+ lr.optimizer
+ .setStepSize(10.0)
+ .setRegParam(0.0)
+ .setNumIterations(10)
val model = lr.run(testRDD, initialWeights)
// Test the weights
- assert(model.weights(0) ~== -1.50 relTol 0.01)
- assert(model.intercept ~== 1.97 relTol 0.01)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
@@ -212,8 +216,8 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
val model = lr.run(testRDD, initialWeights)
// Test the weights
- assert(model.weights(0) ~== -1.50 relTol 0.02)
- assert(model.intercept ~== 1.97 relTol 0.02)
+ assert(model.weights(0) ~== B relTol 0.02)
+ assert(model.intercept ~== A relTol 0.02)
val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 5d90dddb5d..b654813fb4 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -76,7 +76,7 @@ class LogisticRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=1.0, regType="none", intercept=False):
+ initialWeights=None, regParam=0.01, regType="l2", intercept=False):
"""
Train a logistic regression model on the given data.
@@ -87,16 +87,16 @@ class LogisticRegressionWithSGD(object):
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.01).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- - "l1" for using L1Updater
- - "l2" for using SquaredL2Updater
- - "none" for no regularizer
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
- (default: "none")
+ (default: "l2")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
@@ -104,8 +104,9 @@ class LogisticRegressionWithSGD(object):
are activated or not).
"""
def train(rdd, i):
- return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step,
- miniBatchFraction, i, regParam, regType, intercept)
+ return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
+ float(step), float(miniBatchFraction), i, float(regParam), regType,
+ bool(intercept))
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
@@ -145,8 +146,8 @@ class SVMModel(LinearModel):
class SVMWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
- miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False):
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
+ miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False):
"""
Train a support vector machine on the given data.
@@ -154,7 +155,7 @@ class SVMWithSGD(object):
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.01).
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
@@ -162,11 +163,11 @@ class SVMWithSGD(object):
our model.
:Allowed values:
- - "l1" for using L1Updater
- - "l2" for using SquaredL2Updater,
- - "none" for no regularizer.
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
- (default: "none")
+ (default: "l2")
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
@@ -174,8 +175,9 @@ class SVMWithSGD(object):
are activated or not).
"""
def train(rdd, i):
- return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam,
- miniBatchFraction, i, regType, intercept)
+ return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i, regType,
+ bool(intercept))
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 66e25a48df..f4f5e615fa 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -138,7 +138,7 @@ class LinearRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=1.0, regType="none", intercept=False):
+ initialWeights=None, regParam=0.0, regType=None, intercept=False):
"""
Train a linear regression model on the given data.
@@ -149,16 +149,16 @@ class LinearRegressionWithSGD(object):
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.0).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- - "l1" for using L1Updater,
- - "l2" for using SquaredL2Updater,
- - "none" for no regularizer.
+ - "l1" for using L1 regularization (lasso),
+ - "l2" for using L2 regularization (ridge),
+ - None for no regularization
- (default: "none")
+ (default: None)
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
@@ -166,11 +166,11 @@ class LinearRegressionWithSGD(object):
are activated or not).
"""
def train(rdd, i):
- return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step,
- miniBatchFraction, i, regParam, regType, intercept)
+ return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
+ float(step), float(miniBatchFraction), i, float(regParam),
+ regType, bool(intercept))
- return _regression_train_wrapper(train, LinearRegressionModel,
- data, initialWeights)
+ return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -209,12 +209,13 @@ class LassoModel(LinearRegressionModelBase):
class LassoWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
def train(rdd, i):
- return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam,
- miniBatchFraction, i)
+ return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i)
+
return _regression_train_wrapper(train, LassoModel, data, initialWeights)
@@ -254,15 +255,14 @@ class RidgeRegressionModel(LinearRegressionModelBase):
class RidgeRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
def train(rdd, i):
- return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam,
- miniBatchFraction, i)
+ return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i)
- return _regression_train_wrapper(train, RidgeRegressionModel,
- data, initialWeights)
+ return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights)
def _test():