diff options
author | leahmcguire <lmcguire@salesforce.com> | 2015-06-03 15:46:38 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-06-03 15:46:38 -0700 |
commit | d8662cd909a41575df6e0ea1630d2386d3711240 (patch) | |
tree | 6b30157f11fb19753fb3677b12630650d6b8a699 /mllib | |
parent | 26c9d7a0f975009e22ec91e5c0b5cfcada79b35e (diff) | |
download | spark-d8662cd909a41575df6e0ea1630d2386d3711240.tar.gz spark-d8662cd909a41575df6e0ea1630d2386d3711240.tar.bz2 spark-d8662cd909a41575df6e0ea1630d2386d3711240.zip |
[SPARK-6164] [ML] CrossValidatorModel should keep stats from fitting
Added stats from cross validation as a val in the cross validation model to save them for user access.
Author: leahmcguire <lmcguire@salesforce.com>
Closes #5915 from leahmcguire/saveCVmetrics and squashes the following commits:
49b507b [leahmcguire] fixed tyle error
67537b1 [leahmcguire] rebased
85907f0 [leahmcguire] fixed name
59987cc [leahmcguire] changed param name and test according to comments
36e71e3 [leahmcguire] rebasing
4b8223e [leahmcguire] fixed name
4ddffc6 [leahmcguire] changed param name and test according to comments
3a995da [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala | 10 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 1 |
2 files changed, 8 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 6434b64aed..cb29392e8b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -135,7 +135,7 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] - copyValues(new CrossValidatorModel(uid, bestModel).setParent(this)) + copyValues(new CrossValidatorModel(uid, bestModel, metrics).setParent(this)) } override def transformSchema(schema: StructType): StructType = { @@ -158,7 +158,8 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM @Experimental class CrossValidatorModel private[ml] ( override val uid: String, - val bestModel: Model[_]) + val bestModel: Model[_], + val avgMetrics: Array[Double]) extends Model[CrossValidatorModel] with CrossValidatorParams { override def validateParams(): Unit = { @@ -175,7 +176,10 @@ class CrossValidatorModel private[ml] ( } override def copy(extra: ParamMap): CrossValidatorModel = { - val copied = new CrossValidatorModel(uid, bestModel.copy(extra).asInstanceOf[Model[_]]) + val copied = new CrossValidatorModel( + uid, + bestModel.copy(extra).asInstanceOf[Model[_]], + avgMetrics.clone()) copyValues(copied, extra) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 5ba469c7b1..9b3619f004 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -56,6 +56,7 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext { val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(parent.getRegParam === 0.001) assert(parent.getMaxIter === 10) + assert(cvModel.avgMetrics.length === lrParamMaps.length) } test("validateParams should check estimatorParamMaps") { |