aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/linalg.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/linalg.py')
-rw-r--r--python/pyspark/mllib/linalg.py13
1 files changed, 12 insertions, 1 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index c0c3dff31e..e35202dca0 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -33,7 +33,7 @@ from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, Dou
IntegerType, ByteType, Row
-__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
+__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices']
if sys.version_info[:2] == (2, 7):
@@ -578,6 +578,8 @@ class DenseMatrix(Matrix):
def __init__(self, numRows, numCols, values):
Matrix.__init__(self, numRows, numCols)
assert len(values) == numRows * numCols
+ if not isinstance(values, array.array):
+ values = array.array('d', values)
self.values = values
def __reduce__(self):
@@ -596,6 +598,15 @@ class DenseMatrix(Matrix):
return np.reshape(self.values, (self.numRows, self.numCols), order='F')
+class Matrices(object):
+ @staticmethod
+ def dense(numRows, numCols, values):
+ """
+ Create a DenseMatrix
+ """
+ return DenseMatrix(numRows, numCols, values)
+
+
def _test():
import doctest
(failure_count, test_count) = doctest.testmod(optionflags=doctest.ELLIPSIS)