aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala9
1 files changed, 7 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 576584c627..88909a9fb9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.optimization._
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -810,9 +811,13 @@ private[ml] class FeedForwardTrainer(
getWeights
}
// TODO: deprecate standard optimizer because it needs Vector
- val newWeights = optimizer.optimize(dataStacker.stack(data).map { v =>
+ val trainData = dataStacker.stack(data).map { v =>
(v._1, OldVectors.fromML(v._2))
- }, w)
+ }
+ val handlePersistence = trainData.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) trainData.persist(StorageLevel.MEMORY_AND_DISK)
+ val newWeights = optimizer.optimize(trainData, w)
+ if (handlePersistence) trainData.unpersist()
topology.model(newWeights)
}