aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-28 12:03:46 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-28 12:03:46 -0700
commit530efe3e80c62b25c869b85167e00330eb1ddea6 (patch)
treeff3ef6ff89d24c21e5a8aaea24146d21466893de /mllib
parent000df2f0d6af068bb188e81bbb207f0c2f43bf16 (diff)
downloadspark-530efe3e80c62b25c869b85167e00330eb1ddea6.tar.gz
spark-530efe3e80c62b25c869b85167e00330eb1ddea6.tar.bz2
spark-530efe3e80c62b25c869b85167e00330eb1ddea6.zip
[SPARK-7911] [MLLIB] A workaround for VectorUDT serialize (or deserialize) being called multiple times
~~A PythonUDT shouldn't be serialized into external Scala types in PythonRDD. I'm not sure whether this should fix one of the bugs related to SQL UDT/UDF in PySpark.~~ The fix above didn't work. So I added a workaround for this. If a Python UDF is applied to a Python UDT. This will put the Python SQL types as inputs. Still incorrect, but at least it doesn't throw exceptions on the Scala side. davies harsha2010 Author: Xiangrui Meng <meng@databricks.com> Closes #6442 from mengxr/SPARK-7903 and squashes the following commits: c257d2a [Xiangrui Meng] add a workaround for VectorUDT
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala19
1 files changed, 14 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index f6bcdf83cd..2ffa497a99 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -176,27 +176,31 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
}
override def serialize(obj: Any): Row = {
- val row = new GenericMutableRow(4)
obj match {
case SparseVector(size, indices, values) =>
+ val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
row.update(2, indices.toSeq)
row.update(3, values.toSeq)
+ row
case DenseVector(values) =>
+ val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
row.update(3, values.toSeq)
+ row
+ // TODO: There are bugs in UDT serialization because we don't have a clear separation between
+ // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
+ // TODO: deserialize may get called twice. See SPARK-7186.
+ case row: Row =>
+ row
}
- row
}
override def deserialize(datum: Any): Vector = {
datum match {
- // TODO: something wrong with UDT serialization
- case v: Vector =>
- v
case row: Row =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
@@ -211,6 +215,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val values = row.getAs[Iterable[Double]](3).toArray
new DenseVector(values)
}
+ // TODO: There are bugs in UDT serialization because we don't have a clear separation between
+ // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
+ // TODO: deserialize may get called twice. See SPARK-7186.
+ case v: Vector =>
+ v
}
}