aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-09-15 09:30:15 +0100
committerSean Owen <sowen@cloudera.com>2016-09-15 09:30:15 +0100
commitd15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677 (patch)
tree041c284a99f4388d830db10da409cace2fa844a6 /mllib
parent6a6adb1673775df63a62270879eac70f5f8d7d75 (diff)
downloadspark-d15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677.tar.gz
spark-d15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677.tar.bz2
spark-d15b4f90e64f7ec5cf14c7c57d2cb4234c3ce677.zip
[SPARK-17507][ML][MLLIB] check weight vector size in ANN
## What changes were proposed in this pull request? as the TODO described, check weight vector size and if wrong throw exception. ## How was this patch tested? existing tests. Author: WeichenXu <WeichenXu123@outlook.com> Closes #15060 from WeichenXu123/check_input_weight_size_of_ann.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala10
1 files changed, 4 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 88909a9fb9..e7e0dae0b5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -545,7 +545,9 @@ private[ann] object FeedForwardModel {
* @return model
*/
def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
- // TODO: check that weights size is equal to sum of layers sizes
+ val expectedWeightSize = topology.layers.map(_.weightSize).sum
+ require(weights.size == expectedWeightSize,
+ s"Expected weight vector of size ${expectedWeightSize} but got size ${weights.size}.")
new FeedForwardModel(weights, topology)
}
@@ -559,11 +561,7 @@ private[ann] object FeedForwardModel {
def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
val layers = topology.layers
val layerModels = new Array[LayerModel](layers.length)
- var totalSize = 0
- for (i <- 0 until topology.layers.length) {
- totalSize += topology.layers(i).weightSize
- }
- val weights = BDV.zeros[Double](totalSize)
+ val weights = BDV.zeros[Double](topology.layers.map(_.weightSize).sum)
var offset = 0
val random = new XORShiftRandom(seed)
for (i <- 0 until layers.length) {