diff options
author | WeichenXu <WeichenXu123@outlook.com> | 2016-08-04 21:41:35 +0100 |
---|---|---|
committer | Sean Owen <sowen@cloudera.com> | 2016-08-04 21:41:35 +0100 |
commit | 462784ffad77e43455dd0364064ce4994826a426 (patch) | |
tree | 7a0e5b1b7c9cad629f2a7e68374e2ecb9ea0fb68 /mllib/src/main/scala | |
parent | be8ea4b2f7ddf1196111acb61fe1a79866376003 (diff) | |
download | spark-462784ffad77e43455dd0364064ce4994826a426.tar.gz spark-462784ffad77e43455dd0364064ce4994826a426.tar.bz2 spark-462784ffad77e43455dd0364064ce4994826a426.zip |
[SPARK-16880][ML][MLLIB] make ann training data persisted if needed
## What changes were proposed in this pull request?
To Make sure ANN layer input training data to be persisted,
so that it can avoid overhead cost if the RDD need to be computed from lineage.
## How was this patch tested?
Existing Tests.
Author: WeichenXu <WeichenXu123@outlook.com>
Closes #14483 from WeichenXu123/add_ann_persist_training_data.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala | 9 |
1 files changed, 7 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala index 576584c627..88909a9fb9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala @@ -26,6 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors import org.apache.spark.mllib.linalg.VectorImplicits._ import org.apache.spark.mllib.optimization._ import org.apache.spark.rdd.RDD +import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.XORShiftRandom /** @@ -810,9 +811,13 @@ private[ml] class FeedForwardTrainer( getWeights } // TODO: deprecate standard optimizer because it needs Vector - val newWeights = optimizer.optimize(dataStacker.stack(data).map { v => + val trainData = dataStacker.stack(data).map { v => (v._1, OldVectors.fromML(v._2)) - }, w) + } + val handlePersistence = trainData.getStorageLevel == StorageLevel.NONE + if (handlePersistence) trainData.persist(StorageLevel.MEMORY_AND_DISK) + val newWeights = optimizer.optimize(trainData, w) + if (handlePersistence) trainData.unpersist() topology.model(newWeights) } |