diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-03-23 17:34:02 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-03-23 17:34:02 -0700 |
commit | 80c29689ae3b589254a571da3ddb5f9c866ae534 (patch) | |
tree | 1c60763332b65c974ca042ea3306c896e8cc88e5 /examples/src | |
parent | 8265dc7739caccc59bc2456b2df055ca96337fe4 (diff) | |
download | spark-80c29689ae3b589254a571da3ddb5f9c866ae534.tar.gz spark-80c29689ae3b589254a571da3ddb5f9c866ae534.tar.bz2 spark-80c29689ae3b589254a571da3ddb5f9c866ae534.zip |
[SPARK-1212] Adding sparse data support and update KMeans
Continue our discussions from https://github.com/apache/incubator-spark/pull/575
This PR is WIP because it depends on a SNAPSHOT version of breeze.
Per previous discussions and benchmarks, I switched to breeze for linear algebra operations. @dlwh and I made some improvements to breeze to keep its performance comparable to the bare-bone implementation, including norm computation and squared distance. This is why this PR needs to depend on a SNAPSHOT version of breeze.
@fommil , please find the notice of using netlib-core in `NOTICE`. This is following Apache's instructions on appropriate labeling.
I'm going to update this PR to include:
1. Fast distance computation: using `\|a\|_2^2 + \|b\|_2^2 - 2 a^T b` when it doesn't introduce too much numerical error. The squared norms are pre-computed. Otherwise, computing the distance between the center (dense) and a point (possibly sparse) always takes O(n) time.
2. Some numbers about the performance.
3. A released version of breeze. @dlwh, a minor release of breeze will help this PR get merged early. Do you mind sharing breeze's release plan? Thanks!
Author: Xiangrui Meng <meng@databricks.com>
Closes #117 from mengxr/sparse-kmeans and squashes the following commits:
67b368d [Xiangrui Meng] fix SparseVector.toArray
5eda0de [Xiangrui Meng] update NOTICE
67abe31 [Xiangrui Meng] move ArrayRDDs to mllib.rdd
1da1033 [Xiangrui Meng] remove dependency on commons-math3 and compute EPSILON directly
9bb1b31 [Xiangrui Meng] optimize SparseVector.toArray
226d2cd [Xiangrui Meng] update Java friendly methods in Vectors
238ba34 [Xiangrui Meng] add VectorRDDs with a converter from RDD[Array[Double]]
b28ba2f [Xiangrui Meng] add toArray to Vector
e69b10c [Xiangrui Meng] remove examples/JavaKMeans.java, which is replaced by mllib/examples/JavaKMeans.java
72bde33 [Xiangrui Meng] clean up code for distance computation
712cb88 [Xiangrui Meng] make Vectors.sparse Java friendly
27858e4 [Xiangrui Meng] update breeze version to 0.7
07c3cf2 [Xiangrui Meng] change Mahout to breeze in doc use a simple lower bound to avoid unnecessary distance computation
6f5cdde [Xiangrui Meng] fix a bug in filtering finished runs
42512f2 [Xiangrui Meng] Merge branch 'master' into sparse-kmeans
d6e6c07 [Xiangrui Meng] add predict(RDD[Vector]) to KMeansModel
42b4e50 [Xiangrui Meng] line feed at the end
a4ace73 [Xiangrui Meng] Merge branch 'fast-dist' into sparse-kmeans
3ed1a24 [Xiangrui Meng] add doc to BreezeVectorWithSquaredNorm
0107e19 [Xiangrui Meng] update NOTICE
87bc755 [Xiangrui Meng] tuned the KMeans code: changed some for loops to while, use view to avoid copying arrays
0ff8046 [Xiangrui Meng] update KMeans to use fastSquaredDistance
f355411 [Xiangrui Meng] add BreezeVectorWithSquaredNorm case class
ab74f67 [Xiangrui Meng] add fastSquaredDistance for KMeans
4e7d5ca [Xiangrui Meng] minor style update
07ffaf2 [Xiangrui Meng] add dense/sparse vector data models and conversions to/from breeze vectors use breeze to implement KMeans in order to support both dense and sparse data
Diffstat (limited to 'examples/src')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/JavaKMeans.java | 138 | ||||
-rw-r--r-- | examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java | 23 |
2 files changed, 12 insertions, 149 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java deleted file mode 100644 index 2d797279d5..0000000000 --- a/examples/src/main/java/org/apache/spark/examples/JavaKMeans.java +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.examples; - -import scala.Tuple2; -import org.apache.spark.api.java.JavaPairRDD; -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.JavaSparkContext; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.api.java.function.PairFunction; -import org.apache.spark.util.Vector; - -import java.util.List; -import java.util.Map; -import java.util.regex.Pattern; - -/** - * K-means clustering using Java API. - */ -public final class JavaKMeans { - - private static final Pattern SPACE = Pattern.compile(" "); - - /** Parses numbers split by whitespace to a vector */ - static Vector parseVector(String line) { - String[] splits = SPACE.split(line); - double[] data = new double[splits.length]; - int i = 0; - for (String s : splits) { - data[i] = Double.parseDouble(s); - i++; - } - return new Vector(data); - } - - /** Computes the vector to which the input vector is closest using squared distance */ - static int closestPoint(Vector p, List<Vector> centers) { - int bestIndex = 0; - double closest = Double.POSITIVE_INFINITY; - for (int i = 0; i < centers.size(); i++) { - double tempDist = p.squaredDist(centers.get(i)); - if (tempDist < closest) { - closest = tempDist; - bestIndex = i; - } - } - return bestIndex; - } - - /** Computes the mean across all vectors in the input set of vectors */ - static Vector average(List<Vector> ps) { - int numVectors = ps.size(); - Vector out = new Vector(ps.get(0).elements()); - // start from i = 1 since we already copied index 0 above - for (int i = 1; i < numVectors; i++) { - out.addInPlace(ps.get(i)); - } - return out.divide(numVectors); - } - - public static void main(String[] args) throws Exception { - if (args.length < 4) { - System.err.println("Usage: JavaKMeans <master> <file> <k> <convergeDist>"); - System.exit(1); - } - JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", - System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class)); - String path = args[1]; - int K = Integer.parseInt(args[2]); - double convergeDist = Double.parseDouble(args[3]); - - JavaRDD<Vector> data = sc.textFile(path).map( - new Function<String, Vector>() { - @Override - public Vector call(String line) { - return parseVector(line); - } - } - ).cache(); - - final List<Vector> centroids = data.takeSample(false, K, 42); - - double tempDist; - do { - // allocate each vector to closest centroid - JavaPairRDD<Integer, Vector> closest = data.mapToPair( - new PairFunction<Vector, Integer, Vector>() { - @Override - public Tuple2<Integer, Vector> call(Vector vector) { - return new Tuple2<Integer, Vector>( - closestPoint(vector, centroids), vector); - } - } - ); - - // group by cluster id and average the vectors within each cluster to compute centroids - JavaPairRDD<Integer, List<Vector>> pointsGroup = closest.groupByKey(); - Map<Integer, Vector> newCentroids = pointsGroup.mapValues( - new Function<List<Vector>, Vector>() { - @Override - public Vector call(List<Vector> ps) { - return average(ps); - } - }).collectAsMap(); - tempDist = 0.0; - for (int i = 0; i < K; i++) { - tempDist += centroids.get(i).squaredDist(newCentroids.get(i)); - } - for (Map.Entry<Integer, Vector> t: newCentroids.entrySet()) { - centroids.set(t.getKey(), t.getValue()); - } - System.out.println("Finished iteration (delta = " + tempDist + ")"); - } while (tempDist > convergeDist); - - System.out.println("Final centers:"); - for (Vector c : centroids) { - System.out.println(c); - } - - System.exit(0); - - } -} diff --git a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java index 76ebdccfd6..7b0ec36424 100644 --- a/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java +++ b/examples/src/main/java/org/apache/spark/mllib/examples/JavaKMeans.java @@ -17,32 +17,33 @@ package org.apache.spark.mllib.examples; +import java.util.regex.Pattern; + import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; import org.apache.spark.mllib.clustering.KMeans; import org.apache.spark.mllib.clustering.KMeansModel; - -import java.util.Arrays; -import java.util.regex.Pattern; +import org.apache.spark.mllib.linalg.Vector; +import org.apache.spark.mllib.linalg.Vectors; /** * Example using MLLib KMeans from Java. */ public final class JavaKMeans { - static class ParsePoint implements Function<String, double[]> { + private static class ParsePoint implements Function<String, Vector> { private static final Pattern SPACE = Pattern.compile(" "); @Override - public double[] call(String line) { + public Vector call(String line) { String[] tok = SPACE.split(line); double[] point = new double[tok.length]; for (int i = 0; i < tok.length; ++i) { point[i] = Double.parseDouble(tok[i]); } - return point; + return Vectors.dense(point); } } @@ -65,15 +66,15 @@ public final class JavaKMeans { JavaSparkContext sc = new JavaSparkContext(args[0], "JavaKMeans", System.getenv("SPARK_HOME"), JavaSparkContext.jarOfClass(JavaKMeans.class)); - JavaRDD<String> lines = sc.textFile(args[1]); + JavaRDD<String> lines = sc.textFile(inputFile); - JavaRDD<double[]> points = lines.map(new ParsePoint()); + JavaRDD<Vector> points = lines.map(new ParsePoint()); - KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs); + KMeansModel model = KMeans.train(points.rdd(), k, iterations, runs, KMeans.K_MEANS_PARALLEL()); System.out.println("Cluster centers:"); - for (double[] center : model.clusterCenters()) { - System.out.println(" " + Arrays.toString(center)); + for (Vector center : model.clusterCenters()) { + System.out.println(" " + center); } double cost = model.computeCost(points.rdd()); System.out.println("Cost: " + cost); |