aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala22
1 files changed, 13 insertions, 9 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index a27ee51874..0a569c4917 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -21,7 +21,9 @@ import java.util.Random
import breeze.linalg.{*, axpy => Baxpy, DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
+import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.optimization._
import org.apache.spark.rdd.RDD
import org.apache.spark.util.random.XORShiftRandom
@@ -580,10 +582,10 @@ private[ann] object FeedForwardModel {
*/
private[ann] class ANNGradient(topology: Topology, dataStacker: DataStacker) extends Gradient {
override def compute(
- data: Vector,
+ data: OldVector,
label: Double,
- weights: Vector,
- cumGradient: Vector): Double = {
+ weights: OldVector,
+ cumGradient: OldVector): Double = {
val (input, target, realBatchSize) = dataStacker.unstack(data)
val model = topology.model(weights)
model.computeGradient(input, target, cumGradient, realBatchSize)
@@ -657,15 +659,15 @@ private[ann] class DataStacker(stackSize: Int, inputSize: Int, outputSize: Int)
private[ann] class ANNUpdater extends Updater {
override def compute(
- weightsOld: Vector,
- gradient: Vector,
+ weightsOld: OldVector,
+ gradient: OldVector,
stepSize: Double,
iter: Int,
- regParam: Double): (Vector, Double) = {
+ regParam: Double): (OldVector, Double) = {
val thisIterStepSize = stepSize
val brzWeights: BV[Double] = weightsOld.toBreeze.toDenseVector
Baxpy(-thisIterStepSize, gradient.toBreeze, brzWeights)
- (Vectors.fromBreeze(brzWeights), 0)
+ (OldVectors.fromBreeze(brzWeights), 0)
}
}
@@ -808,7 +810,9 @@ private[ml] class FeedForwardTrainer(
getWeights
}
// TODO: deprecate standard optimizer because it needs Vector
- val newWeights = optimizer.optimize(dataStacker.stack(data), w)
+ val newWeights = optimizer.optimize(dataStacker.stack(data).map { v =>
+ (v._1, OldVectors.fromML(v._2))
+ }, w)
topology.model(newWeights)
}