aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorLewuathe <lewuathe@me.com>2015-10-30 02:59:05 -0700
committerDB Tsai <dbt@netflix.com>2015-10-30 02:59:05 -0700
commit86d65265fcab7edab88a7bdb10acba47da95bcb3 (patch)
treebd634fc275041e52bb056a5c58b77117c7ccc7b8 /mllib/src/main/scala/org
parenteb59b94c450fe6391d24d44ff7ea9bd4c6893af8 (diff)
downloadspark-86d65265fcab7edab88a7bdb10acba47da95bcb3.tar.gz
spark-86d65265fcab7edab88a7bdb10acba47da95bcb3.tar.bz2
spark-86d65265fcab7edab88a7bdb10acba47da95bcb3.zip
[SPARK-11207] [ML] Add test cases for solver selection of LinearRegres…
…sion as followup. This is the follow up work of SPARK-10668. * Fix miner style issues. * Add test case for checking whether solver is selected properly. Author: Lewuathe <lewuathe@me.com> Author: lewuathe <lewuathe@me.com> Closes #9180 from Lewuathe/SPARK-11207.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala54
1 files changed, 47 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
index d0ba454f37..6ff07eed6c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala
@@ -77,13 +77,11 @@ object LinearDataGenerator {
nPoints: Int,
seed: Int,
eps: Double = 0.1): Seq[LabeledPoint] = {
- generateLinearInput(intercept, weights,
- Array.fill[Double](weights.length)(0.0),
- Array.fill[Double](weights.length)(1.0 / 3.0),
- nPoints, seed, eps)}
+ generateLinearInput(intercept, weights, Array.fill[Double](weights.length)(0.0),
+ Array.fill[Double](weights.length)(1.0 / 3.0), nPoints, seed, eps)
+ }
/**
- *
* @param intercept Data intercept
* @param weights Weights to be applied.
* @param xMean the mean of the generated features. Lots of time, if the features are not properly
@@ -104,16 +102,49 @@ object LinearDataGenerator {
nPoints: Int,
seed: Int,
eps: Double): Seq[LabeledPoint] = {
+ generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps, 0.0)
+ }
+
+ /**
+ * @param intercept Data intercept
+ * @param weights Weights to be applied.
+ * @param xMean the mean of the generated features. Lots of time, if the features are not properly
+ * standardized, the algorithm with poor implementation will have difficulty
+ * to converge.
+ * @param xVariance the variance of the generated features.
+ * @param nPoints Number of points in sample.
+ * @param seed Random seed
+ * @param eps Epsilon scaling factor.
+ * @param sparsity The ratio of zero elements. If it is 0.0, LabeledPoints with
+ * DenseVector is returned.
+ * @return Seq of input.
+ */
+ @Since("1.6.0")
+ def generateLinearInput(
+ intercept: Double,
+ weights: Array[Double],
+ xMean: Array[Double],
+ xVariance: Array[Double],
+ nPoints: Int,
+ seed: Int,
+ eps: Double,
+ sparsity: Double): Seq[LabeledPoint] = {
+ require(0.0 <= sparsity && sparsity <= 1.0)
val rnd = new Random(seed)
val x = Array.fill[Array[Double]](nPoints)(
Array.fill[Double](weights.length)(rnd.nextDouble()))
+ val sparseRnd = new Random(seed)
x.foreach { v =>
var i = 0
val len = v.length
while (i < len) {
- v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ if (sparseRnd.nextDouble() < sparsity) {
+ v(i) = 0.0
+ } else {
+ v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i)
+ }
i += 1
}
}
@@ -121,7 +152,16 @@ object LinearDataGenerator {
val y = x.map { xi =>
blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian()
}
- y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
+
+ y.zip(x).map { p =>
+ if (sparsity == 0.0) {
+ // Return LabeledPoints with DenseVector
+ LabeledPoint(p._1, Vectors.dense(p._2))
+ } else {
+ // Return LabeledPoints with SparseVector
+ LabeledPoint(p._1, Vectors.dense(p._2).toSparse)
+ }
+ }
}
/**