diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala | 3 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala | 2 |
2 files changed, 5 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index 3e734aabc5..061d04c932 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -136,6 +136,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] @@ -186,6 +187,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) assert(cv.getEvaluator.uid === cv2.getEvaluator.uid) @@ -259,6 +261,7 @@ class CrossValidatorSuite assert(cv.uid === cv2.uid) assert(cv.getNumFolds === cv2.getNumFolds) + assert(cv.getSeed === cv2.getSeed) assert(cv2.getEvaluator.isInstanceOf[BinaryClassificationEvaluator]) val evaluator2 = cv2.getEvaluator.asInstanceOf[BinaryClassificationEvaluator] diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index dbee47c847..df9ba418b8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -127,6 +127,7 @@ class TrainValidationSplitSuite val tvs2 = testDefaultReadWrite(tvs, testParams = false) assert(tvs.getTrainRatio === tvs2.getTrainRatio) + assert(tvs.getSeed === tvs2.getSeed) } test("read/write: TrainValidationSplitModel") { @@ -149,6 +150,7 @@ class TrainValidationSplitSuite assert(tvs.getTrainRatio === tvs2.getTrainRatio) assert(tvs.validationMetrics === tvs2.validationMetrics) + assert(tvs.getSeed === tvs2.getSeed) } } |