diff options
author | Zheng RuiFeng <ruifengz@foxmail.com> | 2016-05-11 10:01:43 +0200 |
---|---|---|
committer | Nick Pentreath <nickp@za.ibm.com> | 2016-05-11 10:01:43 +0200 |
commit | 8beae59144827d81491eed385dc2aa6aedd6a7b4 (patch) | |
tree | 1905c4caa10c9f432262272e120a948772a2846f /examples/src/main/java/org | |
parent | cef73b563864d5f8aa1b26e31e3b9af6f0a08a5d (diff) | |
download | spark-8beae59144827d81491eed385dc2aa6aedd6a7b4.tar.gz spark-8beae59144827d81491eed385dc2aa6aedd6a7b4.tar.bz2 spark-8beae59144827d81491eed385dc2aa6aedd6a7b4.zip |
[SPARK-15149][EXAMPLE][DOC] update kmeans example
## What changes were proposed in this pull request?
Python example for ml.kmeans already exists, but not included in user guide.
1,small changes like: `example_on` `example_off`
2,add it to user guide
3,update examples to directly read datafile
## How was this patch tested?
manual tests
`./bin/spark-submit examples/src/main/python/ml/kmeans_example.py
Author: Zheng RuiFeng <ruifengz@foxmail.com>
Closes #12925 from zhengruifeng/km_pe.
Diffstat (limited to 'examples/src/main/java/org')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java | 60 |
1 files changed, 14 insertions, 46 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java index 65e29ade29..2489a9b80b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java @@ -17,77 +17,45 @@ package org.apache.spark.examples.ml; -import java.util.regex.Pattern; - -import org.apache.spark.api.java.JavaRDD; -import org.apache.spark.api.java.function.Function; -import org.apache.spark.sql.Dataset; -import org.apache.spark.sql.SparkSession; -import org.apache.spark.sql.catalyst.expressions.GenericRow; // $example on$ import org.apache.spark.ml.clustering.KMeansModel; import org.apache.spark.ml.clustering.KMeans; import org.apache.spark.mllib.linalg.Vector; -import org.apache.spark.mllib.linalg.VectorUDT; -import org.apache.spark.mllib.linalg.Vectors; +import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; -import org.apache.spark.sql.types.Metadata; -import org.apache.spark.sql.types.StructField; -import org.apache.spark.sql.types.StructType; // $example off$ +import org.apache.spark.sql.SparkSession; /** - * An example demonstrating a k-means clustering. + * An example demonstrating k-means clustering. * Run with * <pre> - * bin/run-example ml.JavaKMeansExample <file> <k> + * bin/run-example ml.JavaKMeansExample * </pre> */ public class JavaKMeansExample { - private static class ParsePoint implements Function<String, Row> { - private static final Pattern separator = Pattern.compile(" "); - - @Override - public Row call(String line) { - String[] tok = separator.split(line); - double[] point = new double[tok.length]; - for (int i = 0; i < tok.length; ++i) { - point[i] = Double.parseDouble(tok[i]); - } - Vector[] points = {Vectors.dense(point)}; - return new GenericRow(points); - } - } - public static void main(String[] args) { - if (args.length != 2) { - System.err.println("Usage: ml.JavaKMeansExample <file> <k>"); - System.exit(1); - } - String inputFile = args[0]; - int k = Integer.parseInt(args[1]); - - // Parses the arguments + // Create a SparkSession. SparkSession spark = SparkSession .builder() .appName("JavaKMeansExample") .getOrCreate(); // $example on$ - // Loads data - JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint()); - StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())}; - StructType schema = new StructType(fields); - Dataset<Row> dataset = spark.createDataFrame(points, schema); + // Loads data. + Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt"); - // Trains a k-means model - KMeans kmeans = new KMeans() - .setK(k); + // Trains a k-means model. + KMeans kmeans = new KMeans().setK(2).setSeed(1L); KMeansModel model = kmeans.fit(dataset); - // Shows the result + // Evaluate clustering by computing Within Set Sum of Squared Errors. + double WSSSE = model.computeCost(dataset); + System.out.println("Within Set Sum of Squared Errors = " + WSSSE); + + // Shows the result. Vector[] centers = model.clusterCenters(); System.out.println("Cluster Centers: "); for (Vector center: centers) { |