diff options
3 files changed, 100 insertions, 28 deletions
diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md index a59f7e3005..440c455cd0 100644 --- a/docs/ml-clustering.md +++ b/docs/ml-clustering.md @@ -11,6 +11,77 @@ In this section, we introduce the pipeline API for [clustering in mllib](mllib-c * This will become a table of contents (this text will be scraped). {:toc} +## K-means + +[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the +most commonly used clustering algorithms that clusters the data points into a +predefined number of clusters. The MLlib implementation includes a parallelized +variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method +called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf). + +`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model. + +### Input Columns + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + </tr> + </thead> + <tbody> + <tr> + <td>featuresCol</td> + <td>Vector</td> + <td>"features"</td> + <td>Feature vector</td> + </tr> + </tbody> +</table> + +### Output Columns + +<table class="table"> + <thead> + <tr> + <th align="left">Param name</th> + <th align="left">Type(s)</th> + <th align="left">Default</th> + <th align="left">Description</th> + </tr> + </thead> + <tbody> + <tr> + <td>predictionCol</td> + <td>Int</td> + <td>"prediction"</td> + <td>Predicted cluster center</td> + </tr> + </tbody> +</table> + +### Example + +<div class="codetabs"> + +<div data-lang="scala" markdown="1"> +Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details. + +{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %} +</div> + +<div data-lang="java" markdown="1"> +Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details. + +{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %} +</div> + +</div> + + ## Latent Dirichlet allocation (LDA) `LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`, diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index 47665ff2b1..96481d882a 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -23,6 +23,9 @@ import org.apache.spark.SparkConf; import org.apache.spark.api.java.JavaRDD; import org.apache.spark.api.java.JavaSparkContext; import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SQLContext; +import org.apache.spark.sql.catalyst.expressions.GenericRow; +// $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; import org.apache.spark.mllib.linalg.Vector; @@ -30,11 +33,10 @@ import org.apache.spark.mllib.linalg.VectorUDT; import org.apache.spark.mllib.linalg.Vectors; import org.apache.spark.sql.DataFrame; import org.apache.spark.sql.Row; -import org.apache.spark.sql.SQLContext; -import org.apache.spark.sql.catalyst.expressions.GenericRow; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +// $example off$ /** @@ -74,6 +76,7 @@ public class JavaKMeansExample { JavaSparkContext jsc = new JavaSparkContext(conf); SQLContext sqlContext = new SQLContext(jsc); + // $example on$ // Loads data JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint()); StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; @@ -91,6 +94,7 @@ public class JavaKMeansExample { for (Vector center: centers) { System.out.println(center); } + // $example off$ jsc.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala index 5ce38462d1..af90652b55 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala @@ -17,57 +17,54 @@ package org.apache.spark.examples.ml -import org.apache.spark.{SparkContext, SparkConf} -import org.apache.spark.mllib.linalg.{VectorUDT, Vectors} -import org.apache.spark.ml.clustering.KMeans -import org.apache.spark.sql.{Row, SQLContext} -import org.apache.spark.sql.types.{StructField, StructType} +// scalastyle:off println +import org.apache.spark.{SparkConf, SparkContext} +// $example on$ +import org.apache.spark.ml.clustering.KMeans +import org.apache.spark.mllib.linalg.Vectors +// $example off$ +import org.apache.spark.sql.{DataFrame, SQLContext} /** * An example demonstrating a k-means clustering. * Run with * {{{ - * bin/run-example ml.KMeansExample <file> <k> + * bin/run-example ml.KMeansExample * }}} */ object KMeansExample { - final val FEATURES_COL = "features" - def main(args: Array[String]): Unit = { - if (args.length != 2) { - // scalastyle:off println - System.err.println("Usage: ml.KMeansExample <file> <k>") - // scalastyle:on println - System.exit(1) - } - val input = args(0) - val k = args(1).toInt - // Creates a Spark context and a SQL context val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}") val sc = new SparkContext(conf) val sqlContext = new SQLContext(sc) - // Loads data - val rowRDD = sc.textFile(input).filter(_.nonEmpty) - .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_)) - val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false))) - val dataset = sqlContext.createDataFrame(rowRDD, schema) + // $example on$ + // Crates a DataFrame + val dataset: DataFrame = sqlContext.createDataFrame(Seq( + (1, Vectors.dense(0.0, 0.0, 0.0)), + (2, Vectors.dense(0.1, 0.1, 0.1)), + (3, Vectors.dense(0.2, 0.2, 0.2)), + (4, Vectors.dense(9.0, 9.0, 9.0)), + (5, Vectors.dense(9.1, 9.1, 9.1)), + (6, Vectors.dense(9.2, 9.2, 9.2)) + )).toDF("id", "features") // Trains a k-means model val kmeans = new KMeans() - .setK(k) - .setFeaturesCol(FEATURES_COL) + .setK(2) + .setFeaturesCol("features") + .setPredictionCol("prediction") val model = kmeans.fit(dataset) // Shows the result - // scalastyle:off println println("Final Centers: ") model.clusterCenters.foreach(println) - // scalastyle:on println + // $example off$ sc.stop() } } +// scalastyle:on println |