aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAlexander Ulanov <nashb@yandex.ru>2016-03-31 23:48:36 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-31 23:48:36 -0700
commit26867ebc67edab97376c5d8fee76df294359e461 (patch)
tree7af6e342de372beb025537f45013f5153776eeb5 /mllib
parent1b070637fa03ab4966f76427b15e433050eaa956 (diff)
downloadspark-26867ebc67edab97376c5d8fee76df294359e461.tar.gz
spark-26867ebc67edab97376c5d8fee76df294359e461.tar.bz2
spark-26867ebc67edab97376c5d8fee76df294359e461.zip
[SPARK-11262][ML] Unit test for gradient, loss layers, memory management for multilayer perceptron
1.Implement LossFunction trait and implement squared error and cross entropy loss with it 2.Implement unit test for gradient and loss 3.Implement InPlace trait and in-place layer evaluation 4.Refactor interface for ActivationFunction 5.Update of Layer and LayerModel interfaces 6.Fix random weights assignment 7.Implement memory allocation by MLP model instead of individual layers These features decreased the memory usage and increased flexibility of internal API. Author: Alexander Ulanov <nashb@yandex.ru> Author: avulanov <avulanov@gmail.com> Closes #9229 from avulanov/mlp-refactoring.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala662
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala124
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala82
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala76
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala26
7 files changed, 595 insertions, 386 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 2cd94fa8f5..a5b84116e6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -17,9 +17,9 @@
package org.apache.spark.ml.ann
-import breeze.linalg.{*, axpy => Baxpy, sum => Bsum, DenseMatrix => BDM, DenseVector => BDV,
- Vector => BV}
-import breeze.numerics.{log => Blog, sigmoid => Bsigmoid}
+import java.util.Random
+
+import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.optimization._
@@ -32,20 +32,46 @@ import org.apache.spark.util.random.XORShiftRandom
*
*/
private[ann] trait Layer extends Serializable {
+
/**
- * Returns the instance of the layer based on weights provided
- * @param weights vector with layer weights
- * @param position position of weights in the vector
- * @return the layer model
+ * Number of weights that is used to allocate memory for the weights vector
+ */
+ val weightSize: Int
+
+ /**
+ * Returns the output size given the input size (not counting the stack size).
+ * Output size is used to allocate memory for the output.
+ *
+ * @param inputSize input size
+ * @return output size
*/
- def getInstance(weights: Vector, position: Int): LayerModel
+ def getOutputSize(inputSize: Int): Int
/**
+ * If true, the memory is not allocated for the output of this layer.
+ * The memory allocated to the previous layer is used to write the output of this layer.
+ * Developer can set this to true if computing delta of a previous layer
+ * does not involve its output, so the current layer can write there.
+ * This also mean that both layers have the same number of outputs.
+ */
+ val inPlace: Boolean
+
+ /**
+ * Returns the instance of the layer based on weights provided.
+ * Size of weights must be equal to weightSize
+ *
+ * @param initialWeights vector with layer weights
+ * @return the layer model
+ */
+ def createModel(initialWeights: BDV[Double]): LayerModel
+ /**
* Returns the instance of the layer with random generated weights
- * @param seed seed
+ *
+ * @param weights vector for weights initialization, must be equal to weightSize
+ * @param random random number generator
* @return the layer model
*/
- def getInstance(seed: Long): LayerModel
+ def initModel(weights: BDV[Double], random: Random): LayerModel
}
/**
@@ -54,92 +80,102 @@ private[ann] trait Layer extends Serializable {
* Can return weights in Vector format.
*/
private[ann] trait LayerModel extends Serializable {
- /**
- * number of weights
- */
- val size: Int
+ val weights: BDV[Double]
/**
* Evaluates the data (process the data through the layer)
+ * Output is allocated based on the size provided by the
+ * LayerModel implementation and the stack (batch) size
+ * Developer is responsible for checking the size of output
+ * when writing to it
+ *
* @param data data
- * @return processed data
+ * @param output output (modified in place)
*/
- def eval(data: BDM[Double]): BDM[Double]
+ def eval(data: BDM[Double], output: BDM[Double]): Unit
/**
* Computes the delta for back propagation
- * @param nextDelta delta of the next layer
- * @param input input data
- * @return delta
+ * Delta is allocated based on the size provided by the
+ * LayerModel implementation and the stack (batch) size
+ * Developer is responsible for checking the size of
+ * prevDelta when writing to it
+ *
+ * @param delta delta of this layer
+ * @param output output of this layer
+ * @param prevDelta the previous delta (modified in place)
*/
- def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double]
+ def computePrevDelta(delta: BDM[Double], output: BDM[Double], prevDelta: BDM[Double]): Unit
/**
* Computes the gradient
+ * cumGrad is a wrapper on the part of the weight vector
+ * size of cumGrad is based on weightSize provided by
+ * implementation of LayerModel
+ *
* @param delta delta for this layer
* @param input input data
- * @return gradient
+ * @param cumGrad cumulative gradient (modified in place)
*/
- def grad(delta: BDM[Double], input: BDM[Double]): Array[Double]
-
- /**
- * Returns weights for the layer in a single vector
- * @return layer weights
- */
- def weights(): Vector
+ def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit
}
/**
* Layer properties of affine transformations, that is y=A*x+b
+ *
* @param numIn number of inputs
* @param numOut number of outputs
*/
private[ann] class AffineLayer(val numIn: Int, val numOut: Int) extends Layer {
- override def getInstance(weights: Vector, position: Int): LayerModel = {
- AffineLayerModel(this, weights, position)
- }
+ override val weightSize = numIn * numOut + numOut
- override def getInstance(seed: Long = 11L): LayerModel = {
- AffineLayerModel(this, seed)
- }
+ override def getOutputSize(inputSize: Int): Int = numOut
+
+ override val inPlace = false
+
+ override def createModel(weights: BDV[Double]): LayerModel = new AffineLayerModel(weights, this)
+
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ AffineLayerModel(this, weights, random)
}
/**
- * Model of Affine layer y=A*x+b
- * @param w weights (matrix A)
- * @param b bias (vector b)
+ * Model of Affine layer
+ *
+ * @param weights weights
+ * @param layer layer properties
*/
-private[ann] class AffineLayerModel private(w: BDM[Double], b: BDV[Double]) extends LayerModel {
- val size = w.size + b.length
- val gwb = new Array[Double](size)
- private lazy val gw: BDM[Double] = new BDM[Double](w.rows, w.cols, gwb)
- private lazy val gb: BDV[Double] = new BDV[Double](gwb, w.size)
- private var z: BDM[Double] = null
- private var d: BDM[Double] = null
+private[ann] class AffineLayerModel private[ann] (
+ val weights: BDV[Double],
+ val layer: AffineLayer) extends LayerModel {
+ val w = new BDM[Double](layer.numOut, layer.numIn, weights.data, weights.offset)
+ val b =
+ new BDV[Double](weights.data, weights.offset + (layer.numOut * layer.numIn), 1, layer.numOut)
+
private var ones: BDV[Double] = null
- override def eval(data: BDM[Double]): BDM[Double] = {
- if (z == null || z.cols != data.cols) z = new BDM[Double](w.rows, data.cols)
- z(::, *) := b
- BreezeUtil.dgemm(1.0, w, data, 1.0, z)
- z
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ output(::, *) := b
+ BreezeUtil.dgemm(1.0, w, data, 1.0, output)
}
- override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
- if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](w.cols, nextDelta.cols)
- BreezeUtil.dgemm(1.0, w.t, nextDelta, 0.0, d)
- d
+ override def computePrevDelta(
+ delta: BDM[Double],
+ output: BDM[Double],
+ prevDelta: BDM[Double]): Unit = {
+ BreezeUtil.dgemm(1.0, w.t, delta, 0.0, prevDelta)
}
- override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = {
- BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 0.0, gw)
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {
+ // compute gradient of weights
+ val cumGradientOfWeights = new BDM[Double](w.rows, w.cols, cumGrad.data, cumGrad.offset)
+ BreezeUtil.dgemm(1.0 / input.cols, delta, input.t, 1.0, cumGradientOfWeights)
if (ones == null || ones.length != delta.cols) ones = BDV.ones[Double](delta.cols)
- BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 0.0, gb)
- gwb
+ // compute gradient of bias
+ val cumGradientOfBias = new BDV[Double](cumGrad.data, cumGrad.offset + w.size, 1, b.length)
+ BreezeUtil.dgemv(1.0 / input.cols, delta, ones, 1.0, cumGradientOfBias)
}
-
- override def weights(): Vector = AffineLayerModel.roll(w, b)
}
/**
@@ -149,73 +185,40 @@ private[ann] object AffineLayerModel {
/**
* Creates a model of Affine layer
+ *
* @param layer layer properties
- * @param weights vector with weights
- * @param position position of weights in the vector
- * @return model of Affine layer
- */
- def apply(layer: AffineLayer, weights: Vector, position: Int): AffineLayerModel = {
- val (w, b) = unroll(weights, position, layer.numIn, layer.numOut)
- new AffineLayerModel(w, b)
- }
-
- /**
- * Creates a model of Affine layer
- * @param layer layer properties
- * @param seed seed
+ * @param weights vector for weights initialization
+ * @param random random number generator
* @return model of Affine layer
*/
- def apply(layer: AffineLayer, seed: Long): AffineLayerModel = {
- val (w, b) = randomWeights(layer.numIn, layer.numOut, seed)
- new AffineLayerModel(w, b)
- }
-
- /**
- * Unrolls the weights from the vector
- * @param weights vector with weights
- * @param position position of weights for this layer
- * @param numIn number of layer inputs
- * @param numOut number of layer outputs
- * @return matrix A and vector b
- */
- def unroll(
- weights: Vector,
- position: Int,
- numIn: Int,
- numOut: Int): (BDM[Double], BDV[Double]) = {
- val weightsCopy = weights.toArray
- // TODO: the array is not copied to BDMs, make sure this is OK!
- val a = new BDM[Double](numOut, numIn, weightsCopy, position)
- val b = new BDV[Double](weightsCopy, position + (numOut * numIn), 1, numOut)
- (a, b)
- }
-
- /**
- * Roll the layer weights into a vector
- * @param a matrix A
- * @param b vector b
- * @return vector of weights
- */
- def roll(a: BDM[Double], b: BDV[Double]): Vector = {
- val result = new Array[Double](a.size + b.length)
- // TODO: make sure that we need to copy!
- System.arraycopy(a.toArray, 0, result, 0, a.size)
- System.arraycopy(b.toArray, 0, result, a.size, b.length)
- Vectors.dense(result)
+ def apply(layer: AffineLayer, weights: BDV[Double], random: Random): AffineLayerModel = {
+ randomWeights(layer.numIn, layer.numOut, weights, random)
+ new AffineLayerModel(weights, layer)
}
/**
- * Generate random weights for the layer
- * @param numIn number of inputs
+ * Initialize weights randomly in the interval
+ * Uses [Bottou-88] heuristic [-a/sqrt(in); a/sqrt(in)]
+ * where a is chosen in a such way that the weight variance corresponds
+ * to the points to the maximal curvature of the activation function
+ * (which is approximately 2.38 for a standard sigmoid)
+ *
+ * @param numIn number of inputs
* @param numOut number of outputs
- * @param seed seed
- * @return (matrix A, vector b)
+ * @param weights vector for weights initialization
+ * @param random random number generator
*/
- def randomWeights(numIn: Int, numOut: Int, seed: Long = 11L): (BDM[Double], BDV[Double]) = {
- val rand: XORShiftRandom = new XORShiftRandom(seed)
- val weights = BDM.fill[Double](numOut, numIn) { (rand.nextDouble * 4.8 - 2.4) / numIn }
- val bias = BDV.fill[Double](numOut) { (rand.nextDouble * 4.8 - 2.4) / numIn }
- (weights, bias)
+ def randomWeights(
+ numIn: Int,
+ numOut: Int,
+ weights: BDV[Double],
+ random: Random): Unit = {
+ var i = 0
+ val sqrtIn = math.sqrt(numIn)
+ while (i < weights.length) {
+ weights(i) = (random.nextDouble * 4.8 - 2.4) / sqrtIn
+ i += 1
+ }
}
}
@@ -226,44 +229,21 @@ private[ann] trait ActivationFunction extends Serializable {
/**
* Implements a function
- * @param x input data
- * @param y output data
*/
- def eval(x: BDM[Double], y: BDM[Double]): Unit
+ def eval: Double => Double
/**
* Implements a derivative of a function (needed for the back propagation)
- * @param x input data
- * @param y output data
*/
- def derivative(x: BDM[Double], y: BDM[Double]): Unit
-
- /**
- * Implements a cross entropy error of a function.
- * Needed if the functional layer that contains this function is the output layer
- * of the network.
- * @param target target output
- * @param output computed output
- * @param result intermediate result
- * @return cross-entropy
- */
- def crossEntropy(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
-
- /**
- * Implements a mean squared error of a function
- * @param target target output
- * @param output computed output
- * @param result intermediate result
- * @return mean squared error
- */
- def squared(target: BDM[Double], output: BDM[Double], result: BDM[Double]): Double
+ def derivative: Double => Double
}
/**
- * Implements in-place application of functions
+ * Implements in-place application of functions in the arrays
*/
-private[ann] object ActivationFunction {
+private[ann] object ApplyInPlace {
+ // TODO: use Breeze UFunc
def apply(x: BDM[Double], y: BDM[Double], func: Double => Double): Unit = {
var i = 0
while (i < x.rows) {
@@ -276,6 +256,7 @@ private[ann] object ActivationFunction {
}
}
+ // TODO: use Breeze UFunc
def apply(
x1: BDM[Double],
x2: BDM[Double],
@@ -294,179 +275,86 @@ private[ann] object ActivationFunction {
}
/**
- * Implements SoftMax activation function
- */
-private[ann] class SoftmaxFunction extends ActivationFunction {
- override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- var j = 0
- // find max value to make sure later that exponent is computable
- while (j < x.cols) {
- var i = 0
- var max = Double.MinValue
- while (i < x.rows) {
- if (x(i, j) > max) {
- max = x(i, j)
- }
- i += 1
- }
- var sum = 0.0
- i = 0
- while (i < x.rows) {
- val res = Math.exp(x(i, j) - max)
- y(i, j) = res
- sum += res
- i += 1
- }
- i = 0
- while (i < x.rows) {
- y(i, j) /= sum
- i += 1
- }
- j += 1
- }
- }
-
- override def crossEntropy(
- output: BDM[Double],
- target: BDM[Double],
- result: BDM[Double]): Double = {
- def m(o: Double, t: Double): Double = o - t
- ActivationFunction(output, target, result, m)
- -Bsum( target :* Blog(output)) / output.cols
- }
-
- override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
- def sd(z: Double): Double = (1 - z) * z
- ActivationFunction(x, y, sd)
- }
-
- override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
- throw new UnsupportedOperationException("Sorry, squared error is not defined for SoftMax.")
- }
-}
-
-/**
* Implements Sigmoid activation function
*/
private[ann] class SigmoidFunction extends ActivationFunction {
- override def eval(x: BDM[Double], y: BDM[Double]): Unit = {
- def s(z: Double): Double = Bsigmoid(z)
- ActivationFunction(x, y, s)
- }
-
- override def crossEntropy(
- output: BDM[Double],
- target: BDM[Double],
- result: BDM[Double]): Double = {
- def m(o: Double, t: Double): Double = o - t
- ActivationFunction(output, target, result, m)
- -Bsum(target :* Blog(output)) / output.cols
- }
- override def derivative(x: BDM[Double], y: BDM[Double]): Unit = {
- def sd(z: Double): Double = (1 - z) * z
- ActivationFunction(x, y, sd)
- }
+ override def eval: (Double) => Double = x => 1.0 / (1 + math.exp(-x))
- override def squared(output: BDM[Double], target: BDM[Double], result: BDM[Double]): Double = {
- // TODO: make it readable
- def m(o: Double, t: Double): Double = (o - t)
- ActivationFunction(output, target, result, m)
- val e = Bsum(result :* result) / 2 / output.cols
- def m2(x: Double, o: Double) = x * (o - o * o)
- ActivationFunction(result, output, result, m2)
- e
- }
+ override def derivative: (Double) => Double = z => (1 - z) * z
}
/**
* Functional layer properties, y = f(x)
+ *
* @param activationFunction activation function
*/
private[ann] class FunctionalLayer (val activationFunction: ActivationFunction) extends Layer {
- override def getInstance(weights: Vector, position: Int): LayerModel = getInstance(0L)
- override def getInstance(seed: Long): LayerModel =
- FunctionalLayerModel(this)
+ override val weightSize = 0
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+
+ override val inPlace = true
+
+ override def createModel(weights: BDV[Double]): LayerModel = new FunctionalLayerModel(this)
+
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ createModel(weights)
}
/**
* Functional layer model. Holds no weights.
- * @param activationFunction activation function
+ *
+ * @param layer functiona layer
*/
-private[ann] class FunctionalLayerModel private (val activationFunction: ActivationFunction)
+private[ann] class FunctionalLayerModel private[ann] (val layer: FunctionalLayer)
extends LayerModel {
- val size = 0
- // matrices for in-place computations
- // outputs
- private var f: BDM[Double] = null
- // delta
- private var d: BDM[Double] = null
- // matrix for error computation
- private var e: BDM[Double] = null
- // delta gradient
- private lazy val dg = new Array[Double](0)
- override def eval(data: BDM[Double]): BDM[Double] = {
- if (f == null || f.cols != data.cols) f = new BDM[Double](data.rows, data.cols)
- activationFunction.eval(data, f)
- f
- }
+ // empty weights
+ val weights = new BDV[Double](0)
- override def prevDelta(nextDelta: BDM[Double], input: BDM[Double]): BDM[Double] = {
- if (d == null || d.cols != nextDelta.cols) d = new BDM[Double](nextDelta.rows, nextDelta.cols)
- activationFunction.derivative(input, d)
- d :*= nextDelta
- d
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ ApplyInPlace(data, output, layer.activationFunction.eval)
}
- override def grad(delta: BDM[Double], input: BDM[Double]): Array[Double] = dg
-
- override def weights(): Vector = Vectors.dense(new Array[Double](0))
-
- def crossEntropy(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
- val error = activationFunction.crossEntropy(output, target, e)
- (e, error)
+ override def computePrevDelta(
+ nextDelta: BDM[Double],
+ input: BDM[Double],
+ delta: BDM[Double]): Unit = {
+ ApplyInPlace(input, delta, layer.activationFunction.derivative)
+ delta :*= nextDelta
}
- def squared(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- if (e == null || e.cols != output.cols) e = new BDM[Double](output.rows, output.cols)
- val error = activationFunction.squared(output, target, e)
- (e, error)
- }
-
- def error(output: BDM[Double], target: BDM[Double]): (BDM[Double], Double) = {
- // TODO: allow user pick error
- activationFunction match {
- case sigmoid: SigmoidFunction => squared(output, target)
- case softmax: SoftmaxFunction => crossEntropy(output, target)
- }
- }
-}
-
-/**
- * Fabric of functional layer models
- */
-private[ann] object FunctionalLayerModel {
- def apply(layer: FunctionalLayer): FunctionalLayerModel =
- new FunctionalLayerModel(layer.activationFunction)
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {}
}
/**
* Trait for the artificial neural network (ANN) topology properties
*/
-private[ann] trait Topology extends Serializable{
- def getInstance(weights: Vector): TopologyModel
- def getInstance(seed: Long): TopologyModel
+private[ann] trait Topology extends Serializable {
+ def model(weights: Vector): TopologyModel
+ def model(seed: Long): TopologyModel
}
/**
* Trait for ANN topology model
*/
-private[ann] trait TopologyModel extends Serializable{
+private[ann] trait TopologyModel extends Serializable {
+
+ val weights: Vector
+ /**
+ * Array of layers
+ */
+ val layers: Array[Layer]
+
+ /**
+ * Array of layer models
+ */
+ val layerModels: Array[LayerModel]
/**
* Forward propagation
+ *
* @param data input data
* @return array of outputs for each of the layers
*/
@@ -474,6 +362,7 @@ private[ann] trait TopologyModel extends Serializable{
/**
* Prediction of the model
+ *
* @param data input data
* @return prediction
*/
@@ -481,6 +370,7 @@ private[ann] trait TopologyModel extends Serializable{
/**
* Computes gradient for the network
+ *
* @param data input data
* @param target target output
* @param cumGradient cumulative gradient
@@ -489,22 +379,17 @@ private[ann] trait TopologyModel extends Serializable{
*/
def computeGradient(data: BDM[Double], target: BDM[Double], cumGradient: Vector,
blockSize: Int): Double
-
- /**
- * Returns the weights of the ANN
- * @return weights
- */
- def weights(): Vector
}
/**
* Feed forward ANN
+ *
* @param layers
*/
private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends Topology {
- override def getInstance(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
+ override def model(weights: Vector): TopologyModel = FeedForwardModel(this, weights)
- override def getInstance(seed: Long): TopologyModel = FeedForwardModel(this, seed)
+ override def model(seed: Long): TopologyModel = FeedForwardModel(this, seed)
}
/**
@@ -513,6 +398,7 @@ private[ann] class FeedForwardTopology private(val layers: Array[Layer]) extends
private[ml] object FeedForwardTopology {
/**
* Creates a feed forward topology from the array of layers
+ *
* @param layers array of layers
* @return feed forward topology
*/
@@ -522,18 +408,26 @@ private[ml] object FeedForwardTopology {
/**
* Creates a multi-layer perceptron
+ *
* @param layerSizes sizes of layers including input and output size
- * @param softmax whether to use SoftMax or Sigmoid function for an output layer.
+ * @param softmaxOnTop wether to use SoftMax or Sigmoid function for an output layer.
* Softmax is default
* @return multilayer perceptron topology
*/
- def multiLayerPerceptron(layerSizes: Array[Int], softmax: Boolean = true): FeedForwardTopology = {
+ def multiLayerPerceptron(
+ layerSizes: Array[Int],
+ softmaxOnTop: Boolean = true): FeedForwardTopology = {
val layers = new Array[Layer]((layerSizes.length - 1) * 2)
- for(i <- 0 until layerSizes.length - 1) {
+ for (i <- 0 until layerSizes.length - 1) {
layers(i * 2) = new AffineLayer(layerSizes(i), layerSizes(i + 1))
layers(i * 2 + 1) =
- if (softmax && i == layerSizes.length - 2) {
- new FunctionalLayer(new SoftmaxFunction())
+ if (i == layerSizes.length - 2) {
+ if (softmaxOnTop) {
+ new SoftmaxLayerWithCrossEntropyLoss()
+ } else {
+ // TODO: squared error is more natural but converges slower
+ new SigmoidLayerWithSquaredError()
+ }
} else {
new FunctionalLayer(new SigmoidFunction())
}
@@ -545,17 +439,45 @@ private[ml] object FeedForwardTopology {
/**
* Model of Feed Forward Neural Network.
* Implements forward, gradient computation and can return weights in vector format.
- * @param layerModels models of layers
- * @param topology topology of the network
+ *
+ * @param weights network weights
+ * @param topology network topology
*/
private[ml] class FeedForwardModel private(
- val layerModels: Array[LayerModel],
+ val weights: Vector,
val topology: FeedForwardTopology) extends TopologyModel {
+
+ val layers = topology.layers
+ val layerModels = new Array[LayerModel](layers.length)
+ private var offset = 0
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).createModel(
+ new BDV[Double](weights.toArray, offset, 1, layers(i).weightSize))
+ offset += layers(i).weightSize
+ }
+ private var outputs: Array[BDM[Double]] = null
+ private var deltas: Array[BDM[Double]] = null
+
override def forward(data: BDM[Double]): Array[BDM[Double]] = {
- val outputs = new Array[BDM[Double]](layerModels.length)
- outputs(0) = layerModels(0).eval(data)
+ // Initialize output arrays for all layers. Special treatment for InPlace
+ val currentBatchSize = data.cols
+ // TODO: allocate outputs as one big array and then create BDMs from it
+ if (outputs == null || outputs(0).cols != currentBatchSize) {
+ outputs = new Array[BDM[Double]](layers.length)
+ var inputSize = data.rows
+ for (i <- 0 until layers.length) {
+ if (layers(i).inPlace) {
+ outputs(i) = outputs(i - 1)
+ } else {
+ val outputSize = layers(i).getOutputSize(inputSize)
+ outputs(i) = new BDM[Double](outputSize, currentBatchSize)
+ inputSize = outputSize
+ }
+ }
+ }
+ layerModels(0).eval(data, outputs(0))
for (i <- 1 until layerModels.length) {
- outputs(i) = layerModels(i).eval(outputs(i-1))
+ layerModels(i).eval(outputs(i - 1), outputs(i))
}
outputs
}
@@ -566,54 +488,36 @@ private[ml] class FeedForwardModel private(
cumGradient: Vector,
realBatchSize: Int): Double = {
val outputs = forward(data)
- val deltas = new Array[BDM[Double]](layerModels.length)
+ val currentBatchSize = data.cols
+ // TODO: allocate deltas as one big array and then create BDMs from it
+ if (deltas == null || deltas(0).cols != currentBatchSize) {
+ deltas = new Array[BDM[Double]](layerModels.length)
+ var inputSize = data.rows
+ for (i <- 0 until layerModels.length - 1) {
+ val outputSize = layers(i).getOutputSize(inputSize)
+ deltas(i) = new BDM[Double](outputSize, currentBatchSize)
+ inputSize = outputSize
+ }
+ }
val L = layerModels.length - 1
- val (newE, newError) = layerModels.last match {
- case flm: FunctionalLayerModel => flm.error(outputs.last, target)
+ // TODO: explain why delta of top layer is null (because it might contain loss+layer)
+ val loss = layerModels.last match {
+ case levelWithError: LossFunction => levelWithError.loss(outputs.last, target, deltas(L - 1))
case _ =>
- throw new UnsupportedOperationException("Non-functional layer not supported at the top")
+ throw new UnsupportedOperationException("Top layer is required to have objective.")
}
- deltas(L) = new BDM[Double](0, 0)
- deltas(L - 1) = newE
for (i <- (L - 2) to (0, -1)) {
- deltas(i) = layerModels(i + 1).prevDelta(deltas(i + 1), outputs(i + 1))
- }
- val grads = new Array[Array[Double]](layerModels.length)
- for (i <- 0 until layerModels.length) {
- val input = if (i==0) data else outputs(i - 1)
- grads(i) = layerModels(i).grad(deltas(i), input)
+ layerModels(i + 1).computePrevDelta(deltas(i + 1), outputs(i + 1), deltas(i))
}
- // update cumGradient
val cumGradientArray = cumGradient.toArray
var offset = 0
- // TODO: extract roll
- for (i <- 0 until grads.length) {
- val gradArray = grads(i)
- var k = 0
- while (k < gradArray.length) {
- cumGradientArray(offset + k) += gradArray(k)
- k += 1
- }
- offset += gradArray.length
- }
- newError
- }
-
- // TODO: do we really need to copy the weights? they should be read-only
- override def weights(): Vector = {
- // TODO: extract roll
- var size = 0
- for (i <- 0 until layerModels.length) {
- size += layerModels(i).size
- }
- val array = new Array[Double](size)
- var offset = 0
for (i <- 0 until layerModels.length) {
- val layerWeights = layerModels(i).weights().toArray
- System.arraycopy(layerWeights, 0, array, offset, layerWeights.length)
- offset += layerWeights.length
+ val input = if (i == 0) data else outputs(i - 1)
+ layerModels(i).grad(deltas(i), input,
+ new BDV[Double](cumGradientArray, offset, 1, layers(i).weightSize))
+ offset += layers(i).weightSize
}
- Vectors.dense(array)
+ loss
}
override def predict(data: Vector): Vector = {
@@ -630,23 +534,19 @@ private[ann] object FeedForwardModel {
/**
* Creates a model from a topology and weights
+ *
* @param topology topology
* @param weights weights
* @return model
*/
def apply(topology: FeedForwardTopology, weights: Vector): FeedForwardModel = {
- val layers = topology.layers
- val layerModels = new Array[LayerModel](layers.length)
- var offset = 0
- for (i <- 0 until layers.length) {
- layerModels(i) = layers(i).getInstance(weights, offset)
- offset += layerModels(i).size
- }
- new FeedForwardModel(layerModels, topology)
+ // TODO: check that weights size is equal to sum of layers sizes
+ new FeedForwardModel(weights, topology)
}
/**
* Creates a model given a topology and seed
+ *
* @param topology topology
* @param seed seed for generating the weights
* @return model
@@ -654,17 +554,25 @@ private[ann] object FeedForwardModel {
def apply(topology: FeedForwardTopology, seed: Long = 11L): FeedForwardModel = {
val layers = topology.layers
val layerModels = new Array[LayerModel](layers.length)
+ var totalSize = 0
+ for (i <- 0 until topology.layers.length) {
+ totalSize += topology.layers(i).weightSize
+ }
+ val weights = BDV.zeros[Double](totalSize)
var offset = 0
- for(i <- 0 until layers.length) {
- layerModels(i) = layers(i).getInstance(seed)
- offset += layerModels(i).size
+ val random = new XORShiftRandom(seed)
+ for (i <- 0 until layers.length) {
+ layerModels(i) = layers(i).
+ initModel(new BDV[Double](weights.data, offset, 1, layers(i).weightSize), random)
+ offset += layers(i).weightSize
}
- new FeedForwardModel(layerModels, topology)
+ new FeedForwardModel(Vectors.fromBreeze(weights), topology)
}
}
/**
* Neural network gradient. Does nothing but calling Model's gradient
+ *
* @param topology topology
* @param dataStacker data stacker
*/
@@ -682,7 +590,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext
weights: Vector,
cumGradient: Vector): Double = {
val (input, target, realBatchSize) = dataStacker.unstack(data)
- val model = topology.getInstance(weights)
+ val model = topology.model(weights)
model.computeGradient(input, target, cumGradient, realBatchSize)
}
}
@@ -692,6 +600,7 @@ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) ext
* through Optimizer/Gradient interfaces. If stackSize is more than one, makes blocks
* or matrices of inputs and outputs and then stack them in one vector.
* This can be used for further batch computations after unstacking.
+ *
* @param stackSize stack size
* @param inputSize size of the input vectors
* @param outputSize size of the output vectors
@@ -701,6 +610,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
/**
* Stacks the data
+ *
* @param data RDD of vector pairs
* @return RDD of double (always zero) and vector that contains the stacked vectors
*/
@@ -733,6 +643,7 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
/**
* Unstack the stacked vectors into matrices for batch operations
+ *
* @param data stacked vector
* @return pair of matrices holding input and output data and the real stack size
*/
@@ -765,6 +676,7 @@ private[ann] class ANNUpdater extends Updater {
/**
* MLlib-style trainer class that trains a network given the data and topology
+ *
* @param topology topology of ANN
* @param inputSize input size
* @param outputSize output size
@@ -774,8 +686,8 @@ private[ml] class FeedForwardTrainer(
val inputSize: Int,
val outputSize: Int) extends Serializable {
- // TODO: what if we need to pass random seed?
- private var _weights = topology.getInstance(11L).weights()
+ private var _seed = this.getClass.getName.hashCode.toLong
+ private var _weights: Vector = null
private var _stackSize = 128
private var dataStacker = new DataStacker(_stackSize, inputSize, outputSize)
private var _gradient: Gradient = new ANNGradient(topology, dataStacker)
@@ -783,27 +695,41 @@ private[ml] class FeedForwardTrainer(
private var optimizer: Optimizer = LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(100)
/**
+ * Returns seed
+ */
+ def getSeed: Long = _seed
+
+ /**
+ * Sets seed
+ */
+ def setSeed(value: Long): this.type = {
+ _seed = value
+ this
+ }
+
+ /**
* Returns weights
- * @return weights
*/
def getWeights: Vector = _weights
/**
* Sets weights
+ *
* @param value weights
* @return trainer
*/
- def setWeights(value: Vector): FeedForwardTrainer = {
+ def setWeights(value: Vector): this.type = {
_weights = value
this
}
/**
* Sets the stack size
+ *
* @param value stack size
* @return trainer
*/
- def setStackSize(value: Int): FeedForwardTrainer = {
+ def setStackSize(value: Int): this.type = {
_stackSize = value
dataStacker = new DataStacker(value, inputSize, outputSize)
this
@@ -811,6 +737,7 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the SGD optimizer
+ *
* @return SGD optimizer
*/
def SGDOptimizer: GradientDescent = {
@@ -821,6 +748,7 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the LBFGS optimizer
+ *
* @return LBGS optimizer
*/
def LBFGSOptimizer: LBFGS = {
@@ -831,10 +759,11 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the updater
+ *
* @param value updater
* @return trainer
*/
- def setUpdater(value: Updater): FeedForwardTrainer = {
+ def setUpdater(value: Updater): this.type = {
_updater = value
updateUpdater(value)
this
@@ -842,10 +771,11 @@ private[ml] class FeedForwardTrainer(
/**
* Sets the gradient
+ *
* @param value gradient
* @return trainer
*/
- def setGradient(value: Gradient): FeedForwardTrainer = {
+ def setGradient(value: Gradient): this.type = {
_gradient = value
updateGradient(value)
this
@@ -871,12 +801,20 @@ private[ml] class FeedForwardTrainer(
/**
* Trains the ANN
+ *
* @param data RDD of input and output vector pairs
* @return model
*/
def train(data: RDD[(Vector, Vector)]): TopologyModel = {
- val newWeights = optimizer.optimize(dataStacker.stack(data), getWeights)
- topology.getInstance(newWeights)
+ val w = if (getWeights == null) {
+ // TODO: will make a copy if vector is a subvector of BDV (see Vectors code)
+ topology.model(_seed).weights
+ } else {
+ getWeights
+ }
+ // TODO: deprecate standard optimizer because it needs Vector
+ val newWeights = optimizer.optimize(dataStacker.stack(data), w)
+ topology.model(newWeights)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
new file mode 100644
index 0000000000..32d78e9b22
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/LossFunction.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import java.util.Random
+
+import breeze.linalg.{sum => Bsum, DenseMatrix => BDM, DenseVector => BDV}
+import breeze.numerics.{log => brzlog}
+
+/**
+ * Trait for loss function
+ */
+private[ann] trait LossFunction {
+ /**
+ * Returns the value of loss function.
+ * Computes loss based on target and output.
+ * Writes delta (error) to delta in place.
+ * Delta is allocated based on the outputSize
+ * of model implementation.
+ *
+ * @param output actual output
+ * @param target target output
+ * @param delta delta (updated in place)
+ * @return loss
+ */
+ def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double
+}
+
+private[ann] class SigmoidLayerWithSquaredError extends Layer {
+ override val weightSize = 0
+ override val inPlace = true
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+ override def createModel(weights: BDV[Double]): LayerModel =
+ new SigmoidLayerModelWithSquaredError()
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ new SigmoidLayerModelWithSquaredError()
+}
+
+private[ann] class SigmoidLayerModelWithSquaredError
+ extends FunctionalLayerModel(new FunctionalLayer(new SigmoidFunction)) with LossFunction {
+ override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
+ ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
+ val error = Bsum(delta :* delta) / 2 / output.cols
+ ApplyInPlace(delta, output, delta, (x: Double, o: Double) => x * (o - o * o))
+ error
+ }
+}
+
+private[ann] class SoftmaxLayerWithCrossEntropyLoss extends Layer {
+ override val weightSize = 0
+ override val inPlace = true
+
+ override def getOutputSize(inputSize: Int): Int = inputSize
+ override def createModel(weights: BDV[Double]): LayerModel =
+ new SoftmaxLayerModelWithCrossEntropyLoss()
+ override def initModel(weights: BDV[Double], random: Random): LayerModel =
+ new SoftmaxLayerModelWithCrossEntropyLoss()
+}
+
+private[ann] class SoftmaxLayerModelWithCrossEntropyLoss extends LayerModel with LossFunction {
+
+ // loss layer models do not have weights
+ val weights = new BDV[Double](0)
+
+ override def eval(data: BDM[Double], output: BDM[Double]): Unit = {
+ var j = 0
+ // find max value to make sure later that exponent is computable
+ while (j < data.cols) {
+ var i = 0
+ var max = Double.MinValue
+ while (i < data.rows) {
+ if (data(i, j) > max) {
+ max = data(i, j)
+ }
+ i += 1
+ }
+ var sum = 0.0
+ i = 0
+ while (i < data.rows) {
+ val res = math.exp(data(i, j) - max)
+ output(i, j) = res
+ sum += res
+ i += 1
+ }
+ i = 0
+ while (i < data.rows) {
+ output(i, j) /= sum
+ i += 1
+ }
+ j += 1
+ }
+ }
+ override def computePrevDelta(
+ nextDelta: BDM[Double],
+ input: BDM[Double],
+ delta: BDM[Double]): Unit = {
+ /* loss layer model computes delta in loss function */
+ }
+
+ override def grad(delta: BDM[Double], input: BDM[Double], cumGrad: BDV[Double]): Unit = {
+ /* loss layer model does not have weights */
+ }
+
+ override def loss(output: BDM[Double], target: BDM[Double], delta: BDM[Double]): Double = {
+ ApplyInPlace(output, target, delta, (o: Double, t: Double) => o - t)
+ -Bsum( target :* brzlog(output)) / output.cols
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
index 7ce3ec68da..79bb2a8855 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifier.scala
@@ -24,8 +24,8 @@ import org.apache.hadoop.fs.Path
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.{PredictionModel, Predictor, PredictorParams}
import org.apache.spark.ml.ann.{FeedForwardTopology, FeedForwardTrainer}
-import org.apache.spark.ml.param.{IntArrayParam, IntParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasTol}
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasStepSize, HasTol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
@@ -33,11 +33,12 @@ import org.apache.spark.sql.DataFrame
/** Params for Multilayer Perceptron. */
private[ml] trait MultilayerPerceptronParams extends PredictorParams
- with HasSeed with HasMaxIter with HasTol {
+ with HasSeed with HasMaxIter with HasTol with HasStepSize {
/**
* Layer sizes including input size and output size.
* Default: Array(1, 1)
- * @group param
+ *
+ * @group param
*/
final val layers: IntArrayParam = new IntArrayParam(this, "layers",
"Sizes of layers from input layer to output layer" +
@@ -55,7 +56,8 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
* a partition then it is adjusted to the size of this data.
* Recommended size is between 10 and 1000.
* Default: 128
- * @group expertParam
+ *
+ * @group expertParam
*/
final val blockSize: IntParam = new IntParam(this, "blockSize",
"Block size for stacking input data in matrices. Data is stacked within partitions." +
@@ -66,7 +68,33 @@ private[ml] trait MultilayerPerceptronParams extends PredictorParams
/** @group getParam */
final def getBlockSize: Int = $(blockSize)
- setDefault(maxIter -> 100, tol -> 1e-4, layers -> Array(1, 1), blockSize -> 128)
+ /**
+ * Allows setting the solver: minibatch gradient descent (gd) or l-bfgs.
+ * l-bfgs is the default one.
+ *
+ * @group expertParam
+ */
+ final val solver: Param[String] = new Param[String](this, "solver",
+ " Allows setting the solver: minibatch gradient descent (gd) or l-bfgs. " +
+ " l-bfgs is the default one.",
+ ParamValidators.inArray[String](Array("gd", "l-bfgs")))
+
+ /** @group getParam */
+ final def getOptimizer: String = $(solver)
+
+ /**
+ * Model weights. Can be returned either after training or after explicit setting
+ *
+ * @group expertParam
+ */
+ final val weights: Param[Vector] = new Param[Vector](this, "weights",
+ " Sets the weights of the model ")
+
+ /** @group getParam */
+ final def getWeights: Vector = $(weights)
+
+
+ setDefault(maxIter -> 100, tol -> 1e-4, blockSize -> 128, solver -> "l-bfgs", stepSize -> 0.03)
}
/** Label to vector converter. */
@@ -105,6 +133,7 @@ private object LabelConverter {
* Each layer has sigmoid activation function, output layer has softmax.
* Number of inputs has to be equal to the size of feature vectors.
* Number of outputs has to be equal to the total number of labels.
+ *
*/
@Since("1.5.0")
@Experimental
@@ -127,7 +156,8 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
/**
* Set the maximum number of iterations.
* Default is 100.
- * @group setParam
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setMaxIter(value: Int): this.type = set(maxIter, value)
@@ -136,18 +166,28 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
* Set the convergence tolerance of iterations.
* Smaller value will lead to higher accuracy with the cost of more iterations.
* Default is 1E-4.
- * @group setParam
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setTol(value: Double): this.type = set(tol, value)
/**
- * Set the seed for weights initialization.
- * @group setParam
+ * Set the seed for weights initialization if weights are not set
+ *
+ * @group setParam
*/
@Since("1.5.0")
def setSeed(value: Long): this.type = set(seed, value)
+ /**
+ * Sets the model weights.
+ *
+ * @group expertParam
+ */
+ @Since("2.0.0")
+ def setWeights(value: Vector): this.type = set(weights, value)
+
@Since("1.5.0")
override def copy(extra: ParamMap): MultilayerPerceptronClassifier = defaultCopy(extra)
@@ -165,11 +205,18 @@ class MultilayerPerceptronClassifier @Since("1.5.0") (
val lpData = extractLabeledPoints(dataset)
val data = lpData.map(lp => LabelConverter.encodeLabeledPoint(lp, labels))
val topology = FeedForwardTopology.multiLayerPerceptron(myLayers, true)
- val FeedForwardTrainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
- FeedForwardTrainer.LBFGSOptimizer.setConvergenceTol($(tol)).setNumIterations($(maxIter))
- FeedForwardTrainer.setStackSize($(blockSize))
- val mlpModel = FeedForwardTrainer.train(data)
- new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights())
+ val trainer = new FeedForwardTrainer(topology, myLayers(0), myLayers.last)
+ if (isDefined(weights)) {
+ trainer.setWeights($(weights))
+ } else {
+ trainer.setSeed($(seed))
+ }
+ trainer.LBFGSOptimizer
+ .setConvergenceTol($(tol))
+ .setNumIterations($(maxIter))
+ trainer.setStackSize($(blockSize))
+ val mlpModel = trainer.train(data)
+ new MultilayerPerceptronClassificationModel(uid, myLayers, mlpModel.weights)
}
}
@@ -185,7 +232,8 @@ object MultilayerPerceptronClassifier
* :: Experimental ::
* Classification model based on the Multilayer Perceptron.
* Each layer has sigmoid activation function, output layer has softmax.
- * @param uid uid
+ *
+ * @param uid uid
* @param layers array of layer sizes including input and output layers
* @param weights vector of initial weights for the model that consists of the weights of layers
* @return prediction model
@@ -202,7 +250,7 @@ class MultilayerPerceptronClassificationModel private[ml] (
@Since("1.6.0")
override val numFeatures: Int = layers.head
- private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).getInstance(weights)
+ private val mlpModel = FeedForwardTopology.multiLayerPerceptron(layers, true).model(weights)
/**
* Returns layers in a Java List.
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
index d499d363f1..bc955f3cf6 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaMultilayerPerceptronClassifierSuite.java
@@ -63,7 +63,7 @@ public class JavaMultilayerPerceptronClassifierSuite implements Serializable {
MultilayerPerceptronClassifier mlpc = new MultilayerPerceptronClassifier()
.setLayers(new int[] {2, 5, 2})
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100);
MultilayerPerceptronClassificationModel model = mlpc.fit(dataFrame);
Dataset<Row> result = model.transform(dataFrame);
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
index 1292e57d7c..dc91fc5f9e 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/ANNSuite.scala
@@ -42,7 +42,7 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 1)
trainer.setWeights(initialWeights)
trainer.LBFGSOptimizer.setNumIterations(20)
@@ -76,10 +76,11 @@ class ANNSuite extends SparkFunSuite with MLlibTestSparkContext {
val dataSample = rddData.first()
val layerSizes = dataSample._1.size +: hiddenLayersTopology :+ dataSample._2.size
val topology = FeedForwardTopology.multiLayerPerceptron(layerSizes, false)
- val initialWeights = FeedForwardModel(topology, 23124).weights()
+ val initialWeights = FeedForwardModel(topology, 23124).weights
val trainer = new FeedForwardTrainer(topology, 2, 2)
- trainer.SGDOptimizer.setNumIterations(2000)
- trainer.setWeights(initialWeights)
+ // TODO: add a test for SGD
+ trainer.LBFGSOptimizer.setConvergenceTol(1e-4).setNumIterations(20)
+ trainer.setWeights(initialWeights).setStackSize(1)
val model = trainer.train(rddData)
val predictionAndLabels = rddData.map { case (input, label) =>
(model.predict(input), label)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
new file mode 100644
index 0000000000..04cc426c40
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/ann/GradientSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.ann
+
+import breeze.linalg.{DenseMatrix => BDM}
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class GradientSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("Gradient computation against numerical differentiation") {
+ val input = new BDM[Double](3, 1, Array(1.0, 1.0, 1.0))
+ // output must contain zeros and one 1 for SoftMax
+ val target = new BDM[Double](2, 1, Array(0.0, 1.0))
+ val topology = FeedForwardTopology.multiLayerPerceptron(Array(3, 4, 2), softmaxOnTop = false)
+ val layersWithErrors = Seq(
+ new SigmoidLayerWithSquaredError(),
+ new SoftmaxLayerWithCrossEntropyLoss()
+ )
+ // check all layers that provide loss computation
+ // 1) compute loss and gradient given the model and initial weights
+ // 2) modify weights with small number epsilon (per dimension i)
+ // 3) compute new loss
+ // 4) ((newLoss - loss) / epsilon) should be close to the i-th component of the gradient
+ for (layerWithError <- layersWithErrors) {
+ topology.layers(topology.layers.length - 1) = layerWithError
+ val model = topology.model(seed = 12L)
+ val weights = model.weights.toArray
+ val numWeights = weights.size
+ val gradient = Vectors.dense(Array.fill[Double](numWeights)(0.0))
+ val loss = model.computeGradient(input, target, gradient, 1)
+ val eps = 1e-4
+ var i = 0
+ val tol = 1e-4
+ while (i < numWeights) {
+ val originalValue = weights(i)
+ weights(i) += eps
+ val newModel = topology.model(Vectors.dense(weights))
+ val newLoss = computeLoss(input, target, newModel)
+ val derivativeEstimate = (newLoss - loss) / eps
+ assert(math.abs(gradient(i) - derivativeEstimate) < tol, "Layer failed gradient check: " +
+ layerWithError.getClass)
+ weights(i) = originalValue
+ i += 1
+ }
+ }
+ }
+
+ private def computeLoss(input: BDM[Double], target: BDM[Double], model: TopologyModel): Double = {
+ val outputs = model.forward(input)
+ model.layerModels.last match {
+ case layerWithLoss: LossFunction =>
+ layerWithLoss.loss(outputs.last, target, new BDM[Double](target.rows, target.cols))
+ case _ =>
+ throw new UnsupportedOperationException("Top layer is required to have loss." +
+ " Failed layer:" + model.layerModels.last.getClass)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
index 53c7a559e3..43781385db 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/MultilayerPerceptronClassifierSuite.scala
@@ -65,7 +65,7 @@ class MultilayerPerceptronClassifierSuite
val trainer = new MultilayerPerceptronClassifier()
.setLayers(layers)
.setBlockSize(1)
- .setSeed(11L)
+ .setSeed(123L)
.setMaxIter(100)
val model = trainer.fit(dataset)
val result = model.transform(dataset)
@@ -75,7 +75,29 @@ class MultilayerPerceptronClassifierSuite
}
}
- // TODO: implement a more rigorous test
+ test("Test setWeights by training restart") {
+ val dataFrame = sqlContext.createDataFrame(Seq(
+ (Vectors.dense(0.0, 0.0), 0.0),
+ (Vectors.dense(0.0, 1.0), 1.0),
+ (Vectors.dense(1.0, 0.0), 1.0),
+ (Vectors.dense(1.0, 1.0), 0.0))
+ ).toDF("features", "label")
+ val layers = Array[Int](2, 5, 2)
+ val trainer = new MultilayerPerceptronClassifier()
+ .setLayers(layers)
+ .setBlockSize(1)
+ .setSeed(12L)
+ .setMaxIter(1)
+ .setTol(1e-6)
+ val initialWeights = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights1 = trainer.fit(dataFrame).weights
+ trainer.setWeights(initialWeights.copy)
+ val weights2 = trainer.fit(dataFrame).weights
+ assert(weights1 ~== weights2 absTol 10e-5,
+ "Training should produce the same weights given equal initial weights and number of steps")
+ }
+
test("3 class classification with 2 hidden layers") {
val nPoints = 1000