aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXinghao <pxinghao@gmail.com>2013-07-28 21:09:56 -0700
committerXinghao <pxinghao@gmail.com>2013-07-28 21:09:56 -0700
commit67de051bbb81096dc37ea6f92a82a9224b4af61e (patch)
tree712a9a11dfc1398515e5128397e83fc41cac2fc3 /mllib
parent29e042940ac79e42e2f8818ceda6a962a76948ac (diff)
downloadspark-67de051bbb81096dc37ea6f92a82a9224b4af61e.tar.gz
spark-67de051bbb81096dc37ea6f92a82a9224b4af61e.tar.bz2
spark-67de051bbb81096dc37ea6f92a82a9224b4af61e.zip
SVMSuite and LassoSuite rewritten to follow closely with LogisticRegressionSuite
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala99
-rw-r--r--mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala97
2 files changed, 161 insertions, 35 deletions
diff --git a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
index 0d781c310c..2a23825acc 100644
--- a/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/classification/SVMSuite.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package spark.mllib.classification
import scala.util.Random
@@ -7,7 +24,6 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
-import spark.SparkContext._
import java.io._
@@ -19,43 +35,82 @@ class SVMSuite extends FunSuite with BeforeAndAfterAll {
System.clearProperty("spark.driver.port")
}
+ // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
+ def generateSVMInput(
+ intercept: Double,
+ weights: Array[Double],
+ nPoints: Int,
+ seed: Int): Seq[(Double, Array[Double])] = {
+ val rnd = new Random(seed)
+ val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
+ val y = x.map(xi =>
+ signum((xi zip weights).map(xw => xw._1*xw._2).reduce(_+_) + intercept + 0.1 * rnd.nextGaussian())
+ )
+ y zip x
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ math.abs(prediction - expected) > 0.5
+ }.size
+ // At least 80% of the predictions should be on.
+ assert(numOffPredictions < input.length / 5)
+ }
+
test("SVMLocalRandomSGD") {
val nPoints = 10000
- val rnd = new Random(42)
- val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
- val x2 = Array.fill[Double](nPoints)(rnd.nextGaussian())
+ val A = 2.0
+ val B = -1.5
+ val C = 1.0
+
+ val testData = generateSVMInput(A, Array[Double](B,C), nPoints, 42)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val svm = new SVMLocalRandomSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
+
+ val model = svm.train(testRDD)
+
+ val validationData = generateSVMInput(A, Array[Double](B,C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData,2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ }
+
+ test("SVMLocalRandomSGD with initial weights") {
+ val nPoints = 10000
val A = 2.0
val B = -1.5
val C = 1.0
- val y = (0 until nPoints).map { i =>
- signum(A + B * x1(i) + C * x2(i) + 0.0*rnd.nextGaussian())
- }
+ val testData = generateSVMInput(A, Array[Double](B,C), nPoints, 42)
- val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i),x2(i)))).toArray
+ val initialB = -1.0
+ val initialC = -1.0
+ val initialWeights = Array(initialB,initialC)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val writer_data = new PrintWriter(new File("svmtest.dat"))
- testData.foreach(yx => {
- writer_data.write(yx._1 + "")
- yx._2.foreach(xi => writer_data.write("\t" + xi))
- writer_data.write("\n")})
- writer_data.close()
+ val svm = new SVMLocalRandomSGD().setStepSize(1.0).setRegParam(1.0).setNumIterations(100)
- val svm = new SVMLocalRandomSGD().setStepSize(1.0)
- .setRegParam(1.0)
- .setNumIterations(100)
-
- val model = svm.train(testRDD)
+ val model = svm.train(testRDD, initialWeights)
- val yPredict = (0 until nPoints).map(i => model.predict(Array(x1(i),x2(i))))
+ val validationData = generateSVMInput(A, Array[Double](B,C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData,2)
- val accuracy = ((y zip yPredict).map(yy => if (yy._1==yy._2) 1 else 0).reduceLeft(_+_).toDouble / nPoints.toDouble)
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
- assert(accuracy >= 0.90, "Accuracy (" + accuracy + ") too low")
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
}
}
diff --git a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
index 0c39e1e09b..33e87dfd9f 100644
--- a/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/spark/mllib/regression/LassoSuite.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package spark.mllib.regression
import scala.util.Random
@@ -6,7 +23,6 @@ import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import spark.SparkContext
-import spark.SparkContext._
class LassoSuite extends FunSuite with BeforeAndAfterAll {
@@ -17,35 +33,90 @@ class LassoSuite extends FunSuite with BeforeAndAfterAll {
System.clearProperty("spark.driver.port")
}
+ // Generate noisy input of the form Y = x.dot(weights) + intercept + noise
+ def generateLassoInput(
+ intercept: Double,
+ weights: Array[Double],
+ nPoints: Int,
+ seed: Int): Seq[(Double, Array[Double])] = {
+ val rnd = new Random(seed)
+ val x = Array.fill[Array[Double]](nPoints)(Array.fill[Double](weights.length)(rnd.nextGaussian()))
+ val y = x.map(xi => (xi zip weights).map(xw => xw._1*xw._2).reduce(_+_) + intercept + 0.1 * rnd.nextGaussian())
+ y zip x
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[(Double, Array[Double])]) {
+ val numOffPredictions = predictions.zip(input).filter { case (prediction, (expected, _)) =>
+ // A prediction is off if the prediction is more than 0.5 away from expected value.
+ math.abs(prediction - expected) > 0.5
+ }.size
+ // At least 80% of the predictions should be on.
+ assert(numOffPredictions < input.length / 5)
+ }
+
test("LassoLocalRandomSGD") {
val nPoints = 10000
- val rnd = new Random(42)
-
- val x1 = Array.fill[Double](nPoints)(rnd.nextGaussian())
- val x2 = Array.fill[Double](nPoints)(rnd.nextGaussian())
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val y = (0 until nPoints).map { i =>
- A + B * x1(i) + C * x2(i) + 0.1*rnd.nextGaussian()
- }
-
- val testData = (0 until nPoints).map(i => (y(i).toDouble, Array(x1(i),x2(i)))).toArray
+ val testData = generateLassoInput(A, Array[Double](B,C), nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val ls = new LassoLocalRandomSGD().setStepSize(1.0)
- .setRegParam(0.01)
- .setNumIterations(20)
+ val ls = new LassoLocalRandomSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
val model = ls.train(testRDD)
val weight0 = model.weights(0)
val weight1 = model.weights(1)
+ assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+
+ val validationData = generateLassoInput(A, Array[Double](B,C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData,2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
+ }
+
+ test("LassoLocalRandomSGD with initial weights") {
+ val nPoints = 10000
+
+ val A = 2.0
+ val B = -1.5
+ val C = 1.0e-2
+
+ val testData = generateLassoInput(A, Array[Double](B,C), nPoints, 42)
+
+ val initialB = -1.0
+ val initialC = -1.0
+ val initialWeights = Array(initialB,initialC)
+
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+ val ls = new LassoLocalRandomSGD().setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+
+ val model = ls.train(testRDD, initialWeights)
+
+ val weight0 = model.weights(0)
+ val weight1 = model.weights(1)
assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
+ assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
+ assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+
+ val validationData = generateLassoInput(A, Array[Double](B,C), nPoints, 17)
+ val validationRDD = sc.parallelize(validationData,2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_._2)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row._2)), validationData)
}
}