From 471fbadd0c8cb8d310e3e1dd0e694e357ff1233e Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Tue, 6 Aug 2013 15:43:46 -0700 Subject: Java examples, tests for KMeans and ALS - Changes ALS to accept RDD[Rating] instead of (Int, Int, Double) making it easier to call from Java - Renames class methods from `train` to `run` to enable static methods to be called from Java. - Add unit tests which check if both static / class methods can be called. - Also add examples which port the main() function in ALS, KMeans to the examples project. Couple of minor changes to existing code: - Add a toJavaRDD method in RDD to convert scala RDD to java RDD easily - Workaround a bug where using double[] from Java leads to class cast exception in KMeans init --- examples/pom.xml | 12 +++ examples/src/main/java/spark/mllib/JavaALS.java | 87 ++++++++++++++++++++++ examples/src/main/java/spark/mllib/JavaKMeans.java | 81 ++++++++++++++++++++ 3 files changed, 180 insertions(+) create mode 100644 examples/src/main/java/spark/mllib/JavaALS.java create mode 100644 examples/src/main/java/spark/mllib/JavaKMeans.java (limited to 'examples') diff --git a/examples/pom.xml b/examples/pom.xml index 7a8d08fade..ad615b68ff 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -118,6 +118,12 @@ ${project.version} hadoop1 + + org.spark-project + spark-mllib + ${project.version} + hadoop1 + org.apache.hadoop hadoop-core @@ -156,6 +162,12 @@ ${project.version} hadoop2 + + org.spark-project + spark-mllib + ${project.version} + hadoop2 + org.apache.hadoop hadoop-core diff --git a/examples/src/main/java/spark/mllib/JavaALS.java b/examples/src/main/java/spark/mllib/JavaALS.java new file mode 100644 index 0000000000..8be079ad39 --- /dev/null +++ b/examples/src/main/java/spark/mllib/JavaALS.java @@ -0,0 +1,87 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.examples; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; + +import spark.mllib.recommendation.ALS; +import spark.mllib.recommendation.MatrixFactorizationModel; +import spark.mllib.recommendation.Rating; + +import java.io.Serializable; +import java.util.Arrays; +import java.util.StringTokenizer; + +import scala.Tuple2; + +/** + * Example using MLLib ALS from Java. + */ +public class JavaALS { + + static class ParseRating extends Function { + public Rating call(String line) { + StringTokenizer tok = new StringTokenizer(line, ","); + Integer x = Integer.parseInt(tok.nextToken()); + Integer y = Integer.parseInt(tok.nextToken()); + Double rating = Double.parseDouble(tok.nextToken()); + return new Rating(x, y, rating); + } + } + + static class FeaturesToString extends Function, String> { + public String call(Tuple2 element) { + return element._1().toString() + "," + Arrays.toString(element._2()); + } + } + + public static void main(String[] args) { + + if (args.length != 5 && args.length != 6) { + System.err.println( + "Usage: JavaALS []"); + System.exit(1); + } + + int rank = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + String outputDir = args[4]; + int blocks = -1; + if (args.length == 6) { + blocks = Integer.parseInt(args[5]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + + JavaRDD ratings = lines.map(new ParseRating()); + + MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks); + + model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/userFeatures"); + model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile( + outputDir + "/productFeatures"); + System.out.println("Final user/product features written to " + outputDir); + + System.exit(0); + } +} diff --git a/examples/src/main/java/spark/mllib/JavaKMeans.java b/examples/src/main/java/spark/mllib/JavaKMeans.java new file mode 100644 index 0000000000..02f40438b8 --- /dev/null +++ b/examples/src/main/java/spark/mllib/JavaKMeans.java @@ -0,0 +1,81 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package spark.mllib.examples; + +import spark.api.java.JavaRDD; +import spark.api.java.JavaSparkContext; +import spark.api.java.function.Function; + +import spark.mllib.clustering.KMeans; +import spark.mllib.clustering.KMeansModel; + +import java.util.Arrays; +import java.util.StringTokenizer; + +/** + * Example using MLLib KMeans from Java. + */ +public class JavaKMeans { + + static class ParsePoint extends Function { + public double[] call(String line) { + StringTokenizer tok = new StringTokenizer(line, " "); + int numTokens = tok.countTokens(); + double[] point = new double[numTokens]; + for (int i = 0; i < numTokens; ++i) { + point[i] = Double.parseDouble(tok.nextToken()); + } + return point; + } + } + + public static void main(String[] args) { + + if (args.length < 4) { + System.err.println( + "Usage: JavaKMeans []"); + System.exit(1); + } + + String inputFile = args[1]; + int k = Integer.parseInt(args[2]); + int iterations = Integer.parseInt(args[3]); + int runs = 1; + + if (args.length >= 5) { + runs = Integer.parseInt(args[4]); + } + + JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", + System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR")); + JavaRDD lines = sc.textFile(args[1]); + + JavaRDD points = lines.map(new ParsePoint()); + + KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); + + System.out.println("Cluster centers:"); + for (double[] center : model.clusterCenters()) { + System.out.println(" " + Arrays.toString(center)); + } + double cost = model.computeCost(points.rdd()); + System.out.println("Cost: " + cost); + + System.exit(0); + } +} -- cgit v1.2.3