aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/java/org/apache
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2015-11-09 14:56:36 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-09 14:56:36 -0800
commit8a2336893a7ff610a6c4629dd567b85078730616 (patch)
treed671674126c1a00ea22b19d912c86d17135f9a1d /mllib/src/test/java/org/apache
parenta3a7c9103e136035d65a5564f9eb0fa04727c4f3 (diff)
downloadspark-8a2336893a7ff610a6c4629dd567b85078730616.tar.gz
spark-8a2336893a7ff610a6c4629dd567b85078730616.tar.bz2
spark-8a2336893a7ff610a6c4629dd567b85078730616.zip
[SPARK-6517][MLLIB] Implement the Algorithm of Hierarchical Clustering
I implemented a hierarchical clustering algorithm again. This PR doesn't include examples, documentation and spark.ml APIs. I am going to send another PRs later. https://issues.apache.org/jira/browse/SPARK-6517 - This implementation based on a bi-sectiong K-means clustering. - It derives from the freeman-lab 's implementation - The basic idea is not changed from the previous version. (#2906) - However, It is 1000x faster than the previous version through parallel processing. Thank you for your great cooperation, RJ Nowling(rnowling), Jeremy Freeman(freeman-lab), Xiangrui Meng(mengxr) and Sean Owen(srowen). Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Author: Yu ISHIKAWA <yu-iskw@users.noreply.github.com> Closes #5267 from yu-iskw/new-hierarchical-clustering.
Diffstat (limited to 'mllib/src/test/java/org/apache')
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java73
1 files changed, 73 insertions, 0 deletions
diff --git a/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
new file mode 100644
index 0000000000..a714620ff7
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/clustering/JavaBisectingKMeansSuite.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.clustering;
+
+import java.io.Serializable;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+
+public class JavaBisectingKMeansSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", this.getClass().getSimpleName());
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void twoDimensionalData() {
+ JavaRDD<Vector> points = sc.parallelize(Lists.newArrayList(
+ Vectors.dense(4, -1),
+ Vectors.dense(4, 1),
+ Vectors.sparse(2, new int[] {0}, new double[] {1.0})
+ ), 2);
+
+ BisectingKMeans bkm = new BisectingKMeans()
+ .setK(4)
+ .setMaxIterations(2)
+ .setSeed(1L);
+ BisectingKMeansModel model = bkm.run(points);
+ Assert.assertEquals(3, model.k());
+ Assert.assertArrayEquals(new double[] {3.0, 0.0}, model.root().center().toArray(), 1e-12);
+ for (ClusteringTreeNode child: model.root().children()) {
+ double[] center = child.center().toArray();
+ if (center[0] > 2) {
+ Assert.assertEquals(2, child.size());
+ Assert.assertArrayEquals(new double[] {4.0, 0.0}, center, 1e-12);
+ } else {
+ Assert.assertEquals(1, child.size());
+ Assert.assertArrayEquals(new double[] {1.0, 0.0}, center, 1e-12);
+ }
+ }
+ }
+}