aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala59
1 files changed, 55 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 5df8e6a847..80547fad6a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -19,18 +19,19 @@ package org.apache.spark.ml.classification
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
import org.apache.spark.mllib.evaluation.MulticlassMetrics
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
class MultilayerPerceptronClassifierSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -43,12 +44,29 @@ class MultilayerPerceptronClassifierSuite
).toDF("features", "label")
}
+ test("Input Validation") {
+ val mlpc = new MultilayerPerceptronClassifier()
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int]())
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](1))
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](0, 1))
+ }
+ intercept[IllegalArgumentException] {
+ mlpc.setLayers(Array[Int](1, 0))
+ }
+ mlpc.setLayers(Array[Int](1, 1))
+ }
+
test("XOR function learning as binary classification problem with two outputs.") {
val layers = Array[Int](2, 5, 2)
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100)
val model = trainer.fit(dataset)
val result = model.transform(dataset)
@@ -58,7 +76,29 @@ class MultilayerPerceptronClassifierSuite
}
}
- // TODO: implement a more rigorous test
+ test("Test setWeights by training restart") {
+ val dataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(0.0, 0.0), 0.0),
+ (Vectors.dense(0.0, 1.0), 1.0),
+ (Vectors.dense(1.0, 0.0), 1.0),
+ (Vectors.dense(1.0, 1.0), 0.0))
+ ).toDF("features", "label")
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(12L)
+ .setMaxIter(1)
+ .setTol(1e-6)
+ val initialWeights = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights1 = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights2 = trainer.fit(dataFrame).weights
+ assert(weights1 ~== weights2 absTol 10e-5,
+ "Training should produce the same weights given equal initial weights and number of steps")
+ }
+
test("3 class classification with 2 hidden layers") {
val nPoints = 1000
@@ -123,4 +163,15 @@ class MultilayerPerceptronClassifierSuite
assert(newMlpModel.layers === mlpModel.layers)
assert(newMlpModel.weights === mlpModel.weights)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val layers = Array(3, 2)
+ val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[
+ MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier](
+ mpc, isClassification = true, sqlContext) { (expected, actual) =>
+ assert(expected.layers === actual.layers)
+ assert(expected.weights === actual.weights)
+ }
+ }
}