aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/java
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/java')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java49
1 files changed, 18 insertions, 31 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
index 810ad905c5..62871448e3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaBisectingKMeansExample.java
@@ -17,27 +17,22 @@
package org.apache.spark.examples.ml;
-import java.util.Arrays;
-import java.util.List;
-
-import org.apache.spark.sql.RowFactory;
-import org.apache.spark.sql.SparkSession;
// $example on$
import org.apache.spark.ml.clustering.BisectingKMeans;
import org.apache.spark.ml.clustering.BisectingKMeansModel;
import org.apache.spark.mllib.linalg.Vector;
-import org.apache.spark.mllib.linalg.VectorUDT;
-import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
-import org.apache.spark.sql.types.Metadata;
-import org.apache.spark.sql.types.StructField;
-import org.apache.spark.sql.types.StructType;
// $example off$
+import org.apache.spark.sql.SparkSession;
/**
- * An example demonstrating a bisecting k-means clustering.
+ * An example demonstrating bisecting k-means clustering.
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaBisectingKMeansExample
+ * </pre>
*/
public class JavaBisectingKMeansExample {
@@ -48,30 +43,22 @@ public class JavaBisectingKMeansExample {
.getOrCreate();
// $example on$
- List<Row> data = Arrays.asList(
- RowFactory.create(Vectors.dense(0.1, 0.1, 0.1)),
- RowFactory.create(Vectors.dense(0.3, 0.3, 0.25)),
- RowFactory.create(Vectors.dense(0.1, 0.1, -0.1)),
- RowFactory.create(Vectors.dense(20.3, 20.1, 19.9)),
- RowFactory.create(Vectors.dense(20.2, 20.1, 19.7)),
- RowFactory.create(Vectors.dense(18.9, 20.0, 19.7))
- );
-
- StructType schema = new StructType(new StructField[]{
- new StructField("features", new VectorUDT(), false, Metadata.empty()),
- });
-
- Dataset<Row> dataset = spark.createDataFrame(data, schema);
+ // Loads data.
+ Dataset<Row> dataset = spark.read().format("libsvm").load("data/mllib/sample_kmeans_data.txt");
- BisectingKMeans bkm = new BisectingKMeans().setK(2);
+ // Trains a bisecting k-means model.
+ BisectingKMeans bkm = new BisectingKMeans().setK(2).setSeed(1);
BisectingKMeansModel model = bkm.fit(dataset);
- System.out.println("Compute Cost: " + model.computeCost(dataset));
+ // Evaluate clustering.
+ double cost = model.computeCost(dataset);
+ System.out.println("Within Set Sum of Squared Errors = " + cost);
- Vector[] clusterCenters = model.clusterCenters();
- for (int i = 0; i < clusterCenters.length; i++) {
- Vector clusterCenter = clusterCenters[i];
- System.out.println("Cluster Center " + i + ": " + clusterCenter);
+ // Shows the result.
+ System.out.println("Cluster Centers: ");
+ Vector[] centers = model.clusterCenters();
+ for (Vector center : centers) {
+ System.out.println(center);
}
// $example off$