aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2016-03-16 14:19:54 -0700
committerDB Tsai <dbt@netflix.com>2016-03-16 14:19:54 -0700
commit85c42fda99973a0c35c743816a06ce9117bb1aad (patch)
treeded163492ceb349b435d611135ada7d7aba7f43e
parent6fc2b6541fd5ab73b289af5f7296fc602b5b4dce (diff)
downloadspark-85c42fda99973a0c35c743816a06ce9117bb1aad.tar.gz
spark-85c42fda99973a0c35c743816a06ce9117bb1aad.tar.bz2
spark-85c42fda99973a0c35c743816a06ce9117bb1aad.zip
[SPARK-13927][MLLIB] add row/column iterator to local matrices
## What changes were proposed in this pull request? Add row/column iterator to local matrices to simplify tasks like BlockMatrix => RowMatrix conversion. It handles dense and sparse matrices properly. ## How was this patch tested? Unit tests on sparse and dense matrix. cc: dbtsai Author: Xiangrui Meng <meng@databricks.com> Closes #11757 from mengxr/SPARK-13927.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala64
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala13
-rw-r--r--project/MimaExcludes.scala4
3 files changed, 80 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 0fdb402fd6..fdede2ad39 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -22,10 +22,11 @@ import java.util.{Arrays, Random}
import scala.collection.mutable.{ArrayBuffer, ArrayBuilder => MArrayBuilder, HashSet => MHashSet}
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
+import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.annotation.{DeveloperApi, Since}
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.catalyst.util.GenericArrayData
import org.apache.spark.sql.types._
@@ -58,6 +59,20 @@ sealed trait Matrix extends Serializable {
newArray
}
+ /**
+ * Returns an iterator of column vectors.
+ * This operation could be expensive, depending on the underlying storage.
+ */
+ @Since("2.0.0")
+ def colIter: Iterator[Vector]
+
+ /**
+ * Returns an iterator of row vectors.
+ * This operation could be expensive, depending on the underlying storage.
+ */
+ @Since("2.0.0")
+ def rowIter: Iterator[Vector] = this.transpose.colIter
+
/** Converts to a breeze matrix. */
private[mllib] def toBreeze: BM[Double]
@@ -386,6 +401,21 @@ class DenseMatrix @Since("1.3.0") (
}
new SparseMatrix(numRows, numCols, colPtrs, rowIndices.result(), spVals.result())
}
+
+ @Since("2.0.0")
+ override def colIter: Iterator[Vector] = {
+ if (isTransposed) {
+ Iterator.tabulate(numCols) { j =>
+ val col = new Array[Double](numRows)
+ blas.dcopy(numRows, values, j, numCols, col, 0, 1)
+ new DenseVector(col)
+ }
+ } else {
+ Iterator.tabulate(numCols) { j =>
+ new DenseVector(values.slice(j * numRows, (j + 1) * numRows))
+ }
+ }
+ }
}
/**
@@ -656,6 +686,38 @@ class SparseMatrix @Since("1.3.0") (
@Since("1.5.0")
override def numActives: Int = values.length
+ @Since("2.0.0")
+ override def colIter: Iterator[Vector] = {
+ if (isTransposed) {
+ val indicesArray = Array.fill(numCols)(MArrayBuilder.make[Int])
+ val valuesArray = Array.fill(numCols)(MArrayBuilder.make[Double])
+ var i = 0
+ while (i < numRows) {
+ var k = colPtrs(i)
+ val rowEnd = colPtrs(i + 1)
+ while (k < rowEnd) {
+ val j = rowIndices(k)
+ indicesArray(j) += i
+ valuesArray(j) += values(k)
+ k += 1
+ }
+ i += 1
+ }
+ Iterator.tabulate(numCols) { j =>
+ val ii = indicesArray(j).result()
+ val vv = valuesArray(j).result()
+ new SparseVector(numRows, ii, vv)
+ }
+ } else {
+ Iterator.tabulate(numCols) { j =>
+ val colStart = colPtrs(j)
+ val colEnd = colPtrs(j + 1)
+ val ii = rowIndices.slice(colStart, colEnd)
+ val vv = values.slice(colStart, colEnd)
+ new SparseVector(numRows, ii, vv)
+ }
+ }
+ }
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index 1833cf3833..a02b8c9635 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -494,4 +494,17 @@ class MatricesSuite extends SparkFunSuite {
assert(sm1.numNonzeros === 1)
assert(sm1.numActives === 3)
}
+
+ test("row/col iterator") {
+ val dm = new DenseMatrix(3, 2, Array(0, 1, 2, 3, 4, 0))
+ val sm = dm.toSparse
+ val rows = Seq(Vectors.dense(0, 3), Vectors.dense(1, 4), Vectors.dense(2, 0))
+ val cols = Seq(Vectors.dense(0, 1, 2), Vectors.dense(3, 4, 0))
+ for (m <- Seq(dm, sm)) {
+ assert(m.rowIter.toSeq === rows)
+ assert(m.colIter.toSeq === cols)
+ assert(m.transpose.rowIter.toSeq === cols)
+ assert(m.transpose.colIter.toSeq === rows)
+ }
+ }
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 985eb98bc3..59c7e7db2e 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -531,6 +531,10 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.scheduler.SparkListener.onOtherEvent"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.CreatableRelationProvider.createRelation"),
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.sources.InsertableRelation.insert")
+ ) ++ Seq(
+ // SPARK-13927: add row/column iterator to local matrices
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.rowIter"),
+ ProblemFilters.exclude[MissingMethodProblem]("org.apache.spark.mllib.linalg.Matrix.colIter")
)
case v if v.startsWith("1.6") =>
Seq(