aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDominik Jastrzębski <dominik.jastrzebski@codilime.com>2016-05-04 14:25:51 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-04 14:25:51 +0200
commitabecbcd5e9598471b705a2f701731af1adc9d48b (patch)
tree9c44332ad0c6a909464c70f38a6558ac01214ffd
parentf152fae306dc75565cb4648ee1211416d7c0bb23 (diff)
downloadspark-abecbcd5e9598471b705a2f701731af1adc9d48b.tar.gz
spark-abecbcd5e9598471b705a2f701731af1adc9d48b.tar.bz2
spark-abecbcd5e9598471b705a2f701731af1adc9d48b.zip
[SPARK-14844][ML] Add setFeaturesCol and setPredictionCol to KMeansM…
## What changes were proposed in this pull request? Introduction of setFeaturesCol and setPredictionCol methods to KMeansModel in ML library. ## How was this patch tested? By running KMeansSuite. Author: Dominik Jastrzębski <dominik.jastrzebski@codilime.com> Closes #12609 from dominik-jastrzebski/master.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala15
2 files changed, 23 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
index 7c9ac02521..42a25396ad 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -105,6 +105,14 @@ class KMeansModel private[ml] (
copyValues(copied, extra)
}
+ /** @group setParam */
+ @Since("2.0.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
@Since("2.0.0")
override def transform(dataset: Dataset[_]): DataFrame = {
val predictUDF = udf((vector: Vector) => predict(vector))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
index 2ca386e422..241d21961f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -117,6 +117,21 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR
assert(clusterSizes.forall(_ >= 0))
}
+ test("KMeansModel transform with non-default feature and prediction cols") {
+ val featuresColName = "kmeans_model_features"
+ val predictionColName = "kmeans_model_prediction"
+
+ val model = new KMeans().setK(k).setSeed(1).fit(dataset)
+ model.setFeaturesCol(featuresColName).setPredictionCol(predictionColName)
+
+ val transformed = model.transform(dataset.withColumnRenamed("features", featuresColName))
+ Seq(featuresColName, predictionColName).foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+ assert(model.getFeaturesCol == featuresColName)
+ assert(model.getPredictionCol == predictionColName)
+ }
+
test("read/write") {
def checkModelData(model: KMeansModel, model2: KMeansModel): Unit = {
assert(model.clusterCenters === model2.clusterCenters)