aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorryanlecompte <lecompte@gmail.com>2013-07-05 19:54:28 -0700
committerryanlecompte <lecompte@gmail.com>2013-07-05 19:54:28 -0700
commit757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc (patch)
treeceee3f504da982167830bd0176b18ef5ea8515f0 /mllib
parentbf1311e6d2af2daa61c6e6ec0eb9417f44e1c37c (diff)
downloadspark-757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc.tar.gz
spark-757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc.tar.bz2
spark-757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc.zip
make binSearch a tail-recursive method
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala37
1 files changed, 23 insertions, 14 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index a6ececbeb6..8343f28139 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -1,12 +1,13 @@
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
-import spark.SparkContext._
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
import org.jblas.Solve
+import scala.annotation.tailrec
+
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
@@ -99,20 +100,28 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
- val mid = (high - low) / 2 + low
- val lowValue = crossValidate((mid - low) / 2 + low)
- val highValue = crossValidate((high - mid) / 2 + mid)
- val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
- (low, mid + (high-low)/4)
- } else {
- (mid - (high-low)/4, high)
- }
- if (newHigh - newLow > 1.0E-7) {
- // :: is list prepend in Scala.
- lowValue :: highValue :: binSearch(newLow, newHigh)
- } else {
- List(lowValue, highValue)
+ @tailrec
+ def loop(
+ low: Double,
+ high: Double,
+ acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = {
+ val mid = (high - low) / 2 + low
+ val lowValue = crossValidate((mid - low) / 2 + low)
+ val highValue = crossValidate((high - mid) / 2 + mid)
+ val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
+ (low, mid + (high-low)/4)
+ } else {
+ (mid - (high-low)/4, high)
+ }
+ if (newHigh - newLow > 1.0E-7) {
+ // :: is list prepend in Scala.
+ loop(newLow, newHigh, lowValue :: highValue :: acc)
+ } else {
+ lowValue :: highValue :: acc
+ }
}
+
+ loop(low, high, Nil)
}
// Actually compute the best lambda