diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-07-17 16:03:29 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-07-17 16:03:29 -0700 |
commit | 45f3c855181539306d5610c5aa265f24b431c142 (patch) | |
tree | f9cb04c429ff11b2dc23b292dffb2d1335343335 /mllib/src/main | |
parent | 3bf989713654129ad35a80309d1b354ca5ddd06c (diff) | |
download | spark-45f3c855181539306d5610c5aa265f24b431c142.tar.gz spark-45f3c855181539306d5610c5aa265f24b431c142.tar.bz2 spark-45f3c855181539306d5610c5aa265f24b431c142.zip |
Change weights to be Array[Double] in LR model.
Also ensure weights are initialized to a column vector.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala | 5 | ||||
-rw-r--r-- | mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala | 17 |
2 files changed, 13 insertions, 9 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala index 77f5a7ae24..2c5038757b 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala @@ -57,9 +57,8 @@ object GradientDescent { val nexamples: Long = data.count() val miniBatchSize = nexamples * miniBatchFraction - // Initialize weights as a column matrix - var weights = new DoubleMatrix(1, initialWeights.length, - initialWeights:_*) + // Initialize weights as a column vector + var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) var reg_val = 0.0 for (i <- 1 to numIters) { diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala index 664baf33a3..ab865af0c6 100644 --- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala @@ -28,20 +28,23 @@ import org.jblas.DoubleMatrix * Based on Matlab code written by John Duchi. */ class LogisticRegressionModel( - val weights: DoubleMatrix, + val weights: Array[Double], val intercept: Double, val stochasticLosses: Array[Double]) extends RegressionModel { + // Create a column vector that can be used for predictions + private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) + override def predict(testData: spark.RDD[Array[Double]]) = { testData.map { x => - val margin = new DoubleMatrix(1, x.length, x:_*).mmul(this.weights).get(0) + this.intercept + val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept 1.0/ (1.0 + math.exp(margin * -1)) } } override def predict(testData: Array[Double]): Double = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - val margin = dataMat.mmul(this.weights).get(0) + this.intercept + val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept 1.0/ (1.0 + math.exp(margin * -1)) } } @@ -123,12 +126,14 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D initalWeightsWithIntercept, miniBatchFraction) - val weightsScaled = weights.getRange(1, weights.length) - val intercept = weights.get(0) + val weightsArray = weights.toArray() + + val intercept = weightsArray(0) + val weightsScaled = weightsArray.tail val model = new LogisticRegressionModel(weightsScaled, intercept, stochasticLosses) - logInfo("Final model weights " + model.weights) + logInfo("Final model weights " + model.weights.mkString(",")) logInfo("Final model intercept " + model.intercept) logInfo("Last 10 stochastic losses " + model.stochasticLosses.takeRight(10).mkString(", ")) model |