diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-07-17 16:03:29 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-07-17 16:03:29 -0700 |
commit | 45f3c855181539306d5610c5aa265f24b431c142 (patch) | |
tree | f9cb04c429ff11b2dc23b292dffb2d1335343335 /mllib/src | |
parent | 3bf989713654129ad35a80309d1b354ca5ddd06c (diff) | |
download | spark-45f3c855181539306d5610c5aa265f24b431c142.tar.gz spark-45f3c855181539306d5610c5aa265f24b431c142.tar.bz2 spark-45f3c855181539306d5610c5aa265f24b431c142.zip |
Change weights to be Array[Double] in LR model.
Also ensure weights are initialized to a column vector.
Diffstat (limited to 'mllib/src')
3 files changed, 15 insertions, 11 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala index 77f5a7ae24..2c5038757b 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala @@ -57,9 +57,8 @@ object GradientDescent { val nexamples: Long = data.count() val miniBatchSize = nexamples * miniBatchFraction - // Initialize weights as a column matrix - var weights = new DoubleMatrix(1, initialWeights.length, - initialWeights:_*) + // Initialize weights as a column vector + var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*) var reg_val = 0.0 for (i <- 1 to numIters) { diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala index 664baf33a3..ab865af0c6 100644 --- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala @@ -28,20 +28,23 @@ import org.jblas.DoubleMatrix * Based on Matlab code written by John Duchi. */ class LogisticRegressionModel( - val weights: DoubleMatrix, + val weights: Array[Double], val intercept: Double, val stochasticLosses: Array[Double]) extends RegressionModel { + // Create a column vector that can be used for predictions + private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) + override def predict(testData: spark.RDD[Array[Double]]) = { testData.map { x => - val margin = new DoubleMatrix(1, x.length, x:_*).mmul(this.weights).get(0) + this.intercept + val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept 1.0/ (1.0 + math.exp(margin * -1)) } } override def predict(testData: Array[Double]): Double = { val dataMat = new DoubleMatrix(1, testData.length, testData:_*) - val margin = dataMat.mmul(this.weights).get(0) + this.intercept + val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept 1.0/ (1.0 + math.exp(margin * -1)) } } @@ -123,12 +126,14 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D initalWeightsWithIntercept, miniBatchFraction) - val weightsScaled = weights.getRange(1, weights.length) - val intercept = weights.get(0) + val weightsArray = weights.toArray() + + val intercept = weightsArray(0) + val weightsScaled = weightsArray.tail val model = new LogisticRegressionModel(weightsScaled, intercept, stochasticLosses) - logInfo("Final model weights " + model.weights) + logInfo("Final model weights " + model.weights.mkString(",")) logInfo("Final model intercept " + model.intercept) logInfo("Last 10 stochastic losses " + model.stochasticLosses.takeRight(10).mkString(", ")) model diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala index 2ff248d256..47191d9a5a 100644 --- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala @@ -75,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val model = lr.train(testRDD) - val weight0 = model.weights.get(0) + val weight0 = model.weights(0) assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") } @@ -99,7 +99,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val model = lr.train(testRDD, initialWeights) - val weight0 = model.weights.get(0) + val weight0 = model.weights(0) assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") } |