diff options
author | Peter Rudenko <petro.rudenko@gmail.com> | 2015-02-16 00:07:23 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-02-16 00:07:23 -0800 |
commit | d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f (patch) | |
tree | 5c33800c77fad824ad451a988cf4f8ee706dbb43 | |
parent | c78a12c4cc4d4312c4ee1069d3b218882d32d678 (diff) | |
download | spark-d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f.tar.gz spark-d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f.tar.bz2 spark-d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f.zip |
[Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop
On a big dataset explicitly unpersist train and validation folds allows to load more data into memory in the next loop iteration. On my environment (single node 8Gb worker RAM, 2 GB dataset file, 3 folds for cross validation), saved more than 5 minutes.
Author: Peter Rudenko <petro.rudenko@gmail.com>
Closes #4595 from petro-rudenko/patch-2 and squashes the following commits:
66a7cfb [Peter Rudenko] Move validationDataset cache to declaration
c5f3265 [Peter Rudenko] [Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 2 |
1 files changed, 2 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index b139bc8dcb..b07a68269c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -108,6 +108,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP // multi-model training logDebug(s"Train split $splitIndex with multiple sets of parameters.") val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]] + trainingDataset.unpersist() var i = 0 while (i < numModels) { val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map) @@ -115,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP metrics(i) += metric i += 1 } + validationDataset.unpersist() } f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") |