aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-06 17:23:22 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-06 17:23:22 -0700
commit7db69d56f2d050842ecf6e465d2d4f1abf3314d7 (patch)
tree9916df08d3ed49a78c557c46682b5511a801d16f /mllib
parent7388e27668800b2c958b75e13d24f0d2baebe23d (diff)
downloadspark-7db69d56f2d050842ecf6e465d2d4f1abf3314d7.tar.gz
spark-7db69d56f2d050842ecf6e465d2d4f1abf3314d7.tar.bz2
spark-7db69d56f2d050842ecf6e465d2d4f1abf3314d7.zip
Refactor GLM algorithms and add Java tests
This change adds Java examples and unit tests for all GLM algorithms to make sure the MLLib interface works from Java. Changes include - Introduce LabeledPoint and avoid using Doubles in train arguments - Rename train to run in class methods - Make the optimizer a member variable of GLM to make sure the builder pattern works
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala4
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala34
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/SVM.scala25
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala13
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala60
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala32
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Lasso.scala22
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala7
-rw-r--r--mllib/src/main/scala/spark/mllib/util/LassoDataGenerator.scala4
-rw-r--r--mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala5
-rw-r--r--mllib/src/main/scala/spark/mllib/util/MLUtils.scala9
-rw-r--r--mllib/src/main/scala/spark/mllib/util/RidgeRegressionDataGenerator.scala7
-rw-r--r--mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala5
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/JavaLogisticRegressionSuite.java98
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/JavaSVMSuite.java98
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala62
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala62
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/JavaLassoSuite.java96
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala64
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala2
20 files changed, 540 insertions, 169 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala b/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala
index d6154b66ae..70fae8c15a 100644
--- a/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/ClassificationModel.scala
@@ -9,7 +9,7 @@ trait ClassificationModel extends Serializable {
* @param testData RDD representing data points to be predicted
* @return RDD[Int] where each entry contains the corresponding prediction
*/
- def predict(testData: RDD[Array[Double]]): RDD[Int]
+ def predict(testData: RDD[Array[Double]]): RDD[Double]
/**
* Predict values for a single data point using the model trained.
@@ -17,5 +17,5 @@ trait ClassificationModel extends Serializable {
* @param testData array representing a single data point
* @return Int prediction from the trained model
*/
- def predict(testData: Array[Double]): Int
+ def predict(testData: Array[Double]): Double
}
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index 0af99c616d..73949b0103 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -33,13 +33,13 @@ import org.jblas.DoubleMatrix
class LogisticRegressionModel(
override val weights: Array[Double],
override val intercept: Double)
- extends GeneralizedLinearModel[Int](weights, intercept)
+ extends GeneralizedLinearModel(weights, intercept)
with ClassificationModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
val margin = dataMatrix.mmul(weightMatrix).get(0) + intercept
- round(1.0/ (1.0 + math.exp(margin * -1))).toInt
+ round(1.0/ (1.0 + math.exp(margin * -1)))
}
}
@@ -49,12 +49,15 @@ class LogisticRegressionWithSGD (
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
- extends GeneralizedLinearAlgorithm[Int, LogisticRegressionModel]
- with GradientDescent with Serializable {
+ extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
+ with Serializable {
val gradient = new LogisticGradient()
val updater = new SimpleUpdater()
-
+ val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
/**
* Construct a LogisticRegression object with default parameters
*/
@@ -86,14 +89,14 @@ object LogisticRegressionWithSGD {
* the number of features in the data.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
initialWeights: Array[Double])
: LogisticRegressionModel =
{
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).train(
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
input, initialWeights)
}
@@ -109,13 +112,13 @@ object LogisticRegressionWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double)
: LogisticRegressionModel =
{
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).train(
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
input)
}
@@ -131,7 +134,7 @@ object LogisticRegressionWithSGD {
* @return a LogisticRegressionModel which has the weights and offset from training.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double)
: LogisticRegressionModel =
@@ -149,7 +152,7 @@ object LogisticRegressionWithSGD {
* @return a LogisticRegressionModel which has the weights and offset from training.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int)
: LogisticRegressionModel =
{
@@ -157,15 +160,14 @@ object LogisticRegressionWithSGD {
}
def main(args: Array[String]) {
- if (args.length != 5) {
+ if (args.length != 4) {
println("Usage: LogisticRegression <master> <input_dir> <step_size> " +
- "<regularization_parameter> <niters>")
+ "<niters>")
System.exit(1)
}
val sc = new SparkContext(args(0), "LogisticRegression")
- val data = MLUtils.loadLabeledData(sc, args(1)).map(yx => (yx._1.toInt, yx._2))
- val model = LogisticRegressionWithSGD.train(
- data, args(4).toInt, args(2).toDouble, args(3).toDouble)
+ val data = MLUtils.loadLabeledData(sc, args(1))
+ val model = LogisticRegressionWithSGD.train(data, args(3).toInt, args(2).toDouble)
sc.stop()
}
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
index caf9e3cb93..fa9d5a9471 100644
--- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -31,12 +31,12 @@ import org.jblas.DoubleMatrix
class SVMModel(
override val weights: Array[Double],
override val intercept: Double)
- extends GeneralizedLinearModel[Int](weights, intercept)
+ extends GeneralizedLinearModel(weights, intercept)
with ClassificationModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
- signum(dataMatrix.dot(weightMatrix) + intercept).toInt
+ signum(dataMatrix.dot(weightMatrix) + intercept)
}
}
@@ -46,11 +46,14 @@ class SVMWithSGD private (
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
- extends GeneralizedLinearAlgorithm[Int, SVMModel] with GradientDescent with Serializable {
+ extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
val gradient = new HingeGradient()
val updater = new SquaredL2Updater()
-
+ val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
/**
* Construct a SVM object with default parameters
*/
@@ -81,7 +84,7 @@ object SVMWithSGD {
* the number of features in the data.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
@@ -89,7 +92,7 @@ object SVMWithSGD {
initialWeights: Array[Double])
: SVMModel =
{
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).train(input,
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
initialWeights)
}
@@ -105,14 +108,14 @@ object SVMWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
miniBatchFraction: Double)
: SVMModel =
{
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).train(input)
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
}
/**
@@ -127,7 +130,7 @@ object SVMWithSGD {
* @return a SVMModel which has the weights and offset from training.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double)
@@ -146,7 +149,7 @@ object SVMWithSGD {
* @return a SVMModel which has the weights and offset from training.
*/
def train(
- input: RDD[(Int, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int)
: SVMModel =
{
@@ -159,7 +162,7 @@ object SVMWithSGD {
System.exit(1)
}
val sc = new SparkContext(args(0), "SVM")
- val data = MLUtils.loadLabeledData(sc, args(1)).map(yx => (yx._1.toInt, yx._2))
+ val data = MLUtils.loadLabeledData(sc, args(1))
val model = SVMWithSGD.train(data, args(4).toInt, args(2).toDouble, args(3).toDouble)
sc.stop()
diff --git a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
index f7d09a2bd3..54793ca74d 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/GradientDescent.scala
@@ -24,15 +24,12 @@ import org.jblas.DoubleMatrix
import scala.collection.mutable.ArrayBuffer
-trait GradientDescent extends Optimizer {
+class GradientDescent(gradient: Gradient, updater: Updater) extends Optimizer {
- val gradient: Gradient
- val updater: Updater
-
- var stepSize: Double
- var numIterations: Int
- var regParam: Double
- var miniBatchFraction: Double
+ var stepSize: Double = 1.0
+ var numIterations: Int = 100
+ var regParam: Double = 0.0
+ var miniBatchFraction: Double = 1.0
/**
* Set the step size per-iteration of SGD. Default 1.0.
diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 7e80737773..03a7755541 100644
--- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -17,11 +17,8 @@
package spark.mllib.regression
-import spark.{Logging, RDD, SparkContext, SparkException}
+import spark.{Logging, RDD}
import spark.mllib.optimization._
-import spark.mllib.util.MLUtils
-
-import scala.math.round
import org.jblas.DoubleMatrix
@@ -30,18 +27,23 @@ import org.jblas.DoubleMatrix
* GeneralizedLinearAlgorithm. GLMs consist of a weight vector,
* an intercept.
*/
-abstract class GeneralizedLinearModel[T: ClassManifest](
- val weights: Array[Double],
- val intercept: Double)
+abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: Double)
extends Serializable {
// Create a column vector that can be used for predictions
private val weightsMatrix = new DoubleMatrix(weights.length, 1, weights:_*)
+ /**
+ * Predict the result given a data point and the weights learned.
+ *
+ * @param dataMatrix Row vector containing the features for this data point
+ * @param weightMatrix Column vector containing the weights of the model
+ * @param intercept Intercept of the model.
+ */
def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
- intercept: Double): T
+ intercept: Double): Double
- def predict(testData: spark.RDD[Array[Double]]): RDD[T] = {
+ def predict(testData: spark.RDD[Array[Double]]): RDD[Double] = {
// A small optimization to avoid serializing the entire model. Only the weightsMatrix
// and intercept is needed.
val localWeights = weightsMatrix
@@ -53,7 +55,7 @@ abstract class GeneralizedLinearModel[T: ClassManifest](
}
}
- def predict(testData: Array[Double]): T = {
+ def predict(testData: Array[Double]): Double = {
val dataMat = new DoubleMatrix(1, testData.length, testData:_*)
predictPoint(dataMat, weightsMatrix, intercept)
}
@@ -61,24 +63,18 @@ abstract class GeneralizedLinearModel[T: ClassManifest](
/**
* GeneralizedLinearAlgorithm abstracts out the training for all GLMs.
- * This class should be mixed in with an Optimizer to create a new GLM.
- *
- * NOTE(shivaram): This is an abstract class rather than a trait as we use
- * a view bound to convert labels to Double.
+ * This class should be extended with an Optimizer to create a new GLM.
*/
-abstract class GeneralizedLinearAlgorithm[T, M](implicit
- t: T => Double,
- tManifest: Manifest[T],
- methodEv: M <:< GeneralizedLinearModel[T])
+abstract class GeneralizedLinearAlgorithm[M](implicit
+ methodEv: M <:< GeneralizedLinearModel)
extends Logging with Serializable {
- // We need an optimizer mixin to solve the GLM
- self : Optimizer =>
-
- var addIntercept: Boolean
+ val optimizer: Optimizer
def createModel(weights: Array[Double], intercept: Double): M
+ var addIntercept: Boolean
+
/**
* Set if the algorithm should add an intercept. Default true.
*/
@@ -87,26 +83,22 @@ abstract class GeneralizedLinearAlgorithm[T, M](implicit
this
}
- def train(input: RDD[(T, Array[Double])]) : M = {
- val nfeatures: Int = input.first()._2.length
+ def run(input: RDD[LabeledPoint]) : M = {
+ val nfeatures: Int = input.first().features.length
val initialWeights = Array.fill(nfeatures)(1.0)
- train(input, initialWeights)
+ run(input, initialWeights)
}
- def train(
- input: RDD[(T, Array[Double])],
+ def run(
+ input: RDD[LabeledPoint],
initialWeights: Array[Double])
: M = {
// Add a extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
- input.map { case (y, features) =>
- (y.toDouble, Array(1.0, features:_*))
- }
+ input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*)))
} else {
- input.map { case (y, features) =>
- (y.toDouble, features)
- }
+ input.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
}
val initialWeightsWithIntercept = if (addIntercept) {
@@ -115,7 +107,7 @@ abstract class GeneralizedLinearAlgorithm[T, M](implicit
initialWeights
}
- val weights = optimize(data, initialWeightsWithIntercept)
+ val weights = optimizer.optimize(data, initialWeightsWithIntercept)
val intercept = weights(0)
val weightsScaled = weights.tail
diff --git a/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala
new file mode 100644
index 0000000000..592f0b5414
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/regression/LabeledPoint.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.regression
+
+/**
+ * Class that represents the features and labels of a data point.
+ *
+ * @param label Label for this data point.
+ * @param features List of features for this data point.
+ */
+case class LabeledPoint(val label: Double, val features: Array[Double]) {
+
+ /**
+ * Construct a labeled point using java.lang.Double.
+ */
+ def this(label: java.lang.Double, features: Array[Double]) = this(label.doubleValue(), features)
+}
diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
index f8b15033aa..989e5ded58 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
@@ -30,7 +30,7 @@ import org.jblas.DoubleMatrix
class LassoModel(
override val weights: Array[Double],
override val intercept: Double)
- extends GeneralizedLinearModel[Double](weights, intercept)
+ extends GeneralizedLinearModel(weights, intercept)
with RegressionModel with Serializable {
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
@@ -46,11 +46,15 @@ class LassoWithSGD (
var regParam: Double,
var miniBatchFraction: Double,
var addIntercept: Boolean)
- extends GeneralizedLinearAlgorithm[Double, LassoModel]
- with GradientDescent with Serializable {
+ extends GeneralizedLinearAlgorithm[LassoModel]
+ with Serializable {
val gradient = new SquaredGradient()
val updater = new L1Updater()
+ val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
/**
* Construct a Lasso object with default parameters
@@ -82,7 +86,7 @@ object LassoWithSGD {
* the number of features in the data.
*/
def train(
- input: RDD[(Double, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
@@ -90,7 +94,7 @@ object LassoWithSGD {
initialWeights: Array[Double])
: LassoModel =
{
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).train(input,
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
initialWeights)
}
@@ -106,14 +110,14 @@ object LassoWithSGD {
* @param miniBatchFraction Fraction of data to be used per iteration.
*/
def train(
- input: RDD[(Double, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
miniBatchFraction: Double)
: LassoModel =
{
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).train(input)
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
}
/**
@@ -128,7 +132,7 @@ object LassoWithSGD {
* @return a LassoModel which has the weights and offset from training.
*/
def train(
- input: RDD[(Double, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double)
@@ -147,7 +151,7 @@ object LassoWithSGD {
* @return a LassoModel which has the weights and offset from training.
*/
def train(
- input: RDD[(Double, Array[Double])],
+ input: RDD[LabeledPoint],
numIterations: Int)
: LassoModel =
{
diff --git a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
index 6ba141e8fb..de790dde51 100644
--- a/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/RidgeRegression.scala
@@ -71,7 +71,8 @@ class RidgeRegression private (var lambdaLow: Double, var lambdaHigh: Double)
this
}
- def train(input: RDD[(Double, Array[Double])]): RidgeRegressionModel = {
+ def train(inputLabeled: RDD[LabeledPoint]): RidgeRegressionModel = {
+ val input = inputLabeled.map(labeledPoint => (labeledPoint.label, labeledPoint.features))
val nfeatures: Int = input.take(1)(0)._2.length
val nexamples: Long = input.count()
@@ -183,7 +184,7 @@ object RidgeRegression {
* @param lambdaHigh upper bound used in binary search for lambda
*/
def train(
- input: RDD[(Double, Array[Double])],
+ input: RDD[LabeledPoint],
lambdaLow: Double,
lambdaHigh: Double)
: RidgeRegressionModel =
@@ -199,7 +200,7 @@ object RidgeRegression {
*
* @param input RDD of (response, array of features) pairs.
*/
- def train(input: RDD[(Double, Array[Double])]) : RidgeRegressionModel = {
+ def train(input: RDD[LabeledPoint]) : RidgeRegressionModel = {
train(input, 0.0, 100.0)
}
diff --git a/mllib/src/main/scala/spark/mllib/util/LassoDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LassoDataGenerator.scala
index ef4f42a494..1f185c9de7 100644
--- a/mllib/src/main/scala/spark/mllib/util/LassoDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/LassoDataGenerator.scala
@@ -29,14 +29,14 @@ object LassoGenerator {
val trueWeights = new DoubleMatrix(1, nfeatures+1,
Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
- val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
+ val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
val y = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1
- (y, x)
+ LabeledPoint(y, x)
}
MLUtils.saveLabeledData(data, outputPath)
diff --git a/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala
index 8d659cd97c..4fa19c3c23 100644
--- a/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/LogisticRegressionDataGenerator.scala
@@ -20,6 +20,7 @@ package spark.mllib.util
import scala.util.Random
import spark.{RDD, SparkContext}
+import spark.mllib.regression.LabeledPoint
object LogisticRegressionDataGenerator {
@@ -40,7 +41,7 @@ object LogisticRegressionDataGenerator {
nfeatures: Int,
eps: Double,
nparts: Int = 2,
- probOne: Double = 0.5): RDD[(Double, Array[Double])] = {
+ probOne: Double = 0.5): RDD[LabeledPoint] = {
val data = sc.parallelize(0 until nexamples, nparts).map { idx =>
val rnd = new Random(42 + idx)
@@ -48,7 +49,7 @@ object LogisticRegressionDataGenerator {
val x = Array.fill[Double](nfeatures) {
rnd.nextGaussian() + (y * eps)
}
- (y, x)
+ LabeledPoint(y, x)
}
data
}
diff --git a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
index b5e564df6d..e45eda2c99 100644
--- a/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/spark/mllib/util/MLUtils.scala
@@ -21,6 +21,7 @@ import spark.{RDD, SparkContext}
import spark.SparkContext._
import org.jblas.DoubleMatrix
+import spark.mllib.regression.LabeledPoint
/**
* Helper methods to load and save data
@@ -36,17 +37,17 @@ object MLUtils {
* @return An RDD of tuples. For each tuple, the first element is the label, and the second
* element represents the feature values (an array of Double).
*/
- def loadLabeledData(sc: SparkContext, dir: String): RDD[(Double, Array[Double])] = {
+ def loadLabeledData(sc: SparkContext, dir: String): RDD[LabeledPoint] = {
sc.textFile(dir).map { line =>
val parts = line.split(",")
val label = parts(0).toDouble
val features = parts(1).trim().split(" ").map(_.toDouble)
- (label, features)
+ LabeledPoint(label, features)
}
}
- def saveLabeledData(data: RDD[(Double, Array[Double])], dir: String) {
- val dataStr = data.map(x => x._1 + "," + x._2.mkString(" "))
+ def saveLabeledData(data: RDD[LabeledPoint], dir: String) {
+ val dataStr = data.map(x => x.label + "," + x.features.mkString(" "))
dataStr.saveAsTextFile(dir)
}
diff --git a/mllib/src/main/scala/spark/mllib/util/RidgeRegressionDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/RidgeRegressionDataGenerator.scala
index c5b8a29942..c4d65c3f9a 100644
--- a/mllib/src/main/scala/spark/mllib/util/RidgeRegressionDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/RidgeRegressionDataGenerator.scala
@@ -22,6 +22,7 @@ import scala.util.Random
import org.jblas.DoubleMatrix
import spark.{RDD, SparkContext}
+import spark.mllib.regression.LabeledPoint
object RidgeRegressionDataGenerator {
@@ -41,14 +42,14 @@ object RidgeRegressionDataGenerator {
nexamples: Int,
nfeatures: Int,
eps: Double,
- nparts: Int = 2) : RDD[(Double, Array[Double])] = {
+ nparts: Int = 2) : RDD[LabeledPoint] = {
org.jblas.util.Random.seed(42)
// Random values distributed uniformly in [-0.5, 0.5]
val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
w.put(0, 0, 10)
w.put(1, 0, 10)
- val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nparts, nparts).flatMap { p =>
+ val data: RDD[LabeledPoint] = sc.parallelize(0 until nparts, nparts).flatMap { p =>
org.jblas.util.Random.seed(42 + p)
val examplesInPartition = nexamples / nparts
@@ -61,7 +62,7 @@ object RidgeRegressionDataGenerator {
val yObs = new DoubleMatrix(normalValues).addi(y)
Iterator.tabulate(examplesInPartition) { i =>
- (yObs.get(i, 0), X.getRow(i).toArray)
+ LabeledPoint(yObs.get(i, 0), X.getRow(i).toArray)
}
}
data
diff --git a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
index 00a54d9a70..a37f6eb3b3 100644
--- a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
@@ -9,6 +9,7 @@ import spark.{RDD, SparkContext}
import spark.mllib.util.MLUtils
import org.jblas.DoubleMatrix
+import spark.mllib.regression.LabeledPoint
object SVMGenerator {
@@ -32,14 +33,14 @@ object SVMGenerator {
val trueWeights = new DoubleMatrix(1, nfeatures+1,
Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
- val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
+ val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
val y = signum((new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1)
- (y, x)
+ LabeledPoint(y, x)
}
MLUtils.saveLabeledData(data, outputPath)
diff --git a/mllib/src/test/scala/spark/mllib/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/scala/spark/mllib/classification/JavaLogisticRegressionSuite.java
new file mode 100644
index 0000000000..e0ebd45cd8
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/classification/JavaLogisticRegressionSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.classification;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+
+import spark.mllib.regression.LabeledPoint;
+
+public class JavaLogisticRegressionSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ int validatePrediction(List<LabeledPoint> validationData, LogisticRegressionModel model) {
+ int numAccurate = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ if (prediction == point.label()) {
+ numAccurate++;
+ }
+ }
+ return numAccurate;
+ }
+
+ @Test
+ public void runLRUsingConstructor() {
+ int nPoints = 10000;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ List<LabeledPoint> validationData =
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+
+ LogisticRegressionWithSGD lrImpl = new LogisticRegressionWithSGD();
+ lrImpl.optimizer().setStepSize(1.0)
+ .setRegParam(1.0)
+ .setNumIterations(100);
+ LogisticRegressionModel model = lrImpl.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+ @Test
+ public void runLRUsingStaticMethods() {
+ int nPoints = 10000;
+ double A = 2.0;
+ double B = -1.5;
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 42), 2).cache();
+ List<LabeledPoint> validationData =
+ LogisticRegressionSuite.generateLogisticInputAsList(A, B, nPoints, 17);
+
+ LogisticRegressionModel model = LogisticRegressionWithSGD.train(
+ testRDD.rdd(), 100, 1.0, 1.0);
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+}
diff --git a/mllib/src/test/scala/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/scala/spark/mllib/classification/JavaSVMSuite.java
new file mode 100644
index 0000000000..7881b3c38f
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/classification/JavaSVMSuite.java
@@ -0,0 +1,98 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.classification;
+
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+
+import spark.mllib.regression.LabeledPoint;
+
+public class JavaSVMSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaSVMSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ int validatePrediction(List<LabeledPoint> validationData, SVMModel model) {
+ int numAccurate = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ if (prediction == point.label()) {
+ numAccurate++;
+ }
+ }
+ return numAccurate;
+ }
+
+ @Test
+ public void runSVMUsingConstructor() {
+ int nPoints = 10000;
+ double A = 2.0;
+ double[] weights = {-1.5, 1.0};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
+ weights, nPoints, 42), 2).cache();
+ List<LabeledPoint> validationData =
+ SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+
+ SVMWithSGD svmSGDImpl = new SVMWithSGD();
+ svmSGDImpl.optimizer().setStepSize(1.0)
+ .setRegParam(1.0)
+ .setNumIterations(100);
+ SVMModel model = svmSGDImpl.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+ @Test
+ public void runSVMUsingStaticMethods() {
+ int nPoints = 10000;
+ double A = 2.0;
+ double[] weights = {-1.5, 1.0};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(SVMSuite.generateSVMInputAsList(A,
+ weights, nPoints, 42), 2).cache();
+ List<LabeledPoint> validationData =
+ SVMSuite.generateSVMInputAsList(A, weights, nPoints, 17);
+
+ SVMModel model = SVMWithSGD.train(testRDD.rdd(), 100, 1.0, 1.0, 1.0);
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+}
diff --git a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
index ee38486212..16bd2c6b38 100644
--- a/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -18,21 +18,23 @@
package spark.mllib.classification
import scala.util.Random
+import scala.collection.JavaConversions._
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
import spark.SparkContext
-import spark.mllib.optimization._
+import spark.mllib.regression._
+object LogisticRegressionSuite {
-class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
- val sc = new SparkContext("local", "test")
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
+ def generateLogisticInputAsList(
+ offset: Double,
+ scale: Double,
+ nPoints: Int,
+ seed: Int): java.util.List[LabeledPoint] = {
+ seqAsJavaList(generateLogisticInput(offset, scale, nPoints, seed))
}
// Generate input of the form Y = logistic(offset + scale*X)
@@ -40,7 +42,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
offset: Double,
scale: Double,
nPoints: Int,
- seed: Int): Seq[(Int, Array[Double])] = {
+ seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
@@ -58,13 +60,23 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => (y(i), Array(x1(i))))
+ val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
testData
}
- def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
- val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
- (prediction != expected)
+}
+
+class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with ShouldMatchers {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) =>
+ (prediction != expected.label)
}.size
// At least 83% of the predictions should be on.
((input.length - numOffPredictions).toDouble / input.length) should be > 0.83
@@ -76,26 +88,27 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
val A = 2.0
val B = -1.5
- val testData = generateLogisticInput(A, B, nPoints, 42)
+ val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val lr = new LogisticRegressionWithSGD().setStepSize(10.0).setNumIterations(20)
+ val lr = new LogisticRegressionWithSGD()
+ lr.optimizer.setStepSize(10.0).setNumIterations(20)
- val model = lr.train(testRDD)
+ val model = lr.run(testRDD)
// Test the weights
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- val validationData = generateLogisticInput(A, B, nPoints, 17)
+ val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
- validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
test("logistic regression with initial weights") {
@@ -103,7 +116,7 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
val A = 2.0
val B = -1.5
- val testData = generateLogisticInput(A, B, nPoints, 42)
+ val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0
val initialWeights = Array(initialB)
@@ -112,20 +125,21 @@ class LogisticRegressionSuite extends FunSuite with BeforeAndAfterAll with Shoul
testRDD.cache()
// Use half as many iterations as the previous test.
- val lr = new LogisticRegressionWithSGD().setStepSize(10.0).setNumIterations(10)
+ val lr = new LogisticRegressionWithSGD()
+ lr.optimizer.setStepSize(10.0).setNumIterations(10)
- val model = lr.train(testRDD, initialWeights)
+ val model = lr.run(testRDD, initialWeights)
val weight0 = model.weights(0)
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- val validationData = generateLogisticInput(A, B, nPoints, 17)
+ val validationData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
- validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 1eef9387e3..9e0970812d 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -19,21 +19,24 @@ package spark.mllib.classification
import scala.util.Random
import scala.math.signum
+import scala.collection.JavaConversions._
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
-import spark.mllib.optimization._
+import spark.mllib.regression._
import org.jblas.DoubleMatrix
-class SVMSuite extends FunSuite with BeforeAndAfterAll {
- val sc = new SparkContext("local", "test")
+object SVMSuite {
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
+ def generateSVMInputAsList(
+ intercept: Double,
+ weights: Array[Double],
+ nPoints: Int,
+ seed: Int): java.util.List[LabeledPoint] = {
+ seqAsJavaList(generateSVMInput(intercept, weights, nPoints, seed))
}
// Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
@@ -41,7 +44,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
intercept: Double,
weights: Array[Double],
nPoints: Int,
- seed: Int): Seq[(Int, Array[Double])] = {
+ seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
@@ -53,17 +56,28 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
0.1 * rnd.nextGaussian()
).toInt
}
- y.zip(x)
+ y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
- def validatePrediction(predictions: Seq[Int], input: Seq[(Int, Array[Double])]) {
- val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
- (prediction != expected)
+}
+
+class SVMSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) =>
+ (prediction != expected.label)
}.size
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)
}
+
test("SVM using local random SGD") {
val nPoints = 10000
@@ -71,23 +85,24 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val B = -1.5
val C = 1.0
- val testData = generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val svm = new SVMWithSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+ val svm = new SVMWithSGD()
+ svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
- val model = svm.train(testRDD)
+ val model = svm.run(testRDD)
- val validationData = generateSVMInput(A, Array[Double](B,C), nPoints, 17)
+ val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
- validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
test("SVM local random SGD with initial weights") {
@@ -97,7 +112,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val B = -1.5
val C = 1.0
- val testData = generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
val initialB = -1.0
val initialC = -1.0
@@ -106,17 +121,18 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val svm = new SVMWithSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+ val svm = new SVMWithSGD()
+ svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
- val model = svm.train(testRDD, initialWeights)
+ val model = svm.run(testRDD, initialWeights)
- val validationData = generateSVMInput(A, Array[Double](B,C), nPoints, 17)
+ val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
- validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
diff --git a/mllib/src/test/scala/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/scala/spark/mllib/regression/JavaLassoSuite.java
new file mode 100644
index 0000000000..e26d7b385c
--- /dev/null
+++ b/mllib/src/test/scala/spark/mllib/regression/JavaLassoSuite.java
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+
+public class JavaLassoSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaLassoSuite");
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ System.clearProperty("spark.driver.port");
+ }
+
+ int validatePrediction(List<LabeledPoint> validationData, LassoModel model) {
+ int numAccurate = 0;
+ for (LabeledPoint point: validationData) {
+ Double prediction = model.predict(point.features());
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ if (Math.abs(prediction - point.label()) <= 0.5) {
+ numAccurate++;
+ }
+ }
+ return numAccurate;
+ }
+
+ @Test
+ public void runLassoUsingConstructor() {
+ int nPoints = 10000;
+ double A = 2.0;
+ double[] weights = {-1.5, 1.0e-2};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A,
+ weights, nPoints, 42), 2).cache();
+ List<LabeledPoint> validationData =
+ LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17);
+
+ LassoWithSGD svmSGDImpl = new LassoWithSGD();
+ svmSGDImpl.optimizer().setStepSize(1.0)
+ .setRegParam(0.01)
+ .setNumIterations(20);
+ LassoModel model = svmSGDImpl.run(testRDD.rdd());
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+ @Test
+ public void runLassoUsingStaticMethods() {
+ int nPoints = 10000;
+ double A = 2.0;
+ double[] weights = {-1.5, 1.0e-2};
+
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(LassoSuite.generateLassoInputAsList(A,
+ weights, nPoints, 42), 2).cache();
+ List<LabeledPoint> validationData =
+ LassoSuite.generateLassoInputAsList(A, weights, nPoints, 17);
+
+ LassoModel model = LassoWithSGD.train(testRDD.rdd(), 100, 1.0, 0.01, 1.0);
+
+ int numAccurate = validatePrediction(validationData, model);
+ Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
+ }
+
+}
diff --git a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
index ab1d07b879..b9ada2b1ec 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
@@ -17,31 +17,33 @@
package spark.mllib.regression
+import scala.collection.JavaConversions._
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
-import spark.mllib.optimization._
import org.jblas.DoubleMatrix
+object LassoSuite {
-class LassoSuite extends FunSuite with BeforeAndAfterAll {
- val sc = new SparkContext("local", "test")
-
- override def afterAll() {
- sc.stop()
- System.clearProperty("spark.driver.port")
+ def generateLassoInputAsList(
+ intercept: Double,
+ weights: Array[Double],
+ nPoints: Int,
+ seed: Int): java.util.List[LabeledPoint] = {
+ seqAsJavaList(generateLassoInput(intercept, weights, nPoints, seed))
}
+
// Generate noisy input of the form Y = x.dot(weights) + intercept + noise
def generateLassoInput(
intercept: Double,
weights: Array[Double],
nPoints: Int,
- seed: Int): Seq[(Double, Array[Double])] = {
+ seed: Int): Seq[LabeledPoint] = {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
@@ -49,13 +51,23 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val y = x.map(xi =>
(new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.1 * rnd.nextGaussian()
)
- y zip x
+ y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
- def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
- val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
+}
+
+class LassoSuite extends FunSuite with BeforeAndAfterAll {
+ val sc = new SparkContext("local", "test")
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, expected) =>
// A prediction is off if the prediction is more than 0.5 away from expected value.
- math.abs(prediction - expected) > 0.5
+ math.abs(prediction - expected.label) > 0.5
}.size
// At least 80% of the predictions should be on.
assert(numOffPredictions < input.length / 5)
@@ -68,14 +80,15 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val B = -1.5
val C = 1.0e-2
- val testData = generateLassoInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val ls = new LassoWithSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ val ls = new LassoWithSGD()
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
- val model = ls.train(testRDD)
+ val model = ls.run(testRDD)
val weight0 = model.weights(0)
val weight1 = model.weights(1)
@@ -83,14 +96,14 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
- val validationData = generateLassoInput(A, Array[Double](B,C), nPoints, 17)
- val validationRDD = sc.parallelize(validationData,2)
+ val validationData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
- validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
test("Lasso local random SGD with initial weights") {
@@ -100,7 +113,7 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val B = -1.5
val C = 1.0e-2
- val testData = generateLassoInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 42)
val initialB = -1.0
val initialC = -1.0
@@ -109,9 +122,10 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val ls = new LassoWithSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ val ls = new LassoWithSGD()
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
- val model = ls.train(testRDD, initialWeights)
+ val model = ls.run(testRDD, initialWeights)
val weight0 = model.weights(0)
val weight1 = model.weights(1)
@@ -119,13 +133,13 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
- val validationData = generateLassoInput(A, Array[Double](B,C), nPoints, 17)
+ val validationData = LassoSuite.generateLassoInput(A, Array[Double](B,C), nPoints, 17)
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
- validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
// Test prediction on Array.
- validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
}
diff --git a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
index 3c588c6162..4c4900658f 100644
--- a/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -47,7 +47,7 @@ class RidgeRegressionSuite extends FunSuite with BeforeAndAfterAll {
val xMat = (0 until 20).map(i => Array(x1(i), x2(i))).toArray
val y = xMat.map(i => 3 + i(0) + i(1))
- val testData = (0 until 20).map(i => (y(i), xMat(i))).toArray
+ val testData = (0 until 20).map(i => LabeledPoint(y(i), xMat(i))).toArray
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()