aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJosh Rosen <rosenville@gmail.com>2012-08-26 15:24:43 -0700
committerJosh Rosen <rosenville@gmail.com>2012-08-26 15:24:43 -0700
commit566feafe1dd3f7fe36ffdb70dc4981e979824caf (patch)
treedd74fbaf66590bf5a4b3d92589a8c6560832d3cf /examples
parent2c16ae36d72653fdbebee57b93b9aead070a395c (diff)
downloadspark-566feafe1dd3f7fe36ffdb70dc4981e979824caf.tar.gz
spark-566feafe1dd3f7fe36ffdb70dc4981e979824caf.tar.bz2
spark-566feafe1dd3f7fe36ffdb70dc4981e979824caf.zip
Cache points in SparkLR example.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/spark/examples/SparkLR.scala4
1 files changed, 2 insertions, 2 deletions
diff --git a/examples/src/main/scala/spark/examples/SparkLR.scala b/examples/src/main/scala/spark/examples/SparkLR.scala
index 19123db738..6777b471fb 100644
--- a/examples/src/main/scala/spark/examples/SparkLR.scala
+++ b/examples/src/main/scala/spark/examples/SparkLR.scala
@@ -30,7 +30,7 @@ object SparkLR {
}
val sc = new SparkContext(args(0), "SparkLR")
val numSlices = if (args.length > 1) args(1).toInt else 2
- val data = generateData
+ val points = sc.parallelize(generateData, numSlices).cache()
// Initialize w to a random value
var w = Vector(D, _ => 2 * rand.nextDouble - 1)
@@ -38,7 +38,7 @@ object SparkLR {
for (i <- 1 to ITERATIONS) {
println("On iteration " + i)
- val gradient = sc.parallelize(data, numSlices).map { p =>
+ val gradient = points.map { p =>
(1 / (1 + exp(-p.y * (w dot p.x))) - 1) * p.y * p.x
}.reduce(_ + _)
w -= gradient