diff options
author | Dongjoon Hyun <dongjoon@apache.org> | 2016-03-31 09:39:15 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-31 09:39:15 -0700 |
commit | 208fff3ac87f200fd4e6f0407d70bf81cf8c556f (patch) | |
tree | e67c4e7e77cdab6bac2311b34b0edd36aed66378 /mllib/src/test/scala/org/apache | |
parent | a9b93e07391faede77dde4c0b3c21c9b3f97f8eb (diff) | |
download | spark-208fff3ac87f200fd4e6f0407d70bf81cf8c556f.tar.gz spark-208fff3ac87f200fd4e6f0407d70bf81cf8c556f.tar.bz2 spark-208fff3ac87f200fd4e6f0407d70bf81cf8c556f.zip |
[SPARK-14164][MLLIB] Improve input layer validation of MultilayerPerceptronClassifier
## What changes were proposed in this pull request?
This issue improves an input layer validation and adds related testcases to MultilayerPerceptronClassifier.
```scala
- // TODO: how to check ALSO that all elements are greater than 0?
- ParamValidators.arrayLengthGt(1)
+ (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
```
## How was this patch tested?
Pass the Jenkins tests including the new testcases.
Author: Dongjoon Hyun <dongjoon@apache.org>
Closes #11964 from dongjoon-hyun/SPARK-14164.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala | 17 |
1 files changed, 17 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 5df8e6a847..53c7a559e3 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -43,6 +43,23 @@ class MultilayerPerceptronClassifierSuite ).toDF("features", "label") } + test("Input Validation") { + val mlpc = new MultilayerPerceptronClassifier() + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int]()) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](0, 1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1, 0)) + } + mlpc.setLayers(Array[Int](1, 1)) + } + test("XOR function learning as binary classification problem with two outputs.") { val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() |