aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/linalg.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 22:29:48 -0800
commit1a9c6cddadebdc53d083ac3e0da276ce979b5d1f (patch)
treeb485818ba52a9287ae7124e57ef55f1d974f3a1f /python/pyspark/mllib/linalg.py
parent04450d11548cfb25d4fb77d4a33e3a7cd4254183 (diff)
downloadspark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.gz
spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.bz2
spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.zip
[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD
Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples
Diffstat (limited to 'python/pyspark/mllib/linalg.py')
-rw-r--r--python/pyspark/mllib/linalg.py50
1 files changed, 50 insertions, 0 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index d0a0e102a1..c0c3dff31e 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,6 +29,9 @@ import copy_reg
import numpy as np
+from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
+ IntegerType, ByteType, Row
+
__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
@@ -106,7 +109,54 @@ def _format_float(f, digits=4):
return s
+class VectorUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Vector.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("size", IntegerType(), True),
+ StructField("indices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.VectorUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseVector):
+ indices = [int(i) for i in obj.indices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.size, indices, values)
+ elif isinstance(obj, DenseVector):
+ values = [float(v) for v in obj]
+ return (1, None, None, values)
+ else:
+ raise ValueError("cannot serialize %r of type %r" % (obj, type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 4, \
+ "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseVector(datum[1], datum[2], datum[3])
+ elif tpe == 1:
+ return DenseVector(datum[3])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+
class Vector(object):
+
+ __UDT__ = VectorUDT()
+
"""
Abstract class for DenseVector and SparseVector
"""