aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-12-28 01:24:18 -0800
committerYanbo Liang <ybliang8@gmail.com>2016-12-28 01:24:18 -0800
commit9cff67f3465bc6ffe1b5abee9501e3c17f8fd194 (patch)
tree305d9659b94863ffea51b7f5f3456297e7abfedb /mllib/src
parent79ff8536315aef97ee940c52d71cd8de777c7ce6 (diff)
downloadspark-9cff67f3465bc6ffe1b5abee9501e3c17f8fd194.tar.gz
spark-9cff67f3465bc6ffe1b5abee9501e3c17f8fd194.tar.bz2
spark-9cff67f3465bc6ffe1b5abee9501e3c17f8fd194.zip
[MINOR][ML] Correct test cases of LoR raw2prediction & probability2prediction.
## What changes were proposed in this pull request? Correct test cases of ```LogisticRegression``` raw2prediction & probability2prediction. ## How was this patch tested? Changed unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #16407 from yanboliang/raw-probability.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala20
1 files changed, 18 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 9c4c59a5e6..f8bcbeedfb 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -359,8 +359,16 @@ class LogisticRegressionSuite
assert(pred == predFromProb)
}
- // force it to use probability2prediction
+ // force it to use raw2prediction
model.setProbabilityCol("")
+ val resultsUsingRaw2Predict =
+ model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
+ resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
+ case (pred1, pred2) => assert(pred1 === pred2)
+ }
+
+ // force it to use probability2prediction
+ model.setRawPredictionCol("")
val resultsUsingProb2Predict =
model.transform(smallMultinomialDataset).select("prediction").as[Double].collect()
resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
@@ -405,8 +413,16 @@ class LogisticRegressionSuite
assert(pred == predFromProb)
}
- // force it to use probability2prediction
+ // force it to use raw2prediction
model.setProbabilityCol("")
+ val resultsUsingRaw2Predict =
+ model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
+ resultsUsingRaw2Predict.zip(results.select("prediction").as[Double].collect()).foreach {
+ case (pred1, pred2) => assert(pred1 === pred2)
+ }
+
+ // force it to use probability2prediction
+ model.setRawPredictionCol("")
val resultsUsingProb2Predict =
model.transform(smallBinaryDataset).select("prediction").as[Double].collect()
resultsUsingProb2Predict.zip(results.select("prediction").as[Double].collect()).foreach {