aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala36
1 files changed, 20 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 88d1b4575f..fce3935d39 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.SparkException
import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.linalg._
@@ -346,8 +347,9 @@ class LogisticRegression @Since("1.2.0") (
val regParamL1 = $(elasticNetParam) * $(regParam)
val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam)
+ val bcFeaturesStd = instances.context.broadcast(featuresStd)
val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept),
- $(standardization), featuresStd, featuresMean, regParamL2)
+ $(standardization), bcFeaturesStd, regParamL2)
val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) {
new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol))
@@ -442,6 +444,7 @@ class LogisticRegression @Since("1.2.0") (
rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 }
i += 1
}
+ bcFeaturesStd.destroy(blocking = false)
if ($(fitIntercept)) {
(Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last,
@@ -938,11 +941,15 @@ class BinaryLogisticRegressionSummary private[classification] (
* Two LogisticAggregator can be merged together to have a summary of loss and gradient of
* the corresponding joint dataset.
*
+ * @param bcCoefficients The broadcast coefficients corresponding to the features.
+ * @param bcFeaturesStd The broadcast standard deviation values of the features.
* @param numClasses the number of possible outcomes for k classes classification problem in
* Multinomial Logistic Regression.
* @param fitIntercept Whether to fit an intercept term.
*/
private class LogisticAggregator(
+ val bcCoefficients: Broadcast[Vector],
+ val bcFeaturesStd: Broadcast[Array[Double]],
private val numFeatures: Int,
numClasses: Int,
fitIntercept: Boolean) extends Serializable {
@@ -958,14 +965,9 @@ private class LogisticAggregator(
* of the objective function.
*
* @param instance The instance of data point to be added.
- * @param coefficients The coefficients corresponding to the features.
- * @param featuresStd The standard deviation values of the features.
* @return This LogisticAggregator object.
*/
- def add(
- instance: Instance,
- coefficients: Vector,
- featuresStd: Array[Double]): this.type = {
+ def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." +
s" Expecting $numFeatures but got ${features.size}.")
@@ -973,14 +975,16 @@ private class LogisticAggregator(
if (weight == 0.0) return this
- val coefficientsArray = coefficients match {
+ val coefficientsArray = bcCoefficients.value match {
case dv: DenseVector => dv.values
case _ =>
throw new IllegalArgumentException(
- s"coefficients only supports dense vector but got type ${coefficients.getClass}.")
+ "coefficients only supports dense vector" +
+ s"but got type ${bcCoefficients.value.getClass}.")
}
val localGradientSumArray = gradientSumArray
+ val featuresStd = bcFeaturesStd.value
numClasses match {
case 2 =>
// For Binary Logistic Regression.
@@ -1077,24 +1081,23 @@ private class LogisticCostFun(
numClasses: Int,
fitIntercept: Boolean,
standardization: Boolean,
- featuresStd: Array[Double],
- featuresMean: Array[Double],
+ bcFeaturesStd: Broadcast[Array[Double]],
regParamL2: Double) extends DiffFunction[BDV[Double]] {
+ val featuresStd = bcFeaturesStd.value
+
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val numFeatures = featuresStd.length
val coeffs = Vectors.fromBreeze(coefficients)
+ val bcCoeffs = instances.context.broadcast(coeffs)
val n = coeffs.size
- val localFeaturesStd = featuresStd
-
val logisticAggregator = {
- val seqOp = (c: LogisticAggregator, instance: Instance) =>
- c.add(instance, coeffs, localFeaturesStd)
+ val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
instances.treeAggregate(
- new LogisticAggregator(numFeatures, numClasses, fitIntercept)
+ new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept)
)(seqOp, combOp)
}
@@ -1134,6 +1137,7 @@ private class LogisticCostFun(
}
0.5 * regParamL2 * sum
}
+ bcCoeffs.destroy(blocking = false)
(logisticAggregator.loss + regVal, new BDV(totalGradientArray))
}