aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2015-11-09 14:56:36 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-09 14:56:36 -0800
commit8a2336893a7ff610a6c4629dd567b85078730616 (patch)
treed671674126c1a00ea22b19d912c86d17135f9a1d /mllib
parenta3a7c9103e136035d65a5564f9eb0fa04727c4f3 (diff)
downloadspark-8a2336893a7ff610a6c4629dd567b85078730616.tar.gz
spark-8a2336893a7ff610a6c4629dd567b85078730616.tar.bz2
spark-8a2336893a7ff610a6c4629dd567b85078730616.zip
[SPARK-6517][MLLIB] Implement the Algorithm of Hierarchical Clustering
I implemented a hierarchical clustering algorithm again. This PR doesn't include examples, documentation and spark.ml APIs. I am going to send another PRs later. https://issues.apache.org/jira/browse/SPARK-6517 - This implementation based on a bi-sectiong K-means clustering. - It derives from the freeman-lab 's implementation - The basic idea is not changed from the previous version. (#2906) - However, It is 1000x faster than the previous version through parallel processing. Thank you for your great cooperation, RJ Nowling(rnowling), Jeremy Freeman(freeman-lab), Xiangrui Meng(mengxr) and Sean Owen(srowen). Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Author: Yu ISHIKAWA <yu-iskw@users.noreply.github.com> Closes #5267 from yu-iskw/new-hierarchical-clustering.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala491
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala95
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java73
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala182
4 files changed, 841 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
new file mode 100644
index 0000000000..29a7aa0bb6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeans.scala
@@ -0,0 +1,491 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import java.util.Random
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
+
+/**
+ * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques"
+ * by Steinbach, Karypis, and Kumar, with modification to fit Spark.
+ * The algorithm starts from a single cluster that contains all points.
+ * Iteratively it finds divisible clusters on the bottom level and bisects each of them using
+ * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible.
+ * The bisecting steps of clusters on the same level are grouped together to increase parallelism.
+ * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters,
+ * larger clusters get higher priority.
+ *
+ * @param k the desired number of leaf clusters (default: 4). The actual number could be smaller if
+ * there are no divisible leaf clusters.
+ * @param maxIterations the max number of k-means iterations to split clusters (default: 20)
+ * @param minDivisibleClusterSize the minimum number of points (if >= 1.0) or the minimum proportion
+ * of points (if < 1.0) of a divisible cluster (default: 1)
+ * @param seed a random seed (default: hash value of the class name)
+ *
+ * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf
+ * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques,
+ * KDD Workshop on Text Mining, 2000.]]
+ */
+@Since("1.6.0")
+@Experimental
+class BisectingKMeans private (
+ private var k: Int,
+ private var maxIterations: Int,
+ private var minDivisibleClusterSize: Double,
+ private var seed: Long) extends Logging {
+
+ import BisectingKMeans._
+
+ /**
+ * Constructs with the default configuration
+ */
+ @Since("1.6.0")
+ def this() = this(4, 20, 1.0, classOf[BisectingKMeans].getName.##)
+
+ /**
+ * Sets the desired number of leaf clusters (default: 4).
+ * The actual number could be smaller if there are no divisible leaf clusters.
+ */
+ @Since("1.6.0")
+ def setK(k: Int): this.type = {
+ require(k > 0, s"k must be positive but got $k.")
+ this.k = k
+ this
+ }
+
+ /**
+ * Gets the desired number of leaf clusters.
+ */
+ @Since("1.6.0")
+ def getK: Int = this.k
+
+ /**
+ * Sets the max number of k-means iterations to split clusters (default: 20).
+ */
+ @Since("1.6.0")
+ def setMaxIterations(maxIterations: Int): this.type = {
+ require(maxIterations > 0, s"maxIterations must be positive but got $maxIterations.")
+ this.maxIterations = maxIterations
+ this
+ }
+
+ /**
+ * Gets the max number of k-means iterations to split clusters.
+ */
+ @Since("1.6.0")
+ def getMaxIterations: Int = this.maxIterations
+
+ /**
+ * Sets the minimum number of points (if >= `1.0`) or the minimum proportion of points
+ * (if < `1.0`) of a divisible cluster (default: 1).
+ */
+ @Since("1.6.0")
+ def setMinDivisibleClusterSize(minDivisibleClusterSize: Double): this.type = {
+ require(minDivisibleClusterSize > 0.0,
+ s"minDivisibleClusterSize must be positive but got $minDivisibleClusterSize.")
+ this.minDivisibleClusterSize = minDivisibleClusterSize
+ this
+ }
+
+ /**
+ * Gets the minimum number of points (if >= `1.0`) or the minimum proportion of points
+ * (if < `1.0`) of a divisible cluster.
+ */
+ @Since("1.6.0")
+ def getMinDivisibleClusterSize: Double = minDivisibleClusterSize
+
+ /**
+ * Sets the random seed (default: hash value of the class name).
+ */
+ @Since("1.6.0")
+ def setSeed(seed: Long): this.type = {
+ this.seed = seed
+ this
+ }
+
+ /**
+ * Gets the random seed.
+ */
+ @Since("1.6.0")
+ def getSeed: Long = this.seed
+
+ /**
+ * Runs the bisecting k-means algorithm.
+ * @param input RDD of vectors
+ * @return model for the bisecting kmeans
+ */
+ @Since("1.6.0")
+ def run(input: RDD[Vector]): BisectingKMeansModel = {
+ if (input.getStorageLevel == StorageLevel.NONE) {
+ logWarning(s"The input RDD ${input.id} is not directly cached, which may hurt performance if"
+ + " its parent RDDs are also not cached.")
+ }
+ val d = input.map(_.size).first()
+ logInfo(s"Feature dimension: $d.")
+ // Compute and cache vector norms for fast distance computation.
+ val norms = input.map(v => Vectors.norm(v, 2.0)).persist(StorageLevel.MEMORY_AND_DISK)
+ val vectors = input.zip(norms).map { case (x, norm) => new VectorWithNorm(x, norm) }
+ var assignments = vectors.map(v => (ROOT_INDEX, v))
+ var activeClusters = summarize(d, assignments)
+ val rootSummary = activeClusters(ROOT_INDEX)
+ val n = rootSummary.size
+ logInfo(s"Number of points: $n.")
+ logInfo(s"Initial cost: ${rootSummary.cost}.")
+ val minSize = if (minDivisibleClusterSize >= 1.0) {
+ math.ceil(minDivisibleClusterSize).toLong
+ } else {
+ math.ceil(minDivisibleClusterSize * n).toLong
+ }
+ logInfo(s"The minimum number of points of a divisible cluster is $minSize.")
+ var inactiveClusters = mutable.Seq.empty[(Long, ClusterSummary)]
+ val random = new Random(seed)
+ var numLeafClustersNeeded = k - 1
+ var level = 1
+ while (activeClusters.nonEmpty && numLeafClustersNeeded > 0 && level < LEVEL_LIMIT) {
+ // Divisible clusters are sufficiently large and have non-trivial cost.
+ var divisibleClusters = activeClusters.filter { case (_, summary) =>
+ (summary.size >= minSize) && (summary.cost > MLUtils.EPSILON * summary.size)
+ }
+ // If we don't need all divisible clusters, take the larger ones.
+ if (divisibleClusters.size > numLeafClustersNeeded) {
+ divisibleClusters = divisibleClusters.toSeq.sortBy { case (_, summary) =>
+ -summary.size
+ }.take(numLeafClustersNeeded)
+ .toMap
+ }
+ if (divisibleClusters.nonEmpty) {
+ val divisibleIndices = divisibleClusters.keys.toSet
+ logInfo(s"Dividing ${divisibleIndices.size} clusters on level $level.")
+ var newClusterCenters = divisibleClusters.flatMap { case (index, summary) =>
+ val (left, right) = splitCenter(summary.center, random)
+ Iterator((leftChildIndex(index), left), (rightChildIndex(index), right))
+ }.map(identity) // workaround for a Scala bug (SI-7005) that produces a not serializable map
+ var newClusters: Map[Long, ClusterSummary] = null
+ var newAssignments: RDD[(Long, VectorWithNorm)] = null
+ for (iter <- 0 until maxIterations) {
+ newAssignments = updateAssignments(assignments, divisibleIndices, newClusterCenters)
+ .filter { case (index, _) =>
+ divisibleIndices.contains(parentIndex(index))
+ }
+ newClusters = summarize(d, newAssignments)
+ newClusterCenters = newClusters.mapValues(_.center).map(identity)
+ }
+ // TODO: Unpersist old indices.
+ val indices = updateAssignments(assignments, divisibleIndices, newClusterCenters).keys
+ .persist(StorageLevel.MEMORY_AND_DISK)
+ assignments = indices.zip(vectors)
+ inactiveClusters ++= activeClusters
+ activeClusters = newClusters
+ numLeafClustersNeeded -= divisibleClusters.size
+ } else {
+ logInfo(s"None active and divisible clusters left on level $level. Stop iterations.")
+ inactiveClusters ++= activeClusters
+ activeClusters = Map.empty
+ }
+ level += 1
+ }
+ val clusters = activeClusters ++ inactiveClusters
+ val root = buildTree(clusters)
+ new BisectingKMeansModel(root)
+ }
+
+ /**
+ * Java-friendly version of [[run(RDD[Vector])*]]
+ */
+ def run(data: JavaRDD[Vector]): BisectingKMeansModel = run(data.rdd)
+}
+
+private object BisectingKMeans extends Serializable {
+
+ /** The index of the root node of a tree. */
+ private val ROOT_INDEX: Long = 1
+
+ private val MAX_DIVISIBLE_CLUSTER_INDEX: Long = Long.MaxValue / 2
+
+ private val LEVEL_LIMIT = math.log10(Long.MaxValue) / math.log10(2)
+
+ /** Returns the left child index of the given node index. */
+ private def leftChildIndex(index: Long): Long = {
+ require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index.")
+ 2 * index
+ }
+
+ /** Returns the right child index of the given node index. */
+ private def rightChildIndex(index: Long): Long = {
+ require(index <= MAX_DIVISIBLE_CLUSTER_INDEX, s"Child index out of bound: 2 * $index + 1.")
+ 2 * index + 1
+ }
+
+ /** Returns the parent index of the given node index, or 0 if the input is 1 (root). */
+ private def parentIndex(index: Long): Long = {
+ index / 2
+ }
+
+ /**
+ * Summarizes data by each cluster as Map.
+ * @param d feature dimension
+ * @param assignments pairs of point and its cluster index
+ * @return a map from cluster indices to corresponding cluster summaries
+ */
+ private def summarize(
+ d: Int,
+ assignments: RDD[(Long, VectorWithNorm)]): Map[Long, ClusterSummary] = {
+ assignments.aggregateByKey(new ClusterSummaryAggregator(d))(
+ seqOp = (agg, v) => agg.add(v),
+ combOp = (agg1, agg2) => agg1.merge(agg2)
+ ).mapValues(_.summary)
+ .collect().toMap
+ }
+
+ /**
+ * Cluster summary aggregator.
+ * @param d feature dimension
+ */
+ private class ClusterSummaryAggregator(val d: Int) extends Serializable {
+ private var n: Long = 0L
+ private val sum: Vector = Vectors.zeros(d)
+ private var sumSq: Double = 0.0
+
+ /** Adds a point. */
+ def add(v: VectorWithNorm): this.type = {
+ n += 1L
+ // TODO: use a numerically stable approach to estimate cost
+ sumSq += v.norm * v.norm
+ BLAS.axpy(1.0, v.vector, sum)
+ this
+ }
+
+ /** Merges another aggregator. */
+ def merge(other: ClusterSummaryAggregator): this.type = {
+ n += other.n
+ sumSq += other.sumSq
+ BLAS.axpy(1.0, other.sum, sum)
+ this
+ }
+
+ /** Returns the summary. */
+ def summary: ClusterSummary = {
+ val mean = sum.copy
+ if (n > 0L) {
+ BLAS.scal(1.0 / n, mean)
+ }
+ val center = new VectorWithNorm(mean)
+ val cost = math.max(sumSq - n * center.norm * center.norm, 0.0)
+ new ClusterSummary(n, center, cost)
+ }
+ }
+
+ /**
+ * Bisects a cluster center.
+ *
+ * @param center current cluster center
+ * @param random a random number generator
+ * @return initial centers
+ */
+ private def splitCenter(
+ center: VectorWithNorm,
+ random: Random): (VectorWithNorm, VectorWithNorm) = {
+ val d = center.vector.size
+ val norm = center.norm
+ val level = 1e-4 * norm
+ val noise = Vectors.dense(Array.fill(d)(random.nextDouble()))
+ val left = center.vector.copy
+ BLAS.axpy(-level, noise, left)
+ val right = center.vector.copy
+ BLAS.axpy(level, noise, right)
+ (new VectorWithNorm(left), new VectorWithNorm(right))
+ }
+
+ /**
+ * Updates assignments.
+ * @param assignments current assignments
+ * @param divisibleIndices divisible cluster indices
+ * @param newClusterCenters new cluster centers
+ * @return new assignments
+ */
+ private def updateAssignments(
+ assignments: RDD[(Long, VectorWithNorm)],
+ divisibleIndices: Set[Long],
+ newClusterCenters: Map[Long, VectorWithNorm]): RDD[(Long, VectorWithNorm)] = {
+ assignments.map { case (index, v) =>
+ if (divisibleIndices.contains(index)) {
+ val children = Seq(leftChildIndex(index), rightChildIndex(index))
+ val selected = children.minBy { child =>
+ KMeans.fastSquaredDistance(newClusterCenters(child), v)
+ }
+ (selected, v)
+ } else {
+ (index, v)
+ }
+ }
+ }
+
+ /**
+ * Builds a clustering tree by re-indexing internal and leaf clusters.
+ * @param clusters a map from cluster indices to corresponding cluster summaries
+ * @return the root node of the clustering tree
+ */
+ private def buildTree(clusters: Map[Long, ClusterSummary]): ClusteringTreeNode = {
+ var leafIndex = 0
+ var internalIndex = -1
+
+ /**
+ * Builds a subtree from this given node index.
+ */
+ def buildSubTree(rawIndex: Long): ClusteringTreeNode = {
+ val cluster = clusters(rawIndex)
+ val size = cluster.size
+ val center = cluster.center
+ val cost = cluster.cost
+ val isInternal = clusters.contains(leftChildIndex(rawIndex))
+ if (isInternal) {
+ val index = internalIndex
+ internalIndex -= 1
+ val leftIndex = leftChildIndex(rawIndex)
+ val rightIndex = rightChildIndex(rawIndex)
+ val height = math.sqrt(Seq(leftIndex, rightIndex).map { childIndex =>
+ KMeans.fastSquaredDistance(center, clusters(childIndex).center)
+ }.max)
+ val left = buildSubTree(leftIndex)
+ val right = buildSubTree(rightIndex)
+ new ClusteringTreeNode(index, size, center, cost, height, Array(left, right))
+ } else {
+ val index = leafIndex
+ leafIndex += 1
+ val height = 0.0
+ new ClusteringTreeNode(index, size, center, cost, height, Array.empty)
+ }
+ }
+
+ buildSubTree(ROOT_INDEX)
+ }
+
+ /**
+ * Summary of a cluster.
+ *
+ * @param size the number of points within this cluster
+ * @param center the center of the points within this cluster
+ * @param cost the sum of squared distances to the center
+ */
+ private case class ClusterSummary(size: Long, center: VectorWithNorm, cost: Double)
+}
+
+/**
+ * Represents a node in a clustering tree.
+ *
+ * @param index node index, negative for internal nodes and non-negative for leaf nodes
+ * @param size size of the cluster
+ * @param centerWithNorm cluster center with norm
+ * @param cost cost of the cluster, i.e., the sum of squared distances to the center
+ * @param height height of the node in the dendrogram. Currently this is defined as the max distance
+ * from the center to the centers of the children's, but subject to change.
+ * @param children children nodes
+ */
+@Since("1.6.0")
+@Experimental
+class ClusteringTreeNode private[clustering] (
+ val index: Int,
+ val size: Long,
+ private val centerWithNorm: VectorWithNorm,
+ val cost: Double,
+ val height: Double,
+ val children: Array[ClusteringTreeNode]) extends Serializable {
+
+ /** Whether this is a leaf node. */
+ val isLeaf: Boolean = children.isEmpty
+
+ require((isLeaf && index >= 0) || (!isLeaf && index < 0))
+
+ /** Cluster center. */
+ def center: Vector = centerWithNorm.vector
+
+ /** Predicts the leaf cluster node index that the input point belongs to. */
+ def predict(point: Vector): Int = {
+ val (index, _) = predict(new VectorWithNorm(point))
+ index
+ }
+
+ /** Returns the full prediction path from root to leaf. */
+ def predictPath(point: Vector): Array[ClusteringTreeNode] = {
+ predictPath(new VectorWithNorm(point)).toArray
+ }
+
+ /** Returns the full prediction path from root to leaf. */
+ private def predictPath(pointWithNorm: VectorWithNorm): List[ClusteringTreeNode] = {
+ if (isLeaf) {
+ this :: Nil
+ } else {
+ val selected = children.minBy { child =>
+ KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm)
+ }
+ selected :: selected.predictPath(pointWithNorm)
+ }
+ }
+
+ /**
+ * Computes the cost (squared distance to the predicted leaf cluster center) of the input point.
+ */
+ def computeCost(point: Vector): Double = {
+ val (_, cost) = predict(new VectorWithNorm(point))
+ cost
+ }
+
+ /**
+ * Predicts the cluster index and the cost of the input point.
+ */
+ private def predict(pointWithNorm: VectorWithNorm): (Int, Double) = {
+ predict(pointWithNorm, KMeans.fastSquaredDistance(centerWithNorm, pointWithNorm))
+ }
+
+ /**
+ * Predicts the cluster index and the cost of the input point.
+ * @param pointWithNorm input point
+ * @param cost the cost to the current center
+ * @return (predicted leaf cluster index, cost)
+ */
+ private def predict(pointWithNorm: VectorWithNorm, cost: Double): (Int, Double) = {
+ if (isLeaf) {
+ (index, cost)
+ } else {
+ val (selectedChild, minCost) = children.map { child =>
+ (child, KMeans.fastSquaredDistance(child.centerWithNorm, pointWithNorm))
+ }.minBy(_._2)
+ selectedChild.predict(pointWithNorm, minCost)
+ }
+ }
+
+ /**
+ * Returns all leaf nodes from this node.
+ */
+ def leafNodes: Array[ClusteringTreeNode] = {
+ if (isLeaf) {
+ Array(this)
+ } else {
+ children.flatMap(_.leafNodes)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
new file mode 100644
index 0000000000..5015f1540d
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/BisectingKMeansModel.scala
@@ -0,0 +1,95 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.apache.spark.Logging
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.rdd.RDD
+
+/**
+ * Clustering model produced by [[BisectingKMeans]].
+ * The prediction is done level-by-level from the root node to a leaf node, and at each node among
+ * its children the closest to the input point is selected.
+ *
+ * @param root the root node of the clustering tree
+ */
+@Since("1.6.0")
+@Experimental
+class BisectingKMeansModel @Since("1.6.0") (
+ @Since("1.6.0") val root: ClusteringTreeNode
+ ) extends Serializable with Logging {
+
+ /**
+ * Leaf cluster centers.
+ */
+ @Since("1.6.0")
+ def clusterCenters: Array[Vector] = root.leafNodes.map(_.center)
+
+ /**
+ * Number of leaf clusters.
+ */
+ lazy val k: Int = clusterCenters.length
+
+ /**
+ * Predicts the index of the cluster that the input point belongs to.
+ */
+ @Since("1.6.0")
+ def predict(point: Vector): Int = {
+ root.predict(point)
+ }
+
+ /**
+ * Predicts the indices of the clusters that the input points belong to.
+ */
+ @Since("1.6.0")
+ def predict(points: RDD[Vector]): RDD[Int] = {
+ points.map { p => root.predict(p) }
+ }
+
+ /**
+ * Java-friendly version of [[predict(RDD[Vector])*]]
+ */
+ @Since("1.6.0")
+ def predict(points: JavaRDD[Vector]): JavaRDD[java.lang.Integer] =
+ predict(points.rdd).toJavaRDD().asInstanceOf[JavaRDD[java.lang.Integer]]
+
+ /**
+ * Computes the squared distance between the input point and the cluster center it belongs to.
+ */
+ @Since("1.6.0")
+ def computeCost(point: Vector): Double = {
+ root.computeCost(point)
+ }
+
+ /**
+ * Computes the sum of squared distances between the input points and their corresponding cluster
+ * centers.
+ */
+ @Since("1.6.0")
+ def computeCost(data: RDD[Vector]): Double = {
+ data.map(root.computeCost).sum()
+ }
+
+ /**
+ * Java-friendly version of [[computeCost(RDD[Vector])*]].
+ */
+ @Since("1.6.0")
+ def computeCost(data: JavaRDD[Vector]): Double = this.computeCost(data.rdd)
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
new file mode 100644
index 0000000000..a714620ff7
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+public class JavaBisectingKMeansSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", this.getClass().getSimpleName());
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void twoDimensionalData() {
+ JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
+ Vectors.dense(4, -1),
+ Vectors.dense(4, 1),
+ Vectors.sparse(2, new int[] {0}, new double[] {1.0})
+ ), 2);
+
+ BisectingKMeans bkm = new BisectingKMeans()
+ .setK(4)
+ .setMaxIterations(2)
+ .setSeed(1L);
+ BisectingKMeansModel model = bkm.run(points);
+ Assert.assertEquals(3, model.k());
+ Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12);
+ for (ClusteringTreeNode child: model.root().children()) {
+ double[] center = child.center().toArray();
+ if (center[0] > 2) {
+ Assert.assertEquals(2, child.size());
+ Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
+ } else {
+ Assert.assertEquals(1, child.size());
+ Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
+ }
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
new file mode 100644
index 0000000000..41b9d5c0d9
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/clustering/BisectingKMeansSuite.scala
@@ -0,0 +1,182 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+
+class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("default values") {
+ val bkm0 = new BisectingKMeans()
+ assert(bkm0.getK === 4)
+ assert(bkm0.getMaxIterations === 20)
+ assert(bkm0.getMinDivisibleClusterSize === 1.0)
+ val bkm1 = new BisectingKMeans()
+ assert(bkm0.getSeed === bkm1.getSeed, "The default seed should be constant.")
+ }
+
+ test("setter/getter") {
+ val bkm = new BisectingKMeans()
+
+ val k = 10
+ assert(bkm.getK !== k)
+ assert(bkm.setK(k).getK === k)
+ val maxIter = 100
+ assert(bkm.getMaxIterations !== maxIter)
+ assert(bkm.setMaxIterations(maxIter).getMaxIterations === maxIter)
+ val minSize = 2.0
+ assert(bkm.getMinDivisibleClusterSize !== minSize)
+ assert(bkm.setMinDivisibleClusterSize(minSize).getMinDivisibleClusterSize === minSize)
+ val seed = 10L
+ assert(bkm.getSeed !== seed)
+ assert(bkm.setSeed(seed).getSeed === seed)
+
+ intercept[IllegalArgumentException] {
+ bkm.setK(0)
+ }
+ intercept[IllegalArgumentException] {
+ bkm.setMaxIterations(0)
+ }
+ intercept[IllegalArgumentException] {
+ bkm.setMinDivisibleClusterSize(0.0)
+ }
+ }
+
+ test("1D data") {
+ val points = Vectors.sparse(1, Array.empty, Array.empty) +:
+ (1 until 8).map(i => Vectors.dense(i))
+ val data = sc.parallelize(points, 2)
+ val bkm = new BisectingKMeans()
+ .setK(4)
+ .setMaxIterations(1)
+ .setSeed(1L)
+ // The clusters should be
+ // (0, 1, 2, 3, 4, 5, 6, 7)
+ // - (0, 1, 2, 3)
+ // - (0, 1)
+ // - (2, 3)
+ // - (4, 5, 6, 7)
+ // - (4, 5)
+ // - (6, 7)
+ val model = bkm.run(data)
+ assert(model.k === 4)
+ // The total cost should be 8 * 0.5 * 0.5 = 2.0.
+ assert(model.computeCost(data) ~== 2.0 relTol 1e-12)
+ val predictions = data.map(v => (v(0), model.predict(v))).collectAsMap()
+ Range(0, 8, 2).foreach { i =>
+ assert(predictions(i) === predictions(i + 1),
+ s"$i and ${i + 1} should belong to the same cluster.")
+ }
+ val root = model.root
+ assert(root.center(0) ~== 3.5 relTol 1e-12)
+ assert(root.height ~== 2.0 relTol 1e-12)
+ assert(root.children.length === 2)
+ assert(root.children(0).height ~== 1.0 relTol 1e-12)
+ assert(root.children(1).height ~== 1.0 relTol 1e-12)
+ }
+
+ test("points are the same") {
+ val data = sc.parallelize(Seq.fill(8)(Vectors.dense(1.0, 1.0)), 2)
+ val bkm = new BisectingKMeans()
+ .setK(2)
+ .setMaxIterations(1)
+ .setSeed(1L)
+ val model = bkm.run(data)
+ assert(model.k === 1)
+ }
+
+ test("more desired clusters than points") {
+ val data = sc.parallelize(Seq.tabulate(4)(i => Vectors.dense(i)), 2)
+ val bkm = new BisectingKMeans()
+ .setK(8)
+ .setMaxIterations(2)
+ .setSeed(1L)
+ val model = bkm.run(data)
+ assert(model.k === 4)
+ }
+
+ test("min divisible cluster") {
+ val data = sc.parallelize(
+ Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)),
+ 2)
+ val bkm = new BisectingKMeans()
+ .setK(4)
+ .setMinDivisibleClusterSize(10)
+ .setMaxIterations(1)
+ .setSeed(1L)
+ val model = bkm.run(data)
+ assert(model.k === 3)
+ assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97)))
+ assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8)))
+
+ bkm.setMinDivisibleClusterSize(0.5)
+ val sameModel = bkm.run(data)
+ assert(sameModel.k === 3)
+ }
+
+ test("larger clusters get selected first") {
+ val data = sc.parallelize(
+ Seq.tabulate(16)(i => Vectors.dense(i)) ++ Seq.tabulate(4)(i => Vectors.dense(-100.0 - i)),
+ 2)
+ val bkm = new BisectingKMeans()
+ .setK(3)
+ .setMaxIterations(1)
+ .setSeed(1L)
+ val model = bkm.run(data)
+ assert(model.k === 3)
+ assert(model.predict(Vectors.dense(-100)) === model.predict(Vectors.dense(-97)))
+ assert(model.predict(Vectors.dense(7)) !== model.predict(Vectors.dense(8)))
+ }
+
+ test("2D data") {
+ val points = Seq(
+ (11, 10), (9, 10), (10, 9), (10, 11),
+ (11, -10), (9, -10), (10, -9), (10, -11),
+ (0, 1), (0, -1)
+ ).map { case (x, y) =>
+ if (x == 0) {
+ Vectors.sparse(2, Array(1), Array(y))
+ } else {
+ Vectors.dense(x, y)
+ }
+ }
+ val data = sc.parallelize(points, 2)
+ val bkm = new BisectingKMeans()
+ .setK(3)
+ .setMaxIterations(4)
+ .setSeed(1L)
+ val model = bkm.run(data)
+ assert(model.k === 3)
+ assert(model.root.center ~== Vectors.dense(8, 0) relTol 1e-12)
+ model.root.leafNodes.foreach { node =>
+ if (node.center(0) < 5) {
+ assert(node.size === 2)
+ assert(node.center ~== Vectors.dense(0, 0) relTol 1e-12)
+ } else if (node.center(1) > 0) {
+ assert(node.size === 4)
+ assert(node.center ~== Vectors.dense(10, 10) relTol 1e-12)
+ } else {
+ assert(node.size === 4)
+ assert(node.center ~== Vectors.dense(10, -10) relTol 1e-12)
+ }
+ }
+ }
+}