diff options
author | Reynold Xin <reynoldx@gmail.com> | 2013-07-23 12:13:27 -0700 |
---|---|---|
committer | Reynold Xin <reynoldx@gmail.com> | 2013-07-23 12:13:27 -0700 |
commit | 87a9dd898ff51fd110799edae087d59f6b714211 (patch) | |
tree | 72fd98fc0e78c7f584d2453da60ea5afbaa38fd8 | |
parent | 401aac8b189aa6b72ad020ba894ca57b948c53a1 (diff) | |
download | spark-87a9dd898ff51fd110799edae087d59f6b714211.tar.gz spark-87a9dd898ff51fd110799edae087d59f6b714211.tar.bz2 spark-87a9dd898ff51fd110799edae087d59f6b714211.zip |
Made RegressionModel serializable and added unit tests to make sure predict methods would work.
6 files changed, 42 insertions, 16 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala index 4c996c0903..185a2a24f6 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala @@ -39,9 +39,9 @@ object GradientDescent { * @param miniBatchFraction - fraction of the input data set that should be used for * one iteration of SGD. Default value 1.0. * - * @return weights - Column matrix containing weights for every feature. - * @return stochasticLossHistory - Array containing the stochastic loss computed for - * every iteration. + * @return A tuple containing two elements. The first element is a column matrix containing + * weights for every feature, and the second element is an array containing the stochastic + * loss computed for every iteration. */ def runMiniBatchSGD( data: RDD[(Double, Array[Double])], diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala index b864fd4634..18cb5f3a95 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala @@ -23,13 +23,13 @@ abstract class Updater extends Serializable { /** * Compute an updated value for weights given the gradient, stepSize and iteration number. * - * @param weightsOld - Column matrix of size nx1 where n is the number of features. + * @param weightsOlds - Column matrix of size nx1 where n is the number of features. * @param gradient - Column matrix of size nx1 where n is the number of features. * @param stepSize - step size across iterations * @param iter - Iteration number * - * @return weightsNew - Column matrix containing updated weights - * @return reg_val - regularization value + * @return A tuple of 2 elements. The first element is a column matrix containing updated weights, + * and the second element is the regularization value. */ def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int): (DoubleMatrix, Double) diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala index 711e205c39..4b22546017 100644 --- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala @@ -36,8 +36,12 @@ class LogisticRegressionModel( private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*) override def predict(testData: spark.RDD[Array[Double]]) = { + // A small optimization to avoid serializing the entire model. Only the weightsMatrix + // and intercept is needed. + val localWeights = weightsMatrix + val localIntercept = intercept testData.map { x => - val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept + val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept 1.0/ (1.0 + math.exp(margin * -1)) } } diff --git a/mllib/src/main/scala/spark/mllib/regression/Regression.scala b/mllib/src/main/scala/spark/mllib/regression/Regression.scala index 645204ddf3..b845ba1a89 100644 --- a/mllib/src/main/scala/spark/mllib/regression/Regression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/Regression.scala @@ -19,7 +19,7 @@ package spark.mllib.regression import spark.RDD -trait RegressionModel { +trait RegressionModel extends Serializable { /** * Predict values for the given data set using the model trained. * diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala index f724edd732..6ba141e8fb 100644 --- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala +++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala @@ -37,8 +37,11 @@ class RidgeRegressionModel( extends RegressionModel { override def predict(testData: RDD[Array[Double]]): RDD[Double] = { + // A small optimization to avoid serializing the entire model. + val localIntercept = this.intercept + val localWeights = this.weights testData.map { x => - (new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept + (new DoubleMatrix(1, x.length, x:_*).mmul(localWeights)).get(0) + localIntercept } } diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala index 47191d9a5a..6a8098b59d 100644 --- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala @@ -23,7 +23,6 @@ import org.scalatest.BeforeAndAfterAll import org.scalatest.FunSuite import spark.SparkContext -import spark.SparkContext._ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { @@ -51,15 +50,24 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { // y <- A + B*x + rLogis() // y <- as.numeric(y > 0) - val y = (0 until nPoints).map { i => + val y: Seq[Double] = (0 until nPoints).map { i => val yVal = offset + scale * x1(i) + rLogis(i) if (yVal > 0) 1.0 else 0.0 } - val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i)))) + val testData = (0 until nPoints).map(i => (y(i), Array(x1(i)))) testData } + def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) { + val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) => + // A prediction is off if the prediction is more than 0.5 away from expected value. + math.abs(prediction - expected) > 0.5 + }.size + // At least 80% of the predictions should be on. + assert(offPredictions < input.length / 5) + } + // Test if we can correctly learn A, B where Y = logistic(A + B*X) test("logistic regression") { val nPoints = 10000 @@ -70,14 +78,20 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val lr = new LogisticRegression().setStepSize(10.0) - .setNumIterations(20) + val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(20) val model = lr.train(testRDD) + // Test the weights val weight0 = model.weights(0) assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + + // Test prediction on RDD. + validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) + + // Test prediction on Array. + validatePrediction(testData.map(row => model.predict(row._2)), testData) } test("logistic regression with initial weights") { @@ -94,13 +108,18 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll { testRDD.cache() // Use half as many iterations as the previous test. - val lr = new LogisticRegression().setStepSize(10.0) - .setNumIterations(10) + val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(10) val model = lr.train(testRDD, initialWeights) val weight0 = model.weights(0) assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]") assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]") + + // Test prediction on RDD. + validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData) + + // Test prediction on Array. + validatePrediction(testData.map(row => model.predict(row._2)), testData) } } |