diff options
author | Wenchen Fan <cloud0fan@outlook.com> | 2015-07-30 10:04:30 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2015-07-30 10:04:30 -0700 |
commit | c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08 (patch) | |
tree | 582bad5631cde3bac3b5c69e1f22b3c4098de684 /mllib | |
parent | 7492a33fdd074446c30c657d771a69932a00246d (diff) | |
download | spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.gz spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.tar.bz2 spark-c0cc0eaec67208c087a30c1b1f50c00b2c1ebf08.zip |
[SPARK-9390][SQL] create a wrapper for array type
Author: Wenchen Fan <cloud0fan@outlook.com>
Closes #7724 from cloud-fan/array-data and squashes the following commits:
d0408a1 [Wenchen Fan] fix python
661e608 [Wenchen Fan] rebase
f39256c [Wenchen Fan] fix hive...
6dbfa6f [Wenchen Fan] fix hive again...
8cb8842 [Wenchen Fan] remove element type parameter from getArray
43e9816 [Wenchen Fan] fix mllib
e719afc [Wenchen Fan] fix hive
4346290 [Wenchen Fan] address comment
d4a38da [Wenchen Fan] remove sizeInBytes and add license
7e283e2 [Wenchen Fan] create a wrapper for array type
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala | 16 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala | 15 |
2 files changed, 13 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala index d82ba2456d..88914fa875 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala @@ -154,9 +154,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, sm.colPtrs.toSeq) - row.update(4, sm.rowIndices.toSeq) - row.update(5, sm.values.toSeq) + row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) + row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) + row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -165,7 +165,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, dm.values.toSeq) + row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) row.setBoolean(6, dm.isTransposed) } row @@ -179,14 +179,12 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] { val tpe = row.getByte(0) val numRows = row.getInt(1) val numCols = row.getInt(2) - val values = row.getAs[Seq[Double]](5, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(5).toArray.map(_.asInstanceOf[Double]) val isTransposed = row.getBoolean(6) tpe match { case 0 => - val colPtrs = - row.getAs[Seq[Int]](3, ArrayType(IntegerType, containsNull = false)).toArray - val rowIndices = - row.getAs[Seq[Int]](4, ArrayType(IntegerType, containsNull = false)).toArray + val colPtrs = row.getArray(3).toArray.map(_.asInstanceOf[Int]) + val rowIndices = row.getArray(4).toArray.map(_.asInstanceOf[Int]) new SparseMatrix(numRows, numCols, colPtrs, rowIndices, values, isTransposed) case 1 => new DenseMatrix(numRows, numCols, values, isTransposed) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 23c2c16d68..89a1818db0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -187,15 +187,15 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, indices.toSeq) - row.update(3, values.toSeq) + row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, values.toSeq) + row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) row } } @@ -209,14 +209,11 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] { tpe match { case 0 => val size = row.getInt(1) - val indices = - row.getAs[Seq[Int]](2, ArrayType(IntegerType, containsNull = false)).toArray - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val indices = row.getArray(2).toArray().map(_.asInstanceOf[Int]) + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new SparseVector(size, indices, values) case 1 => - val values = - row.getAs[Seq[Double]](3, ArrayType(DoubleType, containsNull = false)).toArray + val values = row.getArray(3).toArray().map(_.asInstanceOf[Double]) new DenseVector(values) } } |