aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorAlexander Ulanov <nashb@yandex.ru>2016-03-31 23:48:36 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-31 23:48:36 -0700
commit26867ebc67edab97376c5d8fee76df294359e461 (patch)
tree7af6e342de372beb025537f45013f5153776eeb5 /mllib/src/test
parent1b070637fa03ab4966f76427b15e433050eaa956 (diff)
downloadspark-26867ebc67edab97376c5d8fee76df294359e461.tar.gz
spark-26867ebc67edab97376c5d8fee76df294359e461.tar.bz2
spark-26867ebc67edab97376c5d8fee76df294359e461.zip
[SPARK-11262][ML] Unit test for gradient, loss layers, memory management for multilayer perceptron
1.Implement LossFunction trait and implement squared error and cross entropy loss with it 2.Implement unit test for gradient and loss 3.Implement InPlace trait and in-place layer evaluation 4.Refactor interface for ActivationFunction 5.Update of Layer and LayerModel interfaces 6.Fix random weights assignment 7.Implement memory allocation by MLP model instead of individual layers These features decreased the memory usage and increased flexibility of internal API. Author: Alexander Ulanov <nashb@yandex.ru> Author: avulanov <avulanov@gmail.com> Closes #9229 from avulanov/mlp-refactoring.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala76
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala26
4 files changed, 106 insertions, 7 deletions
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index d499d363f1..bc955f3cf6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -63,7 +63,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
.setLayers(new int[] {2, 5, 2})
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
index 1292e57d7c..dc91fc5f9e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
@@ -42,7 +42,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 1)
trainer.setWeights(initialWeights)
trainer.LBFGSOptimizer.setNumIterations(20)
@@ -76,10 +76,11 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 2)
- trainer.SGDOptimizer.setNumIterations(2000)
- trainer.setWeights(initialWeights)
+ // TODO: add a test for SGD
+ trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20)
+ trainer.setWeights(initialWeights).setStackSize(1)
val model = trainer.train(rddData)
val predictionAndLabels = rddData.map { case (input, label) =>
(model.predict(input), label)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
new file mode 100644
index 0000000000..04cc426c40
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class GradientSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("Gradient computation against numerical differentiation") {
+ val input = new BDM[Double](3, 1, Array(1.0, 1.0, 1.0))
+ // output must contain zeros and one 1 for SoftMax
+ val target = new BDM[Double](2, 1, Array(0.0, 1.0))
+ val topology = FeedForwardTopology.multiLayerPerceptron(Array(3, 4, 2), softmaxOnTop = false)
+ val layersWithErrors = Seq(
+ new SigmoidLayerWithSquaredError(),
+ new SoftmaxLayerWithCrossEntropyLoss()
+ )
+ // check all layers that provide loss computation
+ // 1) compute loss and gradient given the model and initial weights
+ // 2) modify weights with small number epsilon (per dimension i)
+ // 3) compute new loss
+ // 4) ((newLoss - loss) / epsilon) should be close to the i-th component of the gradient
+ for (layerWithError <- layersWithErrors) {
+ topology.layers(topology.layers.length - 1) = layerWithError
+ val model = topology.model(seed = 12L)
+ val weights = model.weights.toArray
+ val numWeights = weights.size
+ val gradient = Vectors.dense(Array.fill[Double](numWeights)(0.0))
+ val loss = model.computeGradient(input, target, gradient, 1)
+ val eps = 1e-4
+ var i = 0
+ val tol = 1e-4
+ while (i < numWeights) {
+ val originalValue = weights(i)
+ weights(i) += eps
+ val newModel = topology.model(Vectors.dense(weights))
+ val newLoss = computeLoss(input, target, newModel)
+ val derivativeEstimate = (newLoss - loss) / eps
+ assert(math.abs(gradient(i) - derivativeEstimate) < tol, "Layer failed gradient check: " +
+ layerWithError.getClass)
+ weights(i) = originalValue
+ i += 1
+ }
+ }
+ }
+
+ private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = {
+ val outputs = model.forward(input)
+ model.layerModels.last match {
+ case layerWithLoss: LossFunction =>
+ layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols))
+ case _ =>
+ throw new UnsupportedOperationException("Top layer is required to have loss." +
+ " Failed layer:" + model.layerModels.last.getClass)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 53c7a559e3..43781385db 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -65,7 +65,7 @@ class MultilayerPerceptronClassifierSuite
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100)
val model = trainer.fit(dataset)
val result = model.transform(dataset)
@@ -75,7 +75,29 @@ class MultilayerPerceptronClassifierSuite
}
}
- // TODO: implement a more rigorous test
+ test("Test setWeights by training restart") {
+ val dataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(0.0, 0.0), 0.0),
+ (Vectors.dense(0.0, 1.0), 1.0),
+ (Vectors.dense(1.0, 0.0), 1.0),
+ (Vectors.dense(1.0, 1.0), 0.0))
+ ).toDF("features", "label")
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(12L)
+ .setMaxIter(1)
+ .setTol(1e-6)
+ val initialWeights = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights1 = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights2 = trainer.fit(dataFrame).weights
+ assert(weights1 ~== weights2 absTol 10e-5,
+ "Training should produce the same weights given equal initial weights and number of steps")
+ }
+
test("3 class classification with 2 hidden layers") {
val nPoints = 1000