From 1a9c6cddadebdc53d083ac3e0da276ce979b5d1f Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Mon, 3 Nov 2014 22:29:48 -0800 Subject: [SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley. ~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~ marmbrus jkbradley Author: Xiangrui Meng Closes #3070 from mengxr/SPARK-3573 and squashes the following commits: 3a0b6e5 [Xiangrui Meng] organize imports 236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples --- python/pyspark/mllib/linalg.py | 50 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) (limited to 'python/pyspark/mllib/linalg.py') diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index d0a0e102a1..c0c3dff31e 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -29,6 +29,9 @@ import copy_reg import numpy as np +from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \ + IntegerType, ByteType, Row + __all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors'] @@ -106,7 +109,54 @@ def _format_float(f, digits=4): return s +class VectorUDT(UserDefinedType): + """ + SQL user-defined type (UDT) for Vector. + """ + + @classmethod + def sqlType(cls): + return StructType([ + StructField("type", ByteType(), False), + StructField("size", IntegerType(), True), + StructField("indices", ArrayType(IntegerType(), False), True), + StructField("values", ArrayType(DoubleType(), False), True)]) + + @classmethod + def module(cls): + return "pyspark.mllib.linalg" + + @classmethod + def scalaUDT(cls): + return "org.apache.spark.mllib.linalg.VectorUDT" + + def serialize(self, obj): + if isinstance(obj, SparseVector): + indices = [int(i) for i in obj.indices] + values = [float(v) for v in obj.values] + return (0, obj.size, indices, values) + elif isinstance(obj, DenseVector): + values = [float(v) for v in obj] + return (1, None, None, values) + else: + raise ValueError("cannot serialize %r of type %r" % (obj, type(obj))) + + def deserialize(self, datum): + assert len(datum) == 4, \ + "VectorUDT.deserialize given row with length %d but requires 4" % len(datum) + tpe = datum[0] + if tpe == 0: + return SparseVector(datum[1], datum[2], datum[3]) + elif tpe == 1: + return DenseVector(datum[3]) + else: + raise ValueError("do not recognize type %r" % tpe) + + class Vector(object): + + __UDT__ = VectorUDT() + """ Abstract class for DenseVector and SparseVector """ -- cgit v1.2.3