aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala37
2 files changed, 38 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
index e05213536e..316ecd713b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/SVM.scala
@@ -65,7 +65,7 @@ class SVMModel private[mllib] (
intercept: Double) = {
val margin = weightMatrix.toBreeze.dot(dataMatrix.toBreeze) + intercept
threshold match {
- case Some(t) => if (margin < 0) 0.0 else 1.0
+ case Some(t) => if (margin < t) 0.0 else 1.0
case None => margin
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index 77d6f04b32..886c71dde3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -69,6 +69,43 @@ class SVMSuite extends FunSuite with LocalSparkContext {
assert(numOffPredictions < input.length / 5)
}
+ test("SVM with threshold") {
+ val nPoints = 10000
+
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
+ val B = -1.5
+ val C = 1.0
+
+ val testData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 42)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val svm = new SVMWithSGD().setIntercept(true)
+ svm.optimizer.setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+
+ val model = svm.run(testRDD)
+
+ val validationData = SVMSuite.generateSVMInput(A, Array[Double](B, C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+
+ var predictions = model.predict(validationRDD.map(_.features)).collect()
+ assert(predictions.count(_ == 0.0) != predictions.length)
+
+ // High threshold makes all the predictions 0.0
+ model.setThreshold(10000.0)
+ predictions = model.predict(validationRDD.map(_.features)).collect()
+ assert(predictions.count(_ == 0.0) == predictions.length)
+
+ // Low threshold makes all the predictions 1.0
+ model.setThreshold(-10000.0)
+ predictions = model.predict(validationRDD.map(_.features)).collect()
+ assert(predictions.count(_ == 1.0) == predictions.length)
+ }
+
test("SVM using local random SGD") {
val nPoints = 10000