aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/linalg.py')
-rw-r--r--python/pyspark/mllib/linalg.py59
1 files changed, 57 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 23d1a79ffe..e96c5ef87d 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -36,7 +36,7 @@ else:
import numpy as np
from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
- IntegerType, ByteType
+ IntegerType, ByteType, BooleanType
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
@@ -163,6 +163,59 @@ class VectorUDT(UserDefinedType):
return "vector"
+class MatrixUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Matrix.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("numRows", IntegerType(), False),
+ StructField("numCols", IntegerType(), False),
+ StructField("colPtrs", ArrayType(IntegerType(), False), True),
+ StructField("rowIndices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True),
+ StructField("isTransposed", BooleanType(), False)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.MatrixUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseMatrix):
+ colPtrs = [int(i) for i in obj.colPtrs]
+ rowIndices = [int(i) for i in obj.rowIndices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.numRows, obj.numCols, colPtrs,
+ rowIndices, values, bool(obj.isTransposed))
+ elif isinstance(obj, DenseMatrix):
+ values = [float(v) for v in obj.values]
+ return (1, obj.numRows, obj.numCols, None, None, values,
+ bool(obj.isTransposed))
+ else:
+ raise TypeError("cannot serialize type %r" % (type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 7, \
+ "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseMatrix(*datum[1:])
+ elif tpe == 1:
+ return DenseMatrix(datum[1], datum[2], datum[5], datum[6])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+ def simpleString(self):
+ return "matrix"
+
+
class Vector(object):
__UDT__ = VectorUDT()
@@ -781,10 +834,12 @@ class Vectors(object):
class Matrix(object):
+
+ __UDT__ = MatrixUDT()
+
"""
Represents a local matrix.
"""
-
def __init__(self, numRows, numCols, isTransposed=False):
self.numRows = numRows
self.numCols = numCols