aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWeichenXu <WeichenXu123@outlook.com>2016-08-04 21:41:35 +0100
committerSean Owen <sowen@cloudera.com>2016-08-04 21:41:35 +0100
commit462784ffad77e43455dd0364064ce4994826a426 (patch)
tree7a0e5b1b7c9cad629f2a7e68374e2ecb9ea0fb68
parentbe8ea4b2f7ddf1196111acb61fe1a79866376003 (diff)
downloadspark-462784ffad77e43455dd0364064ce4994826a426.tar.gz
spark-462784ffad77e43455dd0364064ce4994826a426.tar.bz2
spark-462784ffad77e43455dd0364064ce4994826a426.zip
[SPARK-16880][ML][MLLIB] make ann training data persisted if needed
## What changes were proposed in this pull request? To Make sure ANN layer input training data to be persisted, so that it can avoid overhead cost if the RDD need to be computed from lineage. ## How was this patch tested? Existing Tests. Author: WeichenXu <WeichenXu123@outlook.com> Closes #14483 from WeichenXu123/add_ann_persist_training_data.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala9
1 files changed, 7 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
index 576584c627..88909a9fb9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/ann/Layer.scala
@@ -26,6 +26,7 @@ import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors
import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.mllib.optimization._
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
/**
@@ -810,9 +811,13 @@ private[ml] class FeedForwardTrainer(
getWeights
}
// TODO: deprecate standard optimizer because it needs Vector
- val newWeights = optimizer.optimize(dataStacker.stack(data).map { v =>
+ val trainData = dataStacker.stack(data).map { v =>
(v._1, OldVectors.fromML(v._2))
- }, w)
+ }
+ val handlePersistence = trainData.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) trainData.persist(StorageLevel.MEMORY_AND_DISK)
+ val newWeights = optimizer.optimize(trainData, w)
+ if (handlePersistence) trainData.unpersist()
topology.model(newWeights)
}