aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/ml-clustering.md71
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java8
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala49
3 files changed, 100 insertions, 28 deletions
diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md
index a59f7e3005..440c455cd0 100644
--- a/docs/ml-clustering.md
+++ b/docs/ml-clustering.md
@@ -11,6 +11,77 @@ In this section, we introduce the pipeline API for [clustering in mllib](mllib-c
* This will become a table of contents (this text will be scraped).
{:toc}
+## K-means
+
+[k-means](http://en.wikipedia.org/wiki/K-means_clustering) is one of the
+most commonly used clustering algorithms that clusters the data points into a
+predefined number of clusters. The MLlib implementation includes a parallelized
+variant of the [k-means++](http://en.wikipedia.org/wiki/K-means%2B%2B) method
+called [kmeans||](http://theory.stanford.edu/~sergei/papers/vldb12-kmpar.pdf).
+
+`KMeans` is implemented as an `Estimator` and generates a `KMeansModel` as the base model.
+
+### Input Columns
+
+<table class="table">
+ <thead>
+ <tr>
+ <th align="left">Param name</th>
+ <th align="left">Type(s)</th>
+ <th align="left">Default</th>
+ <th align="left">Description</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>featuresCol</td>
+ <td>Vector</td>
+ <td>"features"</td>
+ <td>Feature vector</td>
+ </tr>
+ </tbody>
+</table>
+
+### Output Columns
+
+<table class="table">
+ <thead>
+ <tr>
+ <th align="left">Param name</th>
+ <th align="left">Type(s)</th>
+ <th align="left">Default</th>
+ <th align="left">Description</th>
+ </tr>
+ </thead>
+ <tbody>
+ <tr>
+ <td>predictionCol</td>
+ <td>Int</td>
+ <td>"prediction"</td>
+ <td>Predicted cluster center</td>
+ </tr>
+ </tbody>
+</table>
+
+### Example
+
+<div class="codetabs">
+
+<div data-lang="scala" markdown="1">
+Refer to the [Scala API docs](api/scala/index.html#org.apache.spark.ml.clustering.KMeans) for more details.
+
+{% include_example scala/org/apache/spark/examples/ml/KMeansExample.scala %}
+</div>
+
+<div data-lang="java" markdown="1">
+Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html) for more details.
+
+{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
+</div>
+
+</div>
+
+
## Latent Dirichlet allocation (LDA)
`LDA` is implemented as an `Estimator` that supports both `EMLDAOptimizer` and `OnlineLDAOptimizer`,
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
index 47665ff2b1..96481d882a 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
@@ -23,6 +23,9 @@ import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+// $example on$
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
@@ -30,11 +33,10 @@ import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.catalyst.expressions.GenericRow;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+// $example off$
/**
@@ -74,6 +76,7 @@ public class JavaKMeansExample {
JavaSparkContext jsc = new JavaSparkContext(conf);
SQLContext sqlContext = new SQLContext(jsc);
+ // $example on$
// Loads data
JavaRDD<Row> points = jsc.textFile(inputFile).map(new ParsePoint());
StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
@@ -91,6 +94,7 @@ public class JavaKMeansExample {
for (Vector center: centers) {
System.out.println(center);
}
+ // $example off$
jsc.stop();
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
index 5ce38462d1..af90652b55 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
@@ -17,57 +17,54 @@
package org.apache.spark.examples.ml
-import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
-import org.apache.spark.ml.clustering.KMeans
-import org.apache.spark.sql.{Row, SQLContext}
-import org.apache.spark.sql.types.{StructField, StructType}
+// scalastyle:off println
+import org.apache.spark.{SparkConf, SparkContext}
+// $example on$
+import org.apache.spark.ml.clustering.KMeans
+import org.apache.spark.mllib.linalg.Vectors
+// $example off$
+import org.apache.spark.sql.{DataFrame, SQLContext}
/**
* An example demonstrating a k-means clustering.
* Run with
* {{{
- * bin/run-example ml.KMeansExample <file> <k>
+ * bin/run-example ml.KMeansExample
* }}}
*/
object KMeansExample {
- final val FEATURES_COL = "features"
-
def main(args: Array[String]): Unit = {
- if (args.length != 2) {
- // scalastyle:off println
- System.err.println("Usage: ml.KMeansExample <file> <k>")
- // scalastyle:on println
- System.exit(1)
- }
- val input = args(0)
- val k = args(1).toInt
-
// Creates a Spark context and a SQL context
val conf = new SparkConf().setAppName(s"${this.getClass.getSimpleName}")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
- // Loads data
- val rowRDD = sc.textFile(input).filter(_.nonEmpty)
- .map(_.split(" ").map(_.toDouble)).map(Vectors.dense).map(Row(_))
- val schema = StructType(Array(StructField(FEATURES_COL, new VectorUDT, false)))
- val dataset = sqlContext.createDataFrame(rowRDD, schema)
+ // $example on$
+ // Crates a DataFrame
+ val dataset: DataFrame = sqlContext.createDataFrame(Seq(
+ (1, Vectors.dense(0.0, 0.0, 0.0)),
+ (2, Vectors.dense(0.1, 0.1, 0.1)),
+ (3, Vectors.dense(0.2, 0.2, 0.2)),
+ (4, Vectors.dense(9.0, 9.0, 9.0)),
+ (5, Vectors.dense(9.1, 9.1, 9.1)),
+ (6, Vectors.dense(9.2, 9.2, 9.2))
+ )).toDF("id", "features")
// Trains a k-means model
val kmeans = new KMeans()
- .setK(k)
- .setFeaturesCol(FEATURES_COL)
+ .setK(2)
+ .setFeaturesCol("features")
+ .setPredictionCol("prediction")
val model = kmeans.fit(dataset)
// Shows the result
- // scalastyle:off println
println("Final Centers: ")
model.clusterCenters.foreach(println)
- // scalastyle:on println
+ // $example off$
sc.stop()
}
}
+// scalastyle:on println