aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-07-30 10:04:30 -0700
committerReynold Xin <rxin@databricks.com>2015-07-30 10:04:30 -0700
commitc0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 (patch)
tree582bad5631cde3bac3b5c69e1f22b3c4098de684 /mllib
parent7492a33fdd074446c30c657d771a69932a00246d (diff)
downloadspark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.gz
spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.bz2
spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.zip
[SPARK-9390][SQL] create a wrapper for array type
Author: Wenchen Fan <cloud0fan@outlook.com> Closes #7724 from cloud-fan/array-data and squashes the following commits: d0408a1 [Wenchen Fan] fix python 661e608 [Wenchen Fan] rebase f39256c [Wenchen Fan] fix hive... 6dbfa6f [Wenchen Fan] fix hive again... 8cb8842 [Wenchen Fan] remove element type parameter from getArray 43e9816 [Wenchen Fan] fix mllib e719afc [Wenchen Fan] fix hive 4346290 [Wenchen Fan] address comment d4a38da [Wenchen Fan] remove sizeInBytes and add license 7e283e2 [Wenchen Fan] create a wrapper for array type
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala16
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala15
2 files changed, 13 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index d82ba2456d..88914fa875 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
- row.update(3, sm.colPtrs.toSeq)
- row.update(4, sm.rowIndices.toSeq)
- row.update(5, sm.values.toSeq)
+ row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
+ row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
+ row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, sm.isTransposed)
case dm: DenseMatrix =>
@@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
- row.update(5, dm.values.toSeq)
+ row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
row.setBoolean(6, dm.isTransposed)
}
row
@@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
val tpe = row.getByte(0)
val numRows = row.getInt(1)
val numCols = row.getInt(2)
- val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray
+ val values = row.getArray(5).toArray.map(_.asInstanceOf[Double])
val isTransposed = row.getBoolean(6)
tpe match {
case 0 =>
- val colPtrs =
- row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray
- val rowIndices =
- row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray
+ val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int])
+ val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int])
new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed)
case 1 =>
new DenseMatrix(numRows, numCols, values, isTransposed)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 23c2c16d68..89a1818db0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
- row.update(2, indices.toSeq)
- row.update(3, values.toSeq)
+ row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
+ row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
- row.update(3, values.toSeq)
+ row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
row
}
}
@@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
tpe match {
case 0 =>
val size = row.getInt(1)
- val indices =
- row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray
- val values =
- row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
+ val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int])
+ val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new SparseVector(size, indices, values)
case 1 =>
- val values =
- row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray
+ val values = row.getArray(3).toArray().map(_.asInstanceOf[Double])
new DenseVector(values)
}
}