aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-03-31 23:49:58 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-31 23:49:58 -0700
commit22249afb4a932a82ff1f7a3befea9fda5a60a3f4 (patch)
tree107b6166b9f3e1ec51c5d8681c10af7ec57bc836 /mllib
parent26867ebc67edab97376c5d8fee76df294359e461 (diff)
downloadspark-22249afb4a932a82ff1f7a3befea9fda5a60a3f4.tar.gz
spark-22249afb4a932a82ff1f7a3befea9fda5a60a3f4.tar.bz2
spark-22249afb4a932a82ff1f7a3befea9fda5a60a3f4.zip
[SPARK-14303][ML][SPARKR] Define and use KMeansWrapper for SparkR::kmeans
## What changes were proposed in this pull request? Define and use ```KMeansWrapper``` for ```SparkR::kmeans```. It's only the code refactor for the original ```KMeans``` wrapper. ## How was this patch tested? Existing tests. cc mengxr Author: Yanbo Liang <ybliang8@gmail.com> Closes #12039 from yanboliang/spark-14059.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala85
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala52
2 files changed, 86 insertions, 51 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
new file mode 100644
index 0000000000..d3a0df4063
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
+import org.apache.spark.ml.feature.VectorAssembler
+import org.apache.spark.sql.DataFrame
+
+private[r] class KMeansWrapper private (
+ pipeline: PipelineModel) {
+
+ private val kMeansModel: KMeansModel = pipeline.stages(1).asInstanceOf[KMeansModel]
+
+ lazy val coefficients: Array[Double] = kMeansModel.clusterCenters.flatMap(_.toArray)
+
+ private lazy val attrs = AttributeGroup.fromStructField(
+ kMeansModel.summary.predictions.schema(kMeansModel.getFeaturesCol))
+
+ lazy val features: Array[String] = attrs.attributes.get.map(_.name.get)
+
+ lazy val k: Int = kMeansModel.getK
+
+ lazy val size: Array[Int] = kMeansModel.summary.size
+
+ lazy val cluster: DataFrame = kMeansModel.summary.cluster
+
+ def fitted(method: String): DataFrame = {
+ if (method == "centers") {
+ kMeansModel.summary.predictions.drop(kMeansModel.getFeaturesCol)
+ } else if (method == "classes") {
+ kMeansModel.summary.cluster
+ } else {
+ throw new UnsupportedOperationException(
+ s"Method (centers or classes) required but $method found.")
+ }
+ }
+
+ def transform(dataset: DataFrame): DataFrame = {
+ pipeline.transform(dataset).drop(kMeansModel.getFeaturesCol)
+ }
+
+}
+
+private[r] object KMeansWrapper {
+
+ def fit(
+ data: DataFrame,
+ k: Double,
+ maxIter: Double,
+ initMode: String,
+ columns: Array[String]): KMeansWrapper = {
+
+ val assembler = new VectorAssembler()
+ .setInputCols(columns)
+ .setOutputCol("features")
+
+ val kMeans = new KMeans()
+ .setK(k.toInt)
+ .setMaxIter(maxIter.toInt)
+ .setInitMode(initMode)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(assembler, kMeans))
+ .fit(data)
+
+ new KMeansWrapper(pipeline)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
index d23e4fc9d1..551e75dc0a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala
@@ -20,8 +20,7 @@ package org.apache.spark.ml.api.r
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.ml.attribute._
import org.apache.spark.ml.classification.{LogisticRegression, LogisticRegressionModel}
-import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
-import org.apache.spark.ml.feature.{RFormula, VectorAssembler}
+import org.apache.spark.ml.feature.RFormula
import org.apache.spark.ml.regression.{LinearRegression, LinearRegressionModel}
import org.apache.spark.sql.DataFrame
@@ -52,22 +51,6 @@ private[r] object SparkRWrappers {
pipeline.fit(df)
}
- def fitKMeans(
- df: DataFrame,
- initMode: String,
- maxIter: Double,
- k: Double,
- columns: Array[String]): PipelineModel = {
- val assembler = new VectorAssembler().setInputCols(columns)
- val kMeans = new KMeans()
- .setInitMode(initMode)
- .setMaxIter(maxIter.toInt)
- .setK(k.toInt)
- .setFeaturesCol(assembler.getOutputCol)
- val pipeline = new Pipeline().setStages(Array(assembler, kMeans))
- pipeline.fit(df)
- }
-
def getModelCoefficients(model: PipelineModel): Array[Double] = {
model.stages.last match {
case m: LinearRegressionModel => {
@@ -89,8 +72,6 @@ private[r] object SparkRWrappers {
m.coefficients.toArray
}
}
- case m: KMeansModel =>
- m.clusterCenters.flatMap(_.toArray)
}
}
@@ -104,31 +85,6 @@ private[r] object SparkRWrappers {
}
}
- def getKMeansModelSize(model: PipelineModel): Array[Int] = {
- model.stages.last match {
- case m: KMeansModel => Array(m.getK) ++ m.summary.size
- case other => throw new UnsupportedOperationException(
- s"KMeansModel required but ${other.getClass.getSimpleName} found.")
- }
- }
-
- def getKMeansCluster(model: PipelineModel, method: String): DataFrame = {
- model.stages.last match {
- case m: KMeansModel =>
- if (method == "centers") {
- // Drop the assembled vector for easy-print to R side.
- m.summary.predictions.drop(m.summary.featuresCol)
- } else if (method == "classes") {
- m.summary.cluster
- } else {
- throw new UnsupportedOperationException(
- s"Method (centers or classes) required but $method found.")
- }
- case other => throw new UnsupportedOperationException(
- s"KMeansModel required but ${other.getClass.getSimpleName} found.")
- }
- }
-
def getModelFeatures(model: PipelineModel): Array[String] = {
model.stages.last match {
case m: LinearRegressionModel =>
@@ -147,10 +103,6 @@ private[r] object SparkRWrappers {
} else {
attrs.attributes.get.map(_.name.get)
}
- case m: KMeansModel =>
- val attrs = AttributeGroup.fromStructField(
- m.summary.predictions.schema(m.summary.featuresCol))
- attrs.attributes.get.map(_.name.get)
}
}
@@ -160,8 +112,6 @@ private[r] object SparkRWrappers {
"LinearRegressionModel"
case m: LogisticRegressionModel =>
"LogisticRegressionModel"
- case m: KMeansModel =>
- "KMeansModel"
}
}
}