aboutsummaryrefslogtreecommitdiff
path: root/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala')
-rw-r--r--mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala10
1 files changed, 6 insertions, 4 deletions
diff --git a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
index 2bebaa35ba..2327917e2c 100644
--- a/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
+++ b/mllib-local/src/test/scala/org/apache/spark/ml/util/TestingUtils.scala
@@ -154,7 +154,7 @@ object TestingUtils {
*/
def absTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide(
(x: Vector, y: Vector, eps: Double) => {
- x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps)
+ x.size == y.size && x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps)
}, x, eps, ABS_TOL_MSG)
/**
@@ -164,7 +164,7 @@ object TestingUtils {
*/
def relTol(eps: Double): CompareVectorRightSide = CompareVectorRightSide(
(x: Vector, y: Vector, eps: Double) => {
- x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
+ x.size == y.size && x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
}, x, eps, REL_TOL_MSG)
override def toString: String = x.toString
@@ -217,7 +217,8 @@ object TestingUtils {
*/
def absTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide(
(x: Matrix, y: Matrix, eps: Double) => {
- x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps)
+ x.numRows == y.numRows && x.numCols == y.numCols &&
+ x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 absTol eps)
}, x, eps, ABS_TOL_MSG)
/**
@@ -227,7 +228,8 @@ object TestingUtils {
*/
def relTol(eps: Double): CompareMatrixRightSide = CompareMatrixRightSide(
(x: Matrix, y: Matrix, eps: Double) => {
- x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
+ x.numRows == y.numRows && x.numCols == y.numCols &&
+ x.toArray.zip(y.toArray).forall(x => x._1 ~= x._2 relTol eps)
}, x, eps, REL_TOL_MSG)
override def toString: String = x.toString