aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorMatei Zaharia <matei.zaharia@gmail.com>2013-08-09 20:41:13 -0700
committerMatei Zaharia <matei.zaharia@gmail.com>2013-08-09 20:41:13 -0700
commitcd247ba5bb54afa332519826028ab68a4f73849e (patch)
tree14a5862c20ed86c019ec727914ac0e1af4732264 /examples
parentb09d4b79e83330c96c161ea4eb9af284f0a835e6 (diff)
parente1a209f791a29225c7c75861aa4a18b14739fcc4 (diff)
downloadspark-cd247ba5bb54afa332519826028ab68a4f73849e.tar.gz
spark-cd247ba5bb54afa332519826028ab68a4f73849e.tar.bz2
spark-cd247ba5bb54afa332519826028ab68a4f73849e.zip
Merge pull request #786 from shivaram/mllib-java
Java fixes, tests and examples for ALS, KMeans
Diffstat (limited to 'examples')
-rw-r--r--examples/pom.xml12
-rw-r--r--examples/src/main/java/spark/mllib/JavaALS.java87
-rw-r--r--examples/src/main/java/spark/mllib/JavaKMeans.java81
3 files changed, 180 insertions, 0 deletions
diff --git a/examples/pom.xml b/examples/pom.xml
index 7a8d08fade..ad615b68ff 100644
--- a/examples/pom.xml
+++ b/examples/pom.xml
@@ -119,6 +119,12 @@
<classifier>hadoop1</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-mllib</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop1</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
@@ -157,6 +163,12 @@
<classifier>hadoop2</classifier>
</dependency>
<dependency>
+ <groupId>org.spark-project</groupId>
+ <artifactId>spark-mllib</artifactId>
+ <version>${project.version}</version>
+ <classifier>hadoop2</classifier>
+ </dependency>
+ <dependency>
<groupId>org.apache.hadoop</groupId>
<artifactId>hadoop-core</artifactId>
<scope>provided</scope>
diff --git a/examples/src/main/java/spark/mllib/JavaALS.java b/examples/src/main/java/spark/mllib/JavaALS.java
new file mode 100644
index 0000000000..b48f459cb7
--- /dev/null
+++ b/examples/src/main/java/spark/mllib/JavaALS.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.examples;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+
+import spark.mllib.recommendation.ALS;
+import spark.mllib.recommendation.MatrixFactorizationModel;
+import spark.mllib.recommendation.Rating;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.StringTokenizer;
+
+import scala.Tuple2;
+
+/**
+ * Example using MLLib ALS from Java.
+ */
+public class JavaALS {
+
+ static class ParseRating extends Function<String, Rating> {
+ public Rating call(String line) {
+ StringTokenizer tok = new StringTokenizer(line, ",");
+ int x = Integer.parseInt(tok.nextToken());
+ int y = Integer.parseInt(tok.nextToken());
+ double rating = Double.parseDouble(tok.nextToken());
+ return new Rating(x, y, rating);
+ }
+ }
+
+ static class FeaturesToString extends Function<Tuple2<Object, double[]>, String> {
+ public String call(Tuple2<Object, double[]> element) {
+ return element._1().toString() + "," + Arrays.toString(element._2());
+ }
+ }
+
+ public static void main(String[] args) {
+
+ if (args.length != 5 && args.length != 6) {
+ System.err.println(
+ "Usage: JavaALS <master> <ratings_file> <rank> <iterations> <output_dir> [<blocks>]");
+ System.exit(1);
+ }
+
+ int rank = Integer.parseInt(args[2]);
+ int iterations = Integer.parseInt(args[3]);
+ String outputDir = args[4];
+ int blocks = -1;
+ if (args.length == 6) {
+ blocks = Integer.parseInt(args[5]);
+ }
+
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaALS",
+ System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+ JavaRDD<String> lines = sc.textFile(args[1]);
+
+ JavaRDD<Rating> ratings = lines.map(new ParseRating());
+
+ MatrixFactorizationModel model = ALS.train(ratings.rdd(), rank, iterations, 0.01, blocks);
+
+ model.userFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
+ outputDir + "/userFeatures");
+ model.productFeatures().toJavaRDD().map(new FeaturesToString()).saveAsTextFile(
+ outputDir + "/productFeatures");
+ System.out.println("Final user/product features written to " + outputDir);
+
+ System.exit(0);
+ }
+}
diff --git a/examples/src/main/java/spark/mllib/JavaKMeans.java b/examples/src/main/java/spark/mllib/JavaKMeans.java
new file mode 100644
index 0000000000..02f40438b8
--- /dev/null
+++ b/examples/src/main/java/spark/mllib/JavaKMeans.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package spark.mllib.examples;
+
+import spark.api.java.JavaRDD;
+import spark.api.java.JavaSparkContext;
+import spark.api.java.function.Function;
+
+import spark.mllib.clustering.KMeans;
+import spark.mllib.clustering.KMeansModel;
+
+import java.util.Arrays;
+import java.util.StringTokenizer;
+
+/**
+ * Example using MLLib KMeans from Java.
+ */
+public class JavaKMeans {
+
+ static class ParsePoint extends Function<String, double[]> {
+ public double[] call(String line) {
+ StringTokenizer tok = new StringTokenizer(line, " ");
+ int numTokens = tok.countTokens();
+ double[] point = new double[numTokens];
+ for (int i = 0; i < numTokens; ++i) {
+ point[i] = Double.parseDouble(tok.nextToken());
+ }
+ return point;
+ }
+ }
+
+ public static void main(String[] args) {
+
+ if (args.length < 4) {
+ System.err.println(
+ "Usage: JavaKMeans <master> <input_file> <k> <max_iterations> [<runs>]");
+ System.exit(1);
+ }
+
+ String inputFile = args[1];
+ int k = Integer.parseInt(args[2]);
+ int iterations = Integer.parseInt(args[3]);
+ int runs = 1;
+
+ if (args.length >= 5) {
+ runs = Integer.parseInt(args[4]);
+ }
+
+ JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
+ System.getenv("SPARK_HOME"), System.getenv("SPARK_EXAMPLES_JAR"));
+ JavaRDD<String> lines = sc.textFile(args[1]);
+
+ JavaRDD<double[]> points = lines.map(new ParsePoint());
+
+ KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs);
+
+ System.out.println("Cluster centers:");
+ for (double[] center : model.clusterCenters()) {
+ System.out.println(" " + Arrays.toString(center));
+ }
+ double cost = model.computeCost(points.rdd());
+ System.out.println("Cost: " + cost);
+
+ System.exit(0);
+ }
+}