aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-08-11 12:49:47 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-11 12:49:47 -0700
commit520ad44b17f72e6465bf990f64b4e289f8a83447 (patch)
tree9e633e74f6c975f874622b3a4e6e44a292b4cd60 /mllib
parent5831294a7a8fa2524133c5d718cbc8187d2b0620 (diff)
downloadspark-520ad44b17f72e6465bf990f64b4e289f8a83447.tar.gz
spark-520ad44b17f72e6465bf990f64b4e289f8a83447.tar.bz2
spark-520ad44b17f72e6465bf990f64b4e289f8a83447.zip
[SPARK-9750] [MLLIB] Improve equals on SparseMatrix and DenseMatrix
Adds unit test for `equals` on `mllib.linalg.Matrix` class and `equals` to both `SparseMatrix` and `DenseMatrix`. Supports equality testing between `SparseMatrix` and `DenseMatrix`. mengxr Author: Feynman Liang <fliang@databricks.com> Closes #8042 from feynmanliang/SPARK-9750 and squashes the following commits: bb70d5e [Feynman Liang] Breeze compare for dense matrices as well, in case other is sparse ab6f3c8 [Feynman Liang] Sparse matrix compare for equals 22782df [Feynman Liang] Add equality based on matrix semantics, not representation 78f9426 [Feynman Liang] Add casts 43d28fa [Feynman Liang] Fix failing test 6416fa0 [Feynman Liang] Add failing sparse matrix equals tests
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala18
2 files changed, 24 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 1c858348bf..1139ce36d5 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -257,8 +257,7 @@ class DenseMatrix(
this(numRows, numCols, values, false)
override def equals(o: Any): Boolean = o match {
- case m: DenseMatrix =>
- m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
+ case m: Matrix => toBreeze == m.toBreeze
case _ => false
}
@@ -519,6 +518,11 @@ class SparseMatrix(
rowIndices: Array[Int],
values: Array[Double]) = this(numRows, numCols, colPtrs, rowIndices, values, false)
+ override def equals(o: Any): Boolean = o match {
+ case m: Matrix => toBreeze == m.toBreeze
+ case _ => false
+ }
+
private[mllib] def toBreeze: BM[Double] = {
if (!isTransposed) {
new BSM[Double](values, numRows, numCols, colPtrs, rowIndices)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index a270ba2562..bfd6d5495f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -74,6 +74,24 @@ class MatricesSuite extends SparkFunSuite {
}
}
+ test("equals") {
+ val dm1 = Matrices.dense(2, 2, Array(0.0, 1.0, 2.0, 3.0))
+ assert(dm1 === dm1)
+ assert(dm1 !== dm1.transpose)
+
+ val dm2 = Matrices.dense(2, 2, Array(0.0, 2.0, 1.0, 3.0))
+ assert(dm1 === dm2.transpose)
+
+ val sm1 = dm1.asInstanceOf[DenseMatrix].toSparse
+ assert(sm1 === sm1)
+ assert(sm1 === dm1)
+ assert(sm1 !== sm1.transpose)
+
+ val sm2 = dm2.asInstanceOf[DenseMatrix].toSparse
+ assert(sm1 === sm2.transpose)
+ assert(sm1 === dm2.transpose)
+ }
+
test("matrix copies are deep copies") {
val m = 3
val n = 2