aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAndrew Tulloch <andrew@tullo.ch>2014-05-13 17:31:27 -0700
committerReynold Xin <rxin@apache.org>2014-05-13 17:31:27 -0700
commitd1e487473fd509f28daf28dcda856f3c2f1194ec (patch)
tree838c727f8ebd528df1f27207e946ed22ce93b861 /mllib
parent16ffadcc4af21430b5079dc555bcd9d8cf1fa1fa (diff)
downloadspark-d1e487473fd509f28daf28dcda856f3c2f1194ec.tar.gz
spark-d1e487473fd509f28daf28dcda856f3c2f1194ec.tar.bz2
spark-d1e487473fd509f28daf28dcda856f3c2f1194ec.zip
SPARK-1791 - SVM implementation does not use threshold parameter
Summary: https://issues.apache.org/jira/browse/SPARK-1791 Simple fix, and backward compatible, since - anyone who set the threshold was getting completely wrong answers. - anyone who did not set the threshold had the default 0.0 value for the threshold anyway. Test Plan: Unit test added that is verified to fail under the old implementation, and pass under the new implementation. Reviewers: CC: Author: Andrew Tulloch <andrew@tullo.ch> Closes #725 from ajtulloch/SPARK-1791-SVM and squashes the following commits: 770f55d [Andrew Tulloch] SPARK-1791 - SVM implementation does not use threshold parameter
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala37
2 files changed, 38 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index e05213536e..316ecd713b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,7 +65,7 @@ class SVMModel private[mllib] (
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match {
- case Some(t) => if (margin < 0) 0.0 else 1.0
+ case Some(t) => if (margin < t) 0.0 else 1.0
case None => margin
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 77d6f04b32..886c71dde3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
assert(numOffPredictions < input.length / 5)
}
+ test("SVM with threshold") {
+ val nPoints = 10000
+
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
+ val B = -1.5
+ val C = 1.0
+
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val svm = new SVMWithSGD().setIntercept(true)
+ svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+
+ val model = svm.run(testRDD)
+
+ val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+
+ var predictions = model.predict(validationRDD.map(_.features)).collect()
+ assert(predictions.count(_ == 0.0) != predictions.length)
+
+ // High threshold makes all the predictions 0.0
+ model.setThreshold(10000.0)
+ predictions = model.predict(validationRDD.map(_.features)).collect()
+ assert(predictions.count(_ == 0.0) == predictions.length)
+
+ // Low threshold makes all the predictions 1.0
+ model.setThreshold(-10000.0)
+ predictions = model.predict(validationRDD.map(_.features)).collect()
+ assert(predictions.count(_ == 1.0) == predictions.length)
+ }
+
test("SVM using local random SGD") {
val nPoints = 10000