aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorshivaram <shivaram.venkataraman@gmail.com>2013-07-07 17:42:25 -0700
committershivaram <shivaram.venkataraman@gmail.com>2013-07-07 17:42:25 -0700
commit744da8eefda3ae66f3471a12cc02b29cf5441dbc (patch)
tree7f5e94861bf38f3488020ad1114fc3e69382da60
parent3cc6818f138371c279277e4bf733402a43cc40f6 (diff)
parentbe123aa6ef90480ad61663eed6e8ea479b047fad (diff)
downloadspark-744da8eefda3ae66f3471a12cc02b29cf5441dbc.tar.gz
spark-744da8eefda3ae66f3471a12cc02b29cf5441dbc.tar.bz2
spark-744da8eefda3ae66f3471a12cc02b29cf5441dbc.zip
Merge pull request #679 from ryanlecompte/master
Make binSearch method tail-recursive for RidgeRegression
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala43
1 files changed, 27 insertions, 16 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index a6ececbeb6..f66025bc0b 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -1,12 +1,14 @@
package spark.mllib.regression
import spark.{Logging, RDD, SparkContext}
-import spark.SparkContext._
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
import org.jblas.Solve
+import scala.annotation.tailrec
+import scala.collection.mutable
+
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
*/
@@ -14,7 +16,7 @@ class RidgeRegressionModel(
val weights: DoubleMatrix,
val intercept: Double,
val lambdaOpt: Double,
- val lambdas: List[(Double, Double, DoubleMatrix)])
+ val lambdas: Seq[(Double, Double, DoubleMatrix)])
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
@@ -98,21 +100,29 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
}
// Binary search for the best assignment to lambda.
- def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
- val mid = (high - low) / 2 + low
- val lowValue = crossValidate((mid - low) / 2 + low)
- val highValue = crossValidate((high - mid) / 2 + mid)
- val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
- (low, mid + (high-low)/4)
- } else {
- (mid - (high-low)/4, high)
- }
- if (newHigh - newLow > 1.0E-7) {
- // :: is list prepend in Scala.
- lowValue :: highValue :: binSearch(newLow, newHigh)
- } else {
- List(lowValue, highValue)
+ def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
+ val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)]
+
+ @tailrec
+ def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
+ val mid = (high - low) / 2 + low
+ val lowValue = crossValidate((mid - low) / 2 + low)
+ val highValue = crossValidate((high - mid) / 2 + mid)
+ val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
+ (low, mid + (high-low)/4)
+ } else {
+ (mid - (high-low)/4, high)
+ }
+ if (newHigh - newLow > 1.0E-7) {
+ buffer += lowValue += highValue
+ loop(newLow, newHigh)
+ } else {
+ buffer += lowValue += highValue
+ buffer.result()
+ }
}
+
+ loop(low, high)
}
// Actually compute the best lambda
@@ -134,6 +144,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
model
}
}
+
/**
* Top-level methods for calling Ridge Regression.
*/