aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--docs/ml-clustering.md5
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java60
-rw-r--r--examples/src/main/python/ml/kmeans_example.py46
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala33
4 files changed, 50 insertions, 94 deletions
diff --git a/docs/ml-clustering.md b/docs/ml-clustering.md
index 1245b8bbc8..876a280c4c 100644
--- a/docs/ml-clustering.md
+++ b/docs/ml-clustering.md
@@ -79,6 +79,11 @@ Refer to the [Java API docs](api/java/org/apache/spark/ml/clustering/KMeans.html
{% include_example java/org/apache/spark/examples/ml/JavaKMeansExample.java %}
</div>
+<div data-lang="python" markdown="1">
+Refer to the [Python API docs](api/python/pyspark.ml.html#pyspark.ml.clustering.KMeans) for more details.
+
+{% include_example python/ml/kmeans_example.py %}
+</div>
</div>
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
index 65e29ade29..2489a9b80b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
@@ -17,77 +17,45 @@
package org.apache.spark.examples.ml;
-import java.util.regex.Pattern;
-
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.catalyst.expressions.GenericRow;
// $example on$
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.types.Metadata;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
// $example off$
+import org.apache.spark.sql.SparkSession;
/**
- * An example demonstrating a k-means clustering.
+ * An example demonstrating k-means clustering.
* Run with
* <pre>
- * bin/run-example ml.JavaKMeansExample <file> <k>
+ * bin/run-example ml.JavaKMeansExample
* </pre>
*/
public class JavaKMeansExample {
- private static class ParsePoint implements Function<String, Row> {
- private static final Pattern separator = Pattern.compile(" ");
-
- @Override
- public Row call(String line) {
- String[] tok = separator.split(line);
- double[] point = new double[tok.length];
- for (int i = 0; i < tok.length; ++i) {
- point[i] = Double.parseDouble(tok[i]);
- }
- Vector[] points = {Vectors.dense(point)};
- return new GenericRow(points);
- }
- }
-
public static void main(String[] args) {
- if (args.length != 2) {
- System.err.println("Usage: ml.JavaKMeansExample <file> <k>");
- System.exit(1);
- }
- String inputFile = args[0];
- int k = Integer.parseInt(args[1]);
-
- // Parses the arguments
+ // Create a SparkSession.
SparkSession spark = SparkSession
.builder()
.appName("JavaKMeansExample")
.getOrCreate();
// $example on$
- // Loads data
- JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
- StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
- StructType schema = new StructType(fields);
- Dataset<Row> dataset = spark.createDataFrame(points, schema);
+ // Loads data.
+ Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
- // Trains a k-means model
- KMeans kmeans = new KMeans()
- .setK(k);
+ // Trains a k-means model.
+ KMeans kmeans = new KMeans().setK(2).setSeed(1L);
KMeansModel model = kmeans.fit(dataset);
- // Shows the result
+ // Evaluate clustering by computing Within Set Sum of Squared Errors.
+ double WSSSE = model.computeCost(dataset);
+ System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
+
+ // Shows the result.
Vector[] centers = model.clusterCenters();
System.out.println("Cluster Centers: ");
for (Vector center: centers) {
diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py
index 7382396955..4b8b7291f9 100644
--- a/examples/src/main/python/ml/kmeans_example.py
+++ b/examples/src/main/python/ml/kmeans_example.py
@@ -17,55 +17,45 @@
from __future__ import print_function
-import sys
+# $example on$
+from pyspark.ml.clustering import KMeans
+# $example off$
-import numpy as np
-from pyspark.ml.clustering import KMeans, KMeansModel
-from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
from pyspark.sql import SparkSession
-from pyspark.sql.types import Row, StructField, StructType
"""
-A simple example demonstrating a k-means clustering.
+An example demonstrating k-means clustering.
Run with:
- bin/spark-submit examples/src/main/python/ml/kmeans_example.py <input> <k>
+ bin/spark-submit examples/src/main/python/ml/kmeans_example.py
This example requires NumPy (http://www.numpy.org/).
"""
-def parseVector(row):
- array = np.array([float(x) for x in row.value.split(' ')])
- return _convert_to_vector(array)
-
-
if __name__ == "__main__":
- FEATURES_COL = "features"
-
- if len(sys.argv) != 3:
- print("Usage: kmeans_example.py <file> <k>", file=sys.stderr)
- exit(-1)
- path = sys.argv[1]
- k = sys.argv[2]
-
spark = SparkSession\
.builder\
.appName("PythonKMeansExample")\
.getOrCreate()
- lines = spark.read.text(path).rdd
- data = lines.map(parseVector)
- row_rdd = data.map(lambda x: Row(x))
- schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
- df = spark.createDataFrame(row_rdd, schema)
+ # $example on$
+ # Loads data.
+ dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
- kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
- model = kmeans.fit(df)
- centers = model.clusterCenters()
+ # Trains a k-means model.
+ kmeans = KMeans().setK(2).setSeed(1)
+ model = kmeans.fit(dataset)
+
+ # Evaluate clustering by computing Within Set Sum of Squared Errors.
+ wssse = model.computeCost(dataset)
+ print("Within Set Sum of Squared Errors = " + str(wssse))
+ # Shows the result.
+ centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
print(center)
+ # $example off$
spark.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
index 2abd588c6f..2341b36db2 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/KMeansExample.scala
@@ -21,12 +21,11 @@ package org.apache.spark.examples.ml
// $example on$
import org.apache.spark.ml.clustering.KMeans
-import org.apache.spark.mllib.linalg.Vectors
-import org.apache.spark.sql.{DataFrame, SparkSession}
// $example off$
+import org.apache.spark.sql.SparkSession
/**
- * An example demonstrating a k-means clustering.
+ * An example demonstrating k-means clustering.
* Run with
* {{{
* bin/run-example ml.KMeansExample
@@ -35,32 +34,26 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
object KMeansExample {
def main(args: Array[String]): Unit = {
- // Creates a Spark context and a SQL context
+ // Creates a SparkSession.
val spark = SparkSession
.builder
.appName(s"${this.getClass.getSimpleName}")
.getOrCreate()
// $example on$
- // Crates a DataFrame
- val dataset: DataFrame = spark.createDataFrame(Seq(
- (1, Vectors.dense(0.0, 0.0, 0.0)),
- (2, Vectors.dense(0.1, 0.1, 0.1)),
- (3, Vectors.dense(0.2, 0.2, 0.2)),
- (4, Vectors.dense(9.0, 9.0, 9.0)),
- (5, Vectors.dense(9.1, 9.1, 9.1)),
- (6, Vectors.dense(9.2, 9.2, 9.2))
- )).toDF("id", "features")
+ // Loads data.
+ val dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
- // Trains a k-means model
- val kmeans = new KMeans()
- .setK(2)
- .setFeaturesCol("features")
- .setPredictionCol("prediction")
+ // Trains a k-means model.
+ val kmeans = new KMeans().setK(2).setSeed(1L)
val model = kmeans.fit(dataset)
- // Shows the result
- println("Final Centers: ")
+ // Evaluate clustering by computing Within Set Sum of Squared Errors.
+ val WSSSE = model.computeCost(dataset)
+ println(s"Within Set Sum of Squared Errors = $WSSSE")
+
+ // Shows the result.
+ println("Cluster Centers: ")
model.clusterCenters.foreach(println)
// $example off$