aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authoryinxusen <yinxusen@gmail.com>2016-05-03 14:19:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-05-03 14:19:13 -0700
commit2e2a6211c4391d67edb2a252f26647fb059bc18b (patch)
tree366ab0e2c2c9a073f5c39a42076540369b2e897d /mllib
parentd6c7b2a5cc11a82e5137ee86350550e06e81f609 (diff)
downloadspark-2e2a6211c4391d67edb2a252f26647fb059bc18b.tar.gz
spark-2e2a6211c4391d67edb2a252f26647fb059bc18b.tar.bz2
spark-2e2a6211c4391d67edb2a252f26647fb059bc18b.zip
[SPARK-14973][ML] The CrossValidator and TrainValidationSplit miss the seed when saving and loading
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14973 Add seed support when saving/loading of CrossValidator and TrainValidationSplit. ## How was this patch tested? Spark unit test. Author: yinxusen <yinxusen@gmail.com> Closes #12825 from yinxusen/SPARK-14973.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala17
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala9
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala2
5 files changed, 30 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index a41d02cde7..7d42da4a2f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -30,7 +30,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml._
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.{DataFrame, Dataset}
@@ -39,7 +38,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
-private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
+private[ml] trait CrossValidatorParams extends ValidatorParams {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
@@ -179,11 +178,13 @@ object CrossValidator extends MLReadable[CrossValidator] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
+ val seed = (metadata.params \ "seed").extract[Long]
new CrossValidator(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setNumFolds(numFolds)
+ .setSeed(seed)
}
}
}
@@ -267,14 +268,16 @@ object CrossValidatorModel extends MLReadable[CrossValidatorModel] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val numFolds = (metadata.params \ "numFolds").extract[Int]
+ val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val avgMetrics = (metadata.metadata \ "avgMetrics").extract[Seq[Double]].toArray
- val cv = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
- cv.set(cv.estimator, estimator)
- .set(cv.evaluator, evaluator)
- .set(cv.estimatorParamMaps, estimatorParamMaps)
- .set(cv.numFolds, numFolds)
+ val model = new CrossValidatorModel(metadata.uid, bestModel, avgMetrics)
+ model.set(model.estimator, estimator)
+ .set(model.evaluator, evaluator)
+ .set(model.estimatorParamMaps, estimatorParamMaps)
+ .set(model.numFolds, numFolds)
+ .set(model.seed, seed)
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index f2b7badbe5..f6f2bad401 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -30,7 +30,6 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
-import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.types.StructType
@@ -38,7 +37,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
*/
-private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed {
+private[ml] trait TrainValidationSplitParams extends ValidatorParams {
/**
* Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75
@@ -177,11 +176,13 @@ object TrainValidationSplit extends MLReadable[TrainValidationSplit] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ val seed = (metadata.params \ "seed").extract[Long]
new TrainValidationSplit(metadata.uid)
.setEstimator(estimator)
.setEvaluator(evaluator)
.setEstimatorParamMaps(estimatorParamMaps)
.setTrainRatio(trainRatio)
+ .setSeed(seed)
}
}
}
@@ -265,14 +266,16 @@ object TrainValidationSplitModel extends MLReadable[TrainValidationSplitModel] {
val (metadata, estimator, evaluator, estimatorParamMaps) =
ValidatorParams.loadImpl(path, sc, className)
val trainRatio = (metadata.params \ "trainRatio").extract[Double]
+ val seed = (metadata.params \ "seed").extract[Long]
val bestModelPath = new Path(path, "bestModel").toString
val bestModel = DefaultParamsReader.loadParamsInstance[Model[_]](bestModelPath, sc)
val validationMetrics = (metadata.metadata \ "validationMetrics").extract[Seq[Double]].toArray
- val tvs = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
- tvs.set(tvs.estimator, estimator)
- .set(tvs.evaluator, evaluator)
- .set(tvs.estimatorParamMaps, estimatorParamMaps)
- .set(tvs.trainRatio, trainRatio)
+ val model = new TrainValidationSplitModel(metadata.uid, bestModel, validationMetrics)
+ model.set(model.estimator, estimator)
+ .set(model.evaluator, evaluator)
+ .set(model.estimatorParamMaps, estimatorParamMaps)
+ .set(model.trainRatio, trainRatio)
+ .set(model.seed, seed)
}
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
index 7a4e106aeb..26fd73814d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/ValidatorParams.scala
@@ -25,15 +25,15 @@ import org.apache.spark.SparkContext
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params}
-import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite,
- MLWritable}
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MetaAlgorithmReadWrite, MLWritable}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.sql.types.StructType
/**
* Common params for [[TrainValidationSplitParams]] and [[CrossValidatorParams]].
*/
-private[ml] trait ValidatorParams extends Params {
+private[ml] trait ValidatorParams extends HasSeed with Params {
/**
* param for the estimator to be validated
@@ -137,7 +137,8 @@ private[ml] object ValidatorParams {
}
val jsonParams = validatorSpecificParams ++ List(
- "estimatorParamMaps" -> parse(estimatorParamMapsJson))
+ "estimatorParamMaps" -> parse(estimatorParamMapsJson),
+ "seed" -> parse(instance.seed.jsonEncode(instance.getSeed)))
DefaultParamsWriter.saveMetadata(instance, path, sc, extraMetadata, Some(jsonParams))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 3e734aabc5..061d04c932 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -136,6 +136,7 @@ class CrossValidatorSuite
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
@@ -186,6 +187,7 @@ class CrossValidatorSuite
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
@@ -259,6 +261,7 @@ class CrossValidatorSuite
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index dbee47c847..df9ba418b8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -127,6 +127,7 @@ class TrainValidationSplitSuite
val tvs2 = testDefaultReadWrite(tvs, testParams = false)
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ assert(tvs.getSeed === tvs2.getSeed)
}
test("read/write: TrainValidationSplitModel") {
@@ -149,6 +150,7 @@ class TrainValidationSplitSuite
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
assert(tvs.validationMetrics === tvs2.validationMetrics)
+ assert(tvs.getSeed === tvs2.getSeed)
}
}