aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
authorJosh Rosen <rosenville@gmail.com>2012-07-22 14:23:38 -0700
committerJosh Rosen <rosenville@gmail.com>2012-07-22 14:40:39 -0700
commit460da878fcefc861e77b40719c6329cc2a960de8 (patch)
tree2538639d36da07e3afcfd8cc3afdcba03414becb /examples/src/main/java
parent01dce3f569e0085dae2d0e4bc5c9b2bef5bd3120 (diff)
downloadspark-460da878fcefc861e77b40719c6329cc2a960de8.tar.gz
spark-460da878fcefc861e77b40719c6329cc2a960de8.tar.bz2
spark-460da878fcefc861e77b40719c6329cc2a960de8.zip
Improve Java API examples
- Replace JavaLR example with JavaHdfsLR example. - Use anonymous classes in JavaWordCount; add options. - Remove @Override annotations.
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/spark/examples/JavaHdfsLR.java122
-rw-r--r--examples/src/main/java/spark/examples/JavaLR.java127
-rw-r--r--examples/src/main/java/spark/examples/JavaTC.java2
-rw-r--r--examples/src/main/java/spark/examples/JavaTest.java38
-rw-r--r--examples/src/main/java/spark/examples/JavaWordCount.java52
5 files changed, 143 insertions, 198 deletions
diff --git a/examples/src/main/java/spark/examples/JavaHdfsLR.java b/examples/src/main/java/spark/examples/JavaHdfsLR.java
new file mode 100644
index 0000000000..c7a6b4405a
--- /dev/null
+++ b/examples/src/main/java/spark/examples/JavaHdfsLR.java
@@ -0,0 +1,122 @@
+package spark.examples;
+
+import scala.util.Random;
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+import spark.api.java.function.Function2;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.StringTokenizer;
+
+public class JavaHdfsLR {
+
+ static int D = 10; // Number of dimensions
+ static Random rand = new Random(42);
+
+ static class DataPoint implements Serializable {
+ public DataPoint(double[] x, double y) {
+ this.x = x;
+ this.y = y;
+ }
+
+ double[] x;
+ double y;
+ }
+
+ static class ParsePoint extends Function<String, DataPoint> {
+
+ public DataPoint apply(String line) {
+ StringTokenizer tok = new StringTokenizer(line, " ");
+ double y = Double.parseDouble(tok.nextToken());
+ double[] x = new double[D];
+ int i = 0;
+ while (i < D) {
+ x[i] = Double.parseDouble(tok.nextToken());
+ i += 1;
+ }
+ return new DataPoint(x, y);
+ }
+ }
+
+ static class VectorSum extends Function2<double[], double[], double[]> {
+
+ public double[] apply(double[] a, double[] b) {
+ double[] result = new double[D];
+ for (int j = 0; j < D; j++) {
+ result[j] = a[j] + b[j];
+ }
+ return result;
+ }
+ }
+
+ static class ComputeGradient extends Function<DataPoint, double[]> {
+
+ double[] weights;
+
+ public ComputeGradient(double[] weights) {
+ this.weights = weights;
+ }
+
+ public double[] apply(DataPoint p) {
+ double[] gradient = new double[D];
+ for (int i = 0; i < D; i++) {
+ double dot = dot(weights, p.x);
+ gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
+ }
+ return gradient;
+ }
+ }
+
+ public static double dot(double[] a, double[] b) {
+ double x = 0;
+ for (int i = 0; i < D; i++) {
+ x += a[i] * b[i];
+ }
+ return x;
+ }
+
+ public static void printWeights(double[] a) {
+ System.out.println(Arrays.toString(a));
+ }
+
+ public static void main(String[] args) {
+
+ if (args.length < 3) {
+ System.err.println("Usage: JavaHdfsLR <master> <file> <iters>");
+ System.exit(1);
+ }
+
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR");
+ JavaRDD<String> lines = sc.textFile(args[1]);
+ JavaRDD<DataPoint> points = lines.map(new ParsePoint()).cache();
+ int ITERATIONS = Integer.parseInt(args[2]);
+
+ // Initialize w to a random value
+ double[] w = new double[D];
+ for (int i = 0; i < D; i++) {
+ w[i] = 2 * rand.nextDouble() - 1;
+ }
+
+ System.out.print("Initial w: ");
+ printWeights(w);
+
+ for (int i = 1; i <= ITERATIONS; i++) {
+ System.out.println("On iteration " + i);
+
+ double[] gradient = points.map(
+ new ComputeGradient(w)
+ ).reduce(new VectorSum());
+
+ for (int j = 0; j < D; j++) {
+ w[j] -= gradient[j];
+ }
+
+ }
+
+ System.out.print("Final w: ");
+ printWeights(w);
+ System.exit(0);
+ }
+}
diff --git a/examples/src/main/java/spark/examples/JavaLR.java b/examples/src/main/java/spark/examples/JavaLR.java
deleted file mode 100644
index cb6abfad5b..0000000000
--- a/examples/src/main/java/spark/examples/JavaLR.java
+++ /dev/null
@@ -1,127 +0,0 @@
-package spark.examples;
-
-import scala.util.Random;
-import spark.api.java.JavaSparkContext;
-import spark.api.java.function.Function;
-import spark.api.java.function.Function2;
-
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
-
-public class JavaLR {
-
- static int N = 10000; // Number of data points
- static int D = 10; // Number of dimensions
- static double R = 0.7; // Scaling factor
- static int ITERATIONS = 5;
- static Random rand = new Random(42);
-
- static class DataPoint implements Serializable {
- public DataPoint(double[] x, int y) {
- this.x = x;
- this.y = y;
- }
- double[] x;
- int y;
- }
-
- static DataPoint generatePoint(int i) {
- int y = (i % 2 == 0) ? -1 : 1;
- double[] x = new double[D];
- for (int j = 0; j < D; j++) {
- x[j] = rand.nextGaussian() + y * R;
- }
- return new DataPoint(x, y);
- }
-
- static List<DataPoint> generateData() {
- List<DataPoint> points = new ArrayList<DataPoint>(N);
- for (int i = 0; i < N; i++) {
- points.add(generatePoint(i));
- }
- return points;
- }
-
- static class VectorSum extends Function2<double[], double[], double[]> {
-
- @Override
- public double[] apply(double[] a, double[] b) {
- double[] result = new double[D];
- for (int j = 0; j < D; j++) {
- result[j] = a[j] + b[j];
- }
- return result;
- }
- }
-
- static class ComputeGradient extends Function<DataPoint, double[]> {
-
- double[] weights;
-
- public ComputeGradient(double[] weights) {
- this.weights = weights;
- }
-
- @Override
- public double[] apply(DataPoint p) {
- double[] gradient = new double[D];
- for (int i = 0; i < D; i++) {
- double dot = dot(weights, p.x);
- gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i];
- }
- return gradient;
- }
- }
-
- public static double dot(double[] a, double[] b) {
- double x = 0;
- for (int i = 0; i < D; i++) {
- x += a[i] * b[i];
- }
- return x;
- }
-
- public static void printWeights(double[] a) {
- System.out.println(Arrays.toString(a));
- }
-
- public static void main(String[] args) {
-
- if (args.length == 0) {
- System.err.println("Usage: JavaLR <host> [<slices>]");
- System.exit(1);
- }
-
- JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR");
- Integer numSlices = (args.length > 1) ? Integer.parseInt(args[1]): 2;
- List<DataPoint> data = generateData();
-
- // Initialize w to a random value
- double[] w = new double[D];
- for (int i = 0; i < D; i++) {
- w[i] = 2 * rand.nextDouble() - 1;
- }
-
- System.out.print("Initial w: ");
- printWeights(w);
-
- for (int i = 1; i <= ITERATIONS; i++) {
- System.out.println("On iteration " + i);
-
- double[] gradient = sc.parallelize(data, numSlices).map(
- new ComputeGradient(w)
- ).reduce(new VectorSum());
-
- for (int j = 0; j < D; j++) {
- w[j] -= gradient[j];
- }
-
- }
-
- System.out.print("Final w: ");
- printWeights(w);
- System.exit(0);
- }
-}
diff --git a/examples/src/main/java/spark/examples/JavaTC.java b/examples/src/main/java/spark/examples/JavaTC.java
index 7ee1c3e49c..d76bcbbe85 100644
--- a/examples/src/main/java/spark/examples/JavaTC.java
+++ b/examples/src/main/java/spark/examples/JavaTC.java
@@ -31,7 +31,7 @@ public class JavaTC {
static class ProjectFn extends PairFunction<Tuple2<Integer, Tuple2<Integer, Integer>>,
Integer, Integer> {
static ProjectFn INSTANCE = new ProjectFn();
- @Override
+
public Tuple2<Integer, Integer> apply(Tuple2<Integer, Tuple2<Integer, Integer>> triple) {
return new Tuple2<Integer, Integer>(triple._2()._2(), triple._2()._1());
}
diff --git a/examples/src/main/java/spark/examples/JavaTest.java b/examples/src/main/java/spark/examples/JavaTest.java
deleted file mode 100644
index d45795a8e3..0000000000
--- a/examples/src/main/java/spark/examples/JavaTest.java
+++ /dev/null
@@ -1,38 +0,0 @@
-package spark.examples;
-
-import spark.api.java.JavaDoubleRDD;
-import spark.api.java.JavaRDD;
-import spark.api.java.JavaSparkContext;
-import spark.api.java.function.DoubleFunction;
-
-import java.util.List;
-
-public class JavaTest {
-
- public static class MapFunction extends DoubleFunction<String> {
- @Override
- public Double apply(String s) {
- return java.lang.Double.parseDouble(s);
- }
- }
-
- public static void main(String[] args) throws Exception {
-
- JavaSparkContext ctx = new JavaSparkContext("local", "JavaTest");
- JavaRDD<String> lines = ctx.textFile("numbers.txt", 1).cache();
- List<String> lineArr = lines.collect();
-
- for (String line : lineArr) {
- System.out.println(line);
- }
-
- JavaDoubleRDD data = lines.map(new MapFunction()).cache();
-
- System.out.println("output");
- List<Double> output = data.collect();
- for (Double num : output) {
- System.out.println(num);
- }
- System.exit(0);
- }
-}
diff --git a/examples/src/main/java/spark/examples/JavaWordCount.java b/examples/src/main/java/spark/examples/JavaWordCount.java
index b7901d2921..5164dfdd1d 100644
--- a/examples/src/main/java/spark/examples/JavaWordCount.java
+++ b/examples/src/main/java/spark/examples/JavaWordCount.java
@@ -14,43 +14,31 @@ import java.util.List;
public class JavaWordCount {
- public static class SplitFunction extends FlatMapFunction<String, String> {
- @Override
- public Iterable<String> apply(String s) {
- StringOps op = new StringOps(s);
- return Arrays.asList(op.split(' '));
- }
- }
-
- public static class MapFunction extends PairFunction<String, String, Integer> {
- @Override
- public Tuple2<String, Integer> apply(String s) {
- return new Tuple2(s, 1);
- }
- }
-
- public static class ReduceFunction extends Function2<Integer, Integer, Integer> {
- @Override
- public Integer apply(Integer i1, Integer i2) {
- return i1 + i2;
- }
- }
public static void main(String[] args) throws Exception {
- JavaSparkContext ctx = new JavaSparkContext("local", "JavaWordCount");
- JavaRDD<String> lines = ctx.textFile("numbers.txt", 1).cache();
- List<String> lineArr = lines.collect();
- for (String line : lineArr) {
- System.out.println(line);
+ if (args.length < 2) {
+ System.err.println("Usage: JavaWordCount <master> <file>");
+ System.exit(1);
}
- JavaRDD<String> words = lines.flatMap(new SplitFunction());
-
- JavaPairRDD<String, Integer> splits = words.map(new MapFunction());
+ JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount");
+ JavaRDD<String> lines = ctx.textFile(args[1], 1);
+
+ JavaPairRDD<String, Integer> counts = lines.flatMap(new FlatMapFunction<String, String>() {
+ public Iterable<String> apply(String s) {
+ StringOps op = new StringOps(s);
+ return Arrays.asList(op.split(' '));
+ }
+ }).map(new PairFunction<String, String, Integer>() {
+ public Tuple2<String, Integer> apply(String s) {
+ return new Tuple2(s, 1);
+ }
+ }).reduceByKey(new Function2<Integer, Integer, Integer>() {
+ public Integer apply(Integer i1, Integer i2) {
+ return i1 + i2;
+ }
+ });
- JavaPairRDD<String, Integer> counts = splits.reduceByKey(new ReduceFunction());
-
- System.out.println("output");
List<Tuple2<String, Integer>> output = counts.collect();
for (Tuple2 tuple : output) {
System.out.print(tuple._1 + ": ");