aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-06 15:43:46 -0700
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-08-06 15:43:46 -0700
commit471fbadd0c8cb8d310e3e1dd0e694e357ff1233e (patch)
treec8710554a0ce04e87873308540757eadfcbbd244 /examples
parentd2b0f0c23d9ccd5e8a23450e421503d3201f3450 (diff)
downloadspark-471fbadd0c8cb8d310e3e1dd0e694e357ff1233e.tar.gz
spark-471fbadd0c8cb8d310e3e1dd0e694e357ff1233e.tar.bz2
spark-471fbadd0c8cb8d310e3e1dd0e694e357ff1233e.zip
Java examples, tests for KMeans and ALS
- Changes ALS to accept RDD[Rating] instead of (Int, Int, Double) making it easier to call from Java - Renames class methods from `train` to `run` to enable static methods to be called from Java. - Add unit tests which check if both static / class methods can be called. - Also add examples which port the main() function in ALS, KMeans to the examples project. Couple of minor changes to existing code: - Add a toJavaRDD method in RDD to convert scala RDD to java RDD easily - Workaround a bug where using double[] from Java leads to class cast exception in KMeans init
Diffstat (limited to 'examples')
-rw-r--r--examples/pom.xml12
-rw-r--r--examples/src/main/java/spark/mllib/JavaALS.java87
-rw-r--r--examples/src/main/java/spark/mllib/JavaKMeans.java81
3 files changed, 180 insertions, 0 deletions
diff --git a/examples/pom.xml b/examples/pom.xml
index 7a8d08fade..ad615b68ff 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -119,6 +119,12 @@
<classifier>hadoop1</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-mllib</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -157,6 +163,12 @@
<classifier>hadoop2</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-mllib</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
diff --git a/examples/src/main/java/spark/mllib/JavaALS.java b/examples/src/main/java/spark/mllib/JavaALS.java
new file mode 100644
index 0000000000..8be079ad39
--- /dev/null
+++ b/examples/src/main/java/spark/mllib/JavaALS.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.examples;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+
+import spark.mllib.recommendation.ALS;
+import spark.mllib.recommendation.MatrixFactorizationModel;
+import spark.mllib.recommendation.Rating;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.StringTokenizer;
+
+import scala.Tuple2;
+
+/**
+ * Example using MLLib ALS from Java.
+ */
+public class JavaALS {
+
+ static class ParseRating extends Function<String, Rating> {
+ public Rating call(String line) {
+ StringTokenizer tok = new StringTokenizer(line, ",");
+ Integer x = Integer.parseInt(tok.nextToken());
+ Integer y = Integer.parseInt(tok.nextToken());
+ Double rating = Double.parseDouble(tok.nextToken());
+ return new Rating(x, y, rating);
+ }
+ }
+
+ static class FeaturesToString extends Function<Tuple2<Object, double[]>, String> {
+ public String call(Tuple2<Object, double[]> element) {
+ return element._1().toString() + "," + Arrays.toString(element._2());
+ }
+ }
+
+ public static void main(String[] args) {
+
+ if (args.length != 5 && args.length != 6) {
+ System.err.println(
+ "Usage: JavaALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]");
+ System.exit(1);
+ }
+
+ int rank = Integer.parseInt(args[2]);
+ int iterations = Integer.parseInt(args[3]);
+ String outputDir = args[4];
+ int blocks = -1;
+ if (args.length == 6) {
+ blocks = Integer.parseInt(args[5]);
+ }
+
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS",
+ System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+ JavaRDD<String> lines = sc.textFile(args[1]);
+
+ JavaRDD<Rating> ratings = lines.map(new ParseRating());
+
+ MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks);
+
+ model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
+ outputDir + "/userFeatures");
+ model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
+ outputDir + "/productFeatures");
+ System.out.println("Final user/product features written to " + outputDir);
+
+ System.exit(0);
+ }
+}
diff --git a/examples/src/main/java/spark/mllib/JavaKMeans.java b/examples/src/main/java/spark/mllib/JavaKMeans.java
new file mode 100644
index 0000000000..02f40438b8
--- /dev/null
+++ b/examples/src/main/java/spark/mllib/JavaKMeans.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.examples;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+
+import spark.mllib.clustering.KMeans;
+import spark.mllib.clustering.KMeansModel;
+
+import java.util.Arrays;
+import java.util.StringTokenizer;
+
+/**
+ * Example using MLLib KMeans from Java.
+ */
+public class JavaKMeans {
+
+ static class ParsePoint extends Function<String, double[]> {
+ public double[] call(String line) {
+ StringTokenizer tok = new StringTokenizer(line, " ");
+ int numTokens = tok.countTokens();
+ double[] point = new double[numTokens];
+ for (int i = 0; i < numTokens; ++i) {
+ point[i] = Double.parseDouble(tok.nextToken());
+ }
+ return point;
+ }
+ }
+
+ public static void main(String[] args) {
+
+ if (args.length < 4) {
+ System.err.println(
+ "Usage: JavaKMeans <master> <input_file> <k> <max_iterations> [<runs>]");
+ System.exit(1);
+ }
+
+ String inputFile = args[1];
+ int k = Integer.parseInt(args[2]);
+ int iterations = Integer.parseInt(args[3]);
+ int runs = 1;
+
+ if (args.length >= 5) {
+ runs = Integer.parseInt(args[4]);
+ }
+
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
+ System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+ JavaRDD<String> lines = sc.textFile(args[1]);
+
+ JavaRDD<double[]> points = lines.map(new ParsePoint());
+
+ KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs);
+
+ System.out.println("Cluster centers:");
+ for (double[] center : model.clusterCenters()) {
+ System.out.println(" " + Arrays.toString(center));
+ }
+ double cost = model.computeCost(points.rdd());
+ System.out.println("Cost: " + cost);
+
+ System.exit(0);
+ }
+}