diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala | 22 |
1 files changed, 13 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index a27ee51874..0a569c4917 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -21,7 +21,9 @@ import java.util.Random import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV} -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors} +import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.optimization._ import org.apache.spark.rdd.RDD import org.apache.spark.util.random.XORShiftRandom @@ -580,10 +582,10 @@ private[ann] object FeedForwardModel { */ private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient { override def compute( - data: Vector, + data: OldVector, label: Double, - weights: Vector, - cumGradient: Vector): Double = { + weights: OldVector, + cumGradient: OldVector): Double = { val (input, target, realBatchSize) = dataStacker.unstack(data) val model = topology.model(weights) model.computeGradient(input, target, cumGradient, realBatchSize) @@ -657,15 +659,15 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int) private[ann] class ANNUpdater extends Updater { override def compute( - weightsOld: Vector, - gradient: Vector, + weightsOld: OldVector, + gradient: OldVector, stepSize: Double, iter: Int, - regParam: Double): (Vector, Double) = { + regParam: Double): (OldVector, Double) = { val thisIterStepSize = stepSize val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights) - (Vectors.fromBreeze(brzWeights), 0) + (OldVectors.fromBreeze(brzWeights), 0) } } @@ -808,7 +810,9 @@ private[ml] class FeedForwardTrainer( getWeights } // TODO: deprecate standard optimizer because it needs Vector - val newWeights = optimizer.optimize(dataStacker.stack(data), w) + val newWeights = optimizer.optimize(dataStacker.stack(data).map { v => + (v._1, OldVectors.fromML(v._2)) + }, w) topology.model(newWeights) } |