From 462784ffad77e43455dd0364064ce4994826a426 Mon Sep 17 00:00:00 2001 From: WeichenXu Date: Thu, 4 Aug 2016 21:41:35 +0100 Subject: [SPARK-16880][ML][MLLIB] make ann training data persisted if needed ## What changes were proposed in this pull request? To Make sure ANN layer input training data to be persisted, so that it can avoid overhead cost if the RDD need to be computed from lineage. ## How was this patch tested? Existing Tests. Author: WeichenXu Closes #14483 from WeichenXu123/add_ann_persist_training_data. --- mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 576584c627..88909a9fb9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.optimization._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -810,9 +811,13 @@ private[ml] class FeedForwardTrainer( getWeights } // TODO: deprecate standard optimizer because it needs Vector - val newWeights = optimizer.optimize(dataStacker.stack(data).map { v => + val trainData = dataStacker.stack(data).map { v => (v._1, OldVectors.fromML(v._2)) - }, w) + } + val handlePersistence = trainData.getStorageLevel == StorageLevel.NONE + if (handlePersistence) trainData.persist(StorageLevel.MEMORY_AND_DISK) + val newWeights = optimizer.optimize(trainData, w) + if (handlePersistence) trainData.unpersist() topology.model(newWeights) } -- cgit v1.2.3