aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2015-07-17 18:30:04 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-17 18:30:04 -0700
commit34a889db857f8752a0a78dcedec75ac6cd6cd48d (patch)
treed3d059330619ae63f0fc794706ebbfc927049b0b /mllib
parent529a2c2d92fef062e0078a8608fa3a8ae848c139 (diff)
downloadspark-34a889db857f8752a0a78dcedec75ac6cd6cd48d.tar.gz
spark-34a889db857f8752a0a78dcedec75ac6cd6cd48d.tar.bz2
spark-34a889db857f8752a0a78dcedec75ac6cd6cd48d.zip
[SPARK-7879] [MLLIB] KMeans API for spark.ml Pipelines
I Implemented the KMeans API for spark.ml Pipelines. But it doesn't include clustering abstractions for spark.ml (SPARK-7610). It would fit for another issues. And I'll try it later, since we are trying to add the hierarchical clustering algorithms in another issue. Thanks. [SPARK-7879] KMeans API for spark.ml Pipelines - ASF JIRA https://issues.apache.org/jira/browse/SPARK-7879 Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #6756 from yu-iskw/SPARK-7879 and squashes the following commits: be752de [Yu ISHIKAWA] Add assertions a14939b [Yu ISHIKAWA] Fix the dashed line's length in pyspark.ml.rst 4c61693 [Yu ISHIKAWA] Remove the test about whether "features" and "prediction" columns exist or not in Python fb2417c [Yu ISHIKAWA] Use getInt, instead of get f397be4 [Yu ISHIKAWA] Switch the comparisons. ca78b7d [Yu ISHIKAWA] Add the Scala docs about the constraints of each parameter. effc650 [Yu ISHIKAWA] Using expertSetParam and expertGetParam c8dc6e6 [Yu ISHIKAWA] Remove an unnecessary test 19a9d63 [Yu ISHIKAWA] Include spark.ml.clustering to python tests 1abb19c [Yu ISHIKAWA] Add the statements about spark.ml.clustering into pyspark.ml.rst f8338bc [Yu ISHIKAWA] Add the placeholders in Python 4a03003 [Yu ISHIKAWA] Test for contains in Python 6566c8b [Yu ISHIKAWA] Use `get`, instead of `apply` 288e8d5 [Yu ISHIKAWA] Using `contains` to check the column names 5a7d574 [Yu ISHIKAWA] Renamce `validateInitializationMode` to `validateInitMode` and remove throwing exception 97cfae3 [Yu ISHIKAWA] Fix the type of return value of `KMeans.copy` e933723 [Yu ISHIKAWA] Remove the default value of seed from the Model class 978ee2c [Yu ISHIKAWA] Modify the docs of KMeans, according to mllib's KMeans 2ec80bc [Yu ISHIKAWA] Fit on 1 line e186be1 [Yu ISHIKAWA] Make a few variables, setters and getters be expert ones b2c205c [Yu ISHIKAWA] Rename the method `getInitializationSteps` to `getInitSteps` and `setInitializationSteps` to `setInitSteps` in Scala and Python f43f5b4 [Yu ISHIKAWA] Rename the method `getInitializationMode` to `getInitMode` and `setInitializationMode` to `setInitMode` in Scala and Python 3cb5ba4 [Yu ISHIKAWA] Modify the description about epsilon and the validation 4fa409b [Yu ISHIKAWA] Add a comment about the default value of epsilon 2f392e1 [Yu ISHIKAWA] Make some variables `final` and Use `IntParam` and `DoubleParam` 19326f8 [Yu ISHIKAWA] Use `udf`, instead of callUDF 4d2ad1e [Yu ISHIKAWA] Modify the indentations 0ae422f [Yu ISHIKAWA] Add a test for `setParams` 4ff7913 [Yu ISHIKAWA] Add "ml.clustering" to `javacOptions` in SparkBuild.scala 11ffdf1 [Yu ISHIKAWA] Use `===` and the variable 220a176 [Yu ISHIKAWA] Set a random seed in the unit testing 92c3efc [Yu ISHIKAWA] Make the points for a test be fewer c758692 [Yu ISHIKAWA] Modify the parameters of KMeans in Python 6aca147 [Yu ISHIKAWA] Add some unit testings to validate the setter methods 687cacc [Yu ISHIKAWA] Alias mllib.KMeans as MLlibKMeans in KMeansSuite.scala a4dfbef [Yu ISHIKAWA] Modify the last brace and indentations 5bedc51 [Yu ISHIKAWA] Remve an extra new line 444c289 [Yu ISHIKAWA] Add the validation for `runs` e41989c [Yu ISHIKAWA] Modify how to validate `initStep` 7ea133a [Yu ISHIKAWA] Change how to validate `initMode` 7991e15 [Yu ISHIKAWA] Add a validation for `k` c2df35d [Yu ISHIKAWA] Make `predict` private 93aa2ff [Yu ISHIKAWA] Use `withColumn` in `transform` d3a79f7 [Yu ISHIKAWA] Remove the inhefited docs e9532e1 [Yu ISHIKAWA] make `parentModel` of KMeansModel private 8559772 [Yu ISHIKAWA] Remove the `paramMap` parameter of KMeans 6684850 [Yu ISHIKAWA] Rename `initializationSteps` to `initSteps` 99b1b96 [Yu ISHIKAWA] Rename `initializationMode` to `initMode` 79ea82b [Yu ISHIKAWA] Modify the parameters of KMeans docs 6569bcd [Yu ISHIKAWA] Change how to set the default values with `setDefault` 20a795a [Yu ISHIKAWA] Change how to set the default values with `setDefault` 11c2a12 [Yu ISHIKAWA] Limit the imports badb481 [Yu ISHIKAWA] Alias spark.mllib.{KMeans, KMeansModel} f80319a [Yu ISHIKAWA] Rebase mater branch and add copy methods 85d92b1 [Yu ISHIKAWA] Add `KMeans.setPredictionCol` aa9469d [Yu ISHIKAWA] Fix a python test suite error caused by python 3.x c2d6bcb [Yu ISHIKAWA] ADD Java test suites of the KMeans API for spark.ml Pipeline 598ed2e [Yu ISHIKAWA] Implement the KMeans API for spark.ml Pipelines in Python 63ad785 [Yu ISHIKAWA] Implement the KMeans API for spark.ml Pipelines in Scala
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala205
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala12
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java72
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala114
4 files changed, 400 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
new file mode 100644
index 0000000000..dc192add6c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/KMeans.scala
@@ -0,0 +1,205 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.param.{Param, Params, IntParam, DoubleParam, ParamMap}
+import org.apache.spark.ml.param.shared.{HasFeaturesCol, HasMaxIter, HasPredictionCol, HasSeed}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans, KMeansModel => MLlibKMeansModel}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.functions.{col, udf}
+import org.apache.spark.sql.types.{IntegerType, StructType}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.util.Utils
+
+
+/**
+ * Common params for KMeans and KMeansModel
+ */
+private[clustering] trait KMeansParams
+ extends Params with HasMaxIter with HasFeaturesCol with HasSeed with HasPredictionCol {
+
+ /**
+ * Set the number of clusters to create (k). Must be > 1. Default: 2.
+ * @group param
+ */
+ final val k = new IntParam(this, "k", "number of clusters to create", (x: Int) => x > 1)
+
+ /** @group getParam */
+ def getK: Int = $(k)
+
+ /**
+ * Param the number of runs of the algorithm to execute in parallel. We initialize the algorithm
+ * this many times with random starting conditions (configured by the initialization mode), then
+ * return the best clustering found over any run. Must be >= 1. Default: 1.
+ * @group param
+ */
+ final val runs = new IntParam(this, "runs",
+ "number of runs of the algorithm to execute in parallel", (value: Int) => value >= 1)
+
+ /** @group getParam */
+ def getRuns: Int = $(runs)
+
+ /**
+ * Param the distance threshold within which we've consider centers to have converged.
+ * If all centers move less than this Euclidean distance, we stop iterating one run.
+ * Must be >= 0.0. Default: 1e-4
+ * @group param
+ */
+ final val epsilon = new DoubleParam(this, "epsilon",
+ "distance threshold within which we've consider centers to have converge",
+ (value: Double) => value >= 0.0)
+
+ /** @group getParam */
+ def getEpsilon: Double = $(epsilon)
+
+ /**
+ * Param for the initialization algorithm. This can be either "random" to choose random points as
+ * initial cluster centers, or "k-means||" to use a parallel variant of k-means++
+ * (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
+ * @group expertParam
+ */
+ final val initMode = new Param[String](this, "initMode", "initialization algorithm",
+ (value: String) => MLlibKMeans.validateInitMode(value))
+
+ /** @group expertGetParam */
+ def getInitMode: String = $(initMode)
+
+ /**
+ * Param for the number of steps for the k-means|| initialization mode. This is an advanced
+ * setting -- the default of 5 is almost always enough. Must be > 0. Default: 5.
+ * @group expertParam
+ */
+ final val initSteps = new IntParam(this, "initSteps", "number of steps for k-means||",
+ (value: Int) => value > 0)
+
+ /** @group expertGetParam */
+ def getInitSteps: Int = $(initSteps)
+
+ /**
+ * Validates and transforms the input schema.
+ * @param schema input schema
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(featuresCol), new VectorUDT)
+ SchemaUtils.appendColumn(schema, $(predictionCol), IntegerType)
+ }
+}
+
+/**
+ * :: Experimental ::
+ * Model fitted by KMeans.
+ *
+ * @param parentModel a model trained by spark.mllib.clustering.KMeans.
+ */
+@Experimental
+class KMeansModel private[ml] (
+ override val uid: String,
+ private val parentModel: MLlibKMeansModel) extends Model[KMeansModel] with KMeansParams {
+
+ override def copy(extra: ParamMap): KMeansModel = {
+ val copied = new KMeansModel(uid, parentModel)
+ copyValues(copied, extra)
+ }
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ val predictUDF = udf((vector: Vector) => predict(vector))
+ dataset.withColumn($(predictionCol), predictUDF(col($(featuresCol))))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ private[clustering] def predict(features: Vector): Int = parentModel.predict(features)
+
+ def clusterCenters: Array[Vector] = parentModel.clusterCenters
+}
+
+/**
+ * :: Experimental ::
+ * K-means clustering with support for multiple parallel runs and a k-means++ like initialization
+ * mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
+ * they are executed together with joint passes over the data for efficiency.
+ */
+@Experimental
+class KMeans(override val uid: String) extends Estimator[KMeansModel] with KMeansParams {
+
+ setDefault(
+ k -> 2,
+ maxIter -> 20,
+ runs -> 1,
+ initMode -> MLlibKMeans.K_MEANS_PARALLEL,
+ initSteps -> 5,
+ epsilon -> 1e-4)
+
+ override def copy(extra: ParamMap): KMeans = defaultCopy(extra)
+
+ def this() = this(Identifiable.randomUID("kmeans"))
+
+ /** @group setParam */
+ def setFeaturesCol(value: String): this.type = set(featuresCol, value)
+
+ /** @group setParam */
+ def setPredictionCol(value: String): this.type = set(predictionCol, value)
+
+ /** @group setParam */
+ def setK(value: Int): this.type = set(k, value)
+
+ /** @group expertSetParam */
+ def setInitMode(value: String): this.type = set(initMode, value)
+
+ /** @group expertSetParam */
+ def setInitSteps(value: Int): this.type = set(initSteps, value)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ def setRuns(value: Int): this.type = set(runs, value)
+
+ /** @group setParam */
+ def setEpsilon(value: Double): this.type = set(epsilon, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ override def fit(dataset: DataFrame): KMeansModel = {
+ val rdd = dataset.select(col($(featuresCol))).map { case Row(point: Vector) => point }
+
+ val algo = new MLlibKMeans()
+ .setK($(k))
+ .setInitializationMode($(initMode))
+ .setInitializationSteps($(initSteps))
+ .setMaxIterations($(maxIter))
+ .setSeed($(seed))
+ .setEpsilon($(epsilon))
+ .setRuns($(runs))
+ val parentModel = algo.run(rdd)
+ val model = new KMeansModel(uid, parentModel)
+ copyValues(model)
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+}
+
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index 68297130a7..0a65403f4e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -85,9 +85,7 @@ class KMeans private (
* (Bahmani et al., Scalable K-Means++, VLDB 2012). Default: k-means||.
*/
def setInitializationMode(initializationMode: String): this.type = {
- if (initializationMode != KMeans.RANDOM && initializationMode != KMeans.K_MEANS_PARALLEL) {
- throw new IllegalArgumentException("Invalid initialization mode: " + initializationMode)
- }
+ KMeans.validateInitMode(initializationMode)
this.initializationMode = initializationMode
this
}
@@ -550,6 +548,14 @@ object KMeans {
v2: VectorWithNorm): Double = {
MLUtils.fastSquaredDistance(v1.vector, v1.norm, v2.vector, v2.norm)
}
+
+ private[spark] def validateInitMode(initMode: String): Boolean = {
+ initMode match {
+ case KMeans.RANDOM => true
+ case KMeans.K_MEANS_PARALLEL => true
+ case _ => false
+ }
+ }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
new file mode 100644
index 0000000000..d09fa7fd56
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/clustering/JavaKMeansSuite.java
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+public class JavaKMeansSuite implements Serializable {
+
+ private transient int k = 5;
+ private transient JavaSparkContext sc;
+ private transient DataFrame dataset;
+ private transient SQLContext sql;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaKMeansSuite");
+ sql = new SQLContext(sc);
+
+ dataset = KMeansSuite.generateKMeansData(sql, 50, 3, k);
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void fitAndTransform() {
+ KMeans kmeans = new KMeans().setK(k).setSeed(1);
+ KMeansModel model = kmeans.fit(dataset);
+
+ Vector[] centers = model.clusterCenters();
+ assertEquals(k, centers.length);
+
+ DataFrame transformed = model.transform(dataset);
+ List<String> columns = Arrays.asList(transformed.columns());
+ List<String> expectedColumns = Arrays.asList("features", "prediction");
+ for (String column: expectedColumns) {
+ assertTrue(columns.contains(column));
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
new file mode 100644
index 0000000000..1f15ac02f4
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.clustering.{KMeans => MLlibKMeans}
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+private[clustering] case class TestRow(features: Vector)
+
+object KMeansSuite {
+ def generateKMeansData(sql: SQLContext, rows: Int, dim: Int, k: Int): DataFrame = {
+ val sc = sql.sparkContext
+ val rdd = sc.parallelize(1 to rows).map(i => Vectors.dense(Array.fill(dim)((i % k).toDouble)))
+ .map(v => new TestRow(v))
+ sql.createDataFrame(rdd)
+ }
+}
+
+class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ final val k = 5
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+
+ dataset = KMeansSuite.generateKMeansData(sqlContext, 50, 3, k)
+ }
+
+ test("default parameters") {
+ val kmeans = new KMeans()
+
+ assert(kmeans.getK === 2)
+ assert(kmeans.getFeaturesCol === "features")
+ assert(kmeans.getPredictionCol === "prediction")
+ assert(kmeans.getMaxIter === 20)
+ assert(kmeans.getRuns === 1)
+ assert(kmeans.getInitMode === MLlibKMeans.K_MEANS_PARALLEL)
+ assert(kmeans.getInitSteps === 5)
+ assert(kmeans.getEpsilon === 1e-4)
+ }
+
+ test("set parameters") {
+ val kmeans = new KMeans()
+ .setK(9)
+ .setFeaturesCol("test_feature")
+ .setPredictionCol("test_prediction")
+ .setMaxIter(33)
+ .setRuns(7)
+ .setInitMode(MLlibKMeans.RANDOM)
+ .setInitSteps(3)
+ .setSeed(123)
+ .setEpsilon(1e-3)
+
+ assert(kmeans.getK === 9)
+ assert(kmeans.getFeaturesCol === "test_feature")
+ assert(kmeans.getPredictionCol === "test_prediction")
+ assert(kmeans.getMaxIter === 33)
+ assert(kmeans.getRuns === 7)
+ assert(kmeans.getInitMode === MLlibKMeans.RANDOM)
+ assert(kmeans.getInitSteps === 3)
+ assert(kmeans.getSeed === 123)
+ assert(kmeans.getEpsilon === 1e-3)
+ }
+
+ test("parameters validation") {
+ intercept[IllegalArgumentException] {
+ new KMeans().setK(1)
+ }
+ intercept[IllegalArgumentException] {
+ new KMeans().setInitMode("no_such_a_mode")
+ }
+ intercept[IllegalArgumentException] {
+ new KMeans().setInitSteps(0)
+ }
+ intercept[IllegalArgumentException] {
+ new KMeans().setRuns(0)
+ }
+ }
+
+ test("fit & transform") {
+ val predictionColName = "kmeans_prediction"
+ val kmeans = new KMeans().setK(k).setPredictionCol(predictionColName).setSeed(1)
+ val model = kmeans.fit(dataset)
+ assert(model.clusterCenters.length === k)
+
+ val transformed = model.transform(dataset)
+ val expectedColumns = Array("features", predictionColName)
+ expectedColumns.foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+ val clusters = transformed.select(predictionColName).map(_.getInt(0)).distinct().collect().toSet
+ assert(clusters.size === k)
+ assert(clusters === Set(0, 1, 2, 3, 4))
+ }
+}