aboutsummaryrefslogtreecommitdiff
path: root/mllib-local/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib-local/src')
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala8
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2327917e2c..30edd00fb5 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -32,6 +32,10 @@ object TestingUtils {
* the relative tolerance is meaningless, so the exception will be raised to warn users.
*/
private def RelativeErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ // Special case for NaNs
+ if (x.isNaN && y.isNaN) {
+ return true
+ }
val absX = math.abs(x)
val absY = math.abs(y)
val diff = math.abs(x - y)
@@ -49,6 +53,10 @@ object TestingUtils {
* Private helper function for comparing two values using absolute tolerance.
*/
private def AbsoluteErrorComparison(x: Double, y: Double, eps: Double): Boolean = {
+ // Special case for NaNs
+ if (x.isNaN && y.isNaN) {
+ return true
+ }
math.abs(x - y) < eps
}