aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala2
1 files changed, 2 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index b139bc8dcb..b07a68269c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -108,6 +108,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
// multi-model training
logDebug(s"Train split $splitIndex with multiple sets of parameters.")
val models = est.fit(trainingDataset, epm).asInstanceOf[Seq[Model[_]]]
+ trainingDataset.unpersist()
var i = 0
while (i < numModels) {
val metric = eval.evaluate(models(i).transform(validationDataset, epm(i)), map)
@@ -115,6 +116,7 @@ class CrossValidator extends Estimator[CrossValidatorModel] with CrossValidatorP
metrics(i) += metric
i += 1
}
+ validationDataset.unpersist()
}
f2jBLAS.dscal(numModels, 1.0 / map(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")