aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-04-02 14:01:12 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-02 14:01:12 -0700
commit9c65fa76f9d413e311a80f29d35d3ff7722e9476 (patch)
treeb98bac526f3d1bb5954c187745a7e4112b0fbf05 /mllib/src/test
parented730c95026d322f4b24d3d9fe92050ffa74cf4a (diff)
downloadspark-9c65fa76f9d413e311a80f29d35d3ff7722e9476.tar.gz
spark-9c65fa76f9d413e311a80f29d35d3ff7722e9476.tar.bz2
spark-9c65fa76f9d413e311a80f29d35d3ff7722e9476.zip
[SPARK-1212, Part II] Support sparse data in MLlib
In PR https://github.com/apache/spark/pull/117, we added dense/sparse vector data model and updated KMeans to support sparse input. This PR is to replace all other `Array[Double]` usage by `Vector` in generalized linear models (GLMs) and Naive Bayes. Major changes: 1. `LabeledPoint` becomes `LabeledPoint(Double, Vector)`. 2. Methods that accept `RDD[Array[Double]]` now accept `RDD[Vector]`. We cannot support both in an elegant way because of type erasure. 3. Mark 'createModel' and 'predictPoint' protected because they are not for end users. 4. Add libSVMFile to MLContext. 5. NaiveBayes can accept arbitrary labels (introducing a breaking change to Python's `NaiveBayesModel`). 6. Gradient computation no longer creates temp vectors. 7. Column normalization and centering are removed from Lasso and Ridge because the operation will densify the data. Simple feature transformation can be done before training. TODO: 1. ~~Use axpy when possible.~~ 2. ~~Optimize Naive Bayes.~~ Author: Xiangrui Meng <meng@databricks.com> Closes #245 from mengxr/vector and squashes the following commits: eb6e793 [Xiangrui Meng] move libSVMFile to MLUtils and rename to loadLibSVMData c26c4fc [Xiangrui Meng] update DecisionTree to use RDD[Vector] 11999c7 [Xiangrui Meng] Merge branch 'master' into vector f7da54b [Xiangrui Meng] add minSplits to libSVMFile da25e24 [Xiangrui Meng] revert the change to default addIntercept because it might change the behavior of existing code without warning 493f26f [Xiangrui Meng] Merge branch 'master' into vector 7c1bc01 [Xiangrui Meng] add a TODO to NB b9b7ef7 [Xiangrui Meng] change default value of addIntercept to false b01df54 [Xiangrui Meng] allow to change or clear threshold in LR and SVM 4addc50 [Xiangrui Meng] merge master 4ca5b1b [Xiangrui Meng] remove normalization from Lasso and update tests f04fe8a [Xiangrui Meng] remove normalization from RidgeRegression and update tests d088552 [Xiangrui Meng] use static constructor for MLContext 6f59eed [Xiangrui Meng] update libSVMFile to determine number of features automatically 3432e84 [Xiangrui Meng] update NaiveBayes to support sparse data 0f8759b [Xiangrui Meng] minor updates to NB b11659c [Xiangrui Meng] style update 78c4671 [Xiangrui Meng] add libSVMFile to MLContext f0fe616 [Xiangrui Meng] add a test for sparse linear regression 44733e1 [Xiangrui Meng] use in-place gradient computation e981396 [Xiangrui Meng] use axpy in Updater db808a1 [Xiangrui Meng] update JavaLR example befa592 [Xiangrui Meng] passed scala/java tests 75c83a4 [Xiangrui Meng] passed test compile 1859701 [Xiangrui Meng] passed compile 834ada2 [Xiangrui Meng] optimized MLUtils.computeStats update some ml algorithms to use Vector (cont.) 135ab72 [Xiangrui Meng] merge glm 0e57aa4 [Xiangrui Meng] update Lasso and RidgeRegression to parse the weights correctly from GLM mark createModel protected mark predictPoint protected d7f629f [Xiangrui Meng] fix a bug in GLM when intercept is not used 3f346ba [Xiangrui Meng] update some ml algorithms to use Vector
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java13
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java3
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java6
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java4
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java38
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala7
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala51
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala54
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala27
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala59
14 files changed, 200 insertions, 99 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
index 073ded6f36..c80b1134ed 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaNaiveBayesSuite.java
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.classification;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.junit.After;
import org.junit.Assert;
@@ -45,12 +46,12 @@ public class JavaNaiveBayesSuite implements Serializable {
}
private static final List<LabeledPoint> POINTS = Arrays.asList(
- new LabeledPoint(0, new double[] {1.0, 0.0, 0.0}),
- new LabeledPoint(0, new double[] {2.0, 0.0, 0.0}),
- new LabeledPoint(1, new double[] {0.0, 1.0, 0.0}),
- new LabeledPoint(1, new double[] {0.0, 2.0, 0.0}),
- new LabeledPoint(2, new double[] {0.0, 0.0, 1.0}),
- new LabeledPoint(2, new double[] {0.0, 0.0, 2.0})
+ new LabeledPoint(0, Vectors.dense(1.0, 0.0, 0.0)),
+ new LabeledPoint(0, Vectors.dense(2.0, 0.0, 0.0)),
+ new LabeledPoint(1, Vectors.dense(0.0, 1.0, 0.0)),
+ new LabeledPoint(1, Vectors.dense(0.0, 2.0, 0.0)),
+ new LabeledPoint(2, Vectors.dense(0.0, 0.0, 1.0)),
+ new LabeledPoint(2, Vectors.dense(0.0, 0.0, 2.0))
);
private int validatePrediction(List<LabeledPoint> points, NaiveBayesModel model) {
diff --git a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
index 117e5eaa8b..4701a5e545 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/classification/JavaSVMSuite.java
@@ -17,7 +17,6 @@
package org.apache.spark.mllib.classification;
-
import java.io.Serializable;
import java.util.List;
@@ -28,7 +27,6 @@ import org.junit.Test;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-
import org.apache.spark.mllib.regression.LabeledPoint;
public class JavaSVMSuite implements Serializable {
@@ -94,5 +92,4 @@ public class JavaSVMSuite implements Serializable {
int numAccurate = validatePrediction(validationData, model);
Assert.assertTrue(numAccurate > nPoints * 4.0 / 5.0);
}
-
}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
index 2c4d795f96..c6d8425ffc 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/linalg/JavaVectorsSuite.java
@@ -19,10 +19,10 @@ package org.apache.spark.mllib.linalg;
import java.io.Serializable;
-import com.google.common.collect.Lists;
-
import scala.Tuple2;
+import com.google.common.collect.Lists;
+
import org.junit.Test;
import static org.junit.Assert.*;
@@ -36,7 +36,7 @@ public class JavaVectorsSuite implements Serializable {
@Test
public void sparseArrayConstruction() {
- Vector v = Vectors.sparse(3, Lists.newArrayList(
+ Vector v = Vectors.sparse(3, Lists.<Tuple2<Integer, Double>>newArrayList(
new Tuple2<Integer, Double>(0, 2.0),
new Tuple2<Integer, Double>(2, 3.0)));
assertArrayEquals(new double[]{2.0, 0.0, 3.0}, v.toArray(), 0.0);
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
index f44b25cd44..f725924a2d 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaLassoSuite.java
@@ -59,7 +59,7 @@ public class JavaLassoSuite implements Serializable {
@Test
public void runLassoUsingConstructor() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
@@ -80,7 +80,7 @@ public class JavaLassoSuite implements Serializable {
@Test
public void runLassoUsingStaticMethods() {
int nPoints = 10000;
- double A = 2.0;
+ double A = 0.0;
double[] weights = {-1.5, 1.0e-2};
JavaRDD<LabeledPoint> testRDD = sc.parallelize(LinearDataGenerator.generateLinearInputAsList(A,
diff --git a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
index 2fdd5fc8fd..03714ae7e4 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/regression/JavaRidgeRegressionSuite.java
@@ -55,30 +55,27 @@ public class JavaRidgeRegressionSuite implements Serializable {
return errorSum / validationData.size();
}
- List<LabeledPoint> generateRidgeData(int numPoints, int nfeatures, double eps) {
+ List<LabeledPoint> generateRidgeData(int numPoints, int numFeatures, double std) {
org.jblas.util.Random.seed(42);
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
- DoubleMatrix w = DoubleMatrix.rand(nfeatures, 1).subi(0.5);
- // Set first two weights to eps
- w.put(0, 0, eps);
- w.put(1, 0, eps);
- return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, eps);
+ DoubleMatrix w = DoubleMatrix.rand(numFeatures, 1).subi(0.5);
+ return LinearDataGenerator.generateLinearInputAsList(0.0, w.data, numPoints, 42, std);
}
@Test
public void runRidgeRegressionUsingConstructor() {
- int nexamples = 200;
- int nfeatures = 20;
- double eps = 10.0;
- List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
+ int numExamples = 50;
+ int numFeatures = 20;
+ List<LabeledPoint> data = generateRidgeData(2*numExamples, numFeatures, 10.0);
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
- List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
+ List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionWithSGD ridgeSGDImpl = new RidgeRegressionWithSGD();
- ridgeSGDImpl.optimizer().setStepSize(1.0)
- .setRegParam(0.0)
- .setNumIterations(200);
+ ridgeSGDImpl.optimizer()
+ .setStepSize(1.0)
+ .setRegParam(0.0)
+ .setNumIterations(200);
RidgeRegressionModel model = ridgeSGDImpl.run(testRDD.rdd());
double unRegularizedErr = predictionError(validationData, model);
@@ -91,13 +88,12 @@ public class JavaRidgeRegressionSuite implements Serializable {
@Test
public void runRidgeRegressionUsingStaticMethods() {
- int nexamples = 200;
- int nfeatures = 20;
- double eps = 10.0;
- List<LabeledPoint> data = generateRidgeData(2*nexamples, nfeatures, eps);
+ int numExamples = 50;
+ int numFeatures = 20;
+ List<LabeledPoint> data = generateRidgeData(2 * numExamples, numFeatures, 10.0);
- JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, nexamples));
- List<LabeledPoint> validationData = data.subList(nexamples, 2*nexamples);
+ JavaRDD<LabeledPoint> testRDD = sc.parallelize(data.subList(0, numExamples));
+ List<LabeledPoint> validationData = data.subList(numExamples, 2 * numExamples);
RidgeRegressionModel model = RidgeRegressionWithSGD.train(testRDD.rdd(), 200, 1.0, 0.0);
double unRegularizedErr = predictionError(validationData, model);
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 05322b024d..1e03c9df82 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -20,11 +20,10 @@ package org.apache.spark.mllib.classification
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.SparkContext
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
@@ -61,7 +60,7 @@ object LogisticRegressionSuite {
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
+ val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(Array(x1(i)))))
testData
}
@@ -113,7 +112,7 @@ class LogisticRegressionSuite extends FunSuite with LocalSparkContext with Shoul
val testData = LogisticRegressionSuite.generateLogisticInput(A, B, nPoints, 42)
val initialB = -1.0
- val initialWeights = Array(initialB)
+ val initialWeights = Vectors.dense(initialB)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 9dd6c79ee6..516895d042 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,9 +19,9 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.LocalSparkContext
@@ -54,7 +54,7 @@ object NaiveBayesSuite {
if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
}
- LabeledPoint(y, xi)
+ LabeledPoint(y, Vectors.dense(xi))
}
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
index bc7abb568a..dfacbfeee6 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/SVMSuite.scala
@@ -20,7 +20,6 @@ package org.apache.spark.mllib.classification
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.jblas.DoubleMatrix
@@ -28,6 +27,7 @@ import org.jblas.DoubleMatrix
import org.apache.spark.SparkException
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
object SVMSuite {
@@ -54,7 +54,7 @@ object SVMSuite {
intercept + 0.01 * rnd.nextGaussian()
if (yD < 0) 0.0 else 1.0
}
- y.zip(x).map(p => LabeledPoint(p._1, p._2))
+ y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
}
@@ -110,7 +110,7 @@ class SVMSuite extends FunSuite with LocalSparkContext {
val initialB = -1.0
val initialC = -1.0
- val initialWeights = Array(initialB,initialC)
+ val initialWeights = Vectors.dense(initialB, initialC)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
@@ -150,10 +150,10 @@ class SVMSuite extends FunSuite with LocalSparkContext {
}
intercept[SparkException] {
- val model = SVMWithSGD.train(testRDDInvalid, 100)
+ SVMWithSGD.train(testRDDInvalid, 100)
}
// Turning off data validation should not throw an exception
- val noValidationModel = new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
+ new SVMWithSGD().setValidateData(false).run(testRDDInvalid)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
index 631d0e2ad9..c4b433499a 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/optimization/GradientDescentSuite.scala
@@ -20,13 +20,12 @@ package org.apache.spark.mllib.optimization
import scala.util.Random
import scala.collection.JavaConversions._
-import org.scalatest.BeforeAndAfterAll
import org.scalatest.FunSuite
import org.scalatest.matchers.ShouldMatchers
-import org.apache.spark.SparkContext
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.mllib.linalg.Vectors
object GradientDescentSuite {
@@ -58,8 +57,7 @@ object GradientDescentSuite {
if (yVal > 0) 1 else 0
}
- val testData = (0 until nPoints).map(i => LabeledPoint(y(i), Array(x1(i))))
- testData
+ (0 until nPoints).map(i => LabeledPoint(y(i), Vectors.dense(x1(i))))
}
}
@@ -83,11 +81,11 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(A, B, nPoints, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Array(1.0, features: _*)
+ label -> Vectors.dense(1.0, features.toArray: _*)
}
val dataRDD = sc.parallelize(data, 2).cache()
- val initialWeightsWithIntercept = Array(1.0, initialWeights: _*)
+ val initialWeightsWithIntercept = Vectors.dense(1.0, initialWeights: _*)
val (_, loss) = GradientDescent.runMiniBatchSGD(
dataRDD,
@@ -113,13 +111,13 @@ class GradientDescentSuite extends FunSuite with LocalSparkContext with ShouldMa
// Add a extra variable consisting of all 1.0's for the intercept.
val testData = GradientDescentSuite.generateGDInput(2.0, -1.5, 10000, 42)
val data = testData.map { case LabeledPoint(label, features) =>
- label -> Array(1.0, features: _*)
+ label -> Vectors.dense(1.0, features.toArray: _*)
}
val dataRDD = sc.parallelize(data, 2).cache()
// Prepare non-zero weights
- val initialWeightsWithIntercept = Array(1.0, 0.5)
+ val initialWeightsWithIntercept = Vectors.dense(1.0, 0.5)
val regParam0 = 0
val (newWeights0, loss0) = GradientDescent.runMiniBatchSGD(
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
index 2cebac943e..6aad9eb84e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LassoSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LassoSuite extends FunSuite with LocalSparkContext {
@@ -33,29 +34,33 @@ class LassoSuite extends FunSuite with LocalSparkContext {
}
test("Lasso local random SGD") {
- val nPoints = 10000
+ val nPoints = 1000
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
-
- val testRDD = sc.parallelize(testData, 2)
- testRDD.cache()
+ val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
+ val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
val model = ls.run(testRDD)
-
val weight0 = model.weights(0)
val weight1 = model.weights(1)
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ val weight2 = model.weights(2)
+ assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
+ assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
+ assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.
@@ -66,33 +71,39 @@ class LassoSuite extends FunSuite with LocalSparkContext {
}
test("Lasso local random SGD with initial weights") {
- val nPoints = 10000
+ val nPoints = 1000
val A = 2.0
val B = -1.5
val C = 1.0e-2
- val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 42)
+ val testData = LinearDataGenerator.generateLinearInput(A, Array[Double](B, C), nPoints, 42)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
+ val initialA = -1.0
val initialB = -1.0
val initialC = -1.0
- val initialWeights = Array(initialB,initialC)
+ val initialWeights = Vectors.dense(initialA, initialB, initialC)
- val testRDD = sc.parallelize(testData, 2)
- testRDD.cache()
+ val testRDD = sc.parallelize(testData, 2).cache()
val ls = new LassoWithSGD()
- ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(20)
+ ls.optimizer.setStepSize(1.0).setRegParam(0.01).setNumIterations(40)
val model = ls.run(testRDD, initialWeights)
-
val weight0 = model.weights(0)
val weight1 = model.weights(1)
- assert(model.intercept >= 1.9 && model.intercept <= 2.1, model.intercept + " not in [1.9, 2.1]")
- assert(weight0 >= -1.60 && weight0 <= -1.40, weight0 + " not in [-1.6, -1.4]")
- assert(weight1 >= -1.0e-3 && weight1 <= 1.0e-3, weight1 + " not in [-0.001, 0.001]")
+ val weight2 = model.weights(2)
+ assert(weight0 >= 1.9 && weight0 <= 2.1, weight0 + " not in [1.9, 2.1]")
+ assert(weight1 >= -1.60 && weight1 <= -1.40, weight1 + " not in [-1.6, -1.4]")
+ assert(weight2 >= -1.0e-3 && weight2 <= 1.0e-3, weight2 + " not in [-0.001, 0.001]")
val validationData = LinearDataGenerator.generateLinearInput(A, Array[Double](B,C), nPoints, 17)
+ .map { case LabeledPoint(label, features) =>
+ LabeledPoint(label, Vectors.dense(1.0 +: features.toArray))
+ }
val validationRDD = sc.parallelize(validationData,2)
// Test prediction on RDD.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
index 5d251bcbf3..2f7d30708c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/LinearRegressionSuite.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.regression
import org.scalatest.FunSuite
+import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class LinearRegressionSuite extends FunSuite with LocalSparkContext {
@@ -40,11 +41,12 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
val model = linReg.run(testRDD)
-
assert(model.intercept >= 2.5 && model.intercept <= 3.5)
- assert(model.weights.length === 2)
- assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
- assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val weights = model.weights
+ assert(weights.size === 2)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(1) >= 9.0 && weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
3.0, Array(10.0, 10.0), 100, 17)
@@ -67,9 +69,11 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
val model = linReg.run(testRDD)
assert(model.intercept === 0.0)
- assert(model.weights.length === 2)
- assert(model.weights(0) >= 9.0 && model.weights(0) <= 11.0)
- assert(model.weights(1) >= 9.0 && model.weights(1) <= 11.0)
+
+ val weights = model.weights
+ assert(weights.size === 2)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(1) >= 9.0 && weights(1) <= 11.0)
val validationData = LinearDataGenerator.generateLinearInput(
0.0, Array(10.0, 10.0), 100, 17)
@@ -81,4 +85,40 @@ class LinearRegressionSuite extends FunSuite with LocalSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
}
+
+ // Test if we can correctly learn Y = 10*X1 + 10*X10000
+ test("sparse linear regression without intercept") {
+ val denseRDD = sc.parallelize(
+ LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 42), 2)
+ val sparseRDD = denseRDD.map { case LabeledPoint(label, v) =>
+ val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
+ LabeledPoint(label, sv)
+ }.cache()
+ val linReg = new LinearRegressionWithSGD().setIntercept(false)
+ linReg.optimizer.setNumIterations(1000).setStepSize(1.0)
+
+ val model = linReg.run(sparseRDD)
+
+ assert(model.intercept === 0.0)
+
+ val weights = model.weights
+ assert(weights.size === 10000)
+ assert(weights(0) >= 9.0 && weights(0) <= 11.0)
+ assert(weights(9999) >= 9.0 && weights(9999) <= 11.0)
+
+ val validationData = LinearDataGenerator.generateLinearInput(0.0, Array(10.0, 10.0), 100, 17)
+ val sparseValidationData = validationData.map { case LabeledPoint(label, v) =>
+ val sv = Vectors.sparse(10000, Seq((0, v(0)), (9999, v(1))))
+ LabeledPoint(label, sv)
+ }
+ val sparseValidationRDD = sc.parallelize(sparseValidationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(
+ model.predict(sparseValidationRDD.map(_.features)).collect(), sparseValidationData)
+
+ // Test prediction on Array.
+ validatePrediction(
+ sparseValidationData.map(row => model.predict(row.features)), sparseValidationData)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
index b2044ed0d8..f66fc6ea6c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/regression/RidgeRegressionSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.mllib.regression
-import org.jblas.DoubleMatrix
import org.scalatest.FunSuite
+import org.jblas.DoubleMatrix
+
import org.apache.spark.mllib.util.{LinearDataGenerator, LocalSparkContext}
class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
@@ -30,22 +31,22 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
}.reduceLeft(_ + _) / predictions.size
}
- test("regularization with skewed weights") {
- val nexamples = 200
- val nfeatures = 20
- val eps = 10
+ test("ridge regression can help avoid overfitting") {
+
+ // For small number of examples and large variance of error distribution,
+ // ridge regression should give smaller generalization error that linear regression.
+
+ val numExamples = 50
+ val numFeatures = 20
org.jblas.util.Random.seed(42)
// Pick weights as random values distributed uniformly in [-0.5, 0.5]
- val w = DoubleMatrix.rand(nfeatures, 1).subi(0.5)
- // Set first two weights to eps
- w.put(0, 0, eps)
- w.put(1, 0, eps)
+ val w = DoubleMatrix.rand(numFeatures, 1).subi(0.5)
// Use half of data for training and other half for validation
- val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2*nexamples, 42, eps)
- val testData = data.take(nexamples)
- val validationData = data.takeRight(nexamples)
+ val data = LinearDataGenerator.generateLinearInput(3.0, w.toArray, 2 * numExamples, 42, 10.0)
+ val testData = data.take(numExamples)
+ val validationData = data.takeRight(numExamples)
val testRDD = sc.parallelize(testData, 2).cache()
val validationRDD = sc.parallelize(validationData, 2).cache()
@@ -67,7 +68,7 @@ class RidgeRegressionSuite extends FunSuite with LocalSparkContext {
val ridgeErr = predictionError(
ridgeModel.predict(validationRDD.map(_.features)).collect(), validationData)
- // Ridge CV-error should be lower than linear regression
+ // Ridge validation error should be lower than linear regression.
assert(ridgeErr < linearErr,
"ridgeError (" + ridgeErr + ") was not less than linearError(" + linearErr + ")")
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 4349c7000a..350130c914 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.mllib.tree.model.Filter
import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.linalg.Vectors
class DecisionTreeSuite extends FunSuite with BeforeAndAfterAll {
@@ -396,7 +397,7 @@ object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel0(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
- val lp = new LabeledPoint(0.0,Array(i.toDouble,1000.0-i))
+ val lp = new LabeledPoint(0.0, Vectors.dense(i.toDouble, 1000.0 - i))
arr(i) = lp
}
arr
@@ -405,7 +406,7 @@ object DecisionTreeSuite {
def generateOrderedLabeledPointsWithLabel1(): Array[LabeledPoint] = {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
- val lp = new LabeledPoint(1.0,Array(i.toDouble,999.0-i))
+ val lp = new LabeledPoint(1.0, Vectors.dense(i.toDouble, 999.0 - i))
arr(i) = lp
}
arr
@@ -415,9 +416,9 @@ object DecisionTreeSuite {
val arr = new Array[LabeledPoint](1000)
for (i <- 0 until 1000){
if (i < 600){
- arr(i) = new LabeledPoint(1.0,Array(0.0,1.0))
+ arr(i) = new LabeledPoint(1.0, Vectors.dense(0.0, 1.0))
} else {
- arr(i) = new LabeledPoint(0.0,Array(1.0,0.0))
+ arr(i) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0))
}
}
arr
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
index 60f053b381..27d41c7869 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/util/MLUtilsSuite.scala
@@ -17,14 +17,20 @@
package org.apache.spark.mllib.util
+import java.io.File
+
import org.scalatest.FunSuite
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, norm => breezeNorm,
squaredDistance => breezeSquaredDistance}
+import com.google.common.base.Charsets
+import com.google.common.io.Files
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLUtils._
-class MLUtilsSuite extends FunSuite {
+class MLUtilsSuite extends FunSuite with LocalSparkContext {
test("epsilon computation") {
assert(1.0 + EPSILON > 1.0, s"EPSILON is too small: $EPSILON.")
@@ -49,4 +55,55 @@ class MLUtilsSuite extends FunSuite {
assert((fastSquaredDist2 - squaredDist) <= precision * squaredDist, s"failed with m = $m")
}
}
+
+ test("compute stats") {
+ val data = Seq.fill(3)(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0, 2.0, 3.0)),
+ LabeledPoint(0.0, Vectors.dense(3.0, 4.0, 5.0))
+ )).flatten
+ val rdd = sc.parallelize(data, 2)
+ val (meanLabel, mean, std) = MLUtils.computeStats(rdd, 3, 6)
+ assert(meanLabel === 0.5)
+ assert(mean === Vectors.dense(2.0, 3.0, 4.0))
+ assert(std === Vectors.dense(1.0, 1.0, 1.0))
+ }
+
+ test("loadLibSVMData") {
+ val lines =
+ """
+ |+1 1:1.0 3:2.0 5:3.0
+ |-1
+ |-1 2:4.0 4:5.0 6:6.0
+ """.stripMargin
+ val tempDir = Files.createTempDir()
+ val file = new File(tempDir.getPath, "part-00000")
+ Files.write(lines, file, Charsets.US_ASCII)
+ val path = tempDir.toURI.toString
+
+ val pointsWithNumFeatures = MLUtils.loadLibSVMData(sc, path, 6).collect()
+ val pointsWithoutNumFeatures = MLUtils.loadLibSVMData(sc, path).collect()
+
+ for (points <- Seq(pointsWithNumFeatures, pointsWithoutNumFeatures)) {
+ assert(points.length === 3)
+ assert(points(0).label === 1.0)
+ assert(points(0).features === Vectors.sparse(6, Seq((0, 1.0), (2, 2.0), (4, 3.0))))
+ assert(points(1).label == 0.0)
+ assert(points(1).features == Vectors.sparse(6, Seq()))
+ assert(points(2).label === 0.0)
+ assert(points(2).features === Vectors.sparse(6, Seq((1, 4.0), (3, 5.0), (5, 6.0))))
+ }
+
+ val multiclassPoints = MLUtils.loadLibSVMData(sc, path, MLUtils.multiclassLabelParser).collect()
+ assert(multiclassPoints.length === 3)
+ assert(multiclassPoints(0).label === 1.0)
+ assert(multiclassPoints(1).label === -1.0)
+ assert(multiclassPoints(2).label === -1.0)
+
+ try {
+ file.delete()
+ tempDir.delete()
+ } catch {
+ case t: Throwable =>
+ }
+ }
}