diff options
Diffstat (limited to 'examples/src')
-rw-r--r-- | examples/src/main/java/spark/examples/JavaPageRank.java | 17 | ||||
-rw-r--r-- | examples/src/main/java/spark/mllib/JavaALS.java | 87 | ||||
-rw-r--r-- | examples/src/main/java/spark/mllib/JavaKMeans.java | 81 |
3 files changed, 176 insertions, 9 deletions
diff --git a/examples/src/main/java/spark/examples/JavaPageRank.java b/examples/src/main/java/spark/examples/JavaPageRank.java index 9d90ef9174..75df1af2e3 100644 --- a/examples/src/main/java/spark/examples/JavaPageRank.java +++ b/examples/src/main/java/spark/examples/JavaPageRank.java @@ -23,6 +23,7 @@ import spark.api.java.JavaRDD; import spark.api.java.JavaSparkContext; import spark.api.java.function.FlatMapFunction; import spark.api.java.function.Function; +import spark.api.java.function.Function2; import spark.api.java.function.PairFlatMapFunction; import spark.api.java.function.PairFunction; @@ -39,12 +40,11 @@ import java.util.ArrayList; * where URL and their neighbors are separated by space(s). */ public class JavaPageRank { - private static double sum(List<Double> numbers) { - double out = 0.0; - for (double number : numbers) { - out += number; + private static class Sum extends Function2<Double, Double, Double> { + @Override + public Double call(Double a, Double b) { + return a + b; } - return out; } public static void main(String[] args) throws Exception { @@ -91,16 +91,15 @@ public class JavaPageRank { for (String n : s._1) { results.add(new Tuple2<String, Double>(n, s._2 / s._1.size())); } - return results; } }); // Re-calculates URL ranks based on neighbor contributions. - ranks = contribs.groupByKey().mapValues(new Function<List<Double>, Double>() { + ranks = contribs.reduceByKey(new Sum()).mapValues(new Function<Double, Double>() { @Override - public Double call(List<Double> cs) throws Exception { - return 0.15 + sum(cs) * 0.85; + public Double call(Double sum) throws Exception { + return 0.15 + sum * 0.85; } }); } diff --git a/examples/src/main/java/spark/mllib/JavaALS.java b/examples/src/main/java/spark/mllib/JavaALS.java new file mode 100644 index 0000000000..b48f459cb7 --- /dev/null +++ b/examples/src/main/java/spark/mllib/JavaALS.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.examples; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; + +import spark.mllib.recommendation.ALS; +import spark.mllib.recommendation.MatrixFactorizationModel; +import spark.mllib.recommendation.Rating; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; + +import scala.Tuple2; + +/** + * Example using MLLib ALS from Java. + */ +public class JavaALS { + + static class ParseRating extends Function<String, Rating> { + public Rating call(String line) { + StringTokenizer tok = new StringTokenizer(line, ","); + int x = Integer.parseInt(tok.nextToken()); + int y = Integer.parseInt(tok.nextToken()); + double rating = Double.parseDouble(tok.nextToken()); + return new Rating(x, y, rating); + } + } + + static class FeaturesToString extends Function<Tuple2<Object, double[]>, String> { + public String call(Tuple2<Object, double[]> element) { + return element._1().toString() + "," + Arrays.toString(element._2()); + } + } + + public static void main(String[] args) { + + if (args.length != 5 && args.length != 6) { + System.err.println( + "Usage: JavaALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]"); + System.exit(1); + } + + int rank = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + String outputDir = args[4]; + int blocks = -1; + if (args.length == 6) { + blocks = Integer.parseInt(args[5]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD<String> lines = sc.textFile(args[1]); + + JavaRDD<Rating> ratings = lines.map(new ParseRating()); + + MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); + + model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/userFeatures"); + model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/productFeatures"); + System.out.println("Final user/product features written to " + outputDir); + + System.exit(0); + } +} diff --git a/examples/src/main/java/spark/mllib/JavaKMeans.java b/examples/src/main/java/spark/mllib/JavaKMeans.java new file mode 100644 index 0000000000..02f40438b8 --- /dev/null +++ b/examples/src/main/java/spark/mllib/JavaKMeans.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.examples; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; + +import spark.mllib.clustering.KMeans; +import spark.mllib.clustering.KMeansModel; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Example using MLLib KMeans from Java. + */ +public class JavaKMeans { + + static class ParsePoint extends Function<String, double[]> { + public double[] call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + int numTokens = tok.countTokens(); + double[] point = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + point[i] = Double.parseDouble(tok.nextToken()); + } + return point; + } + } + + public static void main(String[] args) { + + if (args.length < 4) { + System.err.println( + "Usage: JavaKMeans <master> <input_file> <k> <max_iterations> [<runs>]"); + System.exit(1); + } + + String inputFile = args[1]; + int k = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + int runs = 1; + + if (args.length >= 5) { + runs = Integer.parseInt(args[4]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD<String> lines = sc.textFile(args[1]); + + JavaRDD<double[]> points = lines.map(new ParsePoint()); + + KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); + + System.out.println("Cluster centers:"); + for (double[] center : model.clusterCenters()) { + System.out.println(" " + Arrays.toString(center)); + } + double cost = model.computeCost(points.rdd()); + System.out.println("Cost: " + cost); + + System.exit(0); + } +} |