aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala19
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala10
-rw-r--r--project/MimaExcludes.scala6
3 files changed, 35 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0a615494bb..75e7004464 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -114,6 +114,16 @@ sealed trait Matrix extends Serializable {
* corresponding value in the matrix with type `Double`.
*/
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
+
+ /**
+ * Find the number of non-zero active values.
+ */
+ def numNonzeros: Int
+
+ /**
+ * Find the number of values stored explicitly. These values can be zero as well.
+ */
+ def numActives: Int
}
@DeveloperApi
@@ -324,6 +334,10 @@ class DenseMatrix(
}
}
+ override def numNonzeros: Int = values.count(_ != 0)
+
+ override def numActives: Int = values.length
+
/**
* Generate a `SparseMatrix` from the given `DenseMatrix`. The new matrix will have isTransposed
* set to false.
@@ -593,6 +607,11 @@ class SparseMatrix(
def toDense: DenseMatrix = {
new DenseMatrix(numRows, numCols, toArray)
}
+
+ override def numNonzeros: Int = values.count(_ != 0)
+
+ override def numActives: Int = values.length
+
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 8dbb70f5d1..a270ba2562 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -455,4 +455,14 @@ class MatricesSuite extends SparkFunSuite {
lines = mat.toString(5, 100).lines.toArray
assert(lines.size == 5 && lines.forall(_.size <= 100))
}
+
+ test("numNonzeros and numActives") {
+ val dm1 = Matrices.dense(3, 2, Array(0, 0, -1, 1, 0, 1))
+ assert(dm1.numNonzeros === 3)
+ assert(dm1.numActives === 6)
+
+ val sm1 = Matrices.sparse(3, 2, Array(0, 2, 3), Array(0, 2, 1), Array(0.0, -1.2, 0.0))
+ assert(sm1.numNonzeros === 1)
+ assert(sm1.numActives === 3)
+ }
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6f86a505b3..680b699e9e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -75,6 +75,12 @@ object MimaExcludes {
"org.apache.spark.sql.parquet.ParquetTypeInfo"),
ProblemFilters.exclude[MissingClassProblem](
"org.apache.spark.sql.parquet.ParquetTypeInfo$")
+ ) ++ Seq(
+ // SPARK-8479 Add numNonzeros and numActives to Matrix.
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.numNonzeros"),
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.mllib.linalg.Matrix.numActives")
)
case v if v.startsWith("1.4") =>
Seq(