diff options
author | Peter Rudenko <petro.rudenko@gmail.com> | 2015-02-16 00:07:23 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-16 00:07:31 -0800 |
commit | 0d932058ed95c2b65dc308fd523cfea6d9b29b16 (patch) | |
tree | 1e57423c8d7c150c41d19b4c0f60b3455bac87aa /mllib | |
parent | 9cf7d7088d245b9b41ec78295cd2d6e3e395793d (diff) | |
download | spark-0d932058ed95c2b65dc308fd523cfea6d9b29b16.tar.gz spark-0d932058ed95c2b65dc308fd523cfea6d9b29b16.tar.bz2 spark-0d932058ed95c2b65dc308fd523cfea6d9b29b16.zip |
[Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop
On a big dataset explicitly unpersist train and validation folds allows to load more data into memory in the next loop iteration. On my environment (single node 8Gb worker RAM, 2 GB dataset file, 3 folds for cross validation), saved more than 5 minutes.
Author: Peter Rudenko <petro.rudenko@gmail.com>
Closes #4595 from petro-rudenko/patch-2 and squashes the following commits:
66a7cfb [Peter Rudenko] Move validationDataset cache to declaration
c5f3265 [Peter Rudenko] [Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop
(cherry picked from commit d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index b139bc8dcb..b07a68269c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -108,6 +108,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) @@ -115,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP metrics(i) += metric i += 1 } + validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") |