diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-13 11:43:49 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-13 11:44:47 -0700 |
commit | 654087194d232221dfb64ba646c8a8e12649f961 (patch) | |
tree | 90cddc6673f93d683668b802a0570aba812fb66c /mllib | |
parent | 622f83ce1ce522ea0058665cbf43c64a73b44439 (diff) | |
download | spark-654087194d232221dfb64ba646c8a8e12649f961.tar.gz spark-654087194d232221dfb64ba646c8a8e12649f961.tar.bz2 spark-654087194d232221dfb64ba646c8a8e12649f961.zip |
Change SVM to use {0,1} labels.
Also add a data validation check to make sure classification labels
are always 0 or 1 and add an appropriate test case.
Diffstat (limited to 'mllib')
7 files changed, 116 insertions, 26 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index 30ee0ab0ff..24f9f4e76b 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -17,12 +17,13 @@ package spark.mllib.classification +import scala.math.round + import spark.{Logging, RDD, SparkContext} import spark.mllib.optimization._ import spark.mllib.regression._ import spark.mllib.util.MLUtils - -import scala.math.round +import spark.mllib.util.DataValidators import org.jblas.DoubleMatrix @@ -59,10 +60,13 @@ class LogisticRegressionWithSGD private ( val gradient = new LogisticGradient() val updater = new SimpleUpdater() - val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + override val validateFuncs = List(DataValidators.classificationLabels) + /** * Construct a LogisticRegression object with default parameters */ diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala index f799cb2829..d2b50f4987 100644 --- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala @@ -18,10 +18,12 @@ package spark.mllib.classification import scala.math.signum + import spark.{Logging, RDD, SparkContext} import spark.mllib.optimization._ import spark.mllib.regression._ import spark.mllib.util.MLUtils +import spark.mllib.util.DataValidators import org.jblas.DoubleMatrix @@ -45,6 +47,7 @@ class SVMModel( /** * Train an SVM using Stochastic Gradient Descent. + * NOTE: Labels used in SVM should be {0, 1} */ class SVMWithSGD private ( var stepSize: Double, @@ -56,10 +59,14 @@ class SVMWithSGD private ( val gradient = new HingeGradient() val updater = new SquaredL2Updater() - val optimizer = new GradientDescent(gradient, updater).setStepSize(stepSize) - .setNumIterations(numIterations) - .setRegParam(regParam) - .setMiniBatchFraction(miniBatchFraction) + override val optimizer = new GradientDescent(gradient, updater) + .setStepSize(stepSize) + .setNumIterations(numIterations) + .setRegParam(regParam) + .setMiniBatchFraction(miniBatchFraction) + + override val validateFuncs = List(DataValidators.classificationLabels) + /** * Construct a SVM object with default parameters */ @@ -71,7 +78,7 @@ class SVMWithSGD private ( } /** - * Top-level methods for calling SVM. + * Top-level methods for calling SVM. NOTE: Labels used in SVM should be {0, 1} */ object SVMWithSGD { @@ -80,6 +87,7 @@ object SVMWithSGD { * of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. The weights used in * gradient descent are initialized using the initial weights provided. + * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -106,6 +114,7 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. Each iteration uses * `miniBatchFraction` fraction of the data to calculate the gradient. + * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. @@ -128,6 +137,7 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using the specified step size. We use the entire data set to * update the gradient in each iteration. + * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param stepSize Step size to be used for each iteration of Gradient Descent. @@ -149,6 +159,7 @@ object SVMWithSGD { * Train a SVM model given an RDD of (label, features) pairs. We run a fixed number * of iterations of gradient descent using a step size of 1.0. We use the entire data set to * update the gradient in each iteration. + * NOTE: Labels used in SVM should be {0, 1} * * @param input RDD of (label, array of features) pairs. * @param numIterations Number of iterations of gradient descent to run. diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala index e72b8b3a92..58bfe3f37b 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala @@ -77,16 +77,20 @@ class SquaredGradient extends Gradient { /** * Compute gradient and loss for a Hinge loss function. + * NOTE: This assumes that the labels are {0,1} */ class HingeGradient extends Gradient { - override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): + override def compute(data: DoubleMatrix, label: Double, weights: DoubleMatrix): (DoubleMatrix, Double) = { val dotProduct = data.dot(weights) + val labelScaled = 2*label - 1.0 - if (1.0 > label * dotProduct) - (data.mul(-label), 1.0 - label * dotProduct) + // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + if (1.0 > labelScaled * dotProduct) + (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) else - (DoubleMatrix.zeros(1,weights.length), 0.0) + (DoubleMatrix.zeros(1, weights.length), 0.0) } } diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 4ecafff08b..55edb3def5 100644 --- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -17,7 +17,7 @@ package spark.mllib.regression -import spark.{Logging, RDD} +import spark.{Logging, RDD, SparkException} import spark.mllib.optimization._ import org.jblas.DoubleMatrix @@ -83,6 +83,8 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] extends Logging with Serializable { + protected val validateFuncs: Seq[RDD[LabeledPoint] => Boolean] = List() + val optimizer: Optimizer /** @@ -116,6 +118,11 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] */ def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { + // Check the data properties before running the optimizer + if (!validateFuncs.forall(func => func(input))) { + throw new SparkException("Input validation failed.") + } + // Add a extra variable consisting of all 1.0's for the intercept. val data = if (addIntercept) { input.map(labeledPoint => (labeledPoint.label, Array(1.0, labeledPoint.features:_*))) diff --git a/mllib/src/main/scala/spark/mllib/util/DataValidators.scala b/mllib/src/main/scala/spark/mllib/util/DataValidators.scala new file mode 100644 index 0000000000..57553accf1 --- /dev/null +++ b/mllib/src/main/scala/spark/mllib/util/DataValidators.scala @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.util + +import spark.{RDD, Logging} +import spark.mllib.regression.LabeledPoint + +/** + * A collection of methods used to validate data before applying ML algorithms. + */ +object DataValidators extends Logging { + + /** + * Function to check if labels used for classification are either zero or one. + * + * @param data - input data set that needs to be checked + * + * @return True if labels are all zero or one, false otherwise. + */ + val classificationLabels: RDD[LabeledPoint] => Boolean = { data => + val numInvalid = data.filter(x => x.label != 1.0 && x.label != 0.0).count() + if (numInvalid != 0) { + logError("Classification labels should be 0 or 1. Found " + numInvalid + " invalid labels") + } + numInvalid == 0 + } +} diff --git a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala index e02bd190f6..eff456cad6 100644 --- a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala +++ b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala @@ -1,7 +1,6 @@ package spark.mllib.util import scala.util.Random -import scala.math.signum import spark.{RDD, SparkContext} @@ -30,8 +29,8 @@ object SVMDataGenerator { val sc = new SparkContext(sparkMaster, "SVMGenerator") val globalRnd = new Random(94720) - val trueWeights = new DoubleMatrix(1, nfeatures+1, - Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*) + val trueWeights = new DoubleMatrix(1, nfeatures + 1, + Array.fill[Double](nfeatures + 1)(globalRnd.nextGaussian()):_*) val data: RDD[LabeledPoint] = sc.parallelize(0 until nexamples, parts).map { idx => val rnd = new Random(42 + idx) @@ -39,11 +38,13 @@ object SVMDataGenerator { val x = Array.fill[Double](nfeatures) { rnd.nextDouble() * 2.0 - 1.0 } - val y = signum((new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1) + val yD = (new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1 + val y = if (yD < 0) 0.0 else 1.0 LabeledPoint(y, x) } MLUtils.saveLabeledData(data, outputPath) + sc.stop() } } diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala index 04f631d80f..f392efa405 100644 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala @@ -50,11 +50,9 @@ object SVMSuite { val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextGaussian())) val y = x.map { xi => - signum( - (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + - intercept + - 0.1 * rnd.nextGaussian() - ).toInt + val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + + intercept + 0.01 * rnd.nextGaussian() + if (yD < 0) 0.0 else 1.0 } y.zip(x).map(p => LabeledPoint(p._1, p._2)) } @@ -100,7 +98,7 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { val model = svm.run(testRDD) val validationData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 17) - val validationRDD = sc.parallelize(validationData,2) + val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) @@ -139,4 +137,27 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { // Test prediction on Array. validatePrediction(validationData.map(row => model.predict(row.features)), validationData) } + + test("SVM with invalid labels") { + val nPoints = 10000 + + val A = 2.0 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B,C), nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + + val testRDDInvalid = testRDD.map { lp => + if (lp.label == 0.0) { + LabeledPoint(-1.0, lp.features) + } else { + lp + } + } + + intercept[spark.SparkException] { + val model = SVMWithSGD.train(testRDDInvalid, 100) + } + } } |