aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorReynold Xin <reynoldx@gmail.com>2013-07-23 12:13:27 -0700
committerReynold Xin <reynoldx@gmail.com>2013-07-23 12:13:27 -0700
commit87a9dd898ff51fd110799edae087d59f6b714211 (patch)
tree72fd98fc0e78c7f584d2453da60ea5afbaa38fd8 /mllib
parent401aac8b189aa6b72ad020ba894ca57b948c53a1 (diff)
downloadspark-87a9dd898ff51fd110799edae087d59f6b714211.tar.gz
spark-87a9dd898ff51fd110799edae087d59f6b714211.tar.bz2
spark-87a9dd898ff51fd110799edae087d59f6b714211.zip
Made RegressionModel serializable and added unit tests to make sure predict methods would work.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala6
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Updater.scala6
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala6
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Regression.scala2
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala5
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala33
6 files changed, 42 insertions, 16 deletions
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index 4c996c0903..185a2a24f6 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -39,9 +39,9 @@ object GradientDescent {
* @param miniBatchFraction - fraction of the input data set that should be used for
* one iteration of SGD. Default value 1.0.
*
- * @return weights - Column matrix containing weights for every feature.
- * @return stochasticLossHistory - Array containing the stochastic loss computed for
- * every iteration.
+ * @return A tuple containing two elements. The first element is a column matrix containing
+ * weights for every feature, and the second element is an array containing the stochastic
+ * loss computed for every iteration.
*/
def runMiniBatchSGD(
data: RDD[(Double, Array[Double])],
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
index b864fd4634..18cb5f3a95 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Updater.scala
@@ -23,13 +23,13 @@ abstract class Updater extends Serializable {
/**
* Compute an updated value for weights given the gradient, stepSize and iteration number.
*
- * @param weightsOld - Column matrix of size nx1 where n is the number of features.
+ * @param weightsOlds - Column matrix of size nx1 where n is the number of features.
* @param gradient - Column matrix of size nx1 where n is the number of features.
* @param stepSize - step size across iterations
* @param iter - Iteration number
*
- * @return weightsNew - Column matrix containing updated weights
- * @return reg_val - regularization value
+ * @return A tuple of 2 elements. The first element is a column matrix containing updated weights,
+ * and the second element is the regularization value.
*/
def compute(weightsOlds: DoubleMatrix, gradient: DoubleMatrix, stepSize: Double, iter: Int):
(DoubleMatrix, Double)
diff --git a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
index 711e205c39..4b22546017 100644
--- a/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/LogisticRegression.scala
@@ -36,8 +36,12 @@ class LogisticRegressionModel(
private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
override def predict(testData: spark.RDD[Array[Double]]) = {
+ // A small optimization to avoid serializing the entire model. Only the weightsMatrix
+ // and intercept is needed.
+ val localWeights = weightsMatrix
+ val localIntercept = intercept
testData.map { x =>
- val margin = new DoubleMatrix(1, x.length, x:_*).mmul(weightsMatrix).get(0) + this.intercept
+ val margin = new DoubleMatrix(1, x.length, x:_*).mmul(localWeights).get(0) + localIntercept
1.0/ (1.0 + math.exp(margin * -1))
}
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/Regression.scala b/mllib/src/main/scala/spark/mllib/regression/Regression.scala
index 645204ddf3..b845ba1a89 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Regression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Regression.scala
@@ -19,7 +19,7 @@ package spark.mllib.regression
import spark.RDD
-trait RegressionModel {
+trait RegressionModel extends Serializable {
/**
* Predict values for the given data set using the model trained.
*
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index f724edd732..6ba141e8fb 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -37,8 +37,11 @@ class RidgeRegressionModel(
extends RegressionModel {
override def predict(testData: RDD[Array[Double]]): RDD[Double] = {
+ // A small optimization to avoid serializing the entire model.
+ val localIntercept = this.intercept
+ val localWeights = this.weights
testData.map { x =>
- (new DoubleMatrix(1, x.length, x:_*).mmul(this.weights)).get(0) + this.intercept
+ (new DoubleMatrix(1, x.length, x:_*).mmul(localWeights)).get(0) + localIntercept
}
}
diff --git a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
index 47191d9a5a..6a8098b59d 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LogisticRegressionSuite.scala
@@ -23,7 +23,6 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
-import spark.SparkContext._
class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
@@ -51,15 +50,24 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
// y <- A + B*x + rLogis()
// y <- as.numeric(y > 0)
- val y = (0 until nPoints).map { i =>
+ val y: Seq[Double] = (0 until nPoints).map { i =>
val yVal = offset + scale * x1(i) + rLogis(i)
if (yVal > 0) 1.0 else 0.0
}
- val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i))))
+ val testData = (0 until nPoints).map(i => (y(i), Array(x1(i))))
testData
}
+ def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
+ val offPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ math.abs(prediction - expected) > 0.5
+ }.size
+ // At least 80% of the predictions should be on.
+ assert(offPredictions < input.length / 5)
+ }
+
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
test("logistic regression") {
val nPoints = 10000
@@ -70,14 +78,20 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val lr = new LogisticRegression().setStepSize(10.0)
- .setNumIterations(20)
+ val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(20)
val model = lr.train(testRDD)
+ // Test the weights
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
+
+ // Test prediction on Array.
+ validatePrediction(testData.map(row => model.predict(row._2)), testData)
}
test("logistic regression with initial weights") {
@@ -94,13 +108,18 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll {
testRDD.cache()
// Use half as many iterations as the previous test.
- val lr = new LogisticRegression().setStepSize(10.0)
- .setNumIterations(10)
+ val lr = new LogisticRegression().setStepSize(10.0).setNumIterations(10)
val model = lr.train(testRDD, initialWeights)
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(testRDD.map(_._2)).collect(), testData)
+
+ // Test prediction on Array.
+ validatePrediction(testData.map(row => model.predict(row._2)), testData)
}
}