From 460da878fcefc861e77b40719c6329cc2a960de8 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 22 Jul 2012 14:23:38 -0700 Subject: Improve Java API examples - Replace JavaLR example with JavaHdfsLR example. - Use anonymous classes in JavaWordCount; add options. - Remove @Override annotations. --- .../src/main/java/spark/examples/JavaHdfsLR.java | 122 ++++++++++++++++++++ examples/src/main/java/spark/examples/JavaLR.java | 127 --------------------- examples/src/main/java/spark/examples/JavaTC.java | 2 +- .../src/main/java/spark/examples/JavaTest.java | 38 ------ .../main/java/spark/examples/JavaWordCount.java | 52 ++++----- 5 files changed, 143 insertions(+), 198 deletions(-) create mode 100644 examples/src/main/java/spark/examples/JavaHdfsLR.java delete mode 100644 examples/src/main/java/spark/examples/JavaLR.java delete mode 100644 examples/src/main/java/spark/examples/JavaTest.java (limited to 'examples/src/main/java') diff --git a/examples/src/main/java/spark/examples/JavaHdfsLR.java b/examples/src/main/java/spark/examples/JavaHdfsLR.java new file mode 100644 index 0000000000..c7a6b4405a --- /dev/null +++ b/examples/src/main/java/spark/examples/JavaHdfsLR.java @@ -0,0 +1,122 @@ +package spark.examples; + +import scala.util.Random; +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; +import spark.api.java.function.Function2; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; + +public class JavaHdfsLR { + + static int D = 10; // Number of dimensions + static Random rand = new Random(42); + + static class DataPoint implements Serializable { + public DataPoint(double[] x, double y) { + this.x = x; + this.y = y; + } + + double[] x; + double y; + } + + static class ParsePoint extends Function { + + public DataPoint apply(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + double y = Double.parseDouble(tok.nextToken()); + double[] x = new double[D]; + int i = 0; + while (i < D) { + x[i] = Double.parseDouble(tok.nextToken()); + i += 1; + } + return new DataPoint(x, y); + } + } + + static class VectorSum extends Function2 { + + public double[] apply(double[] a, double[] b) { + double[] result = new double[D]; + for (int j = 0; j < D; j++) { + result[j] = a[j] + b[j]; + } + return result; + } + } + + static class ComputeGradient extends Function { + + double[] weights; + + public ComputeGradient(double[] weights) { + this.weights = weights; + } + + public double[] apply(DataPoint p) { + double[] gradient = new double[D]; + for (int i = 0; i < D; i++) { + double dot = dot(weights, p.x); + gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; + } + return gradient; + } + } + + public static double dot(double[] a, double[] b) { + double x = 0; + for (int i = 0; i < D; i++) { + x += a[i] * b[i]; + } + return x; + } + + public static void printWeights(double[] a) { + System.out.println(Arrays.toString(a)); + } + + public static void main(String[] args) { + + if (args.length < 3) { + System.err.println("Usage: JavaHdfsLR "); + System.exit(1); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaHdfsLR"); + JavaRDD lines = sc.textFile(args[1]); + JavaRDD points = lines.map(new ParsePoint()).cache(); + int ITERATIONS = Integer.parseInt(args[2]); + + // Initialize w to a random value + double[] w = new double[D]; + for (int i = 0; i < D; i++) { + w[i] = 2 * rand.nextDouble() - 1; + } + + System.out.print("Initial w: "); + printWeights(w); + + for (int i = 1; i <= ITERATIONS; i++) { + System.out.println("On iteration " + i); + + double[] gradient = points.map( + new ComputeGradient(w) + ).reduce(new VectorSum()); + + for (int j = 0; j < D; j++) { + w[j] -= gradient[j]; + } + + } + + System.out.print("Final w: "); + printWeights(w); + System.exit(0); + } +} diff --git a/examples/src/main/java/spark/examples/JavaLR.java b/examples/src/main/java/spark/examples/JavaLR.java deleted file mode 100644 index cb6abfad5b..0000000000 --- a/examples/src/main/java/spark/examples/JavaLR.java +++ /dev/null @@ -1,127 +0,0 @@ -package spark.examples; - -import scala.util.Random; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.Function; -import spark.api.java.function.Function2; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class JavaLR { - - static int N = 10000; // Number of data points - static int D = 10; // Number of dimensions - static double R = 0.7; // Scaling factor - static int ITERATIONS = 5; - static Random rand = new Random(42); - - static class DataPoint implements Serializable { - public DataPoint(double[] x, int y) { - this.x = x; - this.y = y; - } - double[] x; - int y; - } - - static DataPoint generatePoint(int i) { - int y = (i % 2 == 0) ? -1 : 1; - double[] x = new double[D]; - for (int j = 0; j < D; j++) { - x[j] = rand.nextGaussian() + y * R; - } - return new DataPoint(x, y); - } - - static List generateData() { - List points = new ArrayList(N); - for (int i = 0; i < N; i++) { - points.add(generatePoint(i)); - } - return points; - } - - static class VectorSum extends Function2 { - - @Override - public double[] apply(double[] a, double[] b) { - double[] result = new double[D]; - for (int j = 0; j < D; j++) { - result[j] = a[j] + b[j]; - } - return result; - } - } - - static class ComputeGradient extends Function { - - double[] weights; - - public ComputeGradient(double[] weights) { - this.weights = weights; - } - - @Override - public double[] apply(DataPoint p) { - double[] gradient = new double[D]; - for (int i = 0; i < D; i++) { - double dot = dot(weights, p.x); - gradient[i] = (1 / (1 + Math.exp(-p.y * dot)) - 1) * p.y * p.x[i]; - } - return gradient; - } - } - - public static double dot(double[] a, double[] b) { - double x = 0; - for (int i = 0; i < D; i++) { - x += a[i] * b[i]; - } - return x; - } - - public static void printWeights(double[] a) { - System.out.println(Arrays.toString(a)); - } - - public static void main(String[] args) { - - if (args.length == 0) { - System.err.println("Usage: JavaLR []"); - System.exit(1); - } - - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaLR"); - Integer numSlices = (args.length > 1) ? Integer.parseInt(args[1]): 2; - List data = generateData(); - - // Initialize w to a random value - double[] w = new double[D]; - for (int i = 0; i < D; i++) { - w[i] = 2 * rand.nextDouble() - 1; - } - - System.out.print("Initial w: "); - printWeights(w); - - for (int i = 1; i <= ITERATIONS; i++) { - System.out.println("On iteration " + i); - - double[] gradient = sc.parallelize(data, numSlices).map( - new ComputeGradient(w) - ).reduce(new VectorSum()); - - for (int j = 0; j < D; j++) { - w[j] -= gradient[j]; - } - - } - - System.out.print("Final w: "); - printWeights(w); - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaTC.java b/examples/src/main/java/spark/examples/JavaTC.java index 7ee1c3e49c..d76bcbbe85 100644 --- a/examples/src/main/java/spark/examples/JavaTC.java +++ b/examples/src/main/java/spark/examples/JavaTC.java @@ -31,7 +31,7 @@ public class JavaTC { static class ProjectFn extends PairFunction>, Integer, Integer> { static ProjectFn INSTANCE = new ProjectFn(); - @Override + public Tuple2 apply(Tuple2> triple) { return new Tuple2(triple._2()._2(), triple._2()._1()); } diff --git a/examples/src/main/java/spark/examples/JavaTest.java b/examples/src/main/java/spark/examples/JavaTest.java deleted file mode 100644 index d45795a8e3..0000000000 --- a/examples/src/main/java/spark/examples/JavaTest.java +++ /dev/null @@ -1,38 +0,0 @@ -package spark.examples; - -import spark.api.java.JavaDoubleRDD; -import spark.api.java.JavaRDD; -import spark.api.java.JavaSparkContext; -import spark.api.java.function.DoubleFunction; - -import java.util.List; - -public class JavaTest { - - public static class MapFunction extends DoubleFunction { - @Override - public Double apply(String s) { - return java.lang.Double.parseDouble(s); - } - } - - public static void main(String[] args) throws Exception { - - JavaSparkContext ctx = new JavaSparkContext("local", "JavaTest"); - JavaRDD lines = ctx.textFile("numbers.txt", 1).cache(); - List lineArr = lines.collect(); - - for (String line : lineArr) { - System.out.println(line); - } - - JavaDoubleRDD data = lines.map(new MapFunction()).cache(); - - System.out.println("output"); - List output = data.collect(); - for (Double num : output) { - System.out.println(num); - } - System.exit(0); - } -} diff --git a/examples/src/main/java/spark/examples/JavaWordCount.java b/examples/src/main/java/spark/examples/JavaWordCount.java index b7901d2921..5164dfdd1d 100644 --- a/examples/src/main/java/spark/examples/JavaWordCount.java +++ b/examples/src/main/java/spark/examples/JavaWordCount.java @@ -14,43 +14,31 @@ import java.util.List; public class JavaWordCount { - public static class SplitFunction extends FlatMapFunction { - @Override - public Iterable apply(String s) { - StringOps op = new StringOps(s); - return Arrays.asList(op.split(' ')); - } - } - - public static class MapFunction extends PairFunction { - @Override - public Tuple2 apply(String s) { - return new Tuple2(s, 1); - } - } - - public static class ReduceFunction extends Function2 { - @Override - public Integer apply(Integer i1, Integer i2) { - return i1 + i2; - } - } public static void main(String[] args) throws Exception { - JavaSparkContext ctx = new JavaSparkContext("local", "JavaWordCount"); - JavaRDD lines = ctx.textFile("numbers.txt", 1).cache(); - List lineArr = lines.collect(); - for (String line : lineArr) { - System.out.println(line); + if (args.length < 2) { + System.err.println("Usage: JavaWordCount "); + System.exit(1); } - JavaRDD words = lines.flatMap(new SplitFunction()); - - JavaPairRDD splits = words.map(new MapFunction()); + JavaSparkContext ctx = new JavaSparkContext(args[0], "JavaWordCount"); + JavaRDD lines = ctx.textFile(args[1], 1); + + JavaPairRDD counts = lines.flatMap(new FlatMapFunction() { + public Iterable apply(String s) { + StringOps op = new StringOps(s); + return Arrays.asList(op.split(' ')); + } + }).map(new PairFunction() { + public Tuple2 apply(String s) { + return new Tuple2(s, 1); + } + }).reduceByKey(new Function2() { + public Integer apply(Integer i1, Integer i2) { + return i1 + i2; + } + }); - JavaPairRDD counts = splits.reduceByKey(new ReduceFunction()); - - System.out.println("output"); List> output = counts.collect(); for (Tuple2 tuple : output) { System.out.print(tuple._1 + ": "); -- cgit v1.2.3