aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 22:31:43 -0800
commit8395e8fbdf23bef286ec68a4bbadcc448b504c2c (patch)
tree8bea8ca2bba38d861a8e428b9e295bb8782d8d85 /python/pyspark
parent42d02db86cd973cf31ceeede0c5a723238bbe746 (diff)
downloadspark-8395e8fbdf23bef286ec68a4bbadcc448b504c2c.tar.gz
spark-8395e8fbdf23bef286ec68a4bbadcc448b504c2c.tar.bz2
spark-8395e8fbdf23bef286ec68a4bbadcc448b504c2c.zip
[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD
Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples (cherry picked from commit 1a9c6cddadebdc53d083ac3e0da276ce979b5d1f) Signed-off-by: Xiangrui Meng <meng@databricks.com>
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/linalg.py50
-rw-r--r--python/pyspark/mllib/tests.py39
2 files changed, 86 insertions, 3 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index d0a0e102a1..c0c3dff31e 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,6 +29,9 @@ import copy_reg
import numpy as np
+from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
+ IntegerType, ByteType, Row
+
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
@@ -106,7 +109,54 @@ def _format_float(f, digits=4):
return s
+class VectorUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Vector.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("size", IntegerType(), True),
+ StructField("indices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.VectorUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseVector):
+ indices = [int(i) for i in obj.indices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.size, indices, values)
+ elif isinstance(obj, DenseVector):
+ values = [float(v) for v in obj]
+ return (1, None, None, values)
+ else:
+ raise ValueError("cannot serialize %r of type %r" % (obj, type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 4, \
+ "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseVector(datum[1], datum[2], datum[3])
+ elif tpe == 1:
+ return DenseVector(datum[3])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+
class Vector(object):
+
+ __UDT__ = VectorUDT()
+
"""
Abstract class for DenseVector and SparseVector
"""
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d6fb87b378..9fa4d6f6a2 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -33,14 +33,14 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.serializers import PickleSerializer
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-
_have_scipy = False
try:
import scipy.sparse
@@ -221,6 +221,39 @@ class StatTests(PySparkTestCase):
self.assertEqual(10, summary.count())
+class VectorUDTTests(PySparkTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
+ srdd = sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = srdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise ValueError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):