aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala41
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala10
2 files changed, 47 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
index fcb9a3643c..9b5c155b0a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/PowerIterationClustering.scala
@@ -43,15 +43,19 @@ class PowerIterationClusteringModel(
*
* @param k Number of clusters.
* @param maxIterations Maximum number of iterations of the PIC algorithm.
+ * @param initMode Initialization mode.
*/
class PowerIterationClustering private[clustering] (
private var k: Int,
- private var maxIterations: Int) extends Serializable {
+ private var maxIterations: Int,
+ private var initMode: String) extends Serializable {
import org.apache.spark.mllib.clustering.PowerIterationClustering._
- /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100}. */
- def this() = this(k = 2, maxIterations = 100)
+ /** Constructs a PIC instance with default parameters: {k: 2, maxIterations: 100,
+ * initMode: "random"}.
+ */
+ def this() = this(k = 2, maxIterations = 100, initMode = "random")
/**
* Set the number of clusters.
@@ -70,6 +74,18 @@ class PowerIterationClustering private[clustering] (
}
/**
+ * Set the initialization mode. This can be either "random" to use a random vector
+ * as vertex properties, or "degree" to use normalized sum similarities. Default: random.
+ */
+ def setInitializationMode(mode: String): this.type = {
+ this.initMode = mode match {
+ case "random" | "degree" => mode
+ case _ => throw new IllegalArgumentException("Invalid initialization mode: " + mode)
+ }
+ this
+ }
+
+ /**
* Run the PIC algorithm.
*
* @param similarities an RDD of (i, j, s_ij_) tuples representing the affinity matrix, which is
@@ -82,7 +98,10 @@ class PowerIterationClustering private[clustering] (
*/
def run(similarities: RDD[(Long, Long, Double)]): PowerIterationClusteringModel = {
val w = normalize(similarities)
- val w0 = randomInit(w)
+ val w0 = initMode match {
+ case "random" => randomInit(w)
+ case "degree" => initDegreeVector(w)
+ }
pic(w0)
}
@@ -149,6 +168,20 @@ private[clustering] object PowerIterationClustering extends Logging {
}
/**
+ * Generates the degree vector as the vertex properties (v0) to start power iteration.
+ * It is not exactly the node degrees but just the normalized sum similarities. Call it
+ * as degree vector because it is used in the PIC paper.
+ *
+ * @param g a graph representing the normalized affinity matrix (W)
+ * @return a graph with edges representing W and vertices representing the degree vector
+ */
+ def initDegreeVector(g: Graph[Double, Double]): Graph[Double, Double] = {
+ val sum = g.vertices.values.sum()
+ val v0 = g.vertices.mapValues(_ / sum)
+ GraphImpl.fromExistingRDDs(VertexRDD(v0), g.edges)
+ }
+
+ /**
* Runs power iteration.
* @param g input graph with edges representing the normalized affinity matrix (W) and vertices
* representing the initial vector of the power iterations.
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
index 2bae465d39..03ecd9ca73 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/PowerIterationClusteringSuite.scala
@@ -55,6 +55,16 @@ class PowerIterationClusteringSuite extends FunSuite with MLlibTestSparkContext
predictions(c) += i
}
assert(predictions.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
+
+ val model2 = new PowerIterationClustering()
+ .setK(2)
+ .setInitializationMode("degree")
+ .run(sc.parallelize(similarities, 2))
+ val predictions2 = Array.fill(2)(mutable.Set.empty[Long])
+ model2.assignments.collect().foreach { case (i, c) =>
+ predictions2(c) += i
+ }
+ assert(predictions2.toSet == Set((0 to 3).toSet, (4 to 15).toSet))
}
test("normalize and powerIter") {