aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorPeter Rudenko <petro.rudenko@gmail.com>2015-02-16 00:07:23 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-16 00:07:23 -0800
commitd51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f (patch)
tree5c33800c77fad824ad451a988cf4f8ee706dbb43 /mllib
parentc78a12c4cc4d4312c4ee1069d3b218882d32d678 (diff)
downloadspark-d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f.tar.gz
spark-d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f.tar.bz2
spark-d51d6ba1547ae75ac76c9e6d8ea99e937eb7d09f.zip
[Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop
On a big dataset explicitly unpersist train and validation folds allows to load more data into memory in the next loop iteration. On my environment (single node 8Gb worker RAM, 2 GB dataset file, 3 folds for cross validation), saved more than 5 minutes. Author: Peter Rudenko <petro.rudenko@gmail.com> Closes #4595 from petro-rudenko/patch-2 and squashes the following commits: 66a7cfb [Peter Rudenko] Move validationDataset cache to declaration c5f3265 [Peter Rudenko] [Ml] SPARK-5804 Explicitly manage cache in Crossvalidator k-fold loop
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala2
1 files changed, 2 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b139bc8dcb..b07a68269c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -108,6 +108,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ trainingDataset.unpersist()
var i = 0
while (i < numModels) {
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
@@ -115,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
metrics(i) += metric
i += 1
}
+ validationDataset.unpersist()
}
f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")