aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java/org
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-05-11 10:01:43 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-05-11 10:01:43 +0200
commit8beae59144827d81491eed385dc2aa6aedd6a7b4 (patch)
tree1905c4caa10c9f432262272e120a948772a2846f /examples/src/main/java/org
parentcef73b563864d5f8aa1b26e31e3b9af6f0a08a5d (diff)
downloadspark-8beae59144827d81491eed385dc2aa6aedd6a7b4.tar.gz
spark-8beae59144827d81491eed385dc2aa6aedd6a7b4.tar.bz2
spark-8beae59144827d81491eed385dc2aa6aedd6a7b4.zip
[SPARK-15149][EXAMPLE][DOC] update kmeans example
## What changes were proposed in this pull request? Python example for ml.kmeans already exists, but not included in user guide. 1,small changes like: `example_on` `example_off` 2,add it to user guide 3,update examples to directly read datafile ## How was this patch tested? manual tests `./bin/spark-submit examples/src/main/python/ml/kmeans_example.py Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12925 from zhengruifeng/km_pe.
Diffstat (limited to 'examples/src/main/java/org')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java60
1 files changed, 14 insertions, 46 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
index 65e29ade29..2489a9b80b 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaKMeansExample.java
@@ -17,77 +17,45 @@
package org.apache.spark.examples.ml;
-import java.util.regex.Pattern;
-
-import org.apache.spark.api.java.JavaRDD;
-import org.apache.spark.api.java.function.Function;
-import org.apache.spark.sql.Dataset;
-import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.catalyst.expressions.GenericRow;
// $example on$
import org.apache.spark.ml.clustering.KMeansModel;
import org.apache.spark.ml.clustering.KMeans;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.types.Metadata;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
// $example off$
+import org.apache.spark.sql.SparkSession;
/**
- * An example demonstrating a k-means clustering.
+ * An example demonstrating k-means clustering.
* Run with
* <pre>
- * bin/run-example ml.JavaKMeansExample <file> <k>
+ * bin/run-example ml.JavaKMeansExample
* </pre>
*/
public class JavaKMeansExample {
- private static class ParsePoint implements Function<String, Row> {
- private static final Pattern separator = Pattern.compile(" ");
-
- @Override
- public Row call(String line) {
- String[] tok = separator.split(line);
- double[] point = new double[tok.length];
- for (int i = 0; i < tok.length; ++i) {
- point[i] = Double.parseDouble(tok[i]);
- }
- Vector[] points = {Vectors.dense(point)};
- return new GenericRow(points);
- }
- }
-
public static void main(String[] args) {
- if (args.length != 2) {
- System.err.println("Usage: ml.JavaKMeansExample <file> <k>");
- System.exit(1);
- }
- String inputFile = args[0];
- int k = Integer.parseInt(args[1]);
-
- // Parses the arguments
+ // Create a SparkSession.
SparkSession spark = SparkSession
.builder()
.appName("JavaKMeansExample")
.getOrCreate();
// $example on$
- // Loads data
- JavaRDD<Row> points = spark.read().text(inputFile).javaRDD().map(new ParsePoint());
- StructField[] fields = {new StructField("features", new VectorUDT(), false, Metadata.empty())};
- StructType schema = new StructType(fields);
- Dataset<Row> dataset = spark.createDataFrame(points, schema);
+ // Loads data.
+ Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
- // Trains a k-means model
- KMeans kmeans = new KMeans()
- .setK(k);
+ // Trains a k-means model.
+ KMeans kmeans = new KMeans().setK(2).setSeed(1L);
KMeansModel model = kmeans.fit(dataset);
- // Shows the result
+ // Evaluate clustering by computing Within Set Sum of Squared Errors.
+ double WSSSE = model.computeCost(dataset);
+ System.out.println("Within Set Sum of Squared Errors = " + WSSSE);
+
+ // Shows the result.
Vector[] centers = model.clusterCenters();
System.out.println("Cluster Centers: ");
for (Vector center: centers) {