aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-03-20 17:13:18 -0400
committerXiangrui Meng <meng@databricks.com>2015-03-20 17:13:18 -0400
commit11e025956be3818c00effef0d650734f8feeb436 (patch)
treeb63913e36a5e40819a32696b13b180894eafbaba /mllib/src
parent49a01c7ea2c48feee7ab4551c4fa03fd1cdb1a32 (diff)
downloadspark-11e025956be3818c00effef0d650734f8feeb436.tar.gz
spark-11e025956be3818c00effef0d650734f8feeb436.tar.bz2
spark-11e025956be3818c00effef0d650734f8feeb436.zip
[SPARK-6309] [SQL] [MLlib] Implement MatrixUDT
Utilities to serialize and deserialize Matrices in MLlib Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #5048 from MechCoder/spark-6309 and squashes the following commits: 05dc6f2 [MechCoder] Hashcode and organize imports 16d5d47 [MechCoder] Test some more 6e67020 [MechCoder] TST: Test using Array conversion instead of equals 7fa7a2c [MechCoder] [SPARK-6309] [SQL] [MLlib] Implement MatrixUDT
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala90
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala13
2 files changed, 103 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index fdd8848189..849f44295f 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -23,9 +23,15 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+
/**
* Trait for a local matrix.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
sealed trait Matrix extends Serializable {
/** Number of rows. */
@@ -102,6 +108,88 @@ sealed trait Matrix extends Serializable {
private[spark] def foreachActive(f: (Int, Int, Double) => Unit)
}
+@DeveloperApi
+private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
+
+ override def sqlType: StructType = {
+ // type: 0 = sparse, 1 = dense
+ // the dense matrix is built by numRows, numCols, values and isTransposed, all of which are
+ // set as not nullable, except values since in the future, support for binary matrices might
+ // be added for which values are not needed.
+ // the sparse matrix needs colPtrs and rowIndices, which are set as
+ // null, while building the dense matrix.
+ StructType(Seq(
+ StructField("type", ByteType, nullable = false),
+ StructField("numRows", IntegerType, nullable = false),
+ StructField("numCols", IntegerType, nullable = false),
+ StructField("colPtrs", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("rowIndices", ArrayType(IntegerType, containsNull = false), nullable = true),
+ StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true),
+ StructField("isTransposed", BooleanType, nullable = false)
+ ))
+ }
+
+ override def serialize(obj: Any): Row = {
+ val row = new GenericMutableRow(7)
+ obj match {
+ case sm: SparseMatrix =>
+ row.setByte(0, 0)
+ row.setInt(1, sm.numRows)
+ row.setInt(2, sm.numCols)
+ row.update(3, sm.colPtrs.toSeq)
+ row.update(4, sm.rowIndices.toSeq)
+ row.update(5, sm.values.toSeq)
+ row.setBoolean(6, sm.isTransposed)
+
+ case dm: DenseMatrix =>
+ row.setByte(0, 1)
+ row.setInt(1, dm.numRows)
+ row.setInt(2, dm.numCols)
+ row.setNullAt(3)
+ row.setNullAt(4)
+ row.update(5, dm.values.toSeq)
+ row.setBoolean(6, dm.isTransposed)
+ }
+ row
+ }
+
+ override def deserialize(datum: Any): Matrix = {
+ datum match {
+ // TODO: something wrong with UDT serialization, should never happen.
+ case m: Matrix => m
+ case row: Row =>
+ require(row.length == 7,
+ s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
+ val tpe = row.getByte(0)
+ val numRows = row.getInt(1)
+ val numCols = row.getInt(2)
+ val values = row.getAs[Iterable[Double]](5).toArray
+ val isTransposed = row.getBoolean(6)
+ tpe match {
+ case 0 =>
+ val colPtrs = row.getAs[Iterable[Int]](3).toArray
+ val rowIndices = row.getAs[Iterable[Int]](4).toArray
+ new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
+ case 1 =>
+ new DenseMatrix(numRows, numCols, values, isTransposed)
+ }
+ }
+ }
+
+ override def userClass: Class[Matrix] = classOf[Matrix]
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case v: MatrixUDT => true
+ case _ => false
+ }
+ }
+
+ override def hashCode(): Int = 1994
+
+ private[spark] override def asNullable: MatrixUDT = this
+}
+
/**
* Column-major dense matrix.
* The entry values are stored in a single array of doubles with columns listed in sequence.
@@ -119,6 +207,7 @@ sealed trait Matrix extends Serializable {
* @param isTransposed whether the matrix is transposed. If true, `values` stores the matrix in
* row major.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
class DenseMatrix(
val numRows: Int,
val numCols: Int,
@@ -360,6 +449,7 @@ object DenseMatrix {
* Compressed Sparse Row (CSR) format, where `colPtrs` behaves as rowPtrs,
* and `rowIndices` behave as colIndices, and `values` are stored in row major.
*/
+@SQLUserDefinedType(udt = classOf[MatrixUDT])
class SparseMatrix(
val numRows: Int,
val numCols: Int,
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
index c098b5458f..96f677db3f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/MatricesSuite.scala
@@ -424,4 +424,17 @@ class MatricesSuite extends FunSuite {
assert(mat.rowIndices.toSeq === Seq(3, 0, 2, 1))
assert(mat.values.toSeq === Seq(1.0, 2.0, 3.0, 4.0))
}
+
+ test("MatrixUDT") {
+ val dm1 = new DenseMatrix(2, 2, Array(0.9, 1.2, 2.3, 9.8))
+ val dm2 = new DenseMatrix(3, 2, Array(0.0, 1.21, 2.3, 9.8, 9.0, 0.0))
+ val dm3 = new DenseMatrix(0, 0, Array())
+ val sm1 = dm1.toSparse
+ val sm2 = dm2.toSparse
+ val sm3 = dm3.toSparse
+ val mUDT = new MatrixUDT()
+ Seq(dm1, dm2, dm3, sm1, sm2, sm3).foreach {
+ mat => assert(mat.toArray === mUDT.deserialize(mUDT.serialize(mat)).toArray)
+ }
+ }
}