diff options
author | WeichenXu <WeichenXu123@outlook.com> | 2016-08-15 06:38:30 -0700 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-08-15 06:38:30 -0700 |
commit | 3d8bfe7a39015c84cf95561fe17eb2808ce44084 (patch) | |
tree | 057c4093d3bee7483c1968e0319d1ed2f9b7abf0 | |
parent | ddf0d1e3fe18bcd01e1447feea1b76ce86087b3b (diff) | |
download | spark-3d8bfe7a39015c84cf95561fe17eb2808ce44084.tar.gz spark-3d8bfe7a39015c84cf95561fe17eb2808ce44084.tar.bz2 spark-3d8bfe7a39015c84cf95561fe17eb2808ce44084.zip |
[SPARK-16934][ML][MLLIB] Update LogisticCostAggregator serialization code to make it consistent with LinearRegression
## What changes were proposed in this pull request?
Update LogisticCostAggregator serialization code to make it consistent with #14109
## How was this patch tested?
MLlib 2.0:
![image](https://cloud.githubusercontent.com/assets/19235986/17649601/5e2a79ac-61ee-11e6-833c-3bd8b5250470.png)
After this PR:
![image](https://cloud.githubusercontent.com/assets/19235986/17649599/52b002ae-61ee-11e6-9402-9feb3439880f.png)
Author: WeichenXu <WeichenXu123@outlook.com>
Closes #14520 from WeichenXu123/improve_logistic_regression_costfun.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 36 |
1 files changed, 20 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 88d1b4575f..fce3935d39 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -25,6 +25,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.SparkException import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.ml.feature.Instance import org.apache.spark.ml.linalg._ @@ -346,8 +347,9 @@ class LogisticRegression @Since("1.2.0") ( val regParamL1 = $(elasticNetParam) * $(regParam) val regParamL2 = (1.0 - $(elasticNetParam)) * $(regParam) + val bcFeaturesStd = instances.context.broadcast(featuresStd) val costFun = new LogisticCostFun(instances, numClasses, $(fitIntercept), - $(standardization), featuresStd, featuresMean, regParamL2) + $(standardization), bcFeaturesStd, regParamL2) val optimizer = if ($(elasticNetParam) == 0.0 || $(regParam) == 0.0) { new BreezeLBFGS[BDV[Double]]($(maxIter), 10, $(tol)) @@ -442,6 +444,7 @@ class LogisticRegression @Since("1.2.0") ( rawCoefficients(i) *= { if (featuresStd(i) != 0.0) 1.0 / featuresStd(i) else 0.0 } i += 1 } + bcFeaturesStd.destroy(blocking = false) if ($(fitIntercept)) { (Vectors.dense(rawCoefficients.dropRight(1)).compressed, rawCoefficients.last, @@ -938,11 +941,15 @@ class BinaryLogisticRegressionSummary private[classification] ( * Two LogisticAggregator can be merged together to have a summary of loss and gradient of * the corresponding joint dataset. * + * @param bcCoefficients The broadcast coefficients corresponding to the features. + * @param bcFeaturesStd The broadcast standard deviation values of the features. * @param numClasses the number of possible outcomes for k classes classification problem in * Multinomial Logistic Regression. * @param fitIntercept Whether to fit an intercept term. */ private class LogisticAggregator( + val bcCoefficients: Broadcast[Vector], + val bcFeaturesStd: Broadcast[Array[Double]], private val numFeatures: Int, numClasses: Int, fitIntercept: Boolean) extends Serializable { @@ -958,14 +965,9 @@ private class LogisticAggregator( * of the objective function. * * @param instance The instance of data point to be added. - * @param coefficients The coefficients corresponding to the features. - * @param featuresStd The standard deviation values of the features. * @return This LogisticAggregator object. */ - def add( - instance: Instance, - coefficients: Vector, - featuresStd: Array[Double]): this.type = { + def add(instance: Instance): this.type = { instance match { case Instance(label, weight, features) => require(numFeatures == features.size, s"Dimensions mismatch when adding new instance." + s" Expecting $numFeatures but got ${features.size}.") @@ -973,14 +975,16 @@ private class LogisticAggregator( if (weight == 0.0) return this - val coefficientsArray = coefficients match { + val coefficientsArray = bcCoefficients.value match { case dv: DenseVector => dv.values case _ => throw new IllegalArgumentException( - s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + "coefficients only supports dense vector" + + s"but got type ${bcCoefficients.value.getClass}.") } val localGradientSumArray = gradientSumArray + val featuresStd = bcFeaturesStd.value numClasses match { case 2 => // For Binary Logistic Regression. @@ -1077,24 +1081,23 @@ private class LogisticCostFun( numClasses: Int, fitIntercept: Boolean, standardization: Boolean, - featuresStd: Array[Double], - featuresMean: Array[Double], + bcFeaturesStd: Broadcast[Array[Double]], regParamL2: Double) extends DiffFunction[BDV[Double]] { + val featuresStd = bcFeaturesStd.value + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { val numFeatures = featuresStd.length val coeffs = Vectors.fromBreeze(coefficients) + val bcCoeffs = instances.context.broadcast(coeffs) val n = coeffs.size - val localFeaturesStd = featuresStd - val logisticAggregator = { - val seqOp = (c: LogisticAggregator, instance: Instance) => - c.add(instance, coeffs, localFeaturesStd) + val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance) val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2) instances.treeAggregate( - new LogisticAggregator(numFeatures, numClasses, fitIntercept) + new LogisticAggregator(bcCoeffs, bcFeaturesStd, numFeatures, numClasses, fitIntercept) )(seqOp, combOp) } @@ -1134,6 +1137,7 @@ private class LogisticCostFun( } 0.5 * regParamL2 * sum } + bcCoeffs.destroy(blocking = false) (logisticAggregator.loss + regVal, new BDV(totalGradientArray)) } |