aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorZheng RuiFeng <ruifengz@foxmail.com>2016-04-11 09:33:52 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-11 09:33:52 -0700
commit643b4e2257c56338b192f8554e2fe5523bea4bdf (patch)
tree1dc64ce331757cebb1f0e1eabd5d11892533a0d6 /mllib
parent1c751fcf488189e5176546fe0d00f560ffcf1cec (diff)
downloadspark-643b4e2257c56338b192f8554e2fe5523bea4bdf.tar.gz
spark-643b4e2257c56338b192f8554e2fe5523bea4bdf.tar.bz2
spark-643b4e2257c56338b192f8554e2fe5523bea4bdf.zip
[SPARK-14510][MLLIB] Add args-checking for LDA and StreamingKMeans
## What changes were proposed in this pull request? add the checking for LDA and StreamingKMeans ## How was this patch tested? manual tests Author: Zheng RuiFeng <ruifengz@foxmail.com> Closes #12062 from zhengruifeng/initmodel.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala10
2 files changed, 17 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index 12813fd412..d999b9be8e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -130,7 +130,8 @@ class LDA private (
*/
@Since("1.5.0")
def setDocConcentration(docConcentration: Vector): this.type = {
- require(docConcentration.size > 0, "docConcentration must have > 0 elements")
+ require(docConcentration.size == 1 || docConcentration.size == k,
+ s"Size of docConcentration must be 1 or ${k} but got ${docConcentration.size}")
this.docConcentration = docConcentration
this
}
@@ -260,15 +261,18 @@ class LDA private (
def getCheckpointInterval: Int = checkpointInterval
/**
- * Period (in iterations) between checkpoints (default = 10). Checkpointing helps with recovery
+ * Parameter for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that
+ * the cache will get checkpointed every 10 iterations. Checkpointing helps with recovery
* (when nodes fail). It also helps with eliminating temporary shuffle files on disk, which can be
* important when LDA is run for many iterations. If the checkpoint directory is not set in
- * [[org.apache.spark.SparkContext]], this setting is ignored.
+ * [[org.apache.spark.SparkContext]], this setting is ignored. (default = 10)
*
* @see [[org.apache.spark.SparkContext#setCheckpointDir]]
*/
@Since("1.3.0")
def setCheckpointInterval(checkpointInterval: Int): this.type = {
+ require(checkpointInterval == -1 || checkpointInterval > 0,
+ s"Period between checkpoints must be -1 or positive but got ${checkpointInterval}")
this.checkpointInterval = checkpointInterval
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
index 4eb8fc049e..24e1cff0dc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/StreamingKMeans.scala
@@ -218,6 +218,12 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setInitialCenters(centers: Array[Vector], weights: Array[Double]): this.type = {
+ require(centers.size == weights.size,
+ "Number of initial centers must be equal to number of weights")
+ require(centers.size == k,
+ s"Number of initial centers must be ${k} but got ${centers.size}")
+ require(weights.forall(_ >= 0),
+ s"Weight for each inital center must be nonnegative but got [${weights.mkString(" ")}]")
model = new StreamingKMeansModel(centers, weights)
this
}
@@ -231,6 +237,10 @@ class StreamingKMeans @Since("1.2.0") (
*/
@Since("1.2.0")
def setRandomCenters(dim: Int, weight: Double, seed: Long = Utils.random.nextLong): this.type = {
+ require(dim > 0,
+ s"Number of dimensions must be positive but got ${dim}")
+ require(weight >= 0,
+ s"Weight for each center must be nonnegative but got ${weight}")
val random = new XORShiftRandom(seed)
val centers = Array.fill(k)(Vectors.dense(Array.fill(dim)(random.nextGaussian())))
val weights = Array.fill(k)(weight)