diff options
author | Andrew Tulloch <andrew@tullo.ch> | 2014-05-13 17:31:27 -0700 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2014-05-13 17:31:27 -0700 |
commit | d1e487473fd509f28daf28dcda856f3c2f1194ec (patch) | |
tree | 838c727f8ebd528df1f27207e946ed22ce93b861 | |
parent | 16ffadcc4af21430b5079dc555bcd9d8cf1fa1fa (diff) | |
download | spark-d1e487473fd509f28daf28dcda856f3c2f1194ec.tar.gz spark-d1e487473fd509f28daf28dcda856f3c2f1194ec.tar.bz2 spark-d1e487473fd509f28daf28dcda856f3c2f1194ec.zip |
SPARK-1791 - SVM implementation does not use threshold parameter
Summary:
https://issues.apache.org/jira/browse/SPARK-1791
Simple fix, and backward compatible, since
- anyone who set the threshold was getting completely wrong answers.
- anyone who did not set the threshold had the default 0.0 value for the threshold anyway.
Test Plan:
Unit test added that is verified to fail under the old implementation,
and pass under the new implementation.
Reviewers:
CC:
Author: Andrew Tulloch <andrew@tullo.ch>
Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits:
770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala | 2 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala | 37 |
2 files changed, 38 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala index e05213536e..316ecd713b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala @@ -65,7 +65,7 @@ class SVMModel private[mllib] ( intercept: Double) = { val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept threshold match { - case Some(t) => if (margin < 0) 0.0 else 1.0 + case Some(t) => if (margin < t) 0.0 else 1.0 case None => margin } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala index 77d6f04b32..886c71dde3 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala @@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext { assert(numOffPredictions < input.length / 5) } + test("SVM with threshold") { + val nPoints = 10000 + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + + val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42) + + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val svm = new SVMWithSGD().setIntercept(true) + svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100) + + val model = svm.run(testRDD) + + val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + + var predictions = model.predict(validationRDD.map(_.features)).collect() + assert(predictions.count(_ == 0.0) != predictions.length) + + // High threshold makes all the predictions 0.0 + model.setThreshold(10000.0) + predictions = model.predict(validationRDD.map(_.features)).collect() + assert(predictions.count(_ == 0.0) == predictions.length) + + // Low threshold makes all the predictions 1.0 + model.setThreshold(-10000.0) + predictions = model.predict(validationRDD.map(_.features)).collect() + assert(predictions.count(_ == 1.0) == predictions.length) + } + test("SVM using local random SGD") { val nPoints = 10000 |