aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala32
1 files changed, 18 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index f67760d3ca..4d4c303fc8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -25,7 +25,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -65,28 +65,32 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
def fit(
data: DataFrame,
- k: Double,
- maxIter: Double,
- initMode: String,
- columns: Array[String]): KMeansWrapper = {
+ formula: String,
+ k: Int,
+ maxIter: Int,
+ initMode: String): KMeansWrapper = {
+
+ val rFormulaModel = new RFormula()
+ .setFormula(formula)
+ .setFeaturesCol("features")
+ .fit(data)
- val assembler = new VectorAssembler()
- .setInputCols(columns)
- .setOutputCol("features")
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
val kMeans = new KMeans()
- .setK(k.toInt)
- .setMaxIter(maxIter.toInt)
+ .setK(k)
+ .setMaxIter(maxIter)
.setInitMode(initMode)
val pipeline = new Pipeline()
- .setStages(Array(assembler, kMeans))
+ .setStages(Array(rFormulaModel, kMeans))
.fit(data)
val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
- val attrs = AttributeGroup.fromStructField(
- kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
- val features: Array[String] = attrs.attributes.get.map(_.name.get)
val size: Array[Long] = kMeansModel.summary.clusterSizes
new KMeansWrapper(pipeline, features, size)