aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala17
2 files changed, 18 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index f6de5f2df4..7ce3ec68da 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -43,8 +43,7 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
"Sizes of layers from input layer to output layer" +
" E.g., Array(780, 100, 10) means 780 inputs, " +
"one hidden layer with 100 neurons and output layer of 10 neurons.",
- // TODO: how to check ALSO that all elements are greater than 0?
- ParamValidators.arrayLengthGt(1)
+ (t: Array[Int]) => t.forall(ParamValidators.gt(0)) && t.length > 1
)
/** @group getParam */
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 5df8e6a847..53c7a559e3 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -43,6 +43,23 @@ class MultilayerPerceptronClassifierSuite
).toDF("features", "label")
}
+ test("Input Validation") {
+ val mlpc = new MultilayerPerceptronClassifier()
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int]())
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](1))
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](0, 1))
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](1, 0))
+ }
+ mlpc.setLayers(Array[Int](1, 1))
+ }
+
test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()