aboutsummaryrefslogtreecommitdiff
path: root/examples/src
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-03-23 17:34:02 -0700
committerMatei Zaharia <matei@databricks.com>2014-03-23 17:34:02 -0700
commit80c29689ae3b589254a571da3ddb5f9c866ae534 (patch)
tree1c60763332b65c974ca042ea3306c896e8cc88e5 /examples/src
parent8265dc7739caccc59bc2456b2df055ca96337fe4 (diff)
downloadspark-80c29689ae3b589254a571da3ddb5f9c866ae534.tar.gz
spark-80c29689ae3b589254a571da3ddb5f9c866ae534.tar.bz2
spark-80c29689ae3b589254a571da3ddb5f9c866ae534.zip
[SPARK-1212] Adding sparse data support and update KMeans
Continue our discussions from https://github.com/apache/incubator-spark/pull/575 This PR is WIP because it depends on a SNAPSHOT version of breeze. Per previous discussions and benchmarks, I switched to breeze for linear algebra operations. @dlwh and I made some improvements to breeze to keep its performance comparable to the bare-bone implementation, including norm computation and squared distance. This is why this PR needs to depend on a SNAPSHOT version of breeze. @fommil , please find the notice of using netlib-core in `NOTICE`. This is following Apache's instructions on appropriate labeling. I'm going to update this PR to include: 1. Fast distance computation: using `\|a\|_2^2 + \|b\|_2^2 - 2 a^T b` when it doesn't introduce too much numerical error. The squared norms are pre-computed. Otherwise, computing the distance between the center (dense) and a point (possibly sparse) always takes O(n) time. 2. Some numbers about the performance. 3. A released version of breeze. @dlwh, a minor release of breeze will help this PR get merged early. Do you mind sharing breeze's release plan? Thanks! Author: Xiangrui Meng <meng@databricks.com> Closes #117 from mengxr/sparse-kmeans and squashes the following commits: 67b368d [Xiangrui Meng] fix SparseVector.toArray 5eda0de [Xiangrui Meng] update NOTICE 67abe31 [Xiangrui Meng] move ArrayRDDs to mllib.rdd 1da1033 [Xiangrui Meng] remove dependency on commons-math3 and compute EPSILON directly 9bb1b31 [Xiangrui Meng] optimize SparseVector.toArray 226d2cd [Xiangrui Meng] update Java friendly methods in Vectors 238ba34 [Xiangrui Meng] add VectorRDDs with a converter from RDD[Array[Double]] b28ba2f [Xiangrui Meng] add toArray to Vector e69b10c [Xiangrui Meng] remove examples/JavaKMeans.java, which is replaced by mllib/examples/JavaKMeans.java 72bde33 [Xiangrui Meng] clean up code for distance computation 712cb88 [Xiangrui Meng] make Vectors.sparse Java friendly 27858e4 [Xiangrui Meng] update breeze version to 0.7 07c3cf2 [Xiangrui Meng] change Mahout to breeze in doc use a simple lower bound to avoid unnecessary distance computation 6f5cdde [Xiangrui Meng] fix a bug in filtering finished runs 42512f2 [Xiangrui Meng] Merge branch 'master' into sparse-kmeans d6e6c07 [Xiangrui Meng] add predict(RDD[Vector]) to KMeansModel 42b4e50 [Xiangrui Meng] line feed at the end a4ace73 [Xiangrui Meng] Merge branch 'fast-dist' into sparse-kmeans 3ed1a24 [Xiangrui Meng] add doc to BreezeVectorWithSquaredNorm 0107e19 [Xiangrui Meng] update NOTICE 87bc755 [Xiangrui Meng] tuned the KMeans code: changed some for loops to while, use view to avoid copying arrays 0ff8046 [Xiangrui Meng] update KMeans to use fastSquaredDistance f355411 [Xiangrui Meng] add BreezeVectorWithSquaredNorm case class ab74f67 [Xiangrui Meng] add fastSquaredDistance for KMeans 4e7d5ca [Xiangrui Meng] minor style update 07ffaf2 [Xiangrui Meng] add dense/sparse vector data models and conversions to/from breeze vectors use breeze to implement KMeans in order to support both dense and sparse data
Diffstat (limited to 'examples/src')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/JavaKMeans.java138
-rw-r--r--examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java23
2 files changed, 12 insertions, 149 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
deleted file mode 100644
index 2d797279d5..0000000000
--- a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java
+++ /dev/null
@@ -1,138 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.examples;
-
-import scala.Tuple2;
-import org.apache.spark.api.java.JavaPairRDD;
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.api.java.function.PairFunction;
-import org.apache.spark.util.Vector;
-
-import java.util.List;
-import java.util.Map;
-import java.util.regex.Pattern;
-
-/**
- * K-means clustering using Java API.
- */
-public final class JavaKMeans {
-
- private static final Pattern SPACE = Pattern.compile(" ");
-
- /** Parses numbers split by whitespace to a vector */
- static Vector parseVector(String line) {
- String[] splits = SPACE.split(line);
- double[] data = new double[splits.length];
- int i = 0;
- for (String s : splits) {
- data[i] = Double.parseDouble(s);
- i++;
- }
- return new Vector(data);
- }
-
- /** Computes the vector to which the input vector is closest using squared distance */
- static int closestPoint(Vector p, List<Vector> centers) {
- int bestIndex = 0;
- double closest = Double.POSITIVE_INFINITY;
- for (int i = 0; i < centers.size(); i++) {
- double tempDist = p.squaredDist(centers.get(i));
- if (tempDist < closest) {
- closest = tempDist;
- bestIndex = i;
- }
- }
- return bestIndex;
- }
-
- /** Computes the mean across all vectors in the input set of vectors */
- static Vector average(List<Vector> ps) {
- int numVectors = ps.size();
- Vector out = new Vector(ps.get(0).elements());
- // start from i = 1 since we already copied index 0 above
- for (int i = 1; i < numVectors; i++) {
- out.addInPlace(ps.get(i));
- }
- return out.divide(numVectors);
- }
-
- public static void main(String[] args) throws Exception {
- if (args.length < 4) {
- System.err.println("Usage: JavaKMeans <master> <file> <k> <convergeDist>");
- System.exit(1);
- }
- JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
- System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class));
- String path = args[1];
- int K = Integer.parseInt(args[2]);
- double convergeDist = Double.parseDouble(args[3]);
-
- JavaRDD<Vector> data = sc.textFile(path).map(
- new Function<String, Vector>() {
- @Override
- public Vector call(String line) {
- return parseVector(line);
- }
- }
- ).cache();
-
- final List<Vector> centroids = data.takeSample(false, K, 42);
-
- double tempDist;
- do {
- // allocate each vector to closest centroid
- JavaPairRDD<Integer, Vector> closest = data.mapToPair(
- new PairFunction<Vector, Integer, Vector>() {
- @Override
- public Tuple2<Integer, Vector> call(Vector vector) {
- return new Tuple2<Integer, Vector>(
- closestPoint(vector, centroids), vector);
- }
- }
- );
-
- // group by cluster id and average the vectors within each cluster to compute centroids
- JavaPairRDD<Integer, List<Vector>> pointsGroup = closest.groupByKey();
- Map<Integer, Vector> newCentroids = pointsGroup.mapValues(
- new Function<List<Vector>, Vector>() {
- @Override
- public Vector call(List<Vector> ps) {
- return average(ps);
- }
- }).collectAsMap();
- tempDist = 0.0;
- for (int i = 0; i < K; i++) {
- tempDist += centroids.get(i).squaredDist(newCentroids.get(i));
- }
- for (Map.Entry<Integer, Vector> t: newCentroids.entrySet()) {
- centroids.set(t.getKey(), t.getValue());
- }
- System.out.println("Finished iteration (delta = " + tempDist + ")");
- } while (tempDist > convergeDist);
-
- System.out.println("Final centers:");
- for (Vector c : centroids) {
- System.out.println(c);
- }
-
- System.exit(0);
-
- }
-}
diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java
index 76ebdccfd6..7b0ec36424 100644
--- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java
+++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java
@@ -17,32 +17,33 @@
package org.apache.spark.mllib.examples;
+import java.util.regex.Pattern;
+
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.clustering.KMeans;
import org.apache.spark.mllib.clustering.KMeansModel;
-
-import java.util.Arrays;
-import java.util.regex.Pattern;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
/**
* Example using MLLib KMeans from Java.
*/
public final class JavaKMeans {
- static class ParsePoint implements Function<String, double[]> {
+ private static class ParsePoint implements Function<String, Vector> {
private static final Pattern SPACE = Pattern.compile(" ");
@Override
- public double[] call(String line) {
+ public Vector call(String line) {
String[] tok = SPACE.split(line);
double[] point = new double[tok.length];
for (int i = 0; i < tok.length; ++i) {
point[i] = Double.parseDouble(tok[i]);
}
- return point;
+ return Vectors.dense(point);
}
}
@@ -65,15 +66,15 @@ public final class JavaKMeans {
JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans",
System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class));
- JavaRDD<String> lines = sc.textFile(args[1]);
+ JavaRDD<String> lines = sc.textFile(inputFile);
- JavaRDD<double[]> points = lines.map(new ParsePoint());
+ JavaRDD<Vector> points = lines.map(new ParsePoint());
- KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs);
+ KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL());
System.out.println("Cluster centers:");
- for (double[] center : model.clusterCenters()) {
- System.out.println(" " + Arrays.toString(center));
+ for (Vector center : model.clusterCenters()) {
+ System.out.println(" " + center);
}
double cost = model.computeCost(points.rdd());
System.out.println("Cost: " + cost);