aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala4
-rw-r--r--python/pyspark/mllib/clustering.py17
2 files changed, 18 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 69ce7f5070..21e55938fa 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -336,7 +336,8 @@ private[python] class PythonMLLibAPI extends Serializable {
initializationMode: String,
seed: java.lang.Long,
initializationSteps: Int,
- epsilon: Double): KMeansModel = {
+ epsilon: Double,
+ initialModel: java.util.ArrayList[Vector]): KMeansModel = {
val kMeansAlg = new KMeans()
.setK(k)
.setMaxIterations(maxIterations)
@@ -346,6 +347,7 @@ private[python] class PythonMLLibAPI extends Serializable {
.setEpsilon(epsilon)
if (seed != null) kMeansAlg.setSeed(seed)
+ if (!initialModel.isEmpty()) kMeansAlg.setInitialModel(new KMeansModel(initialModel))
try {
kMeansAlg.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK))
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index 900ade248c..6964a45db2 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -90,6 +90,12 @@ class KMeansModel(Saveable, Loader):
... rmtree(path)
... except OSError:
... pass
+
+ >>> data = array([-383.1,-382.9, 28.7,31.2, 366.2,367.3]).reshape(3, 2)
+ >>> model = KMeans.train(sc.parallelize(data), 3, maxIterations=0,
+ ... initialModel = KMeansModel([(-1000.0,-1000.0),(5.0,5.0),(1000.0,1000.0)]))
+ >>> model.clusterCenters
+ [array([-1000., -1000.]), array([ 5., 5.]), array([ 1000., 1000.])]
"""
def __init__(self, centers):
@@ -144,10 +150,17 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||",
- seed=None, initializationSteps=5, epsilon=1e-4):
+ seed=None, initializationSteps=5, epsilon=1e-4, initialModel=None):
"""Train a k-means clustering model."""
+ clusterInitialModel = []
+ if initialModel is not None:
+ if not isinstance(initialModel, KMeansModel):
+ raise Exception("initialModel is of "+str(type(initialModel))+". It needs "
+ "to be of <type 'KMeansModel'>")
+ clusterInitialModel = [_convert_to_vector(c) for c in initialModel.clusterCenters]
model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
- runs, initializationMode, seed, initializationSteps, epsilon)
+ runs, initializationMode, seed, initializationSteps, epsilon,
+ clusterInitialModel)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])