diff options
author | MechCoder <manojkumarsivaraj334@gmail.com> | 2015-06-17 11:10:16 -0700 |
---|---|---|
committer | Davies Liu <davies@databricks.com> | 2015-06-17 11:10:16 -0700 |
commit | 6765ef98dff070768bbcd585d341ee7664fbe76c (patch) | |
tree | 9e88fed33b78a098e8fb91b276eefe0957644e72 /python/pyspark/mllib/linalg.py | |
parent | 104f30c36f3d44b7567f6f77adb92e0a96494541 (diff) | |
download | spark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.gz spark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.bz2 spark-6765ef98dff070768bbcd585d341ee7664fbe76c.zip |
[SPARK-6390] [SQL] [MLlib] Port MatrixUDT to PySpark
MatrixUDT was recently coded in scala. This has been ported to PySpark
Author: MechCoder <manojkumarsivaraj334@gmail.com>
Closes #6354 from MechCoder/spark-6390 and squashes the following commits:
fc4dc1e [MechCoder] Better error message
c940a44 [MechCoder] Added test
aa9c391 [MechCoder] Add pyUDT to MatrixUDT
62a2a7d [MechCoder] [SPARK-6390] Port MatrixUDT to PySpark
Diffstat (limited to 'python/pyspark/mllib/linalg.py')
-rw-r--r-- | python/pyspark/mllib/linalg.py | 59 |
1 files changed, 57 insertions, 2 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 23d1a79ffe..e96c5ef87d 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -36,7 +36,7 @@ else: import numpy as np from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ - IntegerType, ByteType + IntegerType, ByteType, BooleanType __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', @@ -163,6 +163,59 @@ class VectorUDT(UserDefinedType): return "vector" +class MatrixUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Matrix. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("numRows", IntegerType(), False), + StructField("numCols", IntegerType(), False), + StructField("colPtrs", ArrayType(IntegerType(), False), True), + StructField("rowIndices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True), + StructField("isTransposed", BooleanType(), False)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.MatrixUDT" + + def serialize(self, obj): + if isinstance(obj, SparseMatrix): + colPtrs = [int(i) for i in obj.colPtrs] + rowIndices = [int(i) for i in obj.rowIndices] + values = [float(v) for v in obj.values] + return (0, obj.numRows, obj.numCols, colPtrs, + rowIndices, values, bool(obj.isTransposed)) + elif isinstance(obj, DenseMatrix): + values = [float(v) for v in obj.values] + return (1, obj.numRows, obj.numCols, None, None, values, + bool(obj.isTransposed)) + else: + raise TypeError("cannot serialize type %r" % (type(obj))) + + def deserialize(self, datum): + assert len(datum) == 7, \ + "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseMatrix(*datum[1:]) + elif tpe == 1: + return DenseMatrix(datum[1], datum[2], datum[5], datum[6]) + else: + raise ValueError("do not recognize type %r" % tpe) + + def simpleString(self): + return "matrix" + + class Vector(object): __UDT__ = VectorUDT() @@ -781,10 +834,12 @@ class Vectors(object): class Matrix(object): + + __UDT__ = MatrixUDT() + """ Represents a local matrix. """ - def __init__(self, numRows, numCols, isTransposed=False): self.numRows = numRows self.numCols = numCols |