aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorFlytxtRnD <meethu.mathew@flytxt.com>2015-02-02 23:04:55 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-02 23:04:55 -0800
commit50a1a874e1d087a6c79835b1936d0009622a97b1 (patch)
tree81381fdb41d6bf9e3cbf59291f200fbc5ddab3d1 /mllib/src
parentc31c36c4a76bd3449696383321332ec95bff7fed (diff)
downloadspark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.gz
spark-50a1a874e1d087a6c79835b1936d0009622a97b1.tar.bz2
spark-50a1a874e1d087a6c79835b1936d0009622a97b1.zip
[SPARK-5012][MLLib][PySpark]Python API for Gaussian Mixture Model
Python API for the Gaussian Mixture Model clustering algorithm in MLLib. Author: FlytxtRnD <meethu.mathew@flytxt.com> Closes #4059 from FlytxtRnD/PythonGmmWrapper and squashes the following commits: c973ab3 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 339b09c [FlytxtRnD] Added MultivariateGaussian namedtuple and Arraybuffer in trainGaussianMixture fa0a142 [FlytxtRnD] New line added d5b36ab [FlytxtRnD] Changed argument names to lowercase ac134f1 [FlytxtRnD] Merge branch 'PythonGmmWrapper' of https://github.com/FlytxtRnD/spark into PythonGmmWrapper 6671ea1 [FlytxtRnD] Added mllib/stat/distribution.py 3aee84b [FlytxtRnD] Fixed style issues 2e9f12a [FlytxtRnD] Added mllib/stat/distribution.py and fixed style issues b22532c [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 2e14d82 [FlytxtRnD] Incorporate MultivariateGaussian instances in GaussianMixtureModel 05767c7 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper 3464d19 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper c1d4c71 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'origin/PythonGmmWrapper' into PythonGmmWrapper 426d130 [FlytxtRnD] Added random seed parameter 332bad1 [FlytxtRnD] Merge branch 'PythonGmmWrapper', remote-tracking branch 'upstream/master' into PythonGmmWrapper f82750b [FlytxtRnD] Fixed style issues 5c83825 [FlytxtRnD] Split input file with space delimiter fda60f3 [FlytxtRnD] Python API for Gaussian Mixture Model
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala56
1 files changed, 55 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index a66d6f0cf2..980980593d 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -22,6 +22,7 @@ import java.nio.{ByteBuffer, ByteOrder}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
import scala.language.existentials
import scala.reflect.ClassTag
@@ -40,6 +41,7 @@ import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
+import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
import org.apache.spark.mllib.stat.test.ChiSqTestResult
import org.apache.spark.mllib.tree.{GradientBoostedTrees, RandomForest, DecisionTree}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy, Algo, Strategy}
@@ -260,7 +262,7 @@ class PythonMLLibAPI extends Serializable {
}
/**
- * Java stub for Python mllib KMeans.train()
+ * Java stub for Python mllib KMeans.run()
*/
def trainKMeansModel(
data: JavaRDD[Vector],
@@ -285,6 +287,58 @@ class PythonMLLibAPI extends Serializable {
}
/**
+ * Java stub for Python mllib GaussianMixture.run()
+ * Returns a list containing weights, mean and covariance of each mixture component.
+ */
+ def trainGaussianMixture(
+ data: JavaRDD[Vector],
+ k: Int,
+ convergenceTol: Double,
+ maxIterations: Int,
+ seed: Long): JList[Object] = {
+ val gmmAlg = new GaussianMixture()
+ .setK(k)
+ .setConvergenceTol(convergenceTol)
+ .setMaxIterations(maxIterations)
+
+ if (seed != null) gmmAlg.setSeed(seed)
+
+ try {
+ val model = gmmAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
+ var wt = ArrayBuffer.empty[Double]
+ var mu = ArrayBuffer.empty[Vector]
+ var sigma = ArrayBuffer.empty[Matrix]
+ for (i <- 0 until model.k) {
+ wt += model.weights(i)
+ mu += model.gaussians(i).mu
+ sigma += model.gaussians(i).sigma
+ }
+ List(wt.toArray, mu.toArray, sigma.toArray).map(_.asInstanceOf[Object]).asJava
+ } finally {
+ data.rdd.unpersist(blocking = false)
+ }
+ }
+
+ /**
+ * Java stub for Python mllib GaussianMixtureModel.predictSoft()
+ */
+ def predictSoftGMM(
+ data: JavaRDD[Vector],
+ wt: Object,
+ mu: Array[Object],
+ si: Array[Object]): RDD[Array[Double]] = {
+
+ val weight = wt.asInstanceOf[Array[Double]]
+ val mean = mu.map(_.asInstanceOf[DenseVector])
+ val sigma = si.map(_.asInstanceOf[DenseMatrix])
+ val gaussians = Array.tabulate(weight.length){
+ i => new MultivariateGaussian(mean(i), sigma(i))
+ }
+ val model = new GaussianMixtureModel(weight, gaussians)
+ model.predictSoft(data)
+ }
+
+ /**
* A Wrapper of MatrixFactorizationModel to provide helpfer method for Python
*/
private[python] class MatrixFactorizationModelWrapper(model: MatrixFactorizationModel)