aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala14
1 files changed, 9 insertions, 5 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index 36cda721dd..f66025bc0b 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -7,6 +7,7 @@ import org.jblas.DoubleMatrix
import org.jblas.Solve
import scala.annotation.tailrec
+import scala.collection.mutable
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
@@ -100,9 +101,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
+ val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)]
+
@tailrec
- def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)])
- : Seq[(Double, Double, DoubleMatrix)] = {
+ def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
@@ -112,13 +114,15 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
- loop(newLow, newHigh, acc :+ lowValue :+ highValue)
+ buffer += lowValue += highValue
+ loop(newLow, newHigh)
} else {
- acc :+ lowValue :+ highValue
+ buffer += lowValue += highValue
+ buffer.result()
}
}
- loop(low, high, Vector.empty)
+ loop(low, high)
}
// Actually compute the best lambda