aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-13 13:57:06 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-13 13:57:06 -0700
commit0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc (patch)
tree48429251ceab84a3eb83127a8e254455a445f150 /mllib
parent654087194d232221dfb64ba646c8a8e12649f961 (diff)
downloadspark-0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc.tar.gz
spark-0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc.tar.bz2
spark-0ab6ff4c3252e7cb9ea573e09d9188da1fcb87cc.zip
Fix SVM model and unit test to work with {0,1}.
Also rename validateFuncs to validators.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/spark/mllib/classification/SVM.scala5
-rw-r--r--mllib/src/main/scala/spark/mllib/optimization/Gradient.scala8
-rw-r--r--mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala4
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala11
5 files changed, 18 insertions, 12 deletions
diff --git a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
index 24f9f4e76b..7f0b1ba841 100644
--- a/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/LogisticRegression.scala
@@ -65,7 +65,7 @@ class LogisticRegressionWithSGD private (
.setNumIterations(numIterations)
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
- override val validateFuncs = List(DataValidators.classificationLabels)
+ override val validators = List(DataValidators.classificationLabels)
/**
* Construct a LogisticRegression object with default parameters
diff --git a/mllib/src/main/scala/spark/mllib/classification/SVM.scala b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
index d2b50f4987..b680d81e86 100644
--- a/mllib/src/main/scala/spark/mllib/classification/SVM.scala
+++ b/mllib/src/main/scala/spark/mllib/classification/SVM.scala
@@ -41,7 +41,8 @@ class SVMModel(
override def predictPoint(dataMatrix: DoubleMatrix, weightMatrix: DoubleMatrix,
intercept: Double) = {
- signum(dataMatrix.dot(weightMatrix) + intercept)
+ val margin = dataMatrix.dot(weightMatrix) + intercept
+ if (margin < 0) 0.0 else 1.0
}
}
@@ -65,7 +66,7 @@ class SVMWithSGD private (
.setRegParam(regParam)
.setMiniBatchFraction(miniBatchFraction)
- override val validateFuncs = List(DataValidators.classificationLabels)
+ override val validators = List(DataValidators.classificationLabels)
/**
* Construct a SVM object with default parameters
diff --git a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
index 58bfe3f37b..05568f55af 100644
--- a/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
+++ b/mllib/src/main/scala/spark/mllib/optimization/Gradient.scala
@@ -84,13 +84,15 @@ class HingeGradient extends Gradient {
(DoubleMatrix, Double) = {
val dotProduct = data.dot(weights)
- val labelScaled = 2*label - 1.0
// Our loss function with {0, 1} labels is max(0, 1 - (2y – 1) (f_w(x)))
// Therefore the gradient is -(2y - 1)*x
- if (1.0 > labelScaled * dotProduct)
+ val labelScaled = 2 * label - 1.0
+
+ if (1.0 > labelScaled * dotProduct) {
(data.mul(-labelScaled), 1.0 - labelScaled * dotProduct)
- else
+ } else {
(DoubleMatrix.zeros(1, weights.length), 0.0)
+ }
}
}
diff --git a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index 55edb3def5..03f991df39 100644
--- a/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -83,7 +83,7 @@ abstract class GeneralizedLinearModel(val weights: Array[Double], val intercept:
abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
extends Logging with Serializable {
- protected val validateFuncs: Seq[RDD[LabeledPoint] => Boolean] = List()
+ protected val validators: Seq[RDD[LabeledPoint] => Boolean] = List()
val optimizer: Optimizer
@@ -119,7 +119,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
def run(input: RDD[LabeledPoint], initialWeights: Array[Double]) : M = {
// Check the data properties before running the optimizer
- if (!validateFuncs.forall(func => func(input))) {
+ if (!validators.forall(func => func(input))) {
throw new SparkException("Input validation failed.")
}
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index f392efa405..8fa9e4639b 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -48,7 +48,7 @@ object SVMSuite {
val rnd = new Random(seed)
val weightsMat = new DoubleMatrix(1, weights.length, weights:_*)
val x = Array.fill[Array[Double]](nPoints)(
- Array.fill[Double](weights.length)(rnd.nextGaussian()))
+ Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0))
val y = x.map { xi =>
val yD = (new DoubleMatrix(1, xi.length, xi:_*)).dot(weightsMat) +
intercept + 0.01 * rnd.nextGaussian()
@@ -83,7 +83,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM using local random SGD") {
val nPoints = 10000
- val A = 2.0
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
val B = -1.5
val C = 1.0
@@ -110,7 +111,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM local random SGD with initial weights") {
val nPoints = 10000
- val A = 2.0
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
val B = -1.5
val C = 1.0
@@ -141,7 +143,8 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
test("SVM with invalid labels") {
val nPoints = 10000
- val A = 2.0
+ // NOTE: Intercept should be small for generating equal 0s and 1s
+ val A = 0.01
val B = -1.5
val C = 1.0