aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala')
-rw-r--r--mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
index 029f262660..ced52093f5 100644
--- a/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
+++ b/mllib/src/main/scala/spark/mllib/util/SVMDataGenerator.scala
@@ -8,6 +8,8 @@ import org.jblas.DoubleMatrix
import spark.{RDD, SparkContext}
import spark.mllib.util.MLUtils
+import org.jblas.DoubleMatrix
+
object SVMGenerator {
def main(args: Array[String]) {
@@ -27,7 +29,7 @@ object SVMGenerator {
val sc = new SparkContext(sparkMaster, "SVMGenerator")
val globalRnd = new Random(94720)
- val trueWeights = Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }
+ val trueWeights = new DoubleMatrix(1, nfeatures+1, Array.fill[Double](nfeatures + 1) { globalRnd.nextGaussian() }:_*)
val data: RDD[(Double, Array[Double])] = sc.parallelize(0 until nexamples, parts).map { idx =>
val rnd = new Random(42 + idx)
@@ -35,7 +37,7 @@ object SVMGenerator {
val x = Array.fill[Double](nfeatures) {
rnd.nextDouble() * 2.0 - 1.0
}
- val y = signum(((1.0 +: x) zip trueWeights).map{wx => wx._1 * wx._2}.reduceLeft(_+_) + rnd.nextGaussian() * 0.1)
+ val y = signum((new DoubleMatrix(1, x.length, x:_*)).dot(trueWeights) + rnd.nextGaussian() * 0.1)
(y, x)
}