aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala10
1 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 4eb8fc049e..24e1cff0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ require(centers.size == weights.size,
+ "Number of initial centers must be equal to number of weights")
+ require(centers.size == k,
+ s"Number of initial centers must be ${k} but got ${centers.size}")
+ require(weights.forall(_ >= 0),
+ s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]")
model = new StreamingKMeansModel(centers, weights)
this
}
@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ require(dim > 0,
+ s"Number of dimensions must be positive but got ${dim}")
+ require(weight >= 0,
+ s"Weight for each center must be nonnegative but got ${weight}")
val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight)