aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorlewuathe <lewuathe@me.com>2015-10-19 10:46:10 -0700
committerDB Tsai <dbt@netflix.com>2015-10-19 10:46:10 -0700
commit4c33a34ba3167ae67fdb4978ea2166ce65638fb9 (patch)
tree5a5dbae89a230ad0c82acab25ba98e9121b1af6b /mllib/src/main/scala/org
parentdfa41e63b98c28b087c56f94658b5e99e8a7758c (diff)
downloadspark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.gz
spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.tar.bz2
spark-4c33a34ba3167ae67fdb4978ea2166ce65638fb9.zip
[SPARK-10668] [ML] Use WeightedLeastSquares in LinearRegression with L…
…2 regularization if the number of features is small Author: lewuathe <lewuathe@me.com> Author: Lewuathe <sasaki@treasure-data.com> Author: Kai Sasaki <sasaki@treasure-data.com> Author: Lewuathe <lewuathe@me.com> Closes #8884 from Lewuathe/SPARK-10668.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala50
4 files changed, 70 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8cb6b5493c..c7bca12430 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -73,7 +73,9 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."),
ParamDesc[String]("weightCol", "weight column name. If this is not set or empty, we treat " +
- "all instance weights as 1.0."))
+ "all instance weights as 1.0."),
+ ParamDesc[String]("solver", "the solver algorithm for optimization. If this is not set or " +
+ "empty, default value is 'auto'.", Some("\"auto\"")))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index e3625212e5..cb2a060a34 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -357,4 +357,21 @@ private[ml] trait HasWeightCol extends Params {
/** @group getParam */
final def getWeightCol: String = $(weightCol)
}
+
+/**
+ * Trait for shared param solver (default: "auto").
+ */
+private[ml] trait HasSolver extends Params {
+
+ /**
+ * Param for the solver algorithm for optimization. If this is not set or empty, default value is 'auto'..
+ * @group param
+ */
+ final val solver: Param[String] = new Param[String](this, "solver", "the solver algorithm for optimization. If this is not set or empty, default value is 'auto'.")
+
+ setDefault(solver, "auto")
+
+ /** @group getParam */
+ final def getSolver: String = $(solver)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index f5a022c31e..fec61fed3c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -30,13 +30,15 @@ private[r] object SparkRWrappers {
df: DataFrame,
family: String,
lambda: Double,
- alpha: Double): PipelineModel = {
+ alpha: Double,
+ solver: String): PipelineModel = {
val formula = new RFormula().setFormula(value)
val estimator = family match {
case "gaussian" => new LinearRegression()
.setRegParam(lambda)
.setElasticNetParam(alpha)
.setFitIntercept(formula.hasIntercept)
+ .setSolver(solver)
case "binomial" => new LogisticRegression()
.setRegParam(lambda)
.setElasticNetParam(alpha)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index dd09667ef5..573a61a6ea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -25,6 +25,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.feature.Instance
+import org.apache.spark.ml.optim.WeightedLeastSquares
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
@@ -43,7 +44,7 @@ import org.apache.spark.storage.StorageLevel
*/
private[regression] trait LinearRegressionParams extends PredictorParams
with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
- with HasFitIntercept with HasStandardization with HasWeightCol
+ with HasFitIntercept with HasStandardization with HasWeightCol with HasSolver
/**
* :: Experimental ::
@@ -130,9 +131,53 @@ class LinearRegression(override val uid: String)
def setWeightCol(value: String): this.type = set(weightCol, value)
setDefault(weightCol -> "")
+ /**
+ * Set the solver algorithm used for optimization.
+ * In case of linear regression, this can be "l-bfgs", "normal" and "auto".
+ * The default value is "auto" which means that the solver algorithm is
+ * selected automatically.
+ * @group setParam
+ */
+ def setSolver(value: String): this.type = set(solver, value)
+ setDefault(solver -> "auto")
+
override protected def train(dataset: DataFrame): LinearRegressionModel = {
- // Extract columns from data. If dataset is persisted, do not persist instances.
+ // Extract the number of features before deciding optimization solver.
+ val numFeatures = dataset.select(col($(featuresCol))).limit(1).map {
+ case Row(features: Vector) => features.size
+ }.toArray()(0)
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
+
+ if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= 4096) ||
+ $(solver) == "normal") {
+ require($(elasticNetParam) == 0.0, "Only L2 regularization can be used when normal " +
+ "solver is used.'")
+ // For low dimensional data, WeightedLeastSquares is more efficiently since the
+ // training algorithm only requires one pass through the data. (SPARK-10668)
+ val instances: RDD[WeightedLeastSquares.Instance] = dataset.select(
+ col($(labelCol)), w, col($(featuresCol))).map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ WeightedLeastSquares.Instance(weight, features, label)
+ }
+
+ val optimizer = new WeightedLeastSquares($(fitIntercept), $(regParam),
+ $(standardization), true)
+ val model = optimizer.fit(instances)
+ // When it is trained by WeightedLeastSquares, training summary does not
+ // attached returned model.
+ val lrModel = copyValues(new LinearRegressionModel(uid, model.coefficients, model.intercept))
+ // WeightedLeastSquares does not run through iterations. So it does not generate
+ // an objective history.
+ val (summaryModel, predictionColName) = lrModel.findSummaryModelAndPredictionCol()
+ val trainingSummary = new LinearRegressionTrainingSummary(
+ summaryModel.transform(dataset),
+ predictionColName,
+ $(labelCol),
+ $(featuresCol),
+ Array(0D))
+ return lrModel.setSummary(trainingSummary)
+ }
+
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
@@ -155,7 +200,6 @@ class LinearRegression(override val uid: String)
new MultivariateOnlineSummarizer, new MultivariateOnlineSummarizer)(seqOp, combOp)
}
- val numFeatures = featuresSummarizer.mean.size
val yMean = ySummarizer.mean(0)
val yStd = math.sqrt(ySummarizer.variance(0))