From 6a9a9a364ce3b158c4162e401f90eb4d305104e8 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 4 Jun 2013 16:27:02 -0700 Subject: Minor clean up of the RidgeRegression code. I am not even sure why I did this :s. --- ml/src/main/scala/spark/ml/RidgeRegression.scala | 38 ++++++++++------------ .../scala/spark/ml/RidgeRegressionGenerator.scala | 38 ++++++++++++---------- project/SparkBuild.scala | 2 +- 3 files changed, 39 insertions(+), 39 deletions(-) diff --git a/ml/src/main/scala/spark/ml/RidgeRegression.scala b/ml/src/main/scala/spark/ml/RidgeRegression.scala index 7896873d44..b8b632e111 100644 --- a/ml/src/main/scala/spark/ml/RidgeRegression.scala +++ b/ml/src/main/scala/spark/ml/RidgeRegression.scala @@ -1,9 +1,7 @@ package spark.ml -import spark._ -import spark.SparkContext._ +import spark.{Logging, RDD, SparkContext} -import org.apache.commons.math3.distribution.NormalDistribution import org.jblas.DoubleMatrix import org.jblas.Solve @@ -23,39 +21,36 @@ class RidgeRegressionModel( object RidgeRegression extends Logging { - def train(data: spark.RDD[(Double, Array[Double])], - lambdaLow: Double = 0.0, + def train(data: RDD[(Double, Array[Double])], + lambdaLow: Double = 0.0, lambdaHigh: Double = 10000.0) = { data.cache() - val nfeatures = data.take(1)(0)._2.length - val nexamples = data.count + val nfeatures: Int = data.take(1)(0)._2.length + val nexamples: Long = data.count() // Compute XtX - Size of XtX is nfeatures by nfeatures - val XtX = data.map { - case (y, features) => - val x = new DoubleMatrix(1, features.length, features:_*) - x.transpose().mmul(x) + val XtX: DoubleMatrix = data.map { case (y, features) => + val x = new DoubleMatrix(1, features.length, features:_*) + x.transpose().mmul(x) }.reduce(_.add(_)) // Compute Xt*y - Size of Xty is nfeatures by 1 - val Xty = data.map { - case (y, features) => - new DoubleMatrix(features.length, 1, features:_*).mul(y) + val Xty: DoubleMatrix = data.map { case (y, features) => + new DoubleMatrix(features.length, 1, features:_*).mul(y) }.reduce(_.add(_)) // Define a function to compute the leave one out cross validation error // for a single example - def crossValidate(lambda: Double) = { - // Compute the MLE ridge regression parameter value + def crossValidate(lambda: Double): (Double, Double, DoubleMatrix) = { + // Compute the MLE ridge regression parameter value // Ridge Regression parameter = inv(XtX + \lambda*I) * Xty val XtXlambda = DoubleMatrix.eye(nfeatures).muli(lambda).addi(XtX) val w = Solve.solveSymmetric(XtXlambda, Xty) - val invXtX = Solve.solveSymmetric(XtXlambda, - DoubleMatrix.eye(nfeatures)) - + val invXtX = Solve.solveSymmetric(XtXlambda, DoubleMatrix.eye(nfeatures)) + // compute the leave one out cross validation score val cvError = data.map { case (y, features) => @@ -74,11 +69,12 @@ object RidgeRegression extends Logging { val lowValue = crossValidate((mid - low) / 2 + low) val highValue = crossValidate((high - mid) / 2 + mid) val (newLow, newHigh) = if (lowValue._2 < highValue._2) { - (low, mid + (high-low)/4) + (low, mid + (high-low)/4) } else { (mid - (high-low)/4, high) } if (newHigh - newLow > 1.0E-7) { + // :: is list prepend in Scala. lowValue :: highValue :: binSearch(newLow, newHigh) } else { List(lowValue, highValue) @@ -88,7 +84,7 @@ object RidgeRegression extends Logging { // Actually compute the best lambda val lambdas = binSearch(lambdaLow, lambdaHigh).sortBy(_._1) - // Find the best parameter set + // Find the best parameter set by taking the lowest cverror. val (lambdaOpt, cverror, wOpt) = lambdas.reduce((a, b) => if (a._2 < b._2) a else b) logInfo("RidgeRegression: optimal lambda " + lambdaOpt) diff --git a/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala b/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala index 22a1e4613b..ff8640bb50 100644 --- a/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala +++ b/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala @@ -1,11 +1,11 @@ package spark.ml -import spark._ -import spark.SparkContext._ +import spark.{RDD, SparkContext} import org.apache.commons.math3.distribution.NormalDistribution import org.jblas.DoubleMatrix + object RidgeRegressionGenerator { // Helper methods to load and save data used for RidgeRegression @@ -23,30 +23,34 @@ object RidgeRegressionGenerator { data } - def saveData(data: RDD[(Double, Array[Double])], dir: String) { + private def saveData(data: RDD[(Double, Array[Double])], dir: String) { val dataStr = data.map(x => x._1 + "," + x._2.mkString(" ")) dataStr.saveAsTextFile(dir) } def main(args: Array[String]) { if (args.length != 2) { - println("Usage: RidgeRegressionGenerator ") + println("Usage: RidgeRegressionGenerator " + + " ") System.exit(1) } - org.jblas.util.Random.seed(42) - val sc = new SparkContext(args(0), "RidgeRegressionGenerator") - val nexamples = 1000 - val nfeatures = 100 + val sparkMaster: String = args(0) + val outputPath: String = args(1) + val nexamples: Int = if (args.length > 2) args(2).toInt else 1000 + val nfeatures: Int = if (args.length > 3) args(3).toInt else 100 + val parts: Int = if (args.length > 4) args(4).toInt else 2 val eps = 10 - val parts = 2 + + org.jblas.util.Random.seed(42) + val sc = new SparkContext(sparkMaster, "RidgeRegressionGenerator") // Random values distributed uniformly in [-0.5, 0.5] val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5) w.put(0, 0, 10) w.put(1, 0, 10) - val data = sc.parallelize(0 until parts, parts).flatMap { p => + val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until parts, parts).flatMap { p => org.jblas.util.Random.seed(42 + p) val examplesInPartition = nexamples / parts @@ -56,15 +60,15 @@ object RidgeRegressionGenerator { val rnd = new NormalDistribution(0, eps) rnd.reseedRandomGenerator(42 + p) - val normalValues = (0 until examplesInPartition).map(_ => rnd.sample()) - val yObs = new DoubleMatrix(examplesInPartition, 1, normalValues:_*).addi(y) - - (0 until examplesInPartition).map(i => + val normalValues = Array.fill[Double](examplesInPartition)(rnd.sample()) + val yObs = new DoubleMatrix(normalValues).addi(y) + + Iterator.tabulate(examplesInPartition) { i => (yObs.get(i, 0), X.getRow(i).toArray) - ) + } } - saveData(data, args(1)) - System.exit(0) + saveData(data, outputPath) + sc.stop() } } diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 731671c23b..aa877ad4a7 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -221,7 +221,7 @@ object SparkBuild extends Build { def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel") - def mlSettings = examplesSettings ++ Seq( + def mlSettings = sharedSettings ++ Seq( name := "spark-ml", libraryDependencies ++= Seq( "org.jblas" % "jblas" % "1.2.3", -- cgit v1.2.3