aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala
diff options
context:
space:
mode:
authoryinxusen <yinxusen@gmail.com>2016-05-03 14:19:13 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-05-03 14:19:13 -0700
commit2e2a6211c4391d67edb2a252f26647fb059bc18b (patch)
tree366ab0e2c2c9a073f5c39a42076540369b2e897d /mllib/src/test/scala
parentd6c7b2a5cc11a82e5137ee86350550e06e81f609 (diff)
downloadspark-2e2a6211c4391d67edb2a252f26647fb059bc18b.tar.gz
spark-2e2a6211c4391d67edb2a252f26647fb059bc18b.tar.bz2
spark-2e2a6211c4391d67edb2a252f26647fb059bc18b.zip
[SPARK-14973][ML] The CrossValidator and TrainValidationSplit miss the seed when saving and loading
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-14973 Add seed support when saving/loading of CrossValidator and TrainValidationSplit. ## How was this patch tested? Spark unit test. Author: yinxusen <yinxusen@gmail.com> Closes #12825 from yinxusen/SPARK-14973.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala2
2 files changed, 5 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 3e734aabc5..061d04c932 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -136,6 +136,7 @@ class CrossValidatorSuite
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
@@ -186,6 +187,7 @@ class CrossValidatorSuite
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
assert(cv.getEvaluator.uid === cv2.getEvaluator.uid)
@@ -259,6 +261,7 @@ class CrossValidatorSuite
assert(cv.uid === cv2.uid)
assert(cv.getNumFolds === cv2.getNumFolds)
+ assert(cv.getSeed === cv2.getSeed)
assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator])
val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index dbee47c847..df9ba418b8 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -127,6 +127,7 @@ class TrainValidationSplitSuite
val tvs2 = testDefaultReadWrite(tvs, testParams = false)
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
+ assert(tvs.getSeed === tvs2.getSeed)
}
test("read/write: TrainValidationSplitModel") {
@@ -149,6 +150,7 @@ class TrainValidationSplitSuite
assert(tvs.getTrainRatio === tvs2.getTrainRatio)
assert(tvs.validationMetrics === tvs2.validationMetrics)
+ assert(tvs.getSeed === tvs2.getSeed)
}
}