aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/clustering.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/clustering.py')
-rw-r--r--python/pyspark/mllib/clustering.py207
1 files changed, 203 insertions, 4 deletions
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index b55583f822..c38229864d 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -21,16 +21,20 @@ import array as pyarray
if sys.version > '3':
xrange = range
-from numpy import array
+from math import exp, log
+
+from numpy import array, random, tile
-from pyspark import RDD
from pyspark import SparkContext
+from pyspark.rdd import RDD, ignore_unicode_prefix
from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _py2java, _java2py
-from pyspark.mllib.linalg import SparseVector, _convert_to_vector
+from pyspark.mllib.linalg import SparseVector, _convert_to_vector, DenseVector
from pyspark.mllib.stat.distribution import MultivariateGaussian
from pyspark.mllib.util import Saveable, Loader, inherit_doc
+from pyspark.streaming import DStream
-__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture']
+__all__ = ['KMeansModel', 'KMeans', 'GaussianMixtureModel', 'GaussianMixture',
+ 'StreamingKMeans', 'StreamingKMeansModel']
@inherit_doc
@@ -98,6 +102,9 @@ class KMeansModel(Saveable, Loader):
"""Find the cluster to which x belongs in this model."""
best = 0
best_distance = float("inf")
+ if isinstance(x, RDD):
+ return x.map(self.predict)
+
x = _convert_to_vector(x)
for i in xrange(len(self.centers)):
distance = x.squared_distance(self.centers[i])
@@ -264,6 +271,198 @@ class GaussianMixture(object):
return GaussianMixtureModel(weight, mvg_obj)
+class StreamingKMeansModel(KMeansModel):
+ """
+ .. note:: Experimental
+
+ Clustering model which can perform an online update of the centroids.
+
+ The update formula for each centroid is given by
+
+ * c_t+1 = ((c_t * n_t * a) + (x_t * m_t)) / (n_t + m_t)
+ * n_t+1 = n_t * a + m_t
+
+ where
+
+ * c_t: Centroid at the n_th iteration.
+ * n_t: Number of samples (or) weights associated with the centroid
+ at the n_th iteration.
+ * x_t: Centroid of the new data closest to c_t.
+ * m_t: Number of samples (or) weights of the new data closest to c_t
+ * c_t+1: New centroid.
+ * n_t+1: New number of weights.
+ * a: Decay Factor, which gives the forgetfulness.
+
+ Note that if a is set to 1, it is the weighted mean of the previous
+ and new data. If it set to zero, the old centroids are completely
+ forgotten.
+
+ :param clusterCenters: Initial cluster centers.
+ :param clusterWeights: List of weights assigned to each cluster.
+
+ >>> initCenters = [[0.0, 0.0], [1.0, 1.0]]
+ >>> initWeights = [1.0, 1.0]
+ >>> stkm = StreamingKMeansModel(initCenters, initWeights)
+ >>> data = sc.parallelize([[-0.1, -0.1], [0.1, 0.1],
+ ... [0.9, 0.9], [1.1, 1.1]])
+ >>> stkm = stkm.update(data, 1.0, u"batches")
+ >>> stkm.centers
+ array([[ 0., 0.],
+ [ 1., 1.]])
+ >>> stkm.predict([-0.1, -0.1])
+ 0
+ >>> stkm.predict([0.9, 0.9])
+ 1
+ >>> stkm.clusterWeights
+ [3.0, 3.0]
+ >>> decayFactor = 0.0
+ >>> data = sc.parallelize([DenseVector([1.5, 1.5]), DenseVector([0.2, 0.2])])
+ >>> stkm = stkm.update(data, 0.0, u"batches")
+ >>> stkm.centers
+ array([[ 0.2, 0.2],
+ [ 1.5, 1.5]])
+ >>> stkm.clusterWeights
+ [1.0, 1.0]
+ >>> stkm.predict([0.2, 0.2])
+ 0
+ >>> stkm.predict([1.5, 1.5])
+ 1
+ """
+ def __init__(self, clusterCenters, clusterWeights):
+ super(StreamingKMeansModel, self).__init__(centers=clusterCenters)
+ self._clusterWeights = list(clusterWeights)
+
+ @property
+ def clusterWeights(self):
+ """Return the cluster weights."""
+ return self._clusterWeights
+
+ @ignore_unicode_prefix
+ def update(self, data, decayFactor, timeUnit):
+ """Update the centroids, according to data
+
+ :param data: Should be a RDD that represents the new data.
+ :param decayFactor: forgetfulness of the previous centroids.
+ :param timeUnit: Can be "batches" or "points". If points, then the
+ decay factor is raised to the power of number of new
+ points and if batches, it is used as it is.
+ """
+ if not isinstance(data, RDD):
+ raise TypeError("Data should be of an RDD, got %s." % type(data))
+ data = data.map(_convert_to_vector)
+ decayFactor = float(decayFactor)
+ if timeUnit not in ["batches", "points"]:
+ raise ValueError(
+ "timeUnit should be 'batches' or 'points', got %s." % timeUnit)
+ vectorCenters = [_convert_to_vector(center) for center in self.centers]
+ updatedModel = callMLlibFunc(
+ "updateStreamingKMeansModel", vectorCenters, self._clusterWeights,
+ data, decayFactor, timeUnit)
+ self.centers = array(updatedModel[0])
+ self._clusterWeights = list(updatedModel[1])
+ return self
+
+
+class StreamingKMeans(object):
+ """
+ .. note:: Experimental
+
+ Provides methods to set k, decayFactor, timeUnit to configure the
+ KMeans algorithm for fitting and predicting on incoming dstreams.
+ More details on how the centroids are updated are provided under the
+ docs of StreamingKMeansModel.
+
+ :param k: int, number of clusters
+ :param decayFactor: float, forgetfulness of the previous centroids.
+ :param timeUnit: can be "batches" or "points". If points, then the
+ decayfactor is raised to the power of no. of new points.
+ """
+ def __init__(self, k=2, decayFactor=1.0, timeUnit="batches"):
+ self._k = k
+ self._decayFactor = decayFactor
+ if timeUnit not in ["batches", "points"]:
+ raise ValueError(
+ "timeUnit should be 'batches' or 'points', got %s." % timeUnit)
+ self._timeUnit = timeUnit
+ self._model = None
+
+ def latestModel(self):
+ """Return the latest model"""
+ return self._model
+
+ def _validate(self, dstream):
+ if self._model is None:
+ raise ValueError(
+ "Initial centers should be set either by setInitialCenters "
+ "or setRandomCenters.")
+ if not isinstance(dstream, DStream):
+ raise TypeError(
+ "Expected dstream to be of type DStream, "
+ "got type %s" % type(dstream))
+
+ def setK(self, k):
+ """Set number of clusters."""
+ self._k = k
+ return self
+
+ def setDecayFactor(self, decayFactor):
+ """Set decay factor."""
+ self._decayFactor = decayFactor
+ return self
+
+ def setHalfLife(self, halfLife, timeUnit):
+ """
+ Set number of batches after which the centroids of that
+ particular batch has half the weightage.
+ """
+ self._timeUnit = timeUnit
+ self._decayFactor = exp(log(0.5) / halfLife)
+ return self
+
+ def setInitialCenters(self, centers, weights):
+ """
+ Set initial centers. Should be set before calling trainOn.
+ """
+ self._model = StreamingKMeansModel(centers, weights)
+ return self
+
+ def setRandomCenters(self, dim, weight, seed):
+ """
+ Set the initial centres to be random samples from
+ a gaussian population with constant weights.
+ """
+ rng = random.RandomState(seed)
+ clusterCenters = rng.randn(self._k, dim)
+ clusterWeights = tile(weight, self._k)
+ self._model = StreamingKMeansModel(clusterCenters, clusterWeights)
+ return self
+
+ def trainOn(self, dstream):
+ """Train the model on the incoming dstream."""
+ self._validate(dstream)
+
+ def update(rdd):
+ self._model.update(rdd, self._decayFactor, self._timeUnit)
+
+ dstream.foreachRDD(update)
+
+ def predictOn(self, dstream):
+ """
+ Make predictions on a dstream.
+ Returns a transformed dstream object
+ """
+ self._validate(dstream)
+ return dstream.map(lambda x: self._model.predict(x))
+
+ def predictOnValues(self, dstream):
+ """
+ Make predictions on a keyed dstream.
+ Returns a transformed dstream object.
+ """
+ self._validate(dstream)
+ return dstream.mapValues(lambda x: self._model.predict(x))
+
+
def _test():
import doctest
globs = globals().copy()