diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-28 12:03:46 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-28 12:03:46 -0700 |
commit | 530efe3e80c62b25c869b85167e00330eb1ddea6 (patch) | |
tree | ff3ef6ff89d24c21e5a8aaea24146d21466893de /mllib | |
parent | 000df2f0d6af068bb188e81bbb207f0c2f43bf16 (diff) | |
download | spark-530efe3e80c62b25c869b85167e00330eb1ddea6.tar.gz spark-530efe3e80c62b25c869b85167e00330eb1ddea6.tar.bz2 spark-530efe3e80c62b25c869b85167e00330eb1ddea6.zip |
[SPARK-7911] [MLLIB] A workaround for VectorUDT serialize (or deserialize) being called multiple times
~~A PythonUDT shouldn't be serialized into external Scala types in PythonRDD. I'm not sure whether this should fix one of the bugs related to SQL UDT/UDF in PySpark.~~
The fix above didn't work. So I added a workaround for this. If a Python UDF is applied to a Python UDT. This will put the Python SQL types as inputs. Still incorrect, but at least it doesn't throw exceptions on the Scala side. davies harsha2010
Author: Xiangrui Meng <meng@databricks.com>
Closes #6442 from mengxr/SPARK-7903 and squashes the following commits:
c257d2a [Xiangrui Meng] add a workaround for VectorUDT
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 19 |
1 files changed, 14 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index f6bcdf83cd..2ffa497a99 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { } override def serialize(obj: Any): Row = { - val row = new GenericMutableRow(4) obj match { case SparseVector(size, indices, values) => + val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) row.update(2, indices.toSeq) row.update(3, values.toSeq) + row case DenseVector(values) => + val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) row.update(3, values.toSeq) + row + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case row: Row => + row } - row } override def deserialize(datum: Any): Vector = { datum match { - // TODO: something wrong with UDT serialization - case v: Vector => - v case row: Row => require(row.length == 4, s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4") @@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val values = row.getAs[Iterable[Double]](3).toArray new DenseVector(values) } + // TODO: There are bugs in UDT serialization because we don't have a clear separation between + // TODO: internal SQL types and language specific types (including UDT). UDT serialize and + // TODO: deserialize may get called twice. See SPARK-7186. + case v: Vector => + v } } |