aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorEvan Sparks <evan.sparks@gmail.com>2013-08-29 22:13:15 -0700
committerEvan Sparks <evan.sparks@gmail.com>2013-08-29 22:13:15 -0700
commit852d81078743c1bcf81031f0a55dec7889281d77 (patch)
treea26fb9fed3c65692d18f22b096e77fe8decf828a /mllib
parentca716209507e4870fbbf55d96ecd57c218d547ac (diff)
parentdc06b528790c69b2e6de85cba84266fea81dd4f4 (diff)
downloadspark-852d81078743c1bcf81031f0a55dec7889281d77.tar.gz
spark-852d81078743c1bcf81031f0a55dec7889281d77.tar.bz2
spark-852d81078743c1bcf81031f0a55dec7889281d77.zip
Merge pull request #819 from shivaram/sgd-cleanup
Change SVM to use {0,1} labels
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala31
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/SVM.scala33
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Gradient.scala16
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala23
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/Lasso.scala9
-rw-r--r--mllib/src/main/scala/spark/mllib/util/DataValidators.scala42
-rw-r--r--mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala9
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala45
8 files changed, 160 insertions, 48 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index 30ee0ab0ff..482e4a6745 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -17,12 +17,13 @@
package spark.mllib.classification
+import scala.math.round
+
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.regression._
import spark.mllib.util.MLUtils
-
-import scala.math.round
+import spark.mllib.util.DataValidators
import org.jblas.DoubleMatrix
@@ -47,26 +48,29 @@ class LogisticRegressionModel(
/**
* Train a classification model for Logistic Regression using Stochastic Gradient Descent.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
*/
class LogisticRegressionWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LogisticRegressionModel]
with Serializable {
val gradient = new LogisticGradient()
val updater = new SimpleUpdater()
- val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
- .setNumIterations(numIterations)
- .setRegParam(regParam)
- .setMiniBatchFraction(miniBatchFraction)
+ override val optimizer = new GradientDescent(gradient, updater)
+ .setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
+ override val validators = List(DataValidators.classificationLabels)
+
/**
* Construct a LogisticRegression object with default parameters
*/
- def this() = this(1.0, 100, 0.0, 1.0, true)
+ def this() = this(1.0, 100, 0.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LogisticRegressionModel(weights, intercept)
@@ -75,6 +79,7 @@ class LogisticRegressionWithSGD private (
/**
* Top-level methods for calling Logistic Regression.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
*/
object LogisticRegressionWithSGD {
// NOTE(shivaram): We use multiple train methods instead of default arguments to support
@@ -85,6 +90,7 @@ object LogisticRegressionWithSGD {
* number of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
* gradient descent are initialized using the initial weights provided.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -101,7 +107,7 @@ object LogisticRegressionWithSGD {
initialWeights: Array[Double])
: LogisticRegressionModel =
{
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input, initialWeights)
}
@@ -109,6 +115,7 @@ object LogisticRegressionWithSGD {
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
* number of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -123,7 +130,7 @@ object LogisticRegressionWithSGD {
miniBatchFraction: Double)
: LogisticRegressionModel =
{
- new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction, true).run(
+ new LogisticRegressionWithSGD(stepSize, numIterations, 0.0, miniBatchFraction).run(
input)
}
@@ -131,6 +138,7 @@ object LogisticRegressionWithSGD {
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
* number of iterations of gradient descent using the specified step size. We use the entire data
* set to update the gradient in each iteration.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
@@ -151,6 +159,7 @@ object LogisticRegressionWithSGD {
* Train a logistic regression model given an RDD of (label, features) pairs. We run a fixed
* number of iterations of gradient descent using a step size of 1.0. We use the entire data set
* to update the gradient in each iteration.
+ * NOTE: Labels used in Logistic Regression should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
index f799cb2829..69393cd7b0 100644
--- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -18,10 +18,12 @@
package spark.mllib.classification
import scala.math.signum
+
import spark.{Logging, RDD, SparkContext}
import spark.mllib.optimization._
import spark.mllib.regression._
import spark.mllib.util.MLUtils
+import spark.mllib.util.DataValidators
import org.jblas.DoubleMatrix
@@ -39,31 +41,36 @@ class SVMModel(
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
- signum(dataMatrix.dot(weightMatrix) + intercept)
+ val margin = dataMatrix.dot(weightMatrix) + intercept
+ if (margin < 0) 0.0 else 1.0
}
}
/**
* Train an SVM using Stochastic Gradient Descent.
+ * NOTE: Labels used in SVM should be {0, 1}
*/
class SVMWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[SVMModel] with Serializable {
val gradient = new HingeGradient()
val updater = new SquaredL2Updater()
- val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize)
- .setNumIterations(numIterations)
- .setRegParam(regParam)
- .setMiniBatchFraction(miniBatchFraction)
+ override val optimizer = new GradientDescent(gradient, updater)
+ .setStepSize(stepSize)
+ .setNumIterations(numIterations)
+ .setRegParam(regParam)
+ .setMiniBatchFraction(miniBatchFraction)
+
+ override val validators = List(DataValidators.classificationLabels)
+
/**
* Construct a SVM object with default parameters
*/
- def this() = this(1.0, 100, 1.0, 1.0, true)
+ def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new SVMModel(weights, intercept)
@@ -71,7 +78,7 @@ class SVMWithSGD private (
}
/**
- * Top-level methods for calling SVM.
+ * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1}
*/
object SVMWithSGD {
@@ -80,6 +87,7 @@ object SVMWithSGD {
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in
* gradient descent are initialized using the initial weights provided.
+ * NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -98,7 +106,7 @@ object SVMWithSGD {
initialWeights: Array[Double])
: SVMModel =
{
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
@@ -106,6 +114,7 @@ object SVMWithSGD {
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. Each iteration uses
* `miniBatchFraction` fraction of the data to calculate the gradient.
+ * NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
@@ -121,13 +130,14 @@ object SVMWithSGD {
miniBatchFraction: Double)
: SVMModel =
{
- new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+ new SVMWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using the specified step size. We use the entire data set to
* update the gradient in each iteration.
+ * NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param stepSize Step size to be used for each iteration of Gradient Descent.
@@ -149,6 +159,7 @@ object SVMWithSGD {
* Train a SVM model given an RDD of (label, features) pairs. We run a fixed number
* of iterations of gradient descent using a step size of 1.0. We use the entire data set to
* update the gradient in each iteration.
+ * NOTE: Labels used in SVM should be {0, 1}
*
* @param input RDD of (label, array of features) pairs.
* @param numIterations Number of iterations of gradient descent to run.
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
index e72b8b3a92..05568f55af 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
@@ -77,16 +77,22 @@ class SquaredGradient extends Gradient {
/**
* Compute gradient and loss for a Hinge loss function.
+ * NOTE: This assumes that the labels are {0,1}
*/
class HingeGradient extends Gradient {
- override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
+ override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix):
(DoubleMatrix, Double) = {
val dotProduct = data.dot(weights)
- if (1.0 > label * dotProduct)
- (data.mul(-label), 1.0 - label * dotProduct)
- else
- (DoubleMatrix.zeros(1,weights.length), 0.0)
+ // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
+ // Therefore the gradient is -(2y - 1)*x
+ val labelScaled = 2 * label - 1.0
+
+ if (1.0 > labelScaled * dotProduct) {
+ (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct)
+ } else {
+ (DoubleMatrix.zeros(1, weights.length), 0.0)
+ }
}
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 4ecafff08b..d164d415d6 100644
--- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -17,7 +17,7 @@
package spark.mllib.regression
-import spark.{Logging, RDD}
+import spark.{Logging, RDD, SparkException}
import spark.mllib.optimization._
import org.jblas.DoubleMatrix
@@ -83,15 +83,19 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
extends Logging with Serializable {
+ protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List()
+
val optimizer: Optimizer
+ protected var addIntercept: Boolean = true
+
+ protected var validateData: Boolean = true
+
/**
* Create a model given the weights and intercept
*/
protected def createModel(weights: Array[Double], intercept: Double): M
- protected var addIntercept: Boolean
-
/**
* Set if the algorithm should add an intercept. Default true.
*/
@@ -101,6 +105,14 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
/**
+ * Set if the algorithm should validate data before training. Default true.
+ */
+ def setValidateData(validateData: Boolean): this.type = {
+ this.validateData = validateData
+ this
+ }
+
+ /**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
*/
@@ -116,6 +128,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
*/
def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
+ // Check the data properties before running the optimizer
+ if (validateData && !validators.forall(func => func(input))) {
+ throw new SparkException("Input validation failed.")
+ }
+
// Add a extra variable consisting of all 1.0's for the intercept.
val data = if (addIntercept) {
input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*)))
diff --git a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
index 6bbc990a5a..89f791e85a 100644
--- a/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/Lasso.scala
@@ -48,8 +48,7 @@ class LassoWithSGD private (
var stepSize: Double,
var numIterations: Int,
var regParam: Double,
- var miniBatchFraction: Double,
- var addIntercept: Boolean)
+ var miniBatchFraction: Double)
extends GeneralizedLinearAlgorithm[LassoModel]
with Serializable {
@@ -63,7 +62,7 @@ class LassoWithSGD private (
/**
* Construct a Lasso object with default parameters
*/
- def this() = this(1.0, 100, 1.0, 1.0, true)
+ def this() = this(1.0, 100, 1.0, 1.0)
def createModel(weights: Array[Double], intercept: Double) = {
new LassoModel(weights, intercept)
@@ -98,7 +97,7 @@ object LassoWithSGD {
initialWeights: Array[Double])
: LassoModel =
{
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input,
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input,
initialWeights)
}
@@ -121,7 +120,7 @@ object LassoWithSGD {
miniBatchFraction: Double)
: LassoModel =
{
- new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction, true).run(input)
+ new LassoWithSGD(stepSize, numIterations, regParam, miniBatchFraction).run(input)
}
/**
diff --git a/mllib/src/main/scala/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/spark/mllib/util/DataValidators.scala
new file mode 100644
index 0000000000..57553accf1
--- /dev/null
+++ b/mllib/src/main/scala/spark/mllib/util/DataValidators.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.util
+
+import spark.{RDD, Logging}
+import spark.mllib.regression.LabeledPoint
+
+/**
+ * A collection of methods used to validate data before applying ML algorithms.
+ */
+object DataValidators extends Logging {
+
+ /**
+ * Function to check if labels used for classification are either zero or one.
+ *
+ * @param data - input data set that needs to be checked
+ *
+ * @return True if labels are all zero or one, false otherwise.
+ */
+ val classificationLabels: RDD[LabeledPoint] => Boolean = { data =>
+ val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count()
+ if (numInvalid != 0) {
+ logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels")
+ }
+ numInvalid == 0
+ }
+}
diff --git a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
index e02bd190f6..eff456cad6 100644
--- a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
@@ -1,7 +1,6 @@
package spark.mllib.util
import scala.util.Random
-import scala.math.signum
import spark.{RDD, SparkContext}
@@ -30,8 +29,8 @@ object SVMDataGenerator {
val sc = new SparkContext(sparkMaster, "SVMGenerator")
val globalRnd = new Random(94720)
- val trueWeights = new DoubleMatrix(1, nfeatures+1,
- Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
+ val trueWeights = new DoubleMatrix(1, nfeatures + 1,
+ Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*)
val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
@@ -39,11 +38,13 @@ object SVMDataGenerator {
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
- val y = signum((new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1)
+ val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1
+ val y = if (yD < 0) 0.0 else 1.0
LabeledPoint(y, x)
}
MLUtils.saveLabeledData(data, outputPath)
+
sc.stop()
}
}
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 04f631d80f..894ae458ad 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -48,13 +48,11 @@ object SVMSuite {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
- Array.fill[Double](weights.length)(rnd.nextGaussian()))
+ Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
val y = x.map { xi =>
- signum(
- (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) +
- intercept +
- 0.1 * rnd.nextGaussian()
- ).toInt
+ val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) +
+ intercept + 0.01 * rnd.nextGaussian()
+ if (yD < 0) 0.0 else 1.0
}
y.zip(x).map(p => LabeledPoint(p._1, p._2))
}
@@ -85,7 +83,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM using local random SGD") {
val nPoints = 10000
- val A = 2.0
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
val B = -1.5
val C = 1.0
@@ -100,7 +99,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
val model = svm.run(testRDD)
val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17)
- val validationRDD = sc.parallelize(validationData,2)
+ val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
@@ -112,7 +111,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM local random SGD with initial weights") {
val nPoints = 10000
- val A = 2.0
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
val B = -1.5
val C = 1.0
@@ -139,4 +139,31 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ test("SVM with invalid labels") {
+ val nPoints = 10000
+
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
+ val B = -1.5
+ val C = 1.0
+
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+ val testRDD = sc.parallelize(testData, 2)
+
+ val testRDDInvalid = testRDD.map { lp =>
+ if (lp.label == 0.0) {
+ LabeledPoint(-1.0, lp.features)
+ } else {
+ lp
+ }
+ }
+
+ intercept[spark.SparkException] {
+ val model = SVMWithSGD.train(testRDDInvalid, 100)
+ }
+
+ // Turning off data validation should not throw an exception
+ val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
+ }
}