aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2016-04-25 10:48:15 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-25 10:48:15 -0700
commitb50e2eca93e6b0c391901f1e6d30628f8b6ebaa5 (patch)
tree615ec12aa6fb245f70953516335754afce73afbd /mllib
parenta680562a6f87a03a00f71bad1c424267ae75c641 (diff)
downloadspark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.gz
spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.bz2
spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.zip
[SPARK-14433][PYSPARK][ML] PySpark ml GaussianMixture
## What changes were proposed in this pull request? Add Python API in ML for GaussianMixture ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Add doctest and test cases are the same as mllib Python tests ./dev/lint-python PEP8 checks passed. rm -rf _build/* pydoc checks passed. ./python/run-tests --python-executables=python2.7 --modules=pyspark-ml Running PySpark tests. Output is in /Users/mwang/spark_ws_0904/python/unit-tests.log Will test against the following Python executables: ['python2.7'] Will test the following Python modules: ['pyspark-ml'] Finished test(python2.7): pyspark.ml.evaluation (18s) Finished test(python2.7): pyspark.ml.clustering (40s) Finished test(python2.7): pyspark.ml.classification (49s) Finished test(python2.7): pyspark.ml.recommendation (44s) Finished test(python2.7): pyspark.ml.feature (64s) Finished test(python2.7): pyspark.ml.regression (45s) Finished test(python2.7): pyspark.ml.tuning (30s) Finished test(python2.7): pyspark.ml.tests (56s) Tests passed in 106 seconds Author: wm624@hotmail.com <wm624@hotmail.com> Closes #12402 from wangmiao1981/gmm.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala27
1 files changed, 25 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index ead8ad7806..dfbc8b612c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering
import org.apache.hadoop.fs.Path
+import org.apache.spark.SparkContext
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param.{IntParam, ParamMap, Params}
@@ -27,7 +28,7 @@ import org.apache.spark.ml.util._
import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel}
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.stat.distribution.MultivariateGaussian
-import org.apache.spark.sql.{DataFrame, Dataset, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext}
import org.apache.spark.sql.functions.{col, udf}
import org.apache.spark.sql.types.{IntegerType, StructType}
@@ -104,6 +105,27 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0")
def gaussians: Array[MultivariateGaussian] = parentModel.gaussians
+ /**
+ * Retrieve Gaussian distributions as a DataFrame.
+ * Each row represents a Gaussian Distribution.
+ * Two columns are defined: mean and cov.
+ * Schema:
+ * {{{
+ * root
+ * |-- mean: vector (nullable = true)
+ * |-- cov: matrix (nullable = true)
+ * }}}
+ */
+ @Since("2.0.0")
+ def gaussiansDF: DataFrame = {
+ val modelGaussians = gaussians.map { gaussian =>
+ (gaussian.mu, gaussian.sigma)
+ }
+ val sc = SparkContext.getOrCreate()
+ val sqlContext = SQLContext.getOrCreate(sc)
+ sqlContext.createDataFrame(modelGaussians).toDF("mean", "cov")
+ }
+
@Since("2.0.0")
override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this)
@@ -247,7 +269,8 @@ class GaussianMixture @Since("2.0.0") (
.setSeed($(seed))
.setConvergenceTol($(tol))
val parentModel = algo.run(rdd)
- val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this))
+ val model = copyValues(new GaussianMixtureModel(uid, parentModel)
+ .setParent(this))
val summary = new GaussianMixtureSummary(model.transform(dataset),
$(predictionCol), $(probabilityCol), $(featuresCol), $(k))
model.setSummary(summary)