aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2016-01-20 10:48:10 -0800
committerXiangrui Meng <meng@databricks.com>2016-01-20 10:48:10 -0800
commit9376ae723e4ec0515120c488541617a0538f8879 (patch)
tree55453b88f46862d5d43fd02d2f53e126d6444b4f /mllib/src
parent8e4f894e986ccd943df9ddf55fc853eb0558886f (diff)
downloadspark-9376ae723e4ec0515120c488541617a0538f8879.tar.gz
spark-9376ae723e4ec0515120c488541617a0538f8879.tar.bz2
spark-9376ae723e4ec0515120c488541617a0538f8879.zip
[SPARK-6519][ML] Add spark.ml API for bisecting k-means
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #9604 from yu-iskw/SPARK-6519.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala196
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala85
2 files changed, 281 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
new file mode 100644
index 0000000000..0b47cbbac8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/BisectingKMeans.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param.{IntParam, Param, ParamMap, Params}
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.mllib.clustering.
+ {BisectingKMeans => MLlibBisectingKMeans, BisectingKMeansModel => MLlibBisectingKMeansModel}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+
+
+/**
+ * Common params for BisectingKMeans and BisectingKMeansModel
+ */
+private[clustering] trait BisectingKMeansParams extends Params
+ with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
+
+ /**
+ * Set the number of clusters to create (k). Must be > 1. Default: 2.
+ * @group param
+ */
+ @Since("2.0.0")
+ final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
+
+ /** @group getParam */
+ @Since("2.0.0")
+ def getK: Int = $(k)
+
+ /** @group expertParam */
+ @Since("2.0.0")
+ final val minDivisibleClusterSize = new Param[Double](
+ this,
+ "minDivisibleClusterSize",
+ "the minimum number of points (if >= 1.0) or the minimum proportion",
+ (value: Double) => value > 0)
+
+ /** @group expertGetParam */
+ @Since("2.0.0")
+ def getMinDivisibleClusterSize: Double = $(minDivisibleClusterSize)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by BisectingKMeans.
+ *
+ * @param parentModel a model trained by spark.mllib.clustering.BisectingKMeans.
+ */
+@Since("2.0.0")
+@Experimental
+class BisectingKMeansModel private[ml] (
+ @Since("2.0.0") override val uid: String,
+ private val parentModel: MLlibBisectingKMeansModel
+ ) extends Model[BisectingKMeansModel] with BisectingKMeansParams {
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): BisectingKMeansModel = {
+ val copied = new BisectingKMeansModel(uid, parentModel)
+ copyValues(copied, extra)
+ }
+
+ @Since("2.0.0")
+ override def transform(dataset: DataFrame): DataFrame = {
+ val predictUDF = udf((vector: Vector) => predict(vector))
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
+ @Since("2.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+
+ @Since("2.0.0")
+ def clusterCenters: Array[Vector] = parentModel.clusterCenters
+
+ /**
+ * Computes the sum of squared distances between the input points and their corresponding cluster
+ * centers.
+ */
+ @Since("2.0.0")
+ def computeCost(dataset: DataFrame): Double = {
+ SchemaUtils.checkColumnType(dataset.schema, $(featuresCol), new VectorUDT)
+ val data = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
+ parentModel.computeCost(data)
+ }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * A bisecting k-means algorithm based on the paper "A comparison of document clustering techniques"
+ * by Steinbach, Karypis, and Kumar, with modification to fit Spark.
+ * The algorithm starts from a single cluster that contains all points.
+ * Iteratively it finds divisible clusters on the bottom level and bisects each of them using
+ * k-means, until there are `k` leaf clusters in total or no leaf clusters are divisible.
+ * The bisecting steps of clusters on the same level are grouped together to increase parallelism.
+ * If bisecting all divisible clusters on the bottom level would result more than `k` leaf clusters,
+ * larger clusters get higher priority.
+ *
+ * @see [[http://glaros.dtc.umn.edu/gkhome/fetch/papers/docclusterKDDTMW00.pdf
+ * Steinbach, Karypis, and Kumar, A comparison of document clustering techniques,
+ * KDD Workshop on Text Mining, 2000.]]
+ */
+@Since("2.0.0")
+@Experimental
+class BisectingKMeans @Since("2.0.0") (
+ @Since("2.0.0") override val uid: String)
+ extends Estimator[BisectingKMeansModel] with BisectingKMeansParams {
+
+ setDefault(
+ k -> 4,
+ maxIter -> 20,
+ minDivisibleClusterSize -> 1.0)
+
+ @Since("2.0.0")
+ override def copy(extra: ParamMap): BisectingKMeans = defaultCopy(extra)
+
+ @Since("2.0.0")
+ def this() = this(Identifiable.randomUID("bisecting k-means"))
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group expertSetParam */
+ @Since("2.0.0")
+ def setMinDivisibleClusterSize(value: Double): this.type = set(minDivisibleClusterSize, value)
+
+ @Since("2.0.0")
+ override def fit(dataset: DataFrame): BisectingKMeansModel = {
+ val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
+
+ val bkm = new MLlibBisectingKMeans()
+ .setK($(k))
+ .setMaxIterations($(maxIter))
+ .setMinDivisibleClusterSize($(minDivisibleClusterSize))
+ .setSeed($(seed))
+ val parentModel = bkm.run(rdd)
+ val model = new BisectingKMeansModel(uid, parentModel)
+ copyValues(model)
+ }
+
+ @Since("2.0.0")
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+}
+
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
new file mode 100644
index 0000000000..b26571eb9f
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.DataFrame
+
+class BisectingKMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ final val k = 5
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ }
+
+ test("default parameters") {
+ val bkm = new BisectingKMeans()
+
+ assert(bkm.getK === 4)
+ assert(bkm.getFeaturesCol === "features")
+ assert(bkm.getPredictionCol === "prediction")
+ assert(bkm.getMaxIter === 20)
+ assert(bkm.getMinDivisibleClusterSize === 1.0)
+ }
+
+ test("setter/getter") {
+ val bkm = new BisectingKMeans()
+ .setK(9)
+ .setMinDivisibleClusterSize(2.0)
+ .setFeaturesCol("test_feature")
+ .setPredictionCol("test_prediction")
+ .setMaxIter(33)
+ .setSeed(123)
+
+ assert(bkm.getK === 9)
+ assert(bkm.getFeaturesCol === "test_feature")
+ assert(bkm.getPredictionCol === "test_prediction")
+ assert(bkm.getMaxIter === 33)
+ assert(bkm.getMinDivisibleClusterSize === 2.0)
+ assert(bkm.getSeed === 123)
+
+ intercept[IllegalArgumentException] {
+ new BisectingKMeans().setK(1)
+ }
+
+ intercept[IllegalArgumentException] {
+ new BisectingKMeans().setMinDivisibleClusterSize(0)
+ }
+ }
+
+ test("fit & transform") {
+ val predictionColName = "bisecting_kmeans_prediction"
+ val bkm = new BisectingKMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
+ val model = bkm.fit(dataset)
+ assert(model.clusterCenters.length === k)
+
+ val transformed = model.transform(dataset)
+ val expectedColumns = Array("features", predictionColName)
+ expectedColumns.foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+ val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
+ assert(model.computeCost(dataset) < 0.1)
+ }
+}