diff options
author | ryanlecompte <lecompte@gmail.com> | 2013-07-06 16:46:53 -0700 |
---|---|---|
committer | ryanlecompte <lecompte@gmail.com> | 2013-07-06 16:46:53 -0700 |
commit | f78f8d0b416ef4d88883d8f32382661f4c2ac52d (patch) | |
tree | 3be8950efd1498afae35d453e5b592c8eb065fbb /mllib | |
parent | 757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc (diff) | |
download | spark-f78f8d0b416ef4d88883d8f32382661f4c2ac52d.tar.gz spark-f78f8d0b416ef4d88883d8f32382661f4c2ac52d.tar.bz2 spark-f78f8d0b416ef4d88883d8f32382661f4c2ac52d.zip |
fix formatting and use Vector instead of List to maintain order
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala | 18 |
1 files changed, 8 insertions, 10 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index 8343f28139..36cda721dd 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -15,7 +15,7 @@ class RidgeRegressionModel( val weights: DoubleMatrix, val intercept: Double, val lambdaOpt: Double, - val lambdas: List[(Double, Double, DoubleMatrix)]) + val lambdas: Seq[(Double, Double, DoubleMatrix)]) extends RegressionModel { override def predict(testData: RDD[Array[Double]]): RDD[Double] = { @@ -99,12 +99,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) } // Binary search for the best assignment to lambda. - def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = { + def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = { @tailrec - def loop( - low: Double, - high: Double, - acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = { + def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)]) + : Seq[(Double, Double, DoubleMatrix)] = { val mid = (high - low) / 2 + low val lowValue = crossValidate((mid - low) / 2 + low) val highValue = crossValidate((high - mid) / 2 + mid) @@ -114,14 +112,13 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) (mid - (high-low)/4, high) } if (newHigh - newLow > 1.0E-7) { - // :: is list prepend in Scala. - loop(newLow, newHigh, lowValue :: highValue :: acc) + loop(newLow, newHigh, acc :+ lowValue :+ highValue) } else { - lowValue :: highValue :: acc + acc :+ lowValue :+ highValue } } - loop(low, high, Nil) + loop(low, high, Vector.empty) } // Actually compute the best lambda @@ -143,6 +140,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double) model } } + /** * Top-level methods for calling Ridge Regression. */ |