aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXinghao <pxinghao@gmail.com>2013-07-24 15:32:50 -0700
committerXinghao <pxinghao@gmail.com>2013-07-24 15:32:50 -0700
commiteef678703eed96544224209b1555618968b2eb3f (patch)
tree0c4c3ccf6eeafb8e31deb6c466a454699c09729a /mllib/src
parente3d3e6f0ab34f0fe083ef9feb31b9e9fd257519f (diff)
downloadspark-eef678703eed96544224209b1555618968b2eb3f.tar.gz
spark-eef678703eed96544224209b1555618968b2eb3f.tar.bz2
spark-eef678703eed96544224209b1555618968b2eb3f.zip
Adding SVM and Lasso, moving LogisticRegression to classification from regression
Also, add regularization parameter to SGD
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/Classification.scala21
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala (renamed from mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala)28
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegressionGenerator.scala (renamed from mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala)2
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/SVM.scala170
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/SVMGenerator.scala45
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Gradient.scala28
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala10
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Updater.scala31
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Lasso.scala167
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LassoGenerator.scala44
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala (renamed from mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala)2
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala61
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala51
13 files changed, 642 insertions, 18 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/Classification.scala b/mllib/src/main/scala/spark/mllib/classification/Classification.scala
new file mode 100644
index 0000000000..7f1eb21079
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/classification/Classification.scala
@@ -0,0 +1,21 @@
+package spark.mllib.classification
+
+import spark.RDD
+
+trait ClassificationModel {
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param testData RDD representing data points to be predicted
+ * @return RDD[Double] where each entry contains the corresponding prediction
+ */
+ def predict(testData: RDD[Array[Double]]): RDD[Double]
+
+ /**
+ * Predict values for a single data point using the model trained.
+ *
+ * @param testData array representing a single data point
+ * @return Double prediction from the trained model
+ */
+ def predict(testData: Array[Double]): Double
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index e4db7bb9b7..f39c1ec52e 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -1,4 +1,4 @@
-package spark.mllib.regression
+package spark.mllib.classification
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
@@ -13,7 +13,7 @@ import org.jblas.DoubleMatrix
class LogisticRegressionModel(
val weights: DoubleMatrix,
val intercept: Double,
- val losses: Array[Double]) extends RegressionModel {
+ val losses: Array[Double]) extends ClassificationModel {
override def predict(testData: spark.RDD[Array[Double]]) = {
testData.map { x =>
@@ -29,14 +29,14 @@ class LogisticRegressionModel(
}
}
-class LogisticRegression private (var stepSize: Double, var miniBatchFraction: Double,
+class LogisticRegression private (var stepSize: Double, var regParam: Double, var miniBatchFraction: Double,
var numIters: Int)
extends Logging {
/**
* Construct a LogisticRegression object with default parameters
*/
- def this() = this(1.0, 1.0, 100)
+ def this() = this(1.0, 1.0, 1.0, 100)
/**
* Set the step size per-iteration of SGD. Default 1.0.
@@ -69,7 +69,7 @@ class LogisticRegression private (var stepSize: Double, var miniBatchFraction: D
}
val (weights, losses) = GradientDescent.runMiniBatchSGD(
- data, new LogisticGradient(), new SimpleUpdater(), stepSize, numIters, miniBatchFraction)
+ data, new LogisticGradient(), new SimpleUpdater(), stepSize, numIters, regParam, miniBatchFraction)
val weightsScaled = weights.getRange(1, weights.length)
val intercept = weights.get(0)
@@ -96,16 +96,18 @@ object LogisticRegression {
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
* @param stepSize Step size to be used for each iteration of gradient descent.
+ * @param regParam Regularization parameter.
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
input: RDD[(Double, Array[Double])],
numIterations: Int,
stepSize: Double,
+ regParam: Double,
miniBatchFraction: Double)
: LogisticRegressionModel =
{
- new LogisticRegression(stepSize, miniBatchFraction, numIterations).train(input)
+ new LogisticRegression(stepSize, regParam, miniBatchFraction, numIterations).train(input)
}
/**
@@ -115,16 +117,18 @@ object LogisticRegression {
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
+ * @param regParam Regularization parameter.
* @param numIterations Number of iterations of gradient descent to run.
* @return a LogisticRegressionModel which has the weights and offset from training.
*/
def train(
input: RDD[(Double, Array[Double])],
numIterations: Int,
- stepSize: Double)
+ stepSize: Double,
+ regParam: Double)
: LogisticRegressionModel =
{
- train(input, numIterations, stepSize, 1.0)
+ train(input, numIterations, stepSize, regParam, 1.0)
}
/**
@@ -141,17 +145,17 @@ object LogisticRegression {
numIterations: Int)
: LogisticRegressionModel =
{
- train(input, numIterations, 1.0, 1.0)
+ train(input, numIterations, 1.0, 1.0, 1.0)
}
def main(args: Array[String]) {
- if (args.length != 4) {
- println("Usage: LogisticRegression <master> <input_dir> <step_size> <niters>")
+ if (args.length != 5) {
+ println("Usage: LogisticRegression <master> <input_dir> <step_size> <regularization_parameter> <niters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "LogisticRegression")
val data = MLUtils.loadLabeledData(sc, args(1))
- val model = LogisticRegression.train(data, args(3).toInt, args(2).toDouble)
+ val model = LogisticRegression.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
sc.stop()
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegressionGenerator.scala
index 6e7c023bac..cde1148adf 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegressionGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegressionGenerator.scala
@@ -1,4 +1,4 @@
-package spark.mllib.regression
+package spark.mllib.classification
import scala.util.Random
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
new file mode 100644
index 0000000000..aceb903f1d
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -0,0 +1,170 @@
+package spark.mllib.classification
+
+import scala.math.signum
+import spark.{Logging, RDD, SparkContext}
+import spark.mllib.optimization._
+import spark.mllib.util.MLUtils
+
+import org.jblas.DoubleMatrix
+
+/**
+ * SVM using Stochastic Gradient Descent.
+ */
+class SVMModel(
+ val weights: DoubleMatrix,
+ val intercept: Double,
+ val losses: Array[Double]) extends ClassificationModel {
+
+ override def predict(testData: spark.RDD[Array[Double]]) = {
+ testData.map { x => {
+ println("Predicting " + x)
+ signum(new DoubleMatrix(1, x.length, x:_*).dot(this.weights) + this.intercept)
+ }
+ }
+ }
+
+ override def predict(testData: Array[Double]): Double = {
+ val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
+ signum(dataMat.dot(this.weights) + this.intercept)
+ }
+}
+
+class SVM private (var stepSize: Double, var regParam: Double, var miniBatchFraction: Double,
+ var numIters: Int)
+ extends Logging {
+
+ /**
+ * Construct a SVM object with default parameters
+ */
+ def this() = this(1.0, 1.0, 1.0, 100)
+
+ /**
+ * Set the step size per-iteration of SGD. Default 1.0.
+ */
+ def setStepSize(step: Double) = {
+ this.stepSize = step
+ this
+ }
+
+ /**
+ * Set the regularization parameter. Default 1.0.
+ */
+ def setRegParam(param: Double) = {
+ this.regParam = param
+ this
+ }
+
+ /**
+ * Set fraction of data to be used for each SGD iteration. Default 1.0.
+ */
+ def setMiniBatchFraction(fraction: Double) = {
+ this.miniBatchFraction = fraction
+ this
+ }
+
+ /**
+ * Set the number of iterations for SGD. Default 100.
+ */
+ def setNumIterations(iters: Int) = {
+ this.numIters = iters
+ this
+ }
+
+ def train(input: RDD[(Double, Array[Double])]): SVMModel = {
+ // Add a extra variable consisting of all 1.0's for the intercept.
+ val data = input.map { case (y, features) =>
+ (y, Array(1.0, features:_*))
+ }
+
+ val (weights, losses) = GradientDescent.runMiniBatchSGD(
+ data, new HingeGradient(), new SquaredL2Updater(), stepSize, numIters, regParam, miniBatchFraction)
+
+ val weightsScaled = weights.getRange(1, weights.length)
+ val intercept = weights.get(0)
+
+ val model = new SVMModel(weightsScaled, intercept, losses)
+
+ logInfo("Final model weights " + model.weights)
+ logInfo("Final model intercept " + model.intercept)
+ logInfo("Last 10 losses " + model.losses.takeRight(10).mkString(", "))
+ model
+ }
+}
+
+/**
+ * Top-level methods for calling SVM.
+ */
+object SVM {
+
+ /**
+ * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using the specified step size. Each iteration uses
+ * `miniBatchFraction` fraction of the data to calculate the gradient.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @param stepSize Step size to be used for each iteration of gradient descent.
+ * @param regParam Regularization parameter.
+ * @param miniBatchFraction Fraction of data to be used per iteration.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double)
+ : SVMModel =
+ {
+ new SVM(stepSize, regParam, miniBatchFraction, numIterations).train(input)
+ }
+
+ /**
+ * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using the specified step size. We use the entire data set to update
+ * the gradient in each iteration.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param stepSize Step size to be used for each iteration of Gradient Descent.
+ * @param regParam Regularization parameter.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @return a SVMModel which has the weights and offset from training.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double)
+ : SVMModel =
+ {
+ train(input, numIterations, stepSize, regParam, 1.0)
+ }
+
+ /**
+ * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using a step size of 1.0. We use the entire data set to update
+ * the gradient in each iteration.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @return a SVMModel which has the weights and offset from training.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int)
+ : SVMModel =
+ {
+ train(input, numIterations, 1.0, 0.10, 1.0)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: SVM <master> <input_dir> <step_size> <regularization_parameter> <niters>")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "SVM")
+ val data = MLUtils.loadLabeledData(sc, args(1))
+ val model = SVM.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVMGenerator.scala b/mllib/src/main/scala/spark/mllib/classification/SVMGenerator.scala
new file mode 100644
index 0000000000..a5e2837343
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/classification/SVMGenerator.scala
@@ -0,0 +1,45 @@
+package spark.mllib.classification
+
+import scala.util.Random
+import scala.math.signum
+
+import org.jblas.DoubleMatrix
+
+import spark.{RDD, SparkContext}
+import spark.mllib.util.MLUtils
+
+object LassoGenerator {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: LassoGenerator " +
+ "<master> <output_dir> <num_examples> <num_features> <num_partitions>")
+ System.exit(1)
+ }
+
+ val sparkMaster: String = args(0)
+ val outputPath: String = args(1)
+ val nexamples: Int = if (args.length > 2) args(2).toInt else 1000
+ val nfeatures: Int = if (args.length > 3) args(3).toInt else 2
+ val parts: Int = if (args.length > 4) args(4).toInt else 2
+ val eps = 3
+
+ val sc = new SparkContext(sparkMaster, "LassoGenerator")
+
+ val globalRnd = new Random(94720)
+ val trueWeights = Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }
+
+ val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
+ val rnd = new Random(42 + idx)
+
+ val x = Array.fill[Double](nfeatures) {
+ rnd.nextDouble() * 2.0 - 1.0
+ }
+ val y = signum(((1.0 +: x) zip trueWeights).map{wx => wx._1 * wx._2}.reduceLeft(_+_) + rnd.nextGaussian() * 0.1)
+ (y, x)
+ }
+
+ MLUtils.saveLabeledData(data, outputPath)
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
index 90b0999a5e..6ffc3b128b 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
@@ -31,3 +31,31 @@ class LogisticGradient extends Gradient {
(gradient, loss)
}
}
+
+
+class SquaredGradient extends Gradient {
+ override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
+ (DoubleMatrix, Double) = {
+ val diff: Double = data.dot(weights) - label
+
+ val loss = 0.5 * diff * diff
+ val gradient = data.mul(diff)
+
+ (gradient, loss)
+ }
+}
+
+
+class HingeGradient extends Gradient {
+ override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
+ (DoubleMatrix, Double) = {
+
+ val dotProduct = data.dot(weights)
+
+ if (1.0 > label * dotProduct)
+ (data.mul(-label), 1.0 - label * dotProduct)
+ else
+ (DoubleMatrix.zeros(1,weights.length), 0.0)
+ }
+}
+
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index eff853f379..bd8489c386 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -19,6 +19,7 @@ object GradientDescent {
* @param updater - Updater object that will be used to update the model.
* @param stepSize - stepSize to be used during update.
* @param numIters - number of iterations that SGD should be run.
+ * @param regParam - regularization parameter
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
@@ -31,6 +32,7 @@ object GradientDescent {
updater: Updater,
stepSize: Double,
numIters: Int,
+ regParam: Double,
miniBatchFraction: Double=1.0) : (DoubleMatrix, Array[Double]) = {
val lossHistory = new ArrayBuffer[Double](numIters)
@@ -51,10 +53,14 @@ object GradientDescent {
(grad, loss)
}.reduce((a, b) => (a._1.addi(b._1), a._2 + b._2))
- lossHistory.append(lossSum / miniBatchSize + reg_val)
- val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i)
+ val update = updater.compute(weights, gradientSum.div(miniBatchSize), stepSize, i, regParam)
weights = update._1
reg_val = update._2
+ lossHistory.append(lossSum / miniBatchSize + reg_val)
+ /***
+ Xinghao: The loss here is sum of lossSum computed using the weights before applying updater,
+ and reg_val using weights after applying updater
+ ***/
}
(weights, lossHistory.toArray)
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
index ea80bfcbfd..64c54dfb0d 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
@@ -1,5 +1,6 @@
package spark.mllib.optimization
+import scala.math._
import org.jblas.DoubleMatrix
abstract class Updater extends Serializable {
@@ -10,18 +11,44 @@ abstract class Updater extends Serializable {
* @param gradient - Column matrix of size nx1 where n is the number of features.
* @param stepSize - step size across iterations
* @param iter - Iteration number
+ * @param regParam - Regularization parameter
*
* @return weightsNew - Column matrix containing updated weights
* @return reg_val - regularization value
*/
- def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
+ def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int, regParam: Double):
(DoubleMatrix, Double)
}
class SimpleUpdater extends Updater {
override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
- stepSize: Double, iter: Int): (DoubleMatrix, Double) = {
+ stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
val normGradient = gradient.mul(stepSize / math.sqrt(iter))
(weightsOld.sub(normGradient), 0)
}
}
+
+class L1Updater extends Updater {
+ override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
+ stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
+ val thisIterStepSize = stepSize / math.sqrt(iter)
+ val normGradient = gradient.mul(thisIterStepSize)
+ val newWeights = weightsOld.sub(normGradient)
+ (0 until newWeights.length).foreach(i => {
+ val wi = newWeights.get(i)
+ newWeights.put(i, signum(wi) * max(0.0, abs(wi) - regParam * thisIterStepSize))
+ })
+ (newWeights, newWeights.norm1 * regParam)
+ }
+}
+
+class SquaredL2Updater extends Updater {
+ override def compute(weightsOld: DoubleMatrix, gradient: DoubleMatrix,
+ stepSize: Double, iter: Int, regParam: Double): (DoubleMatrix, Double) = {
+ val thisIterStepSize = stepSize / math.sqrt(iter)
+ val normGradient = gradient.mul(thisIterStepSize)
+ val newWeights = weightsOld.sub(normGradient).div(2.0 * thisIterStepSize * regParam + 1.0)
+ (newWeights, pow(newWeights.norm2,2.0) * regParam)
+ }
+}
+
diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
new file mode 100644
index 0000000000..de410711a2
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
@@ -0,0 +1,167 @@
+package spark.mllib.regression
+
+import spark.{Logging, RDD, SparkContext}
+import spark.mllib.optimization._
+import spark.mllib.util.MLUtils
+
+import org.jblas.DoubleMatrix
+
+/**
+ * Lasso using Stochastic Gradient Descent.
+ */
+class LassoModel(
+ val weights: DoubleMatrix,
+ val intercept: Double,
+ val losses: Array[Double]) extends RegressionModel {
+
+ override def predict(testData: spark.RDD[Array[Double]]) = {
+ testData.map { x =>
+ new DoubleMatrix(1, x.length, x:_*).dot(this.weights) + this.intercept
+ }
+ }
+
+ override def predict(testData: Array[Double]): Double = {
+ val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
+ dataMat.dot(this.weights) + this.intercept
+ }
+}
+
+class Lasso private (var stepSize: Double, var regParam: Double, var miniBatchFraction: Double,
+ var numIters: Int)
+ extends Logging {
+
+ /**
+ * Construct a Lasso object with default parameters
+ */
+ def this() = this(1.0, 1.0, 1.0, 100)
+
+ /**
+ * Set the step size per-iteration of SGD. Default 1.0.
+ */
+ def setStepSize(step: Double) = {
+ this.stepSize = step
+ this
+ }
+
+ /**
+ * Set the regularization parameter. Default 1.0.
+ */
+ def setRegParam(param: Double) = {
+ this.regParam = param
+ this
+ }
+
+ /**
+ * Set fraction of data to be used for each SGD iteration. Default 1.0.
+ */
+ def setMiniBatchFraction(fraction: Double) = {
+ this.miniBatchFraction = fraction
+ this
+ }
+
+ /**
+ * Set the number of iterations for SGD. Default 100.
+ */
+ def setNumIterations(iters: Int) = {
+ this.numIters = iters
+ this
+ }
+
+ def train(input: RDD[(Double, Array[Double])]): LassoModel = {
+ // Add a extra variable consisting of all 1.0's for the intercept.
+ val data = input.map { case (y, features) =>
+ (y, Array(1.0, features:_*))
+ }
+
+ val (weights, losses) = GradientDescent.runMiniBatchSGD(
+ data, new SquaredGradient(), new L1Updater(), stepSize, numIters, regParam, miniBatchFraction)
+
+ val weightsScaled = weights.getRange(1, weights.length)
+ val intercept = weights.get(0)
+
+ val model = new LassoModel(weightsScaled, intercept, losses)
+
+ logInfo("Final model weights " + model.weights)
+ logInfo("Final model intercept " + model.intercept)
+ logInfo("Last 10 losses " + model.losses.takeRight(10).mkString(", "))
+ model
+ }
+}
+
+/**
+ * Top-level methods for calling Lasso.
+ */
+object Lasso {
+
+ /**
+ * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using the specified step size. Each iteration uses
+ * `miniBatchFraction` fraction of the data to calculate the gradient.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @param stepSize Step size to be used for each iteration of gradient descent.
+ * @param regParam Regularization parameter.
+ * @param miniBatchFraction Fraction of data to be used per iteration.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double,
+ miniBatchFraction: Double)
+ : LassoModel =
+ {
+ new Lasso(stepSize, regParam, miniBatchFraction, numIterations).train(input)
+ }
+
+ /**
+ * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using the specified step size. We use the entire data set to update
+ * the gradient in each iteration.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param stepSize Step size to be used for each iteration of Gradient Descent.
+ * @param regParam Regularization parameter.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @return a LassoModel which has the weights and offset from training.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int,
+ stepSize: Double,
+ regParam: Double)
+ : LassoModel =
+ {
+ train(input, numIterations, stepSize, regParam, 1.0)
+ }
+
+ /**
+ * Train a Lasso model given an RDD of (label, features) pairs. We run a fixed number
+ * of iterations of gradient descent using a step size of 1.0. We use the entire data set to update
+ * the gradient in each iteration.
+ *
+ * @param input RDD of (label, array of features) pairs.
+ * @param numIterations Number of iterations of gradient descent to run.
+ * @return a LassoModel which has the weights and offset from training.
+ */
+ def train(
+ input: RDD[(Double, Array[Double])],
+ numIterations: Int)
+ : LassoModel =
+ {
+ train(input, numIterations, 1.0, 0.10, 1.0)
+ }
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: Lasso <master> <input_dir> <step_size> <regularization_parameter> <niters>")
+ System.exit(1)
+ }
+ val sc = new SparkContext(args(0), "Lasso")
+ val data = MLUtils.loadLabeledData(sc, args(1))
+ val model = Lasso.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+
+ sc.stop()
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/LassoGenerator.scala b/mllib/src/main/scala/spark/mllib/regression/LassoGenerator.scala
new file mode 100644
index 0000000000..d2d3bb33c7
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/LassoGenerator.scala
@@ -0,0 +1,44 @@
+package spark.mllib.regression
+
+import scala.util.Random
+
+import org.jblas.DoubleMatrix
+
+import spark.{RDD, SparkContext}
+import spark.mllib.util.MLUtils
+
+object LassoGenerator {
+
+ def main(args: Array[String]) {
+ if (args.length != 5) {
+ println("Usage: LassoGenerator " +
+ "<master> <output_dir> <num_examples> <num_features> <num_partitions>")
+ System.exit(1)
+ }
+
+ val sparkMaster: String = args(0)
+ val outputPath: String = args(1)
+ val nexamples: Int = if (args.length > 2) args(2).toInt else 1000
+ val nfeatures: Int = if (args.length > 3) args(3).toInt else 2
+ val parts: Int = if (args.length > 4) args(4).toInt else 2
+ val eps = 3
+
+ val sc = new SparkContext(sparkMaster, "LassoGenerator")
+
+ val globalRnd = new Random(94720)
+ val trueWeights = Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }
+
+ val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
+ val rnd = new Random(42 + idx)
+
+ val x = Array.fill[Double](nfeatures) {
+ rnd.nextDouble() * 2.0 - 1.0
+ }
+ val y = ((1.0 +: x) zip trueWeights).map{wx => wx._1 * wx._2}.reduceLeft(_+_) + rnd.nextGaussian() * 0.1
+ (y, x)
+ }
+
+ MLUtils.saveLabeledData(data, outputPath)
+ sc.stop()
+ }
+}
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
index 04d3400cb4..13612e9a4a 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -1,4 +1,4 @@
-package spark.mllib.regression
+package spark.mllib.classification
import scala.util.Random
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
new file mode 100644
index 0000000000..e3a6681ab2
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -0,0 +1,61 @@
+package spark.mllib.classification
+
+import scala.util.Random
+import scala.math.signum
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import spark.SparkContext
+import spark.SparkContext._
+
+import java.io._
+
+class SVMSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("SVM") {
+ val nPoints = 10000
+ val rnd = new Random(42)
+
+ val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+ val x2 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+
+ val A = 2.0
+ val B = -1.5
+ val C = 1.0
+
+ val y = (0 until nPoints).map { i =>
+ signum(A + B * x1(i) + C * x2(i) + 0.0*rnd.nextGaussian())
+ }
+
+ val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i),x2(i)))).toArray
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val writer_data = new PrintWriter(new File("svmtest.dat"))
+ testData.foreach(yx => {
+ writer_data.write(yx._1 + "")
+ yx._2.foreach(xi => writer_data.write("\t" + xi))
+ writer_data.write("\n")})
+ writer_data.close()
+
+ val svm = new SVM().setStepSize(1.0)
+ .setRegParam(1.0)
+ .setNumIterations(100)
+
+ val model = svm.train(testRDD)
+
+ val yPredict = (0 until nPoints).map(i => model.predict(Array(x1(i),x2(i))))
+
+ val accuracy = ((y zip yPredict).map(yy => if (yy._1==yy._2) 1 else 0).reduceLeft(_+_).toDouble / nPoints.toDouble)
+
+ assert(accuracy >= 0.90, "Accuracy (" + accuracy + ") too low")
+ }
+}
diff --git a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
new file mode 100644
index 0000000000..90fedb3e84
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
@@ -0,0 +1,51 @@
+package spark.mllib.regression
+
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import spark.SparkContext
+import spark.SparkContext._
+
+
+class LassoSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ test("Lasso") {
+ val nPoints = 10000
+ val rnd = new Random(42)
+
+ val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+ val x2 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+
+ val A = 2.0
+ val B = -1.5
+ val C = 1.0e-2
+
+ val y = (0 until nPoints).map { i =>
+ A + B * x1(i) + C * x2(i) + 0.1*rnd.nextGaussian()
+ }
+
+ val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i),x2(i)))).toArray
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+ val ls = new Lasso().setStepSize(1.0)
+ .setRegParam(0.01)
+ .setNumIterations(20)
+
+ val model = ls.train(testRDD)
+
+ val weight0 = model.weights.get(0)
+ val weight1 = model.weights.get(1)
+ assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
+ assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ }
+}