aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbtsai@alpinenow.com>2015-02-02 15:59:15 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 15:59:15 -0800
commitb1aa8fe988301b924048039529234278aeb0298a (patch)
treec30fb67f38a2c288213e80b4b5876edf11d00ed3 /mllib
parent46d50f151c02c6892fc84a37fdf2a521dc774d1c (diff)
downloadspark-b1aa8fe988301b924048039529234278aeb0298a.tar.gz
spark-b1aa8fe988301b924048039529234278aeb0298a.tar.bz2
spark-b1aa8fe988301b924048039529234278aeb0298a.zip
[SPARK-2309][MLlib] Multinomial Logistic Regression
#1379 is automatically closed by asfgit, and github can not reopen it once it's closed, so this will be the new PR. Binary Logistic Regression can be extended to Multinomial Logistic Regression by running K-1 independent Binary Logistic Regression models. The following formula is implemented. http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297/25 Author: DB Tsai <dbtsai@alpinenow.com> Closes #3833 from dbtsai/mlor and squashes the following commits: 4e2f354 [DB Tsai] triger jenkins 697b7c9 [DB Tsai] address some feedback 4ce4d33 [DB Tsai] refactoring ff843b3 [DB Tsai] rebase f114135 [DB Tsai] refactoring 4348426 [DB Tsai] Addressed feedback from Sean Owen a252197 [DB Tsai] first commit
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala128
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala200
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala101
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala179
5 files changed, 565 insertions, 61 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 94d757bc31..282fb3ff28 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -18,30 +18,41 @@
package org.apache.spark.mllib.classification
import org.apache.spark.annotation.Experimental
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.BLAS.dot
+import org.apache.spark.mllib.linalg.{DenseVector, Vector}
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
-import org.apache.spark.mllib.util.DataValidators
+import org.apache.spark.mllib.util.{DataValidators, MLUtils}
import org.apache.spark.rdd.RDD
/**
- * Classification model trained using Logistic Regression.
+ * Classification model trained using Multinomial/Binary Logistic Regression.
*
* @param weights Weights computed for every feature.
- * @param intercept Intercept computed for this model.
+ * @param intercept Intercept computed for this model. (Only used in Binary Logistic Regression.
+ * In Multinomial Logistic Regression, the intercepts will not be a single values,
+ * so the intercepts will be part of the weights.)
+ * @param numFeatures the dimension of the features.
+ * @param numClasses the number of possible outcomes for k classes classification problem in
+ * Multinomial Logistic Regression. By default, it is binary logistic regression
+ * so numClasses will be set to 2.
*/
class LogisticRegressionModel (
override val weights: Vector,
- override val intercept: Double)
+ override val intercept: Double,
+ val numFeatures: Int,
+ val numClasses: Int)
extends GeneralizedLinearModel(weights, intercept) with ClassificationModel with Serializable {
+ def this(weights: Vector, intercept: Double) = this(weights, intercept, weights.size, 2)
+
private var threshold: Option[Double] = Some(0.5)
/**
* :: Experimental ::
- * Sets the threshold that separates positive predictions from negative predictions. An example
- * with prediction score greater than or equal to this threshold is identified as an positive,
- * and negative otherwise. The default value is 0.5.
+ * Sets the threshold that separates positive predictions from negative predictions
+ * in Binary Logistic Regression. An example with prediction score greater than or equal to
+ * this threshold is identified as an positive, and negative otherwise. The default value is 0.5.
*/
@Experimental
def setThreshold(threshold: Double): this.type = {
@@ -61,20 +72,68 @@ class LogisticRegressionModel (
override protected def predictPoint(dataMatrix: Vector, weightMatrix: Vector,
intercept: Double) = {
- val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
- val score = 1.0 / (1.0 + math.exp(-margin))
- threshold match {
- case Some(t) => if (score > t) 1.0 else 0.0
- case None => score
+ require(dataMatrix.size == numFeatures)
+
+ // If dataMatrix and weightMatrix have the same dimension, it's binary logistic regression.
+ if (numClasses == 2) {
+ require(numFeatures == weightMatrix.size)
+ val margin = dot(weights, dataMatrix) + intercept
+ val score = 1.0 / (1.0 + math.exp(-margin))
+ threshold match {
+ case Some(t) => if (score > t) 1.0 else 0.0
+ case None => score
+ }
+ } else {
+ val dataWithBiasSize = weightMatrix.size / (numClasses - 1)
+
+ val weightsArray = weights match {
+ case dv: DenseVector => dv.values
+ case _ =>
+ throw new IllegalArgumentException(
+ s"weights only supports dense vector but got type ${weights.getClass}.")
+ }
+
+ val margins = (0 until numClasses - 1).map { i =>
+ var margin = 0.0
+ dataMatrix.foreachActive { (index, value) =>
+ if (value != 0.0) margin += value * weightsArray((i * dataWithBiasSize) + index)
+ }
+ // Intercept is required to be added into margin.
+ if (dataMatrix.size + 1 == dataWithBiasSize) {
+ margin += weightsArray((i * dataWithBiasSize) + dataMatrix.size)
+ }
+ margin
+ }
+
+ /**
+ * Find the one with maximum margins. If the maxMargin is negative, then the prediction
+ * result will be the first class.
+ *
+ * PS, if you want to compute the probabilities for each outcome instead of the outcome
+ * with maximum probability, remember to subtract the maxMargin from margins if maxMargin
+ * is positive to prevent overflow.
+ */
+ var bestClass = 0
+ var maxMargin = 0.0
+ var i = 0
+ while(i < margins.size) {
+ if (margins(i) > maxMargin) {
+ maxMargin = margins(i)
+ bestClass = i + 1
+ }
+ i += 1
+ }
+ bestClass.toDouble
}
}
}
/**
- * Train a classification model for Logistic Regression using Stochastic Gradient Descent. By
- * default L2 regularization is used, which can be changed via
- * [[LogisticRegressionWithSGD.optimizer]].
- * NOTE: Labels used in Logistic Regression should be {0, 1}.
+ * Train a classification model for Binary Logistic Regression
+ * using Stochastic Gradient Descent. By default L2 regularization is used,
+ * which can be changed via [[LogisticRegressionWithSGD.optimizer]].
+ * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
+ * for k classes multi-label classification problem.
* Using [[LogisticRegressionWithLBFGS]] is recommended over this.
*/
class LogisticRegressionWithSGD private (
@@ -194,9 +253,10 @@ object LogisticRegressionWithSGD {
}
/**
- * Train a classification model for Logistic Regression using Limited-memory BFGS.
- * Standard feature scaling and L2 regularization are used by default.
- * NOTE: Labels used in Logistic Regression should be {0, 1}
+ * Train a classification model for Multinomial/Binary Logistic Regression using
+ * Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default.
+ * NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
+ * for k classes multi-label classification problem.
*/
class LogisticRegressionWithLBFGS
extends GeneralizedLinearAlgorithm[LogisticRegressionModel] with Serializable {
@@ -205,9 +265,33 @@ class LogisticRegressionWithLBFGS
override val optimizer = new LBFGS(new LogisticGradient, new SquaredL2Updater)
- override protected val validators = List(DataValidators.binaryLabelValidator)
+ override protected val validators = List(multiLabelValidator)
+
+ private def multiLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
+ if (numOfLinearPredictor > 1) {
+ DataValidators.multiLabelValidator(numOfLinearPredictor + 1)(data)
+ } else {
+ DataValidators.binaryLabelValidator(data)
+ }
+ }
+
+ /**
+ * :: Experimental ::
+ * Set the number of possible outcomes for k classes classification problem in
+ * Multinomial Logistic Regression.
+ * By default, it is binary logistic regression so k will be set to 2.
+ */
+ @Experimental
+ def setNumClasses(numClasses: Int): this.type = {
+ require(numClasses > 1)
+ numOfLinearPredictor = numClasses - 1
+ if (numClasses > 2) {
+ optimizer.setGradient(new LogisticGradient(numClasses))
+ }
+ this
+ }
override protected def createModel(weights: Vector, intercept: Double) = {
- new LogisticRegressionModel(weights, intercept)
+ new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
index 1ca0f36c6a..0acdab797e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/Gradient.scala
@@ -18,7 +18,7 @@
package org.apache.spark.mllib.optimization
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.{axpy, dot, scal}
import org.apache.spark.mllib.util.MLUtils
@@ -55,24 +55,86 @@ abstract class Gradient extends Serializable {
/**
* :: DeveloperApi ::
- * Compute gradient and loss for a logistic loss function, as used in binary classification.
- * See also the documentation for the precise formulation.
+ * Compute gradient and loss for a multinomial logistic loss function, as used
+ * in multi-class classification (it is also used in binary logistic regression).
+ *
+ * In `The Elements of Statistical Learning: Data Mining, Inference, and Prediction, 2nd Edition`
+ * by Trevor Hastie, Robert Tibshirani, and Jerome Friedman, which can be downloaded from
+ * http://statweb.stanford.edu/~tibs/ElemStatLearn/ , Eq. (4.17) on page 119 gives the formula of
+ * multinomial logistic regression model. A simple calculation shows that
+ *
+ * P(y=0|x, w) = 1 / (1 + \sum_i^{K-1} \exp(x w_i))
+ * P(y=1|x, w) = exp(x w_1) / (1 + \sum_i^{K-1} \exp(x w_i))
+ * ...
+ * P(y=K-1|x, w) = exp(x w_{K-1}) / (1 + \sum_i^{K-1} \exp(x w_i))
+ *
+ * for K classes multiclass classification problem.
+ *
+ * The model weights w = (w_1, w_2, ..., w_{K-1})^T becomes a matrix which has dimension of
+ * (K-1) * (N+1) if the intercepts are added. If the intercepts are not added, the dimension
+ * will be (K-1) * N.
+ *
+ * As a result, the loss of objective function for a single instance of data can be written as
+ * l(w, x) = -log P(y|x, w) = -\alpha(y) log P(y=0|x, w) - (1-\alpha(y)) log P(y|x, w)
+ * = log(1 + \sum_i^{K-1}\exp(x w_i)) - (1-\alpha(y)) x w_{y-1}
+ * = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1}
+ *
+ * where \alpha(i) = 1 if i != 0, and
+ * \alpha(i) = 0 if i == 0,
+ * margins_i = x w_i.
+ *
+ * For optimization, we have to calculate the first derivative of the loss function, and
+ * a simple calculation shows that
+ *
+ * \frac{\partial l(w, x)}{\partial w_{ij}}
+ * = (\exp(x w_i) / (1 + \sum_k^{K-1} \exp(x w_k)) - (1-\alpha(y)\delta_{y, i+1})) * x_j
+ * = multiplier_i * x_j
+ *
+ * where \delta_{i, j} = 1 if i == j,
+ * \delta_{i, j} = 0 if i != j, and
+ * multiplier
+ * = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1})
+ *
+ * If any of margins is larger than 709.78, the numerical computation of multiplier and loss
+ * function will be suffered from arithmetic overflow. This issue occurs when there are outliers
+ * in data which are far away from hyperplane, and this will cause the failing of training once
+ * infinity / infinity is introduced. Note that this is only a concern when max(margins) > 0.
+ *
+ * Fortunately, when max(margins) = maxMargin > 0, the loss function and the multiplier can be
+ * easily rewritten into the following equivalent numerically stable formula.
+ *
+ * l(w, x) = log(1 + \sum_i^{K-1}\exp(margins_i)) - (1-\alpha(y)) margins_{y-1}
+ * = log(\exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin)) + maxMargin
+ * - (1-\alpha(y)) margins_{y-1}
+ * = log(1 + sum) + maxMargin - (1-\alpha(y)) margins_{y-1}
+ *
+ * where sum = \exp(-maxMargin) + \sum_i^{K-1}\exp(margins_i - maxMargin) - 1.
+ *
+ * Note that each term, (margins_i - maxMargin) in \exp is smaller than zero; as a result,
+ * overflow will not happen with this formula.
+ *
+ * For multiplier, similar trick can be applied as the following,
+ *
+ * multiplier = \exp(margins_i) / (1 + \sum_k^{K-1} \exp(margins_i)) - (1-\alpha(y)\delta_{y, i+1})
+ * = \exp(margins_i - maxMargin) / (1 + sum) - (1-\alpha(y)\delta_{y, i+1})
+ *
+ * where each term in \exp is also smaller than zero, so overflow is not a concern.
+ *
+ * For the detailed mathematical derivation, see the reference at
+ * http://www.slideshare.net/dbtsai/2014-0620-mlor-36132297
+ *
+ * @param numClasses the number of possible outcomes for k classes classification problem in
+ * Multinomial Logistic Regression. By default, it is binary logistic regression
+ * so numClasses will be set to 2.
*/
@DeveloperApi
-class LogisticGradient extends Gradient {
- override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
- val margin = -1.0 * dot(data, weights)
- val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
- val gradient = data.copy
- scal(gradientMultiplier, gradient)
- val loss =
- if (label > 0) {
- // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
- MLUtils.log1pExp(margin)
- } else {
- MLUtils.log1pExp(margin) - margin
- }
+class LogisticGradient(numClasses: Int) extends Gradient {
+ def this() = this(2)
+
+ override def compute(data: Vector, label: Double, weights: Vector): (Vector, Double) = {
+ val gradient = Vectors.zeros(weights.size)
+ val loss = compute(data, label, weights, gradient)
(gradient, loss)
}
@@ -81,14 +143,104 @@ class LogisticGradient extends Gradient {
label: Double,
weights: Vector,
cumGradient: Vector): Double = {
- val margin = -1.0 * dot(data, weights)
- val gradientMultiplier = (1.0 / (1.0 + math.exp(margin))) - label
- axpy(gradientMultiplier, data, cumGradient)
- if (label > 0) {
- // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
- MLUtils.log1pExp(margin)
- } else {
- MLUtils.log1pExp(margin) - margin
+ val dataSize = data.size
+
+ // (weights.size / dataSize + 1) is number of classes
+ require(weights.size % dataSize == 0 && numClasses == weights.size / dataSize + 1)
+ numClasses match {
+ case 2 =>
+ /**
+ * For Binary Logistic Regression.
+ *
+ * Although the loss and gradient calculation for multinomial one is more generalized,
+ * and multinomial one can also be used in binary case, we still implement a specialized
+ * binary version for performance reason.
+ */
+ val margin = -1.0 * dot(data, weights)
+ val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
+ axpy(multiplier, data, cumGradient)
+ if (label > 0) {
+ // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+ MLUtils.log1pExp(margin)
+ } else {
+ MLUtils.log1pExp(margin) - margin
+ }
+ case _ =>
+ /**
+ * For Multinomial Logistic Regression.
+ */
+ val weightsArray = weights match {
+ case dv: DenseVector => dv.values
+ case _ =>
+ throw new IllegalArgumentException(
+ s"weights only supports dense vector but got type ${weights.getClass}.")
+ }
+ val cumGradientArray = cumGradient match {
+ case dv: DenseVector => dv.values
+ case _ =>
+ throw new IllegalArgumentException(
+ s"cumGradient only supports dense vector but got type ${cumGradient.getClass}.")
+ }
+
+ // marginY is margins(label - 1) in the formula.
+ var marginY = 0.0
+ var maxMargin = Double.NegativeInfinity
+ var maxMarginIndex = 0
+
+ val margins = Array.tabulate(numClasses - 1) { i =>
+ var margin = 0.0
+ data.foreachActive { (index, value) =>
+ if (value != 0.0) margin += value * weightsArray((i * dataSize) + index)
+ }
+ if (i == label.toInt - 1) marginY = margin
+ if (margin > maxMargin) {
+ maxMargin = margin
+ maxMarginIndex = i
+ }
+ margin
+ }
+
+ /**
+ * When maxMargin > 0, the original formula will cause overflow as we discuss
+ * in the previous comment.
+ * We address this by subtracting maxMargin from all the margins, so it's guaranteed
+ * that all of the new margins will be smaller than zero to prevent arithmetic overflow.
+ */
+ val sum = {
+ var temp = 0.0
+ if (maxMargin > 0) {
+ for (i <- 0 until numClasses - 1) {
+ margins(i) -= maxMargin
+ if (i == maxMarginIndex) {
+ temp += math.exp(-maxMargin)
+ } else {
+ temp += math.exp(margins(i))
+ }
+ }
+ } else {
+ for (i <- 0 until numClasses - 1) {
+ temp += math.exp(margins(i))
+ }
+ }
+ temp
+ }
+
+ for (i <- 0 until numClasses - 1) {
+ val multiplier = math.exp(margins(i)) / (sum + 1.0) - {
+ if (label != 0.0 && label == i + 1) 1.0 else 0.0
+ }
+ data.foreachActive { (index, value) =>
+ if (value != 0.0) cumGradientArray(i * dataSize + index) += multiplier * value
+ }
+ }
+
+ val loss = if (label > 0.0) math.log1p(sum) - marginY else math.log1p(sum)
+
+ if (maxMargin > 0) {
+ loss + maxMargin
+ } else {
+ loss
+ }
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 0287f04e2c..17de215b97 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -99,6 +99,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
protected var validateData: Boolean = true
/**
+ * In `GeneralizedLinearModel`, only single linear predictor is allowed for both weights
+ * and intercept. However, for multinomial logistic regression, with K possible outcomes,
+ * we are training K-1 independent binary logistic regression models which requires K-1 sets
+ * of linear predictor.
+ *
+ * As a result, the workaround here is if more than two sets of linear predictors are needed,
+ * we construct bigger `weights` vector which can hold both weights and intercepts.
+ * If the intercepts are added, the dimension of `weights` will be
+ * (numOfLinearPredictor) * (numFeatures + 1) . If the intercepts are not added,
+ * the dimension of `weights` will be (numOfLinearPredictor) * numFeatures.
+ *
+ * Thus, the intercepts will be encapsulated into weights, and we leave the value of intercept
+ * in GeneralizedLinearModel as zero.
+ */
+ protected var numOfLinearPredictor: Int = 1
+
+ /**
* Whether to perform feature scaling before model training to reduce the condition numbers
* which can significantly help the optimizer converging faster. The scaling correction will be
* translated back to resulting model weights, so it's transparent to users.
@@ -107,6 +124,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
private var useFeatureScaling = false
/**
+ * The dimension of training features.
+ */
+ protected var numFeatures: Int = 0
+
+ /**
* Set if the algorithm should use feature scaling to improve the convergence during optimization.
*/
private[mllib] def setFeatureScaling(useFeatureScaling: Boolean): this.type = {
@@ -141,8 +163,28 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* RDD of LabeledPoint entries.
*/
def run(input: RDD[LabeledPoint]): M = {
- val numFeatures: Int = input.first().features.size
- val initialWeights = Vectors.dense(new Array[Double](numFeatures))
+ numFeatures = input.first().features.size
+
+ /**
+ * When `numOfLinearPredictor > 1`, the intercepts are encapsulated into weights,
+ * so the `weights` will include the intercepts. When `numOfLinearPredictor == 1`,
+ * the intercept will be stored as separated value in `GeneralizedLinearModel`.
+ * This will result in different behaviors since when `numOfLinearPredictor == 1`,
+ * users have no way to set the initial intercept, while in the other case, users
+ * can set the intercepts as part of weights.
+ *
+ * TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always
+ * have the intercept as part of weights to have consistent design.
+ */
+ val initialWeights = {
+ if (numOfLinearPredictor == 1) {
+ Vectors.dense(new Array[Double](numFeatures))
+ } else if (addIntercept) {
+ Vectors.dense(new Array[Double]((numFeatures + 1) * numOfLinearPredictor))
+ } else {
+ Vectors.dense(new Array[Double](numFeatures * numOfLinearPredictor))
+ }
+ }
run(input, initialWeights)
}
@@ -151,6 +193,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* of LabeledPoint entries starting from the initial weights provided.
*/
def run(input: RDD[LabeledPoint], initialWeights: Vector): M = {
+ numFeatures = input.first().features.size
if (input.getStorageLevel == StorageLevel.NONE) {
logWarning("The input data is not directly cached, which may hurt performance if its"
@@ -182,14 +225,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* Currently, it's only enabled in LogisticRegressionWithLBFGS
*/
val scaler = if (useFeatureScaling) {
- (new StandardScaler).fit(input.map(x => x.features))
+ (new StandardScaler(withStd = true, withMean = false)).fit(input.map(x => x.features))
} else {
null
}
// Prepend an extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- if(useFeatureScaling) {
+ if (useFeatureScaling) {
input.map(labeledPoint =>
(labeledPoint.label, appendBias(scaler.transform(labeledPoint.features))))
} else {
@@ -203,21 +246,31 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
}
- val initialWeightsWithIntercept = if (addIntercept) {
+ /**
+ * TODO: For better convergence, in logistic regression, the intercepts should be computed
+ * from the prior probability distribution of the outcomes; for linear regression,
+ * the intercept should be set as the average of response.
+ */
+ val initialWeightsWithIntercept = if (addIntercept && numOfLinearPredictor == 1) {
appendBias(initialWeights)
} else {
+ /** If `numOfLinearPredictor > 1`, initialWeights already contains intercepts. */
initialWeights
}
val weightsWithIntercept = optimizer.optimize(data, initialWeightsWithIntercept)
- val intercept = if (addIntercept) weightsWithIntercept(weightsWithIntercept.size - 1) else 0.0
- var weights =
- if (addIntercept) {
- Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
- } else {
- weightsWithIntercept
- }
+ val intercept = if (addIntercept && numOfLinearPredictor == 1) {
+ weightsWithIntercept(weightsWithIntercept.size - 1)
+ } else {
+ 0.0
+ }
+
+ var weights = if (addIntercept && numOfLinearPredictor == 1) {
+ Vectors.dense(weightsWithIntercept.toArray.slice(0, weightsWithIntercept.size - 1))
+ } else {
+ weightsWithIntercept
+ }
/**
* The weights and intercept are trained in the scaled space; we're converting them back to
@@ -228,7 +281,29 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* is the coefficient in the original space, and v_i is the variance of the column i.
*/
if (useFeatureScaling) {
- weights = scaler.transform(weights)
+ if (numOfLinearPredictor == 1) {
+ weights = scaler.transform(weights)
+ } else {
+ /**
+ * For `numOfLinearPredictor > 1`, we have to transform the weights back to the original
+ * scale for each set of linear predictor. Note that the intercepts have to be explicitly
+ * excluded when `addIntercept == true` since the intercepts are part of weights now.
+ */
+ var i = 0
+ val n = weights.size / numOfLinearPredictor
+ val weightsArray = weights.toArray
+ while (i < numOfLinearPredictor) {
+ val start = i * n
+ val end = (i + 1) * n - { if (addIntercept) 1 else 0 }
+
+ val partialWeightsArray = scaler.transform(
+ Vectors.dense(weightsArray.slice(start, end))).toArray
+
+ System.arraycopy(partialWeightsArray, 0, weightsArray, start, partialWeightsArray.size)
+ i += 1
+ }
+ weights = Vectors.dense(weightsArray)
+ }
}
// Warn at the end of the run as well, for increased visibility.
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
index 45f95482a1..be335a1aca 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/DataValidators.scala
@@ -34,11 +34,27 @@ object DataValidators extends Logging {
*
* @return True if labels are all zero or one, false otherwise.
*/
- val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
+ val binaryLabelValidator: RDD[LabeledPoint] => Boolean = { data =>
val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
if (numInvalid != 0) {
logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels")
}
numInvalid == 0
}
+
+ /**
+ * Function to check if labels used for k class multi-label classification are
+ * in the range of {0, 1, ..., k - 1}.
+ *
+ * @return True if labels are all in the range of {0, 1, ..., k-1}, false otherwise.
+ */
+ def multiLabelValidator(k: Int): RDD[LabeledPoint] => Boolean = { data =>
+ val numInvalid = data.filter(x =>
+ x.label - x.label.toInt != 0.0 || x.label < 0 || x.label > k - 1).count()
+ if (numInvalid != 0) {
+ logError("Classification labels should be in {0 to " + (k - 1) + "}. " +
+ "Found " + numInvalid + " invalid labels")
+ }
+ numInvalid == 0
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 94b0e00f37..3fb45938f7 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -17,13 +17,14 @@
package org.apache.spark.mllib.classification
+import scala.util.control.Breaks._
import scala.util.Random
import scala.collection.JavaConversions._
import org.scalatest.FunSuite
import org.scalatest.Matchers
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -55,6 +56,97 @@ object LogisticRegressionSuite {
val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i)))))
testData
}
+
+ /**
+ * Generates `k` classes multinomial synthetic logistic input in `n` dimensional space given the
+ * model weights and mean/variance of the features. The synthetic data will be drawn from
+ * the probability distribution constructed by weights using the following formula.
+ *
+ * P(y = 0 | x) = 1 / norm
+ * P(y = 1 | x) = exp(x * w_1) / norm
+ * P(y = 2 | x) = exp(x * w_2) / norm
+ * ...
+ * P(y = k-1 | x) = exp(x * w_{k-1}) / norm
+ * where norm = 1 + exp(x * w_1) + exp(x * w_2) + ... + exp(x * w_{k-1})
+ *
+ * @param weights matrix is flatten into a vector; as a result, the dimension of weights vector
+ * will be (k - 1) * (n + 1) if `addIntercept == true`, and
+ * if `addIntercept != true`, the dimension will be (k - 1) * n.
+ * @param xMean the mean of the generated features. Lots of time, if the features are not properly
+ * standardized, the algorithm with poor implementation will have difficulty
+ * to converge.
+ * @param xVariance the variance of the generated features.
+ * @param addIntercept whether to add intercept.
+ * @param nPoints the number of instance of generated data.
+ * @param seed the seed for random generator. For consistent testing result, it will be fixed.
+ */
+ def generateMultinomialLogisticInput(
+ weights: Array[Double],
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ addIntercept: Boolean,
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val rnd = new Random(seed)
+
+ val xDim = xMean.size
+ val xWithInterceptsDim = if (addIntercept) xDim + 1 else xDim
+ val nClasses = weights.size / xWithInterceptsDim + 1
+
+ val x = Array.fill[Vector](nPoints)(Vectors.dense(Array.fill[Double](xDim)(rnd.nextGaussian())))
+
+ x.map(vector => {
+ // This doesn't work if `vector` is a sparse vector.
+ val vectorArray = vector.toArray
+ var i = 0
+ while (i < vectorArray.size) {
+ vectorArray(i) = vectorArray(i) * math.sqrt(xVariance(i)) + xMean(i)
+ i += 1
+ }
+ })
+
+ val y = (0 until nPoints).map { idx =>
+ val xArray = x(idx).toArray
+ val margins = Array.ofDim[Double](nClasses)
+ val probs = Array.ofDim[Double](nClasses)
+
+ for (i <- 0 until nClasses - 1) {
+ for (j <- 0 until xDim) margins(i + 1) += weights(i * xWithInterceptsDim + j) * xArray(j)
+ if (addIntercept) margins(i + 1) += weights((i + 1) * xWithInterceptsDim - 1)
+ }
+ // Preventing the overflow when we compute the probability
+ val maxMargin = margins.max
+ if (maxMargin > 0) for (i <-0 until nClasses) margins(i) -= maxMargin
+
+ // Computing the probabilities for each class from the margins.
+ val norm = {
+ var temp = 0.0
+ for (i <- 0 until nClasses) {
+ probs(i) = math.exp(margins(i))
+ temp += probs(i)
+ }
+ temp
+ }
+ for (i <-0 until nClasses) probs(i) /= norm
+
+ // Compute the cumulative probability so we can generate a random number and assign a label.
+ for (i <- 1 until nClasses) probs(i) += probs(i - 1)
+ val p = rnd.nextDouble()
+ var y = 0
+ breakable {
+ for (i <- 0 until nClasses) {
+ if (p < probs(i)) {
+ y = i
+ break
+ }
+ }
+ }
+ y
+ }
+
+ val testData = (0 until nPoints).map(i => LabeledPoint(y(i), x(i)))
+ testData
+ }
}
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with Matchers {
@@ -285,6 +377,91 @@ class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext with M
assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1)
}
+ test("multinomial logistic regression with LBFGS") {
+ val nPoints = 10000
+
+ /**
+ * The following weights and xMean/xVariance are computed from iris dataset with lambda = 0.2.
+ * As a result, we are actually drawing samples from probability distribution of built model.
+ */
+ val weights = Array(
+ -0.57997, 0.912083, -0.371077, -0.819866, 2.688191,
+ -0.16624, -0.84355, -0.048509, -0.301789, 4.170682)
+
+ val xMean = Array(5.843, 3.057, 3.758, 1.199)
+ val xVariance = Array(0.6856, 0.1899, 3.116, 0.581)
+
+ val testData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 42)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val lr = new LogisticRegressionWithLBFGS().setIntercept(true).setNumClasses(3)
+ lr.optimizer.setConvergenceTol(1E-15).setNumIterations(200)
+
+ val model = lr.run(testRDD)
+
+ /**
+ * The following is the instruction to reproduce the model using R's glmnet package.
+ *
+ * First of all, using the following scala code to save the data into `path`.
+ *
+ * testRDD.map(x => x.label+ ", " + x.features(0) + ", " + x.features(1) + ", " +
+ * x.features(2) + ", " + x.features(3)).saveAsTextFile("path")
+ *
+ * Using the following R code to load the data and train the model using glmnet package.
+ *
+ * library("glmnet")
+ * data <- read.csv("path", header=FALSE)
+ * label = factor(data$V1)
+ * features = as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5))
+ * weights = coef(glmnet(features,label, family="multinomial", alpha = 0, lambda = 0))
+ *
+ * The model weights of mutinomial logstic regression in R have `K` set of linear predictors
+ * for `K` classes classification problem; however, only `K-1` set is required if the first
+ * outcome is chosen as a "pivot", and the other `K-1` outcomes are separately regressed against
+ * the pivot outcome. This can be done by subtracting the first weights from those `K-1` set
+ * weights. The mathematical discussion and proof can be found here:
+ * http://en.wikipedia.org/wiki/Multinomial_logistic_regression
+ *
+ * weights1 = weights$`1` - weights$`0`
+ * weights2 = weights$`2` - weights$`0`
+ *
+ * > weights1
+ * 5 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * 2.6228269
+ * data.V2 -0.5837166
+ * data.V3 0.9285260
+ * data.V4 -0.3783612
+ * data.V5 -0.8123411
+ * > weights2
+ * 5 x 1 sparse Matrix of class "dgCMatrix"
+ * s0
+ * 4.11197445
+ * data.V2 -0.16918650
+ * data.V3 -0.81104784
+ * data.V4 -0.06463799
+ * data.V5 -0.29198337
+ */
+
+ val weightsR = Vectors.dense(Array(
+ -0.5837166, 0.9285260, -0.3783612, -0.8123411, 2.6228269,
+ -0.1691865, -0.811048, -0.0646380, -0.2919834, 4.1119745))
+
+ assert(model.weights ~== weightsR relTol 0.05)
+
+ val validationData = LogisticRegressionSuite.generateMultinomialLogisticInput(
+ weights, xMean, xVariance, true, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+ // The validation accuracy is not good since this model (even the original weights) doesn't have
+ // very steep curve in logistic function so that when we draw samples from distribution, it's
+ // very easy to assign to another labels. However, this prediction result is consistent to R.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData, 0.47)
+
+ }
+
}
class LogisticRegressionClusterSuite extends FunSuite with LocalClusterSparkContext {