diff options
author | ryanlecompte <lecompte@gmail.com> | 2013-07-07 15:35:06 -0700 |
---|---|---|
committer | ryanlecompte <lecompte@gmail.com> | 2013-07-07 15:35:06 -0700 |
commit | be123aa6ef90480ad61663eed6e8ea479b047fad (patch) | |
tree | b9af00d4f227e322784f07f7d050ca7bfb4bf645 /mllib | |
parent | f78f8d0b416ef4d88883d8f32382661f4c2ac52d (diff) | |
download | spark-be123aa6ef90480ad61663eed6e8ea479b047fad.tar.gz spark-be123aa6ef90480ad61663eed6e8ea479b047fad.tar.bz2 spark-be123aa6ef90480ad61663eed6e8ea479b047fad.zip |
update to use ListBuffer, faster than Vector for append operations
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala | 14 |
1 files changed, 9 insertions, 5 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index 36cda721dd..f66025bc0b 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -7,6 +7,7 @@ import org.jblas.DoubleMatrix import org.jblas.Solve import scala.annotation.tailrec +import scala.collection.mutable /** * Ridge Regression from Joseph Gonzalez's implementation in MLBase @@ -100,9 +101,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) // Binary search for the best assignment to lambda. def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = { + val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)] + @tailrec - def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)]) - : Seq[(Double, Double, DoubleMatrix)] = { + def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = { val mid = (high - low) / 2 + low val lowValue = crossValidate((mid - low) / 2 + low) val highValue = crossValidate((high - mid) / 2 + mid) @@ -112,13 +114,15 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) (mid - (high-low)/4, high) } if (newHigh - newLow > 1.0E-7) { - loop(newLow, newHigh, acc :+ lowValue :+ highValue) + buffer += lowValue += highValue + loop(newLow, newHigh) } else { - acc :+ lowValue :+ highValue + buffer += lowValue += highValue + buffer.result() } } - loop(low, high, Vector.empty) + loop(low, high) } // Actually compute the best lambda |