aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala1
2 files changed, 8 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 6434b64aed..cb29392e8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
- copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
+ copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
- val bestModel: Model[_])
+ val bestModel: Model[_],
+ val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {
override def validateParams(): Unit = {
@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
}
override def copy(extra: ParamMap): CrossValidatorModel = {
- val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+ val copied = new CrossValidatorModel(
+ uid,
+ bestModel.copy(extra).asInstanceOf[Model[_]],
+ avgMetrics.clone())
copyValues(copied, extra)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 5ba469c7b1..9b3619f004 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
+ assert(cvModel.avgMetrics.length === lrParamMaps.length)
}
test("validateParams should check estimatorParamMaps") {