aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorReynold Xin <rxin@cs.berkeley.edu>2013-06-04 16:27:02 -0700
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-07-05 11:13:45 -0700
commit6a9a9a364ce3b158c4162e401f90eb4d305104e8 (patch)
treeb74641bc7d4f2641f45cae9a28f195c12a50af0b
parent729e463f649332c5480d2d175d42d4ba0dd3cb74 (diff)
downloadspark-6a9a9a364ce3b158c4162e401f90eb4d305104e8.tar.gz
spark-6a9a9a364ce3b158c4162e401f90eb4d305104e8.tar.bz2
spark-6a9a9a364ce3b158c4162e401f90eb4d305104e8.zip
Minor clean up of the RidgeRegression code. I am not even sure why I did
this :s.
-rw-r--r--ml/src/main/scala/spark/ml/RidgeRegression.scala38
-rw-r--r--ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala38
-rw-r--r--project/SparkBuild.scala2
3 files changed, 39 insertions, 39 deletions
diff --git a/ml/src/main/scala/spark/ml/RidgeRegression.scala b/ml/src/main/scala/spark/ml/RidgeRegression.scala
index 7896873d44..b8b632e111 100644
--- a/ml/src/main/scala/spark/ml/RidgeRegression.scala
+++ b/ml/src/main/scala/spark/ml/RidgeRegression.scala
@@ -1,9 +1,7 @@
package spark.ml
-import spark._
-import spark.SparkContext._
+import spark.{Logging, RDD, SparkContext}
-import org.apache.commons.math3.distribution.NormalDistribution
import org.jblas.DoubleMatrix
import org.jblas.Solve
@@ -23,39 +21,36 @@ class RidgeRegressionModel(
object RidgeRegression extends Logging {
- def train(data: spark.RDD[(Double, Array[Double])],
- lambdaLow: Double = 0.0,
+ def train(data: RDD[(Double, Array[Double])],
+ lambdaLow: Double = 0.0,
lambdaHigh: Double = 10000.0) = {
data.cache()
- val nfeatures = data.take(1)(0)._2.length
- val nexamples = data.count
+ val nfeatures: Int = data.take(1)(0)._2.length
+ val nexamples: Long = data.count()
// Compute XtX - Size of XtX is nfeatures by nfeatures
- val XtX = data.map {
- case (y, features) =>
- val x = new DoubleMatrix(1, features.length, features:_*)
- x.transpose().mmul(x)
+ val XtX: DoubleMatrix = data.map { case (y, features) =>
+ val x = new DoubleMatrix(1, features.length, features:_*)
+ x.transpose().mmul(x)
}.reduce(_.add(_))
// Compute Xt*y - Size of Xty is nfeatures by 1
- val Xty = data.map {
- case (y, features) =>
- new DoubleMatrix(features.length, 1, features:_*).mul(y)
+ val Xty: DoubleMatrix = data.map { case (y, features) =>
+ new DoubleMatrix(features.length, 1, features:_*).mul(y)
}.reduce(_.add(_))
// Define a function to compute the leave one out cross validation error
// for a single example
- def crossValidate(lambda: Double) = {
- // Compute the MLE ridge regression parameter value
+ def crossValidate(lambda: Double): (Double, Double, DoubleMatrix) = {
+ // Compute the MLE ridge regression parameter value
// Ridge Regression parameter = inv(XtX + \lambda*I) * Xty
val XtXlambda = DoubleMatrix.eye(nfeatures).muli(lambda).addi(XtX)
val w = Solve.solveSymmetric(XtXlambda, Xty)
- val invXtX = Solve.solveSymmetric(XtXlambda,
- DoubleMatrix.eye(nfeatures))
-
+ val invXtX = Solve.solveSymmetric(XtXlambda, DoubleMatrix.eye(nfeatures))
+
// compute the leave one out cross validation score
val cvError = data.map {
case (y, features) =>
@@ -74,11 +69,12 @@ object RidgeRegression extends Logging {
val lowValue = crossValidate((mid - low) / 2 + low)
val highValue = crossValidate((high - mid) / 2 + mid)
val (newLow, newHigh) = if (lowValue._2 < highValue._2) {
- (low, mid + (high-low)/4)
+ (low, mid + (high-low)/4)
} else {
(mid - (high-low)/4, high)
}
if (newHigh - newLow > 1.0E-7) {
+ // :: is list prepend in Scala.
lowValue :: highValue :: binSearch(newLow, newHigh)
} else {
List(lowValue, highValue)
@@ -88,7 +84,7 @@ object RidgeRegression extends Logging {
// Actually compute the best lambda
val lambdas = binSearch(lambdaLow, lambdaHigh).sortBy(_._1)
- // Find the best parameter set
+ // Find the best parameter set by taking the lowest cverror.
val (lambdaOpt, cverror, wOpt) = lambdas.reduce((a, b) => if (a._2 < b._2) a else b)
logInfo("RidgeRegression: optimal lambda " + lambdaOpt)
diff --git a/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala b/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala
index 22a1e4613b..ff8640bb50 100644
--- a/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala
+++ b/ml/src/main/scala/spark/ml/RidgeRegressionGenerator.scala
@@ -1,11 +1,11 @@
package spark.ml
-import spark._
-import spark.SparkContext._
+import spark.{RDD, SparkContext}
import org.apache.commons.math3.distribution.NormalDistribution
import org.jblas.DoubleMatrix
+
object RidgeRegressionGenerator {
// Helper methods to load and save data used for RidgeRegression
@@ -23,30 +23,34 @@ object RidgeRegressionGenerator {
data
}
- def saveData(data: RDD[(Double, Array[Double])], dir: String) {
+ private def saveData(data: RDD[(Double, Array[Double])], dir: String) {
val dataStr = data.map(x => x._1 + "," + x._2.mkString(" "))
dataStr.saveAsTextFile(dir)
}
def main(args: Array[String]) {
if (args.length != 2) {
- println("Usage: RidgeRegressionGenerator <master> <output_dir>")
+ println("Usage: RidgeRegressionGenerator " +
+ "<master> <output_dir> <num_examples> <num_features> <num_partitions>")
System.exit(1)
}
- org.jblas.util.Random.seed(42)
- val sc = new SparkContext(args(0), "RidgeRegressionGenerator")
- val nexamples = 1000
- val nfeatures = 100
+ val sparkMaster: String = args(0)
+ val outputPath: String = args(1)
+ val nexamples: Int = if (args.length > 2) args(2).toInt else 1000
+ val nfeatures: Int = if (args.length > 3) args(3).toInt else 100
+ val parts: Int = if (args.length > 4) args(4).toInt else 2
val eps = 10
- val parts = 2
+
+ org.jblas.util.Random.seed(42)
+ val sc = new SparkContext(sparkMaster, "RidgeRegressionGenerator")
// Random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
w.put(0, 0, 10)
w.put(1, 0, 10)
- val data = sc.parallelize(0 until parts, parts).flatMap { p =>
+ val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until parts, parts).flatMap { p =>
org.jblas.util.Random.seed(42 + p)
val examplesInPartition = nexamples / parts
@@ -56,15 +60,15 @@ object RidgeRegressionGenerator {
val rnd = new NormalDistribution(0, eps)
rnd.reseedRandomGenerator(42 + p)
- val normalValues = (0 until examplesInPartition).map(_ => rnd.sample())
- val yObs = new DoubleMatrix(examplesInPartition, 1, normalValues:_*).addi(y)
-
- (0 until examplesInPartition).map(i =>
+ val normalValues = Array.fill[Double](examplesInPartition)(rnd.sample())
+ val yObs = new DoubleMatrix(normalValues).addi(y)
+
+ Iterator.tabulate(examplesInPartition) { i =>
(yObs.get(i, 0), X.getRow(i).toArray)
- )
+ }
}
- saveData(data, args(1))
- System.exit(0)
+ saveData(data, outputPath)
+ sc.stop()
}
}
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 731671c23b..aa877ad4a7 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -221,7 +221,7 @@ object SparkBuild extends Build {
def bagelSettings = sharedSettings ++ Seq(name := "spark-bagel")
- def mlSettings = examplesSettings ++ Seq(
+ def mlSettings = sharedSettings ++ Seq(
name := "spark-ml",
libraryDependencies ++= Seq(
"org.jblas" % "jblas" % "1.2.3",