aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorleahmcguire <lmcguire@salesforce.com>2015-06-03 15:46:38 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-03 15:46:38 -0700
commitd8662cd909a41575df6e0ea1630d2386d3711240 (patch)
tree6b30157f11fb19753fb3677b12630650d6b8a699 /mllib
parent26c9d7a0f975009e22ec91e5c0b5cfcada79b35e (diff)
downloadspark-d8662cd909a41575df6e0ea1630d2386d3711240.tar.gz
spark-d8662cd909a41575df6e0ea1630d2386d3711240.tar.bz2
spark-d8662cd909a41575df6e0ea1630d2386d3711240.zip
[SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting
Added stats from cross validation as a val in the cross validation model to save them for user access. Author: leahmcguire <lmcguire@salesforce.com> Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits: 49b507b [leahmcguire] fixed tyle error 67537b1 [leahmcguire] rebased 85907f0 [leahmcguire] fixed name 59987cc [leahmcguire] changed param name and test according to comments 36e71e3 [leahmcguire] rebasing 4b8223e [leahmcguire] fixed name 4ddffc6 [leahmcguire] changed param name and test according to comments 3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala10
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala1
2 files changed, 8 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 6434b64aed..cb29392e8b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
- copyValues(new CrossValidatorModel(uid, bestModel).setParent(this))
+ copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this))
}
override def transformSchema(schema: StructType): StructType = {
@@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
@Experimental
class CrossValidatorModel private[ml] (
override val uid: String,
- val bestModel: Model[_])
+ val bestModel: Model[_],
+ val avgMetrics: Array[Double])
extends Model[CrossValidatorModel] with CrossValidatorParams {
override def validateParams(): Unit = {
@@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] (
}
override def copy(extra: ParamMap): CrossValidatorModel = {
- val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]])
+ val copied = new CrossValidatorModel(
+ uid,
+ bestModel.copy(extra).asInstanceOf[Model[_]],
+ avgMetrics.clone())
copyValues(copied, extra)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 5ba469c7b1..9b3619f004 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(parent.getRegParam === 0.001)
assert(parent.getMaxIter === 10)
+ assert(cvModel.avgMetrics.length === lrParamMaps.length)
}
test("validateParams should check estimatorParamMaps") {