aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-07-17 16:03:29 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-07-17 16:03:29 -0700
commit45f3c855181539306d5610c5aa265f24b431c142 (patch)
treef9cb04c429ff11b2dc23b292dffb2d1335343335 /mllib/src
parent3bf989713654129ad35a80309d1b354ca5ddd06c (diff)
downloadspark-45f3c855181539306d5610c5aa265f24b431c142.tar.gz
spark-45f3c855181539306d5610c5aa265f24b431c142.tar.bz2
spark-45f3c855181539306d5610c5aa265f24b431c142.zip
Change weights to be Array[Double] in LR model.
Also ensure weights are initialized to a column vector.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala5
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala17
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala4
3 files changed, 15 insertions, 11 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index 77f5a7ae24..2c5038757b 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -57,9 +57,8 @@ object GradientDescent {
val nexamples: Long = data.count()
val miniBatchSize = nexamples * miniBatchFraction
- // Initialize weights as a column matrix
- var weights = new DoubleMatrix(1, initialWeights.length,
- initialWeights:_*)
+ // Initialize weights as a column vector
+ var weights = new DoubleMatrix(initialWeights.length, 1, initialWeights:_*)
var reg_val = 0.0
for (i <- 1 to numIters) {
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
index 664baf33a3..ab865af0c6 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
@@ -28,20 +28,23 @@ import org.jblas.DoubleMatrix
* Based on Matlab code written by John Duchi.
*/
class LogisticRegressionModel(
- val weights: DoubleMatrix,
+ val weights: Array[Double],
val intercept: Double,
val stochasticLosses: Array[Double]) extends RegressionModel {
+ // Create a column vector that can be used for predictions
+ private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
+
override def predict(testData: spark.RDD[Array[Double]]) = {
testData.map { x =>
- val margin = new DoubleMatrix(1, x.length, x:_*).mmul(this.weights).get(0) + this.intercept
+ val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept
1.0/ (1.0 + math.exp(margin * -1))
}
}
override def predict(testData: Array[Double]): Double = {
val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
- val margin = dataMat.mmul(this.weights).get(0) + this.intercept
+ val margin = dataMat.mmul(weightsMatrix).get(0) + this.intercept
1.0/ (1.0 + math.exp(margin * -1))
}
}
@@ -123,12 +126,14 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D
initalWeightsWithIntercept,
miniBatchFraction)
- val weightsScaled = weights.getRange(1, weights.length)
- val intercept = weights.get(0)
+ val weightsArray = weights.toArray()
+
+ val intercept = weightsArray(0)
+ val weightsScaled = weightsArray.tail
val model = new LogisticRegressionModel(weightsScaled, intercept, stochasticLosses)
- logInfo("Final model weights " + model.weights)
+ logInfo("Final model weights " + model.weights.mkString(","))
logInfo("Final model intercept " + model.intercept)
logInfo("Last 10 stochastic losses " + model.stochasticLosses.takeRight(10).mkString(", "))
model
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
index 2ff248d256..47191d9a5a 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
@@ -75,7 +75,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val model = lr.train(testRDD)
- val weight0 = model.weights.get(0)
+ val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
}
@@ -99,7 +99,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val model = lr.train(testRDD, initialWeights)
- val weight0 = model.weights.get(0)
+ val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
}