aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main/python/ml/kmeans_example.py
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main/python/ml/kmeans_example.py')
-rw-r--r--examples/src/main/python/ml/kmeans_example.py46
1 files changed, 18 insertions, 28 deletions
diff --git a/examples/src/main/python/ml/kmeans_example.py b/examples/src/main/python/ml/kmeans_example.py
index 7382396955..4b8b7291f9 100644
--- a/examples/src/main/python/ml/kmeans_example.py
+++ b/examples/src/main/python/ml/kmeans_example.py
@@ -17,55 +17,45 @@
from __future__ import print_function
-import sys
+# $example on$
+from pyspark.ml.clustering import KMeans
+# $example off$
-import numpy as np
-from pyspark.ml.clustering import KMeans, KMeansModel
-from pyspark.mllib.linalg import VectorUDT, _convert_to_vector
from pyspark.sql import SparkSession
-from pyspark.sql.types import Row, StructField, StructType
"""
-A simple example demonstrating a k-means clustering.
+An example demonstrating k-means clustering.
Run with:
- bin/spark-submit examples/src/main/python/ml/kmeans_example.py <input> <k>
+ bin/spark-submit examples/src/main/python/ml/kmeans_example.py
This example requires NumPy (http://www.numpy.org/).
"""
-def parseVector(row):
- array = np.array([float(x) for x in row.value.split(' ')])
- return _convert_to_vector(array)
-
-
if __name__ == "__main__":
- FEATURES_COL = "features"
-
- if len(sys.argv) != 3:
- print("Usage: kmeans_example.py <file> <k>", file=sys.stderr)
- exit(-1)
- path = sys.argv[1]
- k = sys.argv[2]
-
spark = SparkSession\
.builder\
.appName("PythonKMeansExample")\
.getOrCreate()
- lines = spark.read.text(path).rdd
- data = lines.map(parseVector)
- row_rdd = data.map(lambda x: Row(x))
- schema = StructType([StructField(FEATURES_COL, VectorUDT(), False)])
- df = spark.createDataFrame(row_rdd, schema)
+ # $example on$
+ # Loads data.
+ dataset = spark.read.format("libsvm").load("data/mllib/sample_kmeans_data.txt")
- kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol(FEATURES_COL)
- model = kmeans.fit(df)
- centers = model.clusterCenters()
+ # Trains a k-means model.
+ kmeans = KMeans().setK(2).setSeed(1)
+ model = kmeans.fit(dataset)
+
+ # Evaluate clustering by computing Within Set Sum of Squared Errors.
+ wssse = model.computeCost(dataset)
+ print("Within Set Sum of Squared Errors = " + str(wssse))
+ # Shows the result.
+ centers = model.clusterCenters()
print("Cluster Centers: ")
for center in centers:
print(center)
+ # $example off$
spark.stop()