aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/spark/examples/SparkLR.scala4
1 files changed, 2 insertions, 2 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala
index 19123db738..6777b471fb 100644
--- a/examples/src/main/scala/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/spark/examples/SparkLR.scala
@@ -30,7 +30,7 @@ object SparkLR {
}
val sc = new SparkContext(args(0), "SparkLR")
val numSlices = if (args.length > 1) args(1).toInt else 2
- val data = generateData
+ val points = sc.parallelize(generateData, numSlices).cache()
// Initialize w to a random value
var w = Vector(D, _ => 2 * rand.nextDouble - 1)
@@ -38,7 +38,7 @@ object SparkLR {
for (i <- 1 to ITERATIONS) {
println("On iteration " + i)
- val gradient = sc.parallelize(data, numSlices).map { p =>
+ val gradient = points.map { p =>
(1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x
}.reduce(_ + _)
w -= gradient