diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala | 59 |
1 files changed, 55 insertions, 4 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala index 5df8e6a847..80547fad6a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala @@ -19,18 +19,19 @@ package org.apache.spark.ml.classification import org.apache.spark.SparkFunSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.MLTestingUtils import org.apache.spark.mllib.classification.LogisticRegressionSuite._ import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS import org.apache.spark.mllib.evaluation.MulticlassMetrics import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class MultilayerPerceptronClassifierSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -43,12 +44,29 @@ class MultilayerPerceptronClassifierSuite ).toDF("features", "label") } + test("Input Validation") { + val mlpc = new MultilayerPerceptronClassifier() + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int]()) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](0, 1)) + } + intercept[IllegalArgumentException] { + mlpc.setLayers(Array[Int](1, 0)) + } + mlpc.setLayers(Array[Int](1, 1)) + } + test("XOR function learning as binary classification problem with two outputs.") { val layers = Array[Int](2, 5, 2) val trainer = new MultilayerPerceptronClassifier() .setLayers(layers) .setBlockSize(1) - .setSeed(11L) + .setSeed(123L) .setMaxIter(100) val model = trainer.fit(dataset) val result = model.transform(dataset) @@ -58,7 +76,29 @@ class MultilayerPerceptronClassifierSuite } } - // TODO: implement a more rigorous test + test("Test setWeights by training restart") { + val dataFrame = sqlContext.createDataFrame(Seq( + (Vectors.dense(0.0, 0.0), 0.0), + (Vectors.dense(0.0, 1.0), 1.0), + (Vectors.dense(1.0, 0.0), 1.0), + (Vectors.dense(1.0, 1.0), 0.0)) + ).toDF("features", "label") + val layers = Array[Int](2, 5, 2) + val trainer = new MultilayerPerceptronClassifier() + .setLayers(layers) + .setBlockSize(1) + .setSeed(12L) + .setMaxIter(1) + .setTol(1e-6) + val initialWeights = trainer.fit(dataFrame).weights + trainer.setWeights(initialWeights.copy) + val weights1 = trainer.fit(dataFrame).weights + trainer.setWeights(initialWeights.copy) + val weights2 = trainer.fit(dataFrame).weights + assert(weights1 ~== weights2 absTol 10e-5, + "Training should produce the same weights given equal initial weights and number of steps") + } + test("3 class classification with 2 hidden layers") { val nPoints = 1000 @@ -123,4 +163,15 @@ class MultilayerPerceptronClassifierSuite assert(newMlpModel.layers === mlpModel.layers) assert(newMlpModel.weights === mlpModel.weights) } + + test("should support all NumericType labels and not support other types") { + val layers = Array(3, 2) + val mpc = new MultilayerPerceptronClassifier().setLayers(layers).setMaxIter(1) + MLTestingUtils.checkNumericTypes[ + MultilayerPerceptronClassificationModel, MultilayerPerceptronClassifier]( + mpc, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.layers === actual.layers) + assert(expected.weights === actual.weights) + } + } } |