diff options
author | Lewuathe <lewuathe@me.com> | 2015-10-30 02:59:05 -0700 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2015-10-30 02:59:05 -0700 |
commit | 86d65265fcab7edab88a7bdb10acba47da95bcb3 (patch) | |
tree | bd634fc275041e52bb056a5c58b77117c7ccc7b8 /mllib/src/main | |
parent | eb59b94c450fe6391d24d44ff7ea9bd4c6893af8 (diff) | |
download | spark-86d65265fcab7edab88a7bdb10acba47da95bcb3.tar.gz spark-86d65265fcab7edab88a7bdb10acba47da95bcb3.tar.bz2 spark-86d65265fcab7edab88a7bdb10acba47da95bcb3.zip |
[SPARK-11207] [ML] Add test cases for solver selection of LinearRegres…
…sion as followup. This is the follow up work of SPARK-10668.
* Fix miner style issues.
* Add test case for checking whether solver is selected properly.
Author: Lewuathe <lewuathe@me.com>
Author: lewuathe <lewuathe@me.com>
Closes #9180 from Lewuathe/SPARK-11207.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala | 54 |
1 files changed, 47 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala index d0ba454f37..6ff07eed6c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/util/LinearDataGenerator.scala @@ -77,13 +77,11 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double = 0.1): Seq[LabeledPoint] = { - generateLinearInput(intercept, weights, - Array.fill[Double](weights.length)(0.0), - Array.fill[Double](weights.length)(1.0 / 3.0), - nPoints, seed, eps)} + generateLinearInput(intercept, weights, Array.fill[Double](weights.length)(0.0), + Array.fill[Double](weights.length)(1.0 / 3.0), nPoints, seed, eps) + } /** - * * @param intercept Data intercept * @param weights Weights to be applied. * @param xMean the mean of the generated features. Lots of time, if the features are not properly @@ -104,16 +102,49 @@ object LinearDataGenerator { nPoints: Int, seed: Int, eps: Double): Seq[LabeledPoint] = { + generateLinearInput(intercept, weights, xMean, xVariance, nPoints, seed, eps, 0.0) + } + + /** + * @param intercept Data intercept + * @param weights Weights to be applied. + * @param xMean the mean of the generated features. Lots of time, if the features are not properly + * standardized, the algorithm with poor implementation will have difficulty + * to converge. + * @param xVariance the variance of the generated features. + * @param nPoints Number of points in sample. + * @param seed Random seed + * @param eps Epsilon scaling factor. + * @param sparsity The ratio of zero elements. If it is 0.0, LabeledPoints with + * DenseVector is returned. + * @return Seq of input. + */ + @Since("1.6.0") + def generateLinearInput( + intercept: Double, + weights: Array[Double], + xMean: Array[Double], + xVariance: Array[Double], + nPoints: Int, + seed: Int, + eps: Double, + sparsity: Double): Seq[LabeledPoint] = { + require(0.0 <= sparsity && sparsity <= 1.0) val rnd = new Random(seed) val x = Array.fill[Array[Double]](nPoints)( Array.fill[Double](weights.length)(rnd.nextDouble())) + val sparseRnd = new Random(seed) x.foreach { v => var i = 0 val len = v.length while (i < len) { - v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + if (sparseRnd.nextDouble() < sparsity) { + v(i) = 0.0 + } else { + v(i) = (v(i) - 0.5) * math.sqrt(12.0 * xVariance(i)) + xMean(i) + } i += 1 } } @@ -121,7 +152,16 @@ object LinearDataGenerator { val y = x.map { xi => blas.ddot(weights.length, xi, 1, weights, 1) + intercept + eps * rnd.nextGaussian() } - y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) + + y.zip(x).map { p => + if (sparsity == 0.0) { + // Return LabeledPoints with DenseVector + LabeledPoint(p._1, Vectors.dense(p._2)) + } else { + // Return LabeledPoints with SparseVector + LabeledPoint(p._1, Vectors.dense(p._2).toSparse) + } + } } /** |