aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorsboeschhuawei <stephen.boesch@huawei.com>2015-01-30 14:09:49 -0800
committerXiangrui Meng <meng@databricks.com>2015-01-30 14:09:49 -0800
commitf377431a578f621b599b538f069adca6accaf7a9 (patch)
tree3a49fda07b0539f1a02e076f0a1c862835e114ad /mllib
parent6ee8338b377de9dc0adb5b26d9ea9e8519eb58ab (diff)
downloadspark-f377431a578f621b599b538f069adca6accaf7a9.tar.gz
spark-f377431a578f621b599b538f069adca6accaf7a9.tar.bz2
spark-f377431a578f621b599b538f069adca6accaf7a9.zip
[SPARK-4259][MLlib]: Add Power Iteration Clustering Algorithm with Gaussian Similarity Function
Add single pseudo-eigenvector PIC Including documentations and updated pom.xml with the following codes: mllib/src/main/scala/org/apache/spark/mllib/clustering/PIClustering.scala mllib/src/test/scala/org/apache/spark/mllib/clustering/PIClusteringSuite.scala Author: sboeschhuawei <stephen.boesch@huawei.com> Author: Fan Jiang <fanjiang.sc@huawei.com> Author: Jiang Fan <fjiang6@gmail.com> Author: Stephen Boesch <stephen.boesch@huawei.com> Author: Xiangrui Meng <meng@databricks.com> Closes #4254 from fjiang6/PIC and squashes the following commits: 4550850 [sboeschhuawei] Removed pic test data f292f31 [Stephen Boesch] Merge pull request #44 from mengxr/SPARK-4259 4b78aaf [Xiangrui Meng] refactor PIC 24fbf52 [sboeschhuawei] Updated API to be similar to KMeans plus other changes requested by Xiangrui on the PR c12dfc8 [sboeschhuawei] Removed examples files and added pic_data.txt. Revamped testcases yet to come 92d4752 [sboeschhuawei] Move the Guassian/ Affinity matrix calcs out of PIC. Presently in the test suite 7ebd149 [sboeschhuawei] Incorporate Xiangrui's first set of PR comments except restructure PIC.run to take Graph but do not remove Gaussian 121e4d5 [sboeschhuawei] Remove unused testing data files 1c3a62e [sboeschhuawei] removed matplot.py and reordered all private methods to bottom of PIC 218a49d [sboeschhuawei] Applied Xiangrui's comments - especially removing RDD/PICLinalg classes and making noncritical methods private 43ab10b [sboeschhuawei] Change last two println's to log4j logger 88aacc8 [sboeschhuawei] Add assert to testcase on cluster sizes 24f438e [sboeschhuawei] fixed incorrect markdown in clustering doc 060e6bf [sboeschhuawei] Added link to PIC doc from the main clustering md doc be659e3 [sboeschhuawei] Added mllib specific log4j 90e7fa4 [sboeschhuawei] Converted from custom Linalg routines to Breeze: added JavaDoc comments; added Markdown documentation bea48ea [sboeschhuawei] Converted custom Linear Algebra datatypes/routines to use Breeze. b29c0db [Fan Jiang] Update PIClustering.scala ace9749 [Fan Jiang] Update PIClustering.scala a112f38 [sboeschhuawei] Added graphx main and test jars as dependencies to mllib/pom.xml f656c34 [sboeschhuawei] Added iris dataset b7dbcbe [sboeschhuawei] Added axes and combined into single plot for matplotlib a2b1e57 [sboeschhuawei] Revert inadvertent update to KMeans 9294263 [sboeschhuawei] Added visualization/plotting of input/output data e5df2b8 [sboeschhuawei] First end to end working PIC 0700335 [sboeschhuawei] First end to end working version: but has bad performance issue 32a90dc [sboeschhuawei] Update circles test data values 0ef163f [sboeschhuawei] Added ConcentricCircles data generation and KMeans clustering 3fd5bc8 [sboeschhuawei] PIClustering is running in new branch (up to the pseudo-eigenvector convergence step) d5aae20 [Jiang Fan] Adding Power Iteration Clustering and Suite test a3c5fbe [Jiang Fan] Adding Power Iteration Clustering
Diffstat (limited to 'mllib')
-rw-r--r--mllib/pom.xml5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala206
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala103
3 files changed, 314 insertions, 0 deletions
diff --git a/mllib/pom.xml b/mllib/pom.xml
index fc2b2cc09c..a8cee3d51a 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -51,6 +51,11 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>org.apache.spark</groupId>
+ <artifactId>spark-graphx_${scala.binary.version}</artifactId>
+ <version>${project.version}</version>
+ </dependency>
+ <dependency>
<groupId>org.jblas</groupId>
<artifactId>jblas</artifactId>
<version>${jblas.version}</version>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
new file mode 100644
index 0000000000..fcb9a3643c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.apache.spark.{Logging, SparkException}
+import org.apache.spark.graphx._
+import org.apache.spark.graphx.impl.GraphImpl
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.util.random.XORShiftRandom
+
+/**
+ * Model produced by [[PowerIterationClustering]].
+ *
+ * @param k number of clusters
+ * @param assignments an RDD of (vertexID, clusterID) pairs
+ */
+class PowerIterationClusteringModel(
+ val k: Int,
+ val assignments: RDD[(Long, Int)]) extends Serializable
+
+/**
+ * Power Iteration Clustering (PIC), a scalable graph clustering algorithm developed by Lin and
+ * Cohen (see http://www.icml2010.org/papers/387.pdf). From the abstract: PIC finds a very
+ * low-dimensional embedding of a dataset using truncated power iteration on a normalized pair-wise
+ * similarity matrix of the data.
+ *
+ * @param k Number of clusters.
+ * @param maxIterations Maximum number of iterations of the PIC algorithm.
+ */
+class PowerIterationClustering private[clustering] (
+ private var k: Int,
+ private var maxIterations: Int) extends Serializable {
+
+ import org.apache.spark.mllib.clustering.PowerIterationClustering._
+
+ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100}. */
+ def this() = this(k = 2, maxIterations = 100)
+
+ /**
+ * Set the number of clusters.
+ */
+ def setK(k: Int): this.type = {
+ this.k = k
+ this
+ }
+
+ /**
+ * Set maximum number of iterations of the power iteration loop
+ */
+ def setMaxIterations(maxIterations: Int): this.type = {
+ this.maxIterations = maxIterations
+ this
+ }
+
+ /**
+ * Run the PIC algorithm.
+ *
+ * @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is
+ * the matrix A in the PIC paper. The similarity s_ij_ must be nonnegative.
+ * This is a symmetric matrix and hence s_ij_ = s_ji_. For any (i, j) with
+ * nonzero similarity, there should be either (i, j, s_ij_) or (j, i, s_ji_)
+ * in the input. Tuples with i = j are ignored, because we assume s_ij_ = 0.0.
+ *
+ * @return a [[PowerIterationClusteringModel]] that contains the clustering result
+ */
+ def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = {
+ val w = normalize(similarities)
+ val w0 = randomInit(w)
+ pic(w0)
+ }
+
+ /**
+ * Runs the PIC algorithm.
+ *
+ * @param w The normalized affinity matrix, which is the matrix W in the PIC paper with
+ * w_ij_ = a_ij_ / d_ii_ as its edge properties and the initial vector of the power
+ * iteration as its vertex properties.
+ */
+ private def pic(w: Graph[Double, Double]): PowerIterationClusteringModel = {
+ val v = powerIter(w, maxIterations)
+ val assignments = kMeans(v, k)
+ new PowerIterationClusteringModel(k, assignments)
+ }
+}
+
+private[clustering] object PowerIterationClustering extends Logging {
+ /**
+ * Normalizes the affinity matrix (A) by row sums and returns the normalized affinity matrix (W).
+ */
+ def normalize(similarities: RDD[(Long, Long, Double)]): Graph[Double, Double] = {
+ val edges = similarities.flatMap { case (i, j, s) =>
+ if (s < 0.0) {
+ throw new SparkException("Similarity must be nonnegative but found s($i, $j) = $s.")
+ }
+ if (i != j) {
+ Seq(Edge(i, j, s), Edge(j, i, s))
+ } else {
+ None
+ }
+ }
+ val gA = Graph.fromEdges(edges, 0.0)
+ val vD = gA.aggregateMessages[Double](
+ sendMsg = ctx => {
+ ctx.sendToSrc(ctx.attr)
+ },
+ mergeMsg = _ + _,
+ TripletFields.EdgeOnly)
+ GraphImpl.fromExistingRDDs(vD, gA.edges)
+ .mapTriplets(
+ e => e.attr / math.max(e.srcAttr, MLUtils.EPSILON),
+ TripletFields.Src)
+ }
+
+ /**
+ * Generates random vertex properties (v0) to start power iteration.
+ *
+ * @param g a graph representing the normalized affinity matrix (W)
+ * @return a graph with edges representing W and vertices representing a random vector
+ * with unit 1-norm
+ */
+ def randomInit(g: Graph[Double, Double]): Graph[Double, Double] = {
+ val r = g.vertices.mapPartitionsWithIndex(
+ (part, iter) => {
+ val random = new XORShiftRandom(part)
+ iter.map { case (id, _) =>
+ (id, random.nextGaussian())
+ }
+ }, preservesPartitioning = true).cache()
+ val sum = r.values.map(math.abs).sum()
+ val v0 = r.mapValues(x => x / sum)
+ GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
+ }
+
+ /**
+ * Runs power iteration.
+ * @param g input graph with edges representing the normalized affinity matrix (W) and vertices
+ * representing the initial vector of the power iterations.
+ * @param maxIterations maximum number of iterations
+ * @return a [[VertexRDD]] representing the pseudo-eigenvector
+ */
+ def powerIter(
+ g: Graph[Double, Double],
+ maxIterations: Int): VertexRDD[Double] = {
+ // the default tolerance used in the PIC paper, with a lower bound 1e-8
+ val tol = math.max(1e-5 / g.vertices.count(), 1e-8)
+ var prevDelta = Double.MaxValue
+ var diffDelta = Double.MaxValue
+ var curG = g
+ for (iter <- 0 until maxIterations if math.abs(diffDelta) > tol) {
+ val msgPrefix = s"Iteration $iter"
+ // multiply W by vt
+ val v = curG.aggregateMessages[Double](
+ sendMsg = ctx => ctx.sendToSrc(ctx.attr * ctx.dstAttr),
+ mergeMsg = _ + _,
+ TripletFields.Dst).cache()
+ // normalize v
+ val norm = v.values.map(math.abs).sum()
+ logInfo(s"$msgPrefix: norm(v) = $norm.")
+ val v1 = v.mapValues(x => x / norm)
+ // compare difference
+ val delta = curG.joinVertices(v1) { case (_, x, y) =>
+ math.abs(x - y)
+ }.vertices.values.sum()
+ logInfo(s"$msgPrefix: delta = $delta.")
+ diffDelta = math.abs(delta - prevDelta)
+ logInfo(s"$msgPrefix: diff(delta) = $diffDelta.")
+ // update v
+ curG = GraphImpl.fromExistingRDDs(VertexRDD(v1), g.edges)
+ prevDelta = delta
+ }
+ curG.vertices
+ }
+
+ /**
+ * Runs k-means clustering.
+ * @param v a [[VertexRDD]] representing the pseudo-eigenvector
+ * @param k number of clusters
+ * @return a [[VertexRDD]] representing the clustering assignments
+ */
+ def kMeans(v: VertexRDD[Double], k: Int): VertexRDD[Int] = {
+ val points = v.mapValues(x => Vectors.dense(x)).cache()
+ val model = new KMeans()
+ .setK(k)
+ .setRuns(5)
+ .setSeed(0L)
+ .run(points.values)
+ points.mapValues(p => model.predict(p)).cache()
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
new file mode 100644
index 0000000000..2bae465d39
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import scala.collection.mutable
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.graphx.{Edge, Graph}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext {
+
+ import org.apache.spark.mllib.clustering.PowerIterationClustering._
+
+ test("power iteration clustering") {
+ /*
+ We use the following graph to test PIC. All edges are assigned similarity 1.0 except 0.1 for
+ edge (3, 4).
+
+ 15-14 -13 -12
+ | |
+ 4 . 3 - 2 11
+ | | x | |
+ 5 0 - 1 10
+ | |
+ 6 - 7 - 8 - 9
+ */
+
+ val similarities = Seq[(Long, Long, Double)]((0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0),
+ (1, 3, 1.0), (2, 3, 1.0), (3, 4, 0.1), // (3, 4) is a weak edge
+ (4, 5, 1.0), (4, 15, 1.0), (5, 6, 1.0), (6, 7, 1.0), (7, 8, 1.0), (8, 9, 1.0), (9, 10, 1.0),
+ (10, 11, 1.0), (11, 12, 1.0), (12, 13, 1.0), (13, 14, 1.0), (14, 15, 1.0))
+ val model = new PowerIterationClustering()
+ .setK(2)
+ .run(sc.parallelize(similarities, 2))
+ val predictions = Array.fill(2)(mutable.Set.empty[Long])
+ model.assignments.collect().foreach { case (i, c) =>
+ predictions(c) += i
+ }
+ assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
+ }
+
+ test("normalize and powerIter") {
+ /*
+ Test normalize() with the following graph:
+
+ 0 - 3
+ | \ |
+ 1 - 2
+
+ The affinity matrix (A) is
+
+ 0 1 1 1
+ 1 0 1 0
+ 1 1 0 1
+ 1 0 1 0
+
+ D is diag(3, 2, 3, 2) and hence W is
+
+ 0 1/3 1/3 1/3
+ 1/2 0 1/2 0
+ 1/3 1/3 0 1/3
+ 1/2 0 1/2 0
+ */
+ val similarities = Seq[(Long, Long, Double)](
+ (0, 1, 1.0), (0, 2, 1.0), (0, 3, 1.0), (1, 2, 1.0), (2, 3, 1.0))
+ val expected = Array(
+ Array(0.0, 1.0/3.0, 1.0/3.0, 1.0/3.0),
+ Array(1.0/2.0, 0.0, 1.0/2.0, 0.0),
+ Array(1.0/3.0, 1.0/3.0, 0.0, 1.0/3.0),
+ Array(1.0/2.0, 0.0, 1.0/2.0, 0.0))
+ val w = normalize(sc.parallelize(similarities, 2))
+ w.edges.collect().foreach { case Edge(i, j, x) =>
+ assert(x ~== expected(i.toInt)(j.toInt) absTol 1e-14)
+ }
+ val v0 = sc.parallelize(Seq[(Long, Double)]((0, 0.1), (1, 0.2), (2, 0.3), (3, 0.4)), 2)
+ val w0 = Graph(v0, w.edges)
+ val v1 = powerIter(w0, maxIterations = 1).collect()
+ val u = Array(0.3, 0.2, 0.7/3.0, 0.2)
+ val norm = u.sum
+ val u1 = u.map(x => x / norm)
+ v1.foreach { case (i, x) =>
+ assert(x ~== u1(i.toInt) absTol 1e-14)
+ }
+ }
+}