diff options
author | wm624@hotmail.com <wm624@hotmail.com> | 2016-04-25 10:48:15 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-25 10:48:15 -0700 |
commit | b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5 (patch) | |
tree | 615ec12aa6fb245f70953516335754afce73afbd /mllib/src/main/scala/org/apache | |
parent | a680562a6f87a03a00f71bad1c424267ae75c641 (diff) | |
download | spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.gz spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.tar.bz2 spark-b50e2eca93e6b0c391901f1e6d30628f8b6ebaa5.zip |
[SPARK-14433][PYSPARK][ML] PySpark ml GaussianMixture
## What changes were proposed in this pull request?
Add Python API in ML for GaussianMixture
## How was this patch tested?
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Add doctest and test cases are the same as mllib Python tests
./dev/lint-python
PEP8 checks passed.
rm -rf _build/*
pydoc checks passed.
./python/run-tests --python-executables=python2.7 --modules=pyspark-ml
Running PySpark tests. Output is in /Users/mwang/spark_ws_0904/python/unit-tests.log
Will test against the following Python executables: ['python2.7']
Will test the following Python modules: ['pyspark-ml']
Finished test(python2.7): pyspark.ml.evaluation (18s)
Finished test(python2.7): pyspark.ml.clustering (40s)
Finished test(python2.7): pyspark.ml.classification (49s)
Finished test(python2.7): pyspark.ml.recommendation (44s)
Finished test(python2.7): pyspark.ml.feature (64s)
Finished test(python2.7): pyspark.ml.regression (45s)
Finished test(python2.7): pyspark.ml.tuning (30s)
Finished test(python2.7): pyspark.ml.tests (56s)
Tests passed in 106 seconds
Author: wm624@hotmail.com <wm624@hotmail.com>
Closes #12402 from wangmiao1981/gmm.
Diffstat (limited to 'mllib/src/main/scala/org/apache')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala | 27 |
1 files changed, 25 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala index ead8ad7806..dfbc8b612c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala @@ -19,6 +19,7 @@ package org.apache.spark.ml.clustering import org.apache.hadoop.fs.Path +import org.apache.spark.SparkContext import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param.{IntParam, ParamMap, Params} @@ -27,7 +28,7 @@ import org.apache.spark.ml.util._ import org.apache.spark.mllib.clustering.{GaussianMixture => MLlibGM, GaussianMixtureModel => MLlibGMModel} import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.stat.distribution.MultivariateGaussian -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row, SQLContext} import org.apache.spark.sql.functions.{col, udf} import org.apache.spark.sql.types.{IntegerType, StructType} @@ -104,6 +105,27 @@ class GaussianMixtureModel private[ml] ( @Since("2.0.0") def gaussians: Array[MultivariateGaussian] = parentModel.gaussians + /** + * Retrieve Gaussian distributions as a DataFrame. + * Each row represents a Gaussian Distribution. + * Two columns are defined: mean and cov. + * Schema: + * {{{ + * root + * |-- mean: vector (nullable = true) + * |-- cov: matrix (nullable = true) + * }}} + */ + @Since("2.0.0") + def gaussiansDF: DataFrame = { + val modelGaussians = gaussians.map { gaussian => + (gaussian.mu, gaussian.sigma) + } + val sc = SparkContext.getOrCreate() + val sqlContext = SQLContext.getOrCreate(sc) + sqlContext.createDataFrame(modelGaussians).toDF("mean", "cov") + } + @Since("2.0.0") override def write: MLWriter = new GaussianMixtureModel.GaussianMixtureModelWriter(this) @@ -247,7 +269,8 @@ class GaussianMixture @Since("2.0.0") ( .setSeed($(seed)) .setConvergenceTol($(tol)) val parentModel = algo.run(rdd) - val model = copyValues(new GaussianMixtureModel(uid, parentModel).setParent(this)) + val model = copyValues(new GaussianMixtureModel(uid, parentModel) + .setParent(this)) val summary = new GaussianMixtureSummary(model.transform(dataset), $(predictionCol), $(probabilityCol), $(featuresCol), $(k)) model.setSummary(summary) |