From 2e2a6211c4391d67edb2a252f26647fb059bc18b Mon Sep 17 00:00:00 2001 From: yinxusen Date: Tue, 3 May 2016 14:19:13 -0700 Subject: [SPARK-14973][ML] The CrossValidator and TrainValidationSplit miss the seed when saving and loading ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14973 Add seed support when saving/loading of CrossValidator and TrainValidationSplit. ## How was this patch tested? Spark unit test. Author: yinxusen Closes #12825 from yinxusen/SPARK-14973. --- .../org/apache/spark/ml/tuning/CrossValidator.scala | 17 ++++++++++------- .../apache/spark/ml/tuning/TrainValidationSplit.scala | 17 ++++++++++------- .../org/apache/spark/ml/tuning/ValidatorParams.scala | 9 +++++---- .../apache/spark/ml/tuning/CrossValidatorSuite.scala | 3 +++ .../spark/ml/tuning/TrainValidationSplitSuite.scala | 2 ++ 5 files changed, 30 insertions(+), 18 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index a41d02cde7..7d42da4a2f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -30,7 +30,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml._ import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.sql.{DataFrame, Dataset} @@ -39,7 +38,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[CrossValidator]] and [[CrossValidatorModel]]. */ -private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed { +private[ml] trait CrossValidatorParams extends ValidatorParams { /** * Param for number of folds for cross validation. Must be >= 2. * Default: 3 @@ -179,11 +178,13 @@ object CrossValidator extends MLReadable[CrossValidator] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] + val seed = (metadata.params \ "seed").extract[Long] new CrossValidator(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) .setNumFolds(numFolds) + .setSeed(seed) } } } @@ -267,14 +268,16 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val numFolds = (metadata.params \ "numFolds").extract[Int] + val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray - val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) - cv.set(cv.estimator, estimator) - .set(cv.evaluator, evaluator) - .set(cv.estimatorParamMaps, estimatorParamMaps) - .set(cv.numFolds, numFolds) + val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics) + model.set(model.estimator, estimator) + .set(model.evaluator, evaluator) + .set(model.estimatorParamMaps, estimatorParamMaps) + .set(model.numFolds, numFolds) + .set(model.seed, seed) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index f2b7badbe5..f6f2bad401 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -30,7 +30,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators} -import org.apache.spark.ml.param.shared.HasSeed import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -38,7 +37,7 @@ import org.apache.spark.sql.types.StructType /** * Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]]. */ -private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed { +private[ml] trait TrainValidationSplitParams extends ValidatorParams { /** * Param for ratio between train and validation data. Must be between 0 and 1. * Default: 0.75 @@ -177,11 +176,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val trainRatio = (metadata.params \ "trainRatio").extract[Double] + val seed = (metadata.params \ "seed").extract[Long] new TrainValidationSplit(metadata.uid) .setEstimator(estimator) .setEvaluator(evaluator) .setEstimatorParamMaps(estimatorParamMaps) .setTrainRatio(trainRatio) + .setSeed(seed) } } } @@ -265,14 +266,16 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] { val (metadata, estimator, evaluator, estimatorParamMaps) = ValidatorParams.loadImpl(path, sc, className) val trainRatio = (metadata.params \ "trainRatio").extract[Double] + val seed = (metadata.params \ "seed").extract[Long] val bestModelPath = new Path(path, "bestModel").toString val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc) val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray - val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) - tvs.set(tvs.estimator, estimator) - .set(tvs.evaluator, evaluator) - .set(tvs.estimatorParamMaps, estimatorParamMaps) - .set(tvs.trainRatio, trainRatio) + val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics) + model.set(model.estimator, estimator) + .set(model.evaluator, evaluator) + .set(model.estimatorParamMaps, estimatorParamMaps) + .set(model.trainRatio, trainRatio) + .set(model.seed, seed) } } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala index 7a4e106aeb..26fd73814d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala @@ -25,15 +25,15 @@ import org.apache.spark.SparkContext import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.evaluation.Evaluator import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} -import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, - MLWritable} +import org.apache.spark.ml.param.shared.HasSeed +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable} import org.apache.spark.ml.util.DefaultParamsReader.Metadata import org.apache.spark.sql.types.StructType /** * Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]]. */ -private[ml] trait ValidatorParams extends Params { +private[ml] trait ValidatorParams extends HasSeed with Params { /** * param for the estimator to be validated @@ -137,7 +137,8 @@ private[ml] object ValidatorParams { } val jsonParams = validatorSpecificParams ++ List( - "estimatorParamMaps" -> parse(estimatorParamMapsJson)) + "estimatorParamMaps" -> parse(estimatorParamMapsJson), + "seed" -> parse(instance.seed.jsonEncode(instance.getSeed))) DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams)) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 3e734aabc5..061d04c932 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -136,6 +136,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -186,6 +187,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) @@ -259,6 +261,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index dbee47c847..df9ba418b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -127,6 +127,7 @@ class TrainValidationSplitSuite val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.getSeed === tvs2.getSeed) } test("read/write: TrainValidationSplitModel") { @@ -149,6 +150,7 @@ class TrainValidationSplitSuite assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.validationMetrics === tvs2.validationMetrics) + assert(tvs.getSeed === tvs2.getSeed) } } -- cgit v1.2.3