aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorTravis Galoppo <tjg2107@columbia.edu>2014-12-29 15:29:15 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-29 15:29:15 -0800
commit6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464 (patch)
treea9d9bcd5af2c93f1bb89cb63edc60278ba4124c2 /examples
parent9bc0df6804f241aff24520d9c6ec54d9b11f5785 (diff)
downloadspark-6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464.tar.gz
spark-6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464.tar.bz2
spark-6cf6fdf3ff5d1cf33c2dc28f039adc4d7c0f0464.zip
SPARK-4156 [MLLIB] EM algorithm for GMMs
Implementation of Expectation-Maximization for Gaussian Mixture Models. This is my maiden contribution to Apache Spark, so I apologize now if I have done anything incorrectly; having said that, this work is my own, and I offer it to the project under the project's open source license. Author: Travis Galoppo <tjg2107@columbia.edu> Author: Travis Galoppo <travis@localhost.localdomain> Author: tgaloppo <tjg2107@columbia.edu> Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #3022 from tgaloppo/master and squashes the following commits: aaa8f25 [Travis Galoppo] MLUtils: changed privacy of EPSILON from [util] to [mllib] 709e4bf [Travis Galoppo] fixed usage line to include optional maxIterations parameter acf1fba [Travis Galoppo] Fixed parameter comment in GaussianMixtureModel Made maximum iterations an optional parameter to DenseGmmEM 9b2fc2a [Travis Galoppo] Style improvements Changed ExpectationSum to a private class b97fe00 [Travis Galoppo] Minor fixes and tweaks. 1de73f3 [Travis Galoppo] Removed redundant array from array creation 578c2d1 [Travis Galoppo] Removed unused import 227ad66 [Travis Galoppo] Moved prediction methods into model class. 308c8ad [Travis Galoppo] Numerous changes to improve code cff73e0 [Travis Galoppo] Replaced accumulators with RDD.aggregate 20ebca1 [Travis Galoppo] Removed unusued code 42b2142 [Travis Galoppo] Added functionality to allow setting of GMM starting point. Added two cluster test to testing suite. 8b633f3 [Travis Galoppo] Style issue 9be2534 [Travis Galoppo] Style issue d695034 [Travis Galoppo] Fixed style issues c3b8ce0 [Travis Galoppo] Merge branch 'master' of https://github.com/tgaloppo/spark Adds predict() method 2df336b [Travis Galoppo] Fixed style issue b99ecc4 [tgaloppo] Merge pull request #1 from FlytxtRnD/predictBranch f407b4c [FlytxtRnD] Added predict() to return the cluster labels and membership values 97044cf [Travis Galoppo] Fixed style issues dc9c742 [Travis Galoppo] Moved MultivariateGaussian utility class e7d413b [Travis Galoppo] Moved multivariate Gaussian utility class to mllib/stat/impl Improved comments 9770261 [Travis Galoppo] Corrected a variety of style and naming issues. 8aaa17d [Travis Galoppo] Added additional train() method to companion object for cluster count and tolerance parameters. 676e523 [Travis Galoppo] Fixed to no longer ignore delta value provided on command line e6ea805 [Travis Galoppo] Merged with master branch; update test suite with latest context changes. Improved cluster initialization strategy. 86fb382 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' 719d8cc [Travis Galoppo] Added scala test suite with basic test c1a8e16 [Travis Galoppo] Made GaussianMixtureModel class serializable Modified sum function for better performance 5c96c57 [Travis Galoppo] Merge remote-tracking branch 'upstream/master' c15405c [Travis Galoppo] SPARK-4156
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala67
1 files changed, 67 insertions, 0 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
new file mode 100644
index 0000000000..948c350953
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DenseGmmEM.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.mllib
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.mllib.clustering.GaussianMixtureEM
+import org.apache.spark.mllib.linalg.Vectors
+
+/**
+ * An example Gaussian Mixture Model EM app. Run with
+ * {{{
+ * ./bin/run-example org.apache.spark.examples.mllib.DenseGmmEM <input> <k> <covergenceTol>
+ * }}}
+ * If you use it as a template to create your own app, please use `spark-submit` to submit your app.
+ */
+object DenseGmmEM {
+ def main(args: Array[String]): Unit = {
+ if (args.length < 3) {
+ println("usage: DenseGmmEM <input file> <k> <convergenceTol> [maxIterations]")
+ } else {
+ val maxIterations = if (args.length > 3) args(3).toInt else 100
+ run(args(0), args(1).toInt, args(2).toDouble, maxIterations)
+ }
+ }
+
+ private def run(inputFile: String, k: Int, convergenceTol: Double, maxIterations: Int) {
+ val conf = new SparkConf().setAppName("Gaussian Mixture Model EM example")
+ val ctx = new SparkContext(conf)
+
+ val data = ctx.textFile(inputFile).map { line =>
+ Vectors.dense(line.trim.split(' ').map(_.toDouble))
+ }.cache()
+
+ val clusters = new GaussianMixtureEM()
+ .setK(k)
+ .setConvergenceTol(convergenceTol)
+ .setMaxIterations(maxIterations)
+ .run(data)
+
+ for (i <- 0 until clusters.k) {
+ println("weight=%f\nmu=%s\nsigma=\n%s\n" format
+ (clusters.weight(i), clusters.mu(i), clusters.sigma(i)))
+ }
+
+ println("Cluster labels (first <= 100):")
+ val clusterLabels = clusters.predict(data)
+ clusterLabels.take(100).foreach { x =>
+ print(" " + x)
+ }
+ println()
+ }
+}