aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-01-26 21:01:59 -0800
committerFelix Cheung <felixcheung@apache.org>2017-01-26 21:01:59 -0800
commitc0ba284300e494354f5bb205a10a12ac7daa2b5e (patch)
tree43592dfb0b4576ca018414721f809ecfa1972739 /mllib/src
parent1191fe267d2faad2a99a83f3375ce2d9d382cfa0 (diff)
downloadspark-c0ba284300e494354f5bb205a10a12ac7daa2b5e.tar.gz
spark-c0ba284300e494354f5bb205a10a12ac7daa2b5e.tar.bz2
spark-c0ba284300e494354f5bb205a10a12ac7daa2b5e.zip
[SPARK-18821][SPARKR] Bisecting k-means wrapper in SparkR
## What changes were proposed in this pull request? Add R wrapper for bisecting Kmeans. As JIRA is down, I will update title to link with corresponding JIRA later. ## How was this patch tested? Add new unit tests. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16566 from wangmiao1981/bk.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala143
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
2 files changed, 145 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala
new file mode 100644
index 0000000000..71712c1c5e
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/BisectingKMeansWrapper.scala
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.ml.{Pipeline, PipelineModel}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.clustering.{BisectingKMeans, BisectingKMeansModel}
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+
+private[r] class BisectingKMeansWrapper private (
+ val pipeline: PipelineModel,
+ val features: Array[String],
+ val size: Array[Long],
+ val isLoaded: Boolean = false) extends MLWritable {
+ private val bisectingKmeansModel: BisectingKMeansModel =
+ pipeline.stages.last.asInstanceOf[BisectingKMeansModel]
+
+ lazy val coefficients: Array[Double] = bisectingKmeansModel.clusterCenters.flatMap(_.toArray)
+
+ lazy val k: Int = bisectingKmeansModel.getK
+
+ // If the model is loaded from a saved model, cluster is NULL. It is checked on R side
+ lazy val cluster: DataFrame = bisectingKmeansModel.summary.cluster
+
+ def fitted(method: String): DataFrame = {
+ if (method == "centers") {
+ bisectingKmeansModel.summary.predictions.drop(bisectingKmeansModel.getFeaturesCol)
+ } else if (method == "classes") {
+ bisectingKmeansModel.summary.cluster
+ } else {
+ throw new UnsupportedOperationException(
+ s"Method (centers or classes) required but $method found.")
+ }
+ }
+
+ def transform(dataset: Dataset[_]): DataFrame = {
+ pipeline.transform(dataset).drop(bisectingKmeansModel.getFeaturesCol)
+ }
+
+ override def write: MLWriter = new BisectingKMeansWrapper.BisectingKMeansWrapperWriter(this)
+}
+
+private[r] object BisectingKMeansWrapper extends MLReadable[BisectingKMeansWrapper] {
+
+ def fit(
+ data: DataFrame,
+ formula: String,
+ k: Int,
+ maxIter: Int,
+ seed: String,
+ minDivisibleClusterSize: Double
+ ): BisectingKMeansWrapper = {
+
+ val rFormula = new RFormula()
+ .setFormula(formula)
+ .setFeaturesCol("features")
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
+
+ // get feature names from output schema
+ val schema = rFormulaModel.transform(data).schema
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
+ .attributes.get
+ val features = featureAttrs.map(_.name.get)
+
+ val bisectingKmeans = new BisectingKMeans()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setMinDivisibleClusterSize(minDivisibleClusterSize)
+ .setFeaturesCol(rFormula.getFeaturesCol)
+
+ if (seed != null && seed.length > 0) bisectingKmeans.setSeed(seed.toInt)
+
+ val pipeline = new Pipeline()
+ .setStages(Array(rFormulaModel, bisectingKmeans))
+ .fit(data)
+
+ val bisectingKmeansModel: BisectingKMeansModel =
+ pipeline.stages.last.asInstanceOf[BisectingKMeansModel]
+ val size: Array[Long] = bisectingKmeansModel.summary.clusterSizes
+
+ new BisectingKMeansWrapper(pipeline, features, size)
+ }
+
+ override def read: MLReader[BisectingKMeansWrapper] = new BisectingKMeansWrapperReader
+
+ override def load(path: String): BisectingKMeansWrapper = super.load(path)
+
+ class BisectingKMeansWrapperWriter(instance: BisectingKMeansWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("features" -> instance.features.toSeq) ~
+ ("size" -> instance.size.toSeq)
+ val rMetadataJson: String = compact(render(rMetadata))
+
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class BisectingKMeansWrapperReader extends MLReader[BisectingKMeansWrapper] {
+
+ override def load(path: String): BisectingKMeansWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+ val pipeline = PipelineModel.load(pipelinePath)
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val features = (rMetadata \ "features").extract[Array[String]]
+ val size = (rMetadata \ "size").extract[Array[Long]]
+ new BisectingKMeansWrapper(pipeline, features, size, isLoaded = true)
+ }
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index b59fe29234..c44179281b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -64,6 +64,8 @@ private[r] object RWrappers extends MLReader[Object] {
GBTRegressorWrapper.load(path)
case "org.apache.spark.ml.r.GBTClassifierWrapper" =>
GBTClassifierWrapper.load(path)
+ case "org.apache.spark.ml.r.BisectingKMeansWrapper" =>
+ BisectingKMeansWrapper.load(path)
case _ =>
throw new SparkException(s"SparkR read.ml does not support load $className")
}