aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorryanlecompte <lecompte@gmail.com>2013-07-07 15:35:06 -0700
committerryanlecompte <lecompte@gmail.com>2013-07-07 15:35:06 -0700
commitbe123aa6ef90480ad61663eed6e8ea479b047fad (patch)
treeb9af00d4f227e322784f07f7d050ca7bfb4bf645 /mllib
parentf78f8d0b416ef4d88883d8f32382661f4c2ac52d (diff)
downloadspark-be123aa6ef90480ad61663eed6e8ea479b047fad.tar.gz
spark-be123aa6ef90480ad61663eed6e8ea479b047fad.tar.bz2
spark-be123aa6ef90480ad61663eed6e8ea479b047fad.zip
update to use ListBuffer, faster than Vector for append operations
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala14
1 files changed, 9 insertions, 5 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index 36cda721dd..f66025bc0b 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -7,6 +7,7 @@ import org.jblas.DoubleMatrix
import org.jblas.Solve
import scala.annotation.tailrec
+import scala.collection.mutable
/**
* Ridge Regression from Joseph Gonzalez's implementation in MLBase
@@ -100,9 +101,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
// Binary search for the best assignment to lambda.
def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
+ val buffer = mutable.ListBuffer.empty[(Double, Double, DoubleMatrix)]
+
@tailrec
- def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)])
- : Seq[(Double, Double, DoubleMatrix)] = {
+ def loop(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
@@ -112,13 +114,15 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
- loop(newLow, newHigh, acc :+ lowValue :+ highValue)
+ buffer += lowValue += highValue
+ loop(newLow, newHigh)
} else {
- acc :+ lowValue :+ highValue
+ buffer += lowValue += highValue
+ buffer.result()
}
}
- loop(low, high, Vector.empty)
+ loop(low, high)
}
// Actually compute the best lambda