diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-12-28 01:24:18 -0800 |
---|---|---|
committer | Yanbo Liang <ybliang8@gmail.com> | 2016-12-28 01:24:18 -0800 |
commit | 9cff67f3465bc6ffe1b5abee9501e3c17f8fd194 (patch) | |
tree | 305d9659b94863ffea51b7f5f3456297e7abfedb /mllib/src/test/scala | |
parent | 79ff8536315aef97ee940c52d71cd8de777c7ce6 (diff) | |
download | spark-9cff67f3465bc6ffe1b5abee9501e3c17f8fd194.tar.gz spark-9cff67f3465bc6ffe1b5abee9501e3c17f8fd194.tar.bz2 spark-9cff67f3465bc6ffe1b5abee9501e3c17f8fd194.zip |
[MINOR][ML] Correct test cases of LoR raw2prediction & probability2prediction.
## What changes were proposed in this pull request?
Correct test cases of ```LogisticRegression``` raw2prediction & probability2prediction.
## How was this patch tested?
Changed unit tests.
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #16407 from yanboliang/raw-probability.
Diffstat (limited to 'mllib/src/test/scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala | 20 |
1 files changed, 18 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 9c4c59a5e6..f8bcbeedfb 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -359,8 +359,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallMultinomialDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { @@ -405,8 +413,16 @@ class LogisticRegressionSuite assert(pred == predFromProb) } - // force it to use probability2prediction + // force it to use raw2prediction model.setProbabilityCol("") + val resultsUsingRaw2Predict = + model.transform(smallBinaryDataset).select("prediction").as[Double].collect() + resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach { + case (pred1, pred2) => assert(pred1 === pred2) + } + + // force it to use probability2prediction + model.setRawPredictionCol("") val resultsUsingProb2Predict = model.transform(smallBinaryDataset).select("prediction").as[Double].collect() resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach { |