aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2015-04-28 09:46:08 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-28 09:46:08 -0700
commit6a827d5d1ec520f129e42c3818fe7d0d870dcbef (patch)
treebfe32cbf65e5eede7994fa41e8c28e08853c4e60
parent268c419f1586110b90e68f98cd000a782d18828c (diff)
downloadspark-6a827d5d1ec520f129e42c3818fe7d0d870dcbef.tar.gz
spark-6a827d5d1ec520f129e42c3818fe7d0d870dcbef.tar.bz2
spark-6a827d5d1ec520f129e42c3818fe7d0d870dcbef.zip
[SPARK-5253] [ML] LinearRegression with L1/L2 (ElasticNet) using OWLQN
Author: DB Tsai <dbt@netflix.com> Author: DB Tsai <dbtsai@alpinenow.com> Closes #4259 from dbtsai/lir and squashes the following commits: a81c201 [DB Tsai] add import org.apache.spark.util.Utils back 9fc48ed [DB Tsai] rebase 2178b63 [DB Tsai] add comments 9988ca8 [DB Tsai] addressed feedback and fixed a bug. TODO: documentation and build another synthetic dataset which can catch the bug fixed in this commit. fcbaefe [DB Tsai] Refactoring 4eb078d [DB Tsai] first commit
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala34
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala304
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala43
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala158
8 files changed, 508 insertions, 64 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index e88c48741e..3f7e8f5a0b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -46,7 +46,9 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("outputCol", "output column name"),
ParamDesc[Int]("checkpointInterval", "checkpoint interval"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
- ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")))
+ ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
+ ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
+ ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index a860b8834c..7d2c76d6c6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -276,4 +276,38 @@ trait HasSeed extends Params {
/** @group getParam */
final def getSeed: Long = getOrDefault(seed)
}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param elasticNetParam.
+ */
+@DeveloperApi
+trait HasElasticNetParam extends Params {
+
+ /**
+ * Param for the ElasticNet mixing parameter.
+ * @group param
+ */
+ final val elasticNetParam: DoubleParam = new DoubleParam(this, "elasticNetParam", "the ElasticNet mixing parameter")
+
+ /** @group getParam */
+ final def getElasticNetParam: Double = getOrDefault(elasticNetParam)
+}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param tol.
+ */
+@DeveloperApi
+trait HasTol extends Params {
+
+ /**
+ * Param for the convergence tolerance for iterative algorithms.
+ * @group param
+ */
+ final val tol: DoubleParam = new DoubleParam(this, "tol", "the convergence tolerance for iterative algorithms")
+
+ /** @group getParam */
+ final def getTol: Double = getOrDefault(tol)
+}
// scalastyle:on
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index 26ca7459c4..f92c6816eb 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -17,21 +17,29 @@
package org.apache.spark.ml.regression
+import scala.collection.mutable
+
+import breeze.linalg.{norm => brzNorm, DenseVector => BDV}
+import breeze.optimize.{LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
+import breeze.optimize.{CachedDiffFunction, DiffFunction}
+
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param.{Params, ParamMap}
-import org.apache.spark.ml.param.shared._
-import org.apache.spark.mllib.linalg.{BLAS, Vector}
-import org.apache.spark.mllib.regression.LinearRegressionWithSGD
+import org.apache.spark.ml.param.shared.{HasElasticNetParam, HasMaxIter, HasRegParam, HasTol}
+import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.BLAS._
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
-
+import org.apache.spark.util.StatCounter
/**
* Params for linear regression.
*/
private[regression] trait LinearRegressionParams extends RegressorParams
- with HasRegParam with HasMaxIter
-
+ with HasRegParam with HasElasticNetParam with HasMaxIter with HasTol
/**
* :: AlphaComponent ::
@@ -42,34 +50,119 @@ private[regression] trait LinearRegressionParams extends RegressorParams
class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
with LinearRegressionParams {
- setDefault(regParam -> 0.1, maxIter -> 100)
-
- /** @group setParam */
+ /**
+ * Set the regularization parameter.
+ * Default is 0.0.
+ * @group setParam
+ */
def setRegParam(value: Double): this.type = set(regParam, value)
+ setDefault(regParam -> 0.0)
+
+ /**
+ * Set the ElasticNet mixing parameter.
+ * For alpha = 0, the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.
+ * For 0 < alpha < 1, the penalty is a combination of L1 and L2.
+ * Default is 0.0 which is an L2 penalty.
+ * @group setParam
+ */
+ def setElasticNetParam(value: Double): this.type = set(elasticNetParam, value)
+ setDefault(elasticNetParam -> 0.0)
- /** @group setParam */
+ /**
+ * Set the maximal number of iterations.
+ * Default is 100.
+ * @group setParam
+ */
def setMaxIter(value: Int): this.type = set(maxIter, value)
+ setDefault(maxIter -> 100)
+
+ /**
+ * Set the convergence tolerance of iterations.
+ * Smaller value will lead to higher accuracy with the cost of more iterations.
+ * Default is 1E-6.
+ * @group setParam
+ */
+ def setTol(value: Double): this.type = set(tol, value)
+ setDefault(tol -> 1E-6)
override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
- // Extract columns from data. If dataset is persisted, do not persist oldDataset.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ // Extract columns from data. If dataset is persisted, do not persist instances.
+ val instances = extractLabeledPoints(dataset, paramMap).map {
+ case LabeledPoint(label: Double, features: Vector) => (label, features)
+ }
val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) {
- oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ instances.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ val (summarizer, statCounter) = instances.treeAggregate(
+ (new MultivariateOnlineSummarizer, new StatCounter))( {
+ case ((summarizer: MultivariateOnlineSummarizer, statCounter: StatCounter),
+ (label: Double, features: Vector)) =>
+ (summarizer.add(features), statCounter.merge(label))
+ }, {
+ case ((summarizer1: MultivariateOnlineSummarizer, statCounter1: StatCounter),
+ (summarizer2: MultivariateOnlineSummarizer, statCounter2: StatCounter)) =>
+ (summarizer1.merge(summarizer2), statCounter1.merge(statCounter2))
+ })
+
+ val numFeatures = summarizer.mean.size
+ val yMean = statCounter.mean
+ val yStd = math.sqrt(statCounter.variance)
+
+ val featuresMean = summarizer.mean.toArray
+ val featuresStd = summarizer.variance.toArray.map(math.sqrt)
+
+ // Since we implicitly do the feature scaling when we compute the cost function
+ // to improve the convergence, the effective regParam will be changed.
+ val effectiveRegParam = paramMap(regParam) / yStd
+ val effectiveL1RegParam = paramMap(elasticNetParam) * effectiveRegParam
+ val effectiveL2RegParam = (1.0 - paramMap(elasticNetParam)) * effectiveRegParam
+
+ val costFun = new LeastSquaresCostFun(instances, yStd, yMean,
+ featuresStd, featuresMean, effectiveL2RegParam)
+
+ val optimizer = if (paramMap(elasticNetParam) == 0.0 || effectiveRegParam == 0.0) {
+ new BreezeLBFGS[BDV[Double]](paramMap(maxIter), 10, paramMap(tol))
+ } else {
+ new BreezeOWLQN[Int, BDV[Double]](paramMap(maxIter), 10, effectiveL1RegParam, paramMap(tol))
+ }
+
+ val initialWeights = Vectors.zeros(numFeatures)
+ val states =
+ optimizer.iterations(new CachedDiffFunction(costFun), initialWeights.toBreeze.toDenseVector)
+
+ var state = states.next()
+ val lossHistory = mutable.ArrayBuilder.make[Double]
+
+ while (states.hasNext) {
+ lossHistory += state.value
+ state = states.next()
+ }
+ lossHistory += state.value
+
+ // TODO: Based on the sparsity of weights, we may convert the weights to the sparse vector.
+ // The weights are trained in the scaled space; we're converting them back to
+ // the original space.
+ val weights = {
+ val rawWeights = state.x.toArray.clone()
+ var i = 0
+ while (i < rawWeights.length) {
+ rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
+ i += 1
+ }
+ Vectors.dense(rawWeights)
}
- // Train model
- val lr = new LinearRegressionWithSGD()
- lr.optimizer
- .setRegParam(paramMap(regParam))
- .setNumIterations(paramMap(maxIter))
- val model = lr.run(oldDataset)
- val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
+ // The intercept in R's GLMNET is computed using closed form after the coefficients are
+ // converged. See the following discussion for detail.
+ // http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
+ val intercept = yMean - dot(weights, Vectors.dense(featuresMean))
if (handlePersistence) {
- oldDataset.unpersist()
+ instances.unpersist()
}
- lrm
+ new LinearRegressionModel(this, paramMap, weights, intercept)
}
}
@@ -88,7 +181,7 @@ class LinearRegressionModel private[ml] (
with LinearRegressionParams {
override protected def predict(features: Vector): Double = {
- BLAS.dot(features, weights) + intercept
+ dot(features, weights) + intercept
}
override protected def copy(): LinearRegressionModel = {
@@ -97,3 +190,168 @@ class LinearRegressionModel private[ml] (
m
}
}
+
+/**
+ * LeastSquaresAggregator computes the gradient and loss for a Least-squared loss function,
+ * as used in linear regression for samples in sparse or dense vector in a online fashion.
+ *
+ * Two LeastSquaresAggregator can be merged together to have a summary of loss and gradient of
+ * the corresponding joint dataset.
+ *
+
+ * * Compute gradient and loss for a Least-squared loss function, as used in linear regression.
+ * This is correct for the averaged least squares loss function (mean squared error)
+ * L = 1/2n ||A weights-y||^2
+ * See also the documentation for the precise formulation.
+ *
+ * @param weights weights/coefficients corresponding to features
+ *
+ * @param updater Updater to be used to update weights after every iteration.
+ */
+private class LeastSquaresAggregator(
+ weights: Vector,
+ labelStd: Double,
+ labelMean: Double,
+ featuresStd: Array[Double],
+ featuresMean: Array[Double]) extends Serializable {
+
+ private var totalCnt: Long = 0
+ private var lossSum = 0.0
+ private var diffSum = 0.0
+
+ private val (effectiveWeightsArray: Array[Double], offset: Double, dim: Int) = {
+ val weightsArray = weights.toArray.clone()
+ var sum = 0.0
+ var i = 0
+ while (i < weightsArray.length) {
+ if (featuresStd(i) != 0.0) {
+ weightsArray(i) /= featuresStd(i)
+ sum += weightsArray(i) * featuresMean(i)
+ } else {
+ weightsArray(i) = 0.0
+ }
+ i += 1
+ }
+ (weightsArray, -sum + labelMean / labelStd, weightsArray.length)
+ }
+ private val effectiveWeightsVector = Vectors.dense(effectiveWeightsArray)
+
+ private val gradientSumArray: Array[Double] = Array.ofDim[Double](dim)
+
+ /**
+ * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
+ * of the objective function.
+ *
+ * @param label The label for this data point.
+ * @param data The features for one data point in dense/sparse vector format to be added
+ * into this aggregator.
+ * @return This LeastSquaresAggregator object.
+ */
+ def add(label: Double, data: Vector): this.type = {
+ require(dim == data.size, s"Dimensions mismatch when adding new sample." +
+ s" Expecting $dim but got ${data.size}.")
+
+ val diff = dot(data, effectiveWeightsVector) - label / labelStd + offset
+
+ if (diff != 0) {
+ val localGradientSumArray = gradientSumArray
+ data.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ localGradientSumArray(index) += diff * value / featuresStd(index)
+ }
+ }
+ lossSum += diff * diff / 2.0
+ diffSum += diff
+ }
+
+ totalCnt += 1
+ this
+ }
+
+ /**
+ * Merge another LeastSquaresAggregator, and update the loss and gradient
+ * of the objective function.
+ * (Note that it's in place merging; as a result, `this` object will be modified.)
+ *
+ * @param other The other LeastSquaresAggregator to be merged.
+ * @return This LeastSquaresAggregator object.
+ */
+ def merge(other: LeastSquaresAggregator): this.type = {
+ require(dim == other.dim, s"Dimensions mismatch when merging with another " +
+ s"LeastSquaresAggregator. Expecting $dim but got ${other.dim}.")
+
+ if (other.totalCnt != 0) {
+ totalCnt += other.totalCnt
+ lossSum += other.lossSum
+ diffSum += other.diffSum
+
+ var i = 0
+ val localThisGradientSumArray = this.gradientSumArray
+ val localOtherGradientSumArray = other.gradientSumArray
+ while (i < dim) {
+ localThisGradientSumArray(i) += localOtherGradientSumArray(i)
+ i += 1
+ }
+ }
+ this
+ }
+
+ def count: Long = totalCnt
+
+ def loss: Double = lossSum / totalCnt
+
+ def gradient: Vector = {
+ val result = Vectors.dense(gradientSumArray.clone())
+
+ val correction = {
+ val temp = effectiveWeightsArray.clone()
+ var i = 0
+ while (i < temp.length) {
+ temp(i) *= featuresMean(i)
+ i += 1
+ }
+ Vectors.dense(temp)
+ }
+
+ axpy(-diffSum, correction, result)
+ scal(1.0 / totalCnt, result)
+ result
+ }
+}
+
+/**
+ * LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost.
+ * It returns the loss and gradient with L2 regularization at a particular point (weights).
+ * It's used in Breeze's convex optimization routines.
+ */
+private class LeastSquaresCostFun(
+ data: RDD[(Double, Vector)],
+ labelStd: Double,
+ labelMean: Double,
+ featuresStd: Array[Double],
+ featuresMean: Array[Double],
+ effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
+
+ override def calculate(weights: BDV[Double]): (Double, BDV[Double]) = {
+ val w = Vectors.fromBreeze(weights)
+
+ val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(w, labelStd,
+ labelMean, featuresStd, featuresMean))(
+ seqOp = (c, v) => (c, v) match {
+ case (aggregator, (label, features)) => aggregator.add(label, features)
+ },
+ combOp = (c1, c2) => (c1, c2) match {
+ case (aggregator1, aggregator2) => aggregator1.merge(aggregator2)
+ })
+
+ // regVal is the sum of weight squares for L2 regularization
+ val norm = brzNorm(weights, 2.0)
+ val regVal = 0.5 * effectiveL2regParam * norm * norm
+
+ val loss = leastSquaresAggregator.loss + regVal
+ val gradient = leastSquaresAggregator.gradient
+ axpy(effectiveL2regParam, w, gradient)
+
+ (loss, gradient.toBreeze.asInstanceOf[BDV[Double]])
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 166c00cff6..af0cfe22ca 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -85,7 +85,7 @@ sealed trait Vector extends Serializable {
/**
* Converts the instance to a breeze vector.
*/
- private[mllib] def toBreeze: BV[Double]
+ private[spark] def toBreeze: BV[Double]
/**
* Gets the value of the ith element.
@@ -284,7 +284,7 @@ object Vectors {
/**
* Creates a vector instance from a breeze vector.
*/
- private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = {
+ private[spark] def fromBreeze(breezeVector: BV[Double]): Vector = {
breezeVector match {
case v: BDV[Double] =>
if (v.offset == 0 && v.stride == 1 && v.length == v.data.length) {
@@ -483,7 +483,7 @@ class DenseVector(val values: Array[Double]) extends Vector {
override def toArray: Array[Double] = values
- private[mllib] override def toBreeze: BV[Double] = new BDV[Double](values)
+ private[spark] override def toBreeze: BV[Double] = new BDV[Double](values)
override def apply(i: Int): Double = values(i)
@@ -543,7 +543,7 @@ class SparseVector(
new SparseVector(size, indices.clone(), values.clone())
}
- private[mllib] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
+ private[spark] override def toBreeze: BV[Double] = new BSV[Double](indices, values, size)
private[spark] override def foreachActive(f: (Int, Double) => Unit) = {
var i = 0
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 8bfa0d2b64..240baeb5a1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -37,7 +37,11 @@ abstract class Gradient extends Serializable {
*
* @return (gradient: Vector, loss: Double)
*/
- def compute(data: Vector, label: Double, weights: Vector): (Vector, Double)
+ def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
+ val gradient = Vectors.zeros(weights.size)
+ val loss = compute(data, label, weights, gradient)
+ (gradient, loss)
+ }
/**
* Compute the gradient and loss given the features of a single data point,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index ef6eccd907..efedc112d3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -17,6 +17,7 @@
package org.apache.spark.mllib.optimization
+import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer
import breeze.linalg.{DenseVector => BDV}
@@ -164,7 +165,7 @@ object LBFGS extends Logging {
regParam: Double,
initialWeights: Vector): (Vector, Array[Double]) = {
- val lossHistory = new ArrayBuffer[Double](maxNumIterations)
+ val lossHistory = mutable.ArrayBuilder.make[Double]
val numExamples = data.count()
@@ -181,17 +182,19 @@ object LBFGS extends Logging {
* and regVal is the regularization value computed in the previous iteration as well.
*/
var state = states.next()
- while(states.hasNext) {
- lossHistory.append(state.value)
+ while (states.hasNext) {
+ lossHistory += state.value
state = states.next()
}
- lossHistory.append(state.value)
+ lossHistory += state.value
val weights = Vectors.fromBreeze(state.x)
+ val lossHistoryArray = lossHistory.result()
+
logInfo("LBFGS.runLBFGS finished. Last 10 losses %s".format(
- lossHistory.takeRight(10).mkString(", ")))
+ lossHistoryArray.takeRight(10).mkString(", ")))
- (weights, lossHistory.toArray)
+ (weights, lossHistoryArray)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index c9d33787b0..d7bb943e84 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -56,6 +56,10 @@ object LinearDataGenerator {
}
/**
+ * For compatibility, the generated data without specifying the mean and variance
+ * will have zero mean and variance of (1.0/3.0) since the original output range is
+ * [-1, 1] with uniform distribution, and the variance of uniform distribution
+ * is (b - a)^2^ / 12 which will be (1.0/3.0)
*
* @param intercept Data intercept
* @param weights Weights to be applied.
@@ -70,10 +74,47 @@ object LinearDataGenerator {
nPoints: Int,
seed: Int,
eps: Double = 0.1): Seq[LabeledPoint] = {
+ generateLinearInput(intercept, weights,
+ Array.fill[Double](weights.size)(0.0),
+ Array.fill[Double](weights.size)(1.0 / 3.0),
+ nPoints, seed, eps)}
+
+ /**
+ *
+ * @param intercept Data intercept
+ * @param weights Weights to be applied.
+ * @param xMean the mean of the generated features. Lots of time, if the features are not properly
+ * standardized, the algorithm with poor implementation will have difficulty
+ * to converge.
+ * @param xVariance the variance of the generated features.
+ * @param nPoints Number of points in sample.
+ * @param seed Random seed
+ * @param eps Epsilon scaling factor.
+ * @return Seq of input.
+ */
+ def generateLinearInput(
+ intercept: Double,
+ weights: Array[Double],
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ nPoints: Int,
+ seed: Int,
+ eps: Double): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val x = Array.fill[Array[Double]](nPoints)(
- Array.fill[Double](weights.length)(2 * rnd.nextDouble - 1.0))
+ Array.fill[Double](weights.length)(rnd.nextDouble))
+
+ x.map(vector => {
+ // This doesn't work if `vector` is a sparse vector.
+ val vectorArray = vector.toArray
+ var i = 0
+ while (i < vectorArray.size) {
+ vectorArray(i) = (vectorArray(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ i += 1
+ }
+ })
+
val y = x.map { xi =>
blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index bbb44c3e2d..80323ef520 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -19,47 +19,149 @@ package org.apache.spark.ml.regression
import org.scalatest.FunSuite
-import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{DataFrame, SQLContext}
+import org.apache.spark.mllib.linalg.DenseVector
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext, DataFrame}
class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
+ /**
+ * In `LinearRegressionSuite`, we will make sure that the model trained by SparkML
+ * is the same as the one trained by R's glmnet package. The following instruction
+ * describes how to reproduce the data in R.
+ *
+ * import org.apache.spark.mllib.util.LinearDataGenerator
+ * val data =
+ * sc.parallelize(LinearDataGenerator.generateLinearInput(6.3, Array(4.7, 7.2), 10000, 42), 2)
+ * data.map(x=> x.label + ", " + x.features(0) + ", " + x.features(1)).saveAsTextFile("path")
+ */
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
- sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
+ sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 10000, 42, 0.1), 2))
}
- test("linear regression: default params") {
- val lr = new LinearRegression
- assert(lr.getLabelCol == "label")
- val model = lr.fit(dataset)
- model.transform(dataset)
- .select("label", "prediction")
- .collect()
- // Check defaults
- assert(model.getFeaturesCol == "features")
- assert(model.getPredictionCol == "prediction")
+ test("linear regression with intercept without regularization") {
+ val trainer = new LinearRegression
+ val model = trainer.fit(dataset)
+
+ /**
+ * Using the following R code to load the data and train the model using glmnet package.
+ *
+ * library("glmnet")
+ * data <- read.csv("path", header=FALSE, stringsAsFactors=FALSE)
+ * features <- as.matrix(data.frame(as.numeric(data$V2), as.numeric(data$V3)))
+ * label <- as.numeric(data$V1)
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0, lambda = 0))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) 6.300528
+ * as.numeric.data.V2. 4.701024
+ * as.numeric.data.V3. 7.198257
+ */
+ val interceptR = 6.298698
+ val weightsR = Array(4.700706, 7.199082)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
+ test("linear regression with intercept with L1 regularization") {
+ val trainer = (new LinearRegression).setElasticNetParam(1.0).setRegParam(0.57)
+ val model = trainer.fit(dataset)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 1.0, lambda = 0.57))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) 6.311546
+ * as.numeric.data.V2. 2.123522
+ * as.numeric.data.V3. 4.605651
+ */
+ val interceptR = 6.243000
+ val weightsR = Array(4.024821, 6.679841)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
}
- test("linear regression with setters") {
- // Set params, train, and check as many as we can.
- val lr = new LinearRegression()
- .setMaxIter(10)
- .setRegParam(1.0)
- val model = lr.fit(dataset)
- assert(model.fittingParamMap.get(lr.maxIter).get === 10)
- assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
-
- // Call fit() with new params, and check as many as we can.
- val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
- assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
- assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
- assert(model2.getPredictionCol == "thePred")
+ test("linear regression with intercept with L2 regularization") {
+ val trainer = (new LinearRegression).setElasticNetParam(0.0).setRegParam(2.3)
+ val model = trainer.fit(dataset)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.0, lambda = 2.3))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) 6.328062
+ * as.numeric.data.V2. 3.222034
+ * as.numeric.data.V3. 4.926260
+ */
+ val interceptR = 5.269376
+ val weightsR = Array(3.736216, 5.712356)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
+ }
+
+ test("linear regression with intercept with ElasticNet regularization") {
+ val trainer = (new LinearRegression).setElasticNetParam(0.3).setRegParam(1.6)
+ val model = trainer.fit(dataset)
+
+ /**
+ * weights <- coef(glmnet(features, label, family="gaussian", alpha = 0.3, lambda = 1.6))
+ * > weights
+ * 3 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * (Intercept) 6.324108
+ * as.numeric.data.V2. 3.168435
+ * as.numeric.data.V3. 5.200403
+ */
+ val interceptR = 5.696056
+ val weightsR = Array(3.670489, 6.001122)
+
+ assert(model.intercept ~== interceptR relTol 1E-3)
+ assert(model.weights(0) ~== weightsR(0) relTol 1E-3)
+ assert(model.weights(1) ~== weightsR(1) relTol 1E-3)
+
+ model.transform(dataset).select("features", "prediction").collect().foreach {
+ case Row(features: DenseVector, prediction1: Double) =>
+ val prediction2 =
+ features(0) * model.weights(0) + features(1) * model.weights(1) + model.intercept
+ assert(prediction1 ~== prediction2 relTol 1E-5)
+ }
}
}