aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorMechCoder <manojkumarsivaraj334@gmail.com>2015-06-17 11:10:16 -0700
committerDavies Liu <davies@databricks.com>2015-06-17 11:10:16 -0700
commit6765ef98dff070768bbcd585d341ee7664fbe76c (patch)
tree9e88fed33b78a098e8fb91b276eefe0957644e72 /python
parent104f30c36f3d44b7567f6f77adb92e0a96494541 (diff)
downloadspark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.gz
spark-6765ef98dff070768bbcd585d341ee7664fbe76c.tar.bz2
spark-6765ef98dff070768bbcd585d341ee7664fbe76c.zip
[SPARK-6390] [SQL] [MLlib] Port MatrixUDT to PySpark
MatrixUDT was recently coded in scala. This has been ported to PySpark Author: MechCoder <manojkumarsivaraj334@gmail.com> Closes #6354 from MechCoder/spark-6390 and squashes the following commits: fc4dc1e [MechCoder] Better error message c940a44 [MechCoder] Added test aa9c391 [MechCoder] Add pyUDT to MatrixUDT 62a2a7d [MechCoder] [SPARK-6390] Port MatrixUDT to PySpark
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg.py59
-rw-r--r--python/pyspark/mllib/tests.py34
-rw-r--r--python/pyspark/sql/dataframe.py6
3 files changed, 95 insertions, 4 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 23d1a79ffe..e96c5ef87d 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -36,7 +36,7 @@ else:
import numpy as np
from pyspark.sql.types import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
- IntegerType, ByteType
+ IntegerType, ByteType, BooleanType
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors',
@@ -163,6 +163,59 @@ class VectorUDT(UserDefinedType):
return "vector"
+class MatrixUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Matrix.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("numRows", IntegerType(), False),
+ StructField("numCols", IntegerType(), False),
+ StructField("colPtrs", ArrayType(IntegerType(), False), True),
+ StructField("rowIndices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True),
+ StructField("isTransposed", BooleanType(), False)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.MatrixUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseMatrix):
+ colPtrs = [int(i) for i in obj.colPtrs]
+ rowIndices = [int(i) for i in obj.rowIndices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.numRows, obj.numCols, colPtrs,
+ rowIndices, values, bool(obj.isTransposed))
+ elif isinstance(obj, DenseMatrix):
+ values = [float(v) for v in obj.values]
+ return (1, obj.numRows, obj.numCols, None, None, values,
+ bool(obj.isTransposed))
+ else:
+ raise TypeError("cannot serialize type %r" % (type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 7, \
+ "MatrixUDT.deserialize given row with length %d but requires 7" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseMatrix(*datum[1:])
+ elif tpe == 1:
+ return DenseMatrix(datum[1], datum[2], datum[5], datum[6])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+ def simpleString(self):
+ return "matrix"
+
+
class Vector(object):
__UDT__ = VectorUDT()
@@ -781,10 +834,12 @@ class Vectors(object):
class Matrix(object):
+
+ __UDT__ = MatrixUDT()
+
"""
Represents a local matrix.
"""
-
def __init__(self, numRows, numCols, isTransposed=False):
self.numRows = numRows
self.numCols = numCols
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 36a4c7a540..f4c997261e 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -39,7 +39,7 @@ else:
from pyspark import SparkContext
from pyspark.mllib.common import _to_java_object_rdd
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
- DenseMatrix, SparseMatrix, Vectors, Matrices
+ DenseMatrix, SparseMatrix, Vectors, Matrices, MatrixUDT
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -507,6 +507,38 @@ class VectorUDTTests(MLlibTestCase):
raise TypeError("expecting a vector but got %r of type %r" % (v, type(v)))
+class MatrixUDTTests(MLlibTestCase):
+
+ dm1 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10])
+ dm2 = DenseMatrix(3, 2, [0, 1, 4, 5, 9, 10], isTransposed=True)
+ sm1 = SparseMatrix(1, 1, [0, 1], [0], [2.0])
+ sm2 = SparseMatrix(2, 1, [0, 0, 1], [0], [5.0], isTransposed=True)
+ udt = MatrixUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(MatrixUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for m in [self.dm1, self.dm2, self.sm1, self.sm2]:
+ self.assertEqual(m, self.udt.deserialize(self.udt.serialize(m)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([("dense", self.dm1), ("sparse", self.sm1)])
+ df = rdd.toDF()
+ schema = df.schema
+ self.assertTrue(schema.fields[1].dataType, self.udt)
+ matrices = df.map(lambda x: x._2).collect()
+ self.assertEqual(len(matrices), 2)
+ for m in matrices:
+ if isinstance(m, DenseMatrix):
+ self.assertTrue(m, self.dm1)
+ elif isinstance(m, SparseMatrix):
+ self.assertTrue(m, self.sm1)
+ else:
+ raise ValueError("Expected a matrix but got type %r" % type(m))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(MLlibTestCase):
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index 9615e57649..152b87351d 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -194,7 +194,11 @@ class DataFrame(object):
StructType(List(StructField(age,IntegerType,true),StructField(name,StringType,true)))
"""
if self._schema is None:
- self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ try:
+ self._schema = _parse_datatype_json_string(self._jdf.schema().json())
+ except AttributeError as e:
+ raise Exception(
+ "Unable to parse datatype from schema. %s" % e)
return self._schema
@since(1.3)