aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-11-24 05:46:05 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-11-24 05:46:05 -0800
commit2dfabec38c24174e7f747c27c7144f7738483ec1 (patch)
tree739af68a3d8d9b8347d7d78b1efd55610ddff05b /mllib
parent223fa218e1f637f0d62332785a3bee225b65b990 (diff)
downloadspark-2dfabec38c24174e7f747c27c7144f7738483ec1.tar.gz
spark-2dfabec38c24174e7f747c27c7144f7738483ec1.tar.bz2
spark-2dfabec38c24174e7f747c27c7144f7738483ec1.zip
[SPARK-18520][ML] Add missing setXXXCol methods for BisectingKMeansModel and GaussianMixtureModel
## What changes were proposed in this pull request? add `setFeaturesCol` and `setPredictionCol` for BiKModel and GMModel add `setProbabilityCol` for GMModel ## How was this patch tested? existing tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #15957 from zhengruifeng/bikm_set.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala12
2 files changed, 20 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
index e6ca3aedff..cf11ba37ab 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -98,6 +98,14 @@ class BisectingKMeansModel private[ml] (
copied.setSummary(trainingSummary).setParent(this.parent)
}
+ /** @group setParam */
+ @Since("2.1.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
index 92d0b7d085..19998ca44b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/GaussianMixture.scala
@@ -87,6 +87,18 @@ class GaussianMixtureModel private[ml] (
@Since("2.0.0") val gaussians: Array[MultivariateGaussian])
extends Model[GaussianMixtureModel] with GaussianMixtureParams with MLWritable {
+ /** @group setParam */
+ @Since("2.1.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ @Since("2.1.0")
+ def setProbabilityCol(value: String): this.type = set(probabilityCol, value)
+
@Since("2.0.0")
override def copy(extra: ParamMap): GaussianMixtureModel = {
val copied = copyValues(new GaussianMixtureModel(uid, weights, gaussians), extra)