aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorryanlecompte <lecompte@gmail.com>2013-07-06 16:46:53 -0700
committerryanlecompte <lecompte@gmail.com>2013-07-06 16:46:53 -0700
commitf78f8d0b416ef4d88883d8f32382661f4c2ac52d (patch)
tree3be8950efd1498afae35d453e5b592c8eb065fbb /mllib
parent757e56dfc7bd900d5b3f3f145eabe8198bfbe7cc (diff)
downloadspark-f78f8d0b416ef4d88883d8f32382661f4c2ac52d.tar.gz
spark-f78f8d0b416ef4d88883d8f32382661f4c2ac52d.tar.bz2
spark-f78f8d0b416ef4d88883d8f32382661f4c2ac52d.zip
fix formatting and use Vector instead of List to maintain order
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala18
1 files changed, 8 insertions, 10 deletions
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index 8343f28139..36cda721dd 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -15,7 +15,7 @@ class RidgeRegressionModel(
val weights: DoubleMatrix,
val intercept: Double,
val lambdaOpt: Double,
- val lambdas: List[(Double, Double, DoubleMatrix)])
+ val lambdas: Seq[(Double, Double, DoubleMatrix)])
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
@@ -99,12 +99,10 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
}
// Binary search for the best assignment to lambda.
- def binSearch(low: Double, high: Double): List[(Double, Double, DoubleMatrix)] = {
+ def binSearch(low: Double, high: Double): Seq[(Double, Double, DoubleMatrix)] = {
@tailrec
- def loop(
- low: Double,
- high: Double,
- acc: List[(Double, Double, DoubleMatrix)]): List[(Double, Double, DoubleMatrix)] = {
+ def loop(low: Double, high: Double, acc: Seq[(Double, Double, DoubleMatrix)])
+ : Seq[(Double, Double, DoubleMatrix)] = {
val mid = (high - low) / 2 + low
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
@@ -114,14 +112,13 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
- // :: is list prepend in Scala.
- loop(newLow, newHigh, lowValue :: highValue :: acc)
+ loop(newLow, newHigh, acc :+ lowValue :+ highValue)
} else {
- lowValue :: highValue :: acc
+ acc :+ lowValue :+ highValue
}
}
- loop(low, high, Nil)
+ loop(low, high, Vector.empty)
}
// Actually compute the best lambda
@@ -143,6 +140,7 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
model
}
}
+
/**
* Top-level methods for calling Ridge Regression.
*/