diff options
author | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-06 15:43:46 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@eecs.berkeley.edu> | 2013-08-06 15:43:46 -0700 |
commit | 471fbadd0c8cb8d310e3e1dd0e694e357ff1233e (patch) | |
tree | c8710554a0ce04e87873308540757eadfcbbd244 /examples | |
parent | d2b0f0c23d9ccd5e8a23450e421503d3201f3450 (diff) | |
download | spark-471fbadd0c8cb8d310e3e1dd0e694e357ff1233e.tar.gz spark-471fbadd0c8cb8d310e3e1dd0e694e357ff1233e.tar.bz2 spark-471fbadd0c8cb8d310e3e1dd0e694e357ff1233e.zip |
Java examples, tests for KMeans and ALS
- Changes ALS to accept RDD[Rating] instead of (Int, Int, Double) making it
easier to call from Java
- Renames class methods from `train` to `run` to enable static methods to be
called from Java.
- Add unit tests which check if both static / class methods can be called.
- Also add examples which port the main() function in ALS, KMeans to the
examples project.
Couple of minor changes to existing code:
- Add a toJavaRDD method in RDD to convert scala RDD to java RDD easily
- Workaround a bug where using double[] from Java leads to class cast exception in
KMeans init
Diffstat (limited to 'examples')
-rw-r--r-- | examples/pom.xml | 12 | ||||
-rw-r--r-- | examples/src/main/java/spark/mllib/JavaALS.java | 87 | ||||
-rw-r--r-- | examples/src/main/java/spark/mllib/JavaKMeans.java | 81 |
3 files changed, 180 insertions, 0 deletions
diff --git a/examples/pom.xml b/examples/pom.xml index 7a8d08fade..ad615b68ff 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -119,6 +119,12 @@ <classifier>hadoop1</classifier> </dependency> <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-mllib</artifactId> + <version>${project.version}</version> + <classifier>hadoop1</classifier> + </dependency> + <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-core</artifactId> <scope>provided</scope> @@ -157,6 +163,12 @@ <classifier>hadoop2</classifier> </dependency> <dependency> + <groupId>org.spark-project</groupId> + <artifactId>spark-mllib</artifactId> + <version>${project.version}</version> + <classifier>hadoop2</classifier> + </dependency> + <dependency> <groupId>org.apache.hadoop</groupId> <artifactId>hadoop-core</artifactId> <scope>provided</scope> diff --git a/examples/src/main/java/spark/mllib/JavaALS.java b/examples/src/main/java/spark/mllib/JavaALS.java new file mode 100644 index 0000000000..8be079ad39 --- /dev/null +++ b/examples/src/main/java/spark/mllib/JavaALS.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.examples; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; + +import spark.mllib.recommendation.ALS; +import spark.mllib.recommendation.MatrixFactorizationModel; +import spark.mllib.recommendation.Rating; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; + +import scala.Tuple2; + +/** + * Example using MLLib ALS from Java. + */ +public class JavaALS { + + static class ParseRating extends Function<String, Rating> { + public Rating call(String line) { + StringTokenizer tok = new StringTokenizer(line, ","); + Integer x = Integer.parseInt(tok.nextToken()); + Integer y = Integer.parseInt(tok.nextToken()); + Double rating = Double.parseDouble(tok.nextToken()); + return new Rating(x, y, rating); + } + } + + static class FeaturesToString extends Function<Tuple2<Object, double[]>, String> { + public String call(Tuple2<Object, double[]> element) { + return element._1().toString() + "," + Arrays.toString(element._2()); + } + } + + public static void main(String[] args) { + + if (args.length != 5 && args.length != 6) { + System.err.println( + "Usage: JavaALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]"); + System.exit(1); + } + + int rank = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + String outputDir = args[4]; + int blocks = -1; + if (args.length == 6) { + blocks = Integer.parseInt(args[5]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD<String> lines = sc.textFile(args[1]); + + JavaRDD<Rating> ratings = lines.map(new ParseRating()); + + MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); + + model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/userFeatures"); + model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/productFeatures"); + System.out.println("Final user/product features written to " + outputDir); + + System.exit(0); + } +} diff --git a/examples/src/main/java/spark/mllib/JavaKMeans.java b/examples/src/main/java/spark/mllib/JavaKMeans.java new file mode 100644 index 0000000000..02f40438b8 --- /dev/null +++ b/examples/src/main/java/spark/mllib/JavaKMeans.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.examples; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; + +import spark.mllib.clustering.KMeans; +import spark.mllib.clustering.KMeansModel; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Example using MLLib KMeans from Java. + */ +public class JavaKMeans { + + static class ParsePoint extends Function<String, double[]> { + public double[] call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + int numTokens = tok.countTokens(); + double[] point = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + point[i] = Double.parseDouble(tok.nextToken()); + } + return point; + } + } + + public static void main(String[] args) { + + if (args.length < 4) { + System.err.println( + "Usage: JavaKMeans <master> <input_file> <k> <max_iterations> [<runs>]"); + System.exit(1); + } + + String inputFile = args[1]; + int k = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + int runs = 1; + + if (args.length >= 5) { + runs = Integer.parseInt(args[4]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD<String> lines = sc.textFile(args[1]); + + JavaRDD<double[]> points = lines.map(new ParsePoint()); + + KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); + + System.out.println("Cluster centers:"); + for (double[] center : model.clusterCenters()) { + System.out.println(" " + Arrays.toString(center)); + } + double cost = model.computeCost(points.rdd()); + System.out.println("Cost: " + cost); + + System.exit(0); + } +} |