diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-13 13:57:06 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-13 13:57:06 -0700 |
commit | 0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc (patch) | |
tree | 48429251ceab84a3eb83127a8e254455a445f150 | |
parent | 654087194d232221dfb64ba646c8a8e12649f961 (diff) | |
download | spark-0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc.tar.gz spark-0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc.tar.bz2 spark-0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc.zip |
Fix SVM model and unit test to work with {0,1}.
Also rename validateFuncs to validators.
5 files changed, 18 insertions, 12 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala index 24f9f4e76b..7f0b1ba841 100644 --- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala @@ -65,7 +65,7 @@ class LogisticRegressionWithSGD private ( .setNumIterations(numIterations) .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override val validateFuncs = List(DataValidators.classificationLabels) + override val validators = List(DataValidators.classificationLabels) /** * Construct a LogisticRegression object with default parameters diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala index d2b50f4987..b680d81e86 100644 --- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala @@ -41,7 +41,8 @@ class SVMModel( override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix, intercept: Double) = { - signum(dataMatrix.dot(weightMatrix) + intercept) + val margin = dataMatrix.dot(weightMatrix) + intercept + if (margin < 0) 0.0 else 1.0 } } @@ -65,7 +66,7 @@ class SVMWithSGD private ( .setRegParam(regParam) .setMiniBatchFraction(miniBatchFraction) - override val validateFuncs = List(DataValidators.classificationLabels) + override val validators = List(DataValidators.classificationLabels) /** * Construct a SVM object with default parameters diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala index 58bfe3f37b..05568f55af 100644 --- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala +++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala @@ -84,13 +84,15 @@ class HingeGradient extends Gradient { (DoubleMatrix, Double) = { val dotProduct = data.dot(weights) - val labelScaled = 2*label - 1.0 // Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x))) // Therefore the gradient is -(2y - 1)*x - if (1.0 > labelScaled * dotProduct) + val labelScaled = 2 * label - 1.0 + + if (1.0 > labelScaled * dotProduct) { (data.mul(-labelScaled), 1.0 - labelScaled * dotProduct) - else + } else { (DoubleMatrix.zeros(1, weights.length), 0.0) + } } } diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala index 55edb3def5..03f991df39 100644 --- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala +++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala @@ -83,7 +83,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept: abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] extends Logging with Serializable { - protected val validateFuncs: Seq[RDD[LabeledPoint] => Boolean] = List() + protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List() val optimizer: Optimizer @@ -119,7 +119,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel] def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = { // Check the data properties before running the optimizer - if (!validateFuncs.forall(func => func(input))) { + if (!validators.forall(func => func(input))) { throw new SparkException("Input validation failed.") } diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala index f392efa405..8fa9e4639b 100644 --- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala @@ -48,7 +48,7 @@ object SVMSuite { val rnd = new Random(seed) val weightsMat = new DoubleMatrix(1, weights.length, weights:_*) val x = Array.fill[Array[Double]](nPoints)( - Array.fill[Double](weights.length)(rnd.nextGaussian())) + Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) val y = x.map { xi => val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) + intercept + 0.01 * rnd.nextGaussian() @@ -83,7 +83,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { test("SVM using local random SGD") { val nPoints = 10000 - val A = 2.0 + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 val B = -1.5 val C = 1.0 @@ -110,7 +111,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { test("SVM local random SGD with initial weights") { val nPoints = 10000 - val A = 2.0 + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 val B = -1.5 val C = 1.0 @@ -141,7 +143,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll { test("SVM with invalid labels") { val nPoints = 10000 - val A = 2.0 + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 val B = -1.5 val C = 1.0 |