diff options
author | Xiangrui Meng <meng@databricks.com> | 2016-04-29 23:51:01 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-29 23:51:01 -0700 |
commit | 3d09ceeef9212d4f3a8cd286ce369ace47242358 (patch) | |
tree | 30d4592df6c79f0c00862bb17f5a54d8a22420c6 /mllib/src | |
parent | 73c20bf32524c2232febc8c4b12d5fa228347163 (diff) | |
download | spark-3d09ceeef9212d4f3a8cd286ce369ace47242358.tar.gz spark-3d09ceeef9212d4f3a8cd286ce369ace47242358.tar.bz2 spark-3d09ceeef9212d4f3a8cd286ce369ace47242358.zip |
[SPARK-14850][.2][ML] use UnsafeArrayData.fromPrimitiveArray in ml.VectorUDT/MatrixUDT
## What changes were proposed in this pull request?
This PR uses `UnsafeArrayData.fromPrimitiveArray` to implement `ml.VectorUDT/MatrixUDT` to avoid boxing/unboxing.
## How was this patch tested?
Exiting unit tests.
cc: cloud-fan
Author: Xiangrui Meng <meng@databricks.com>
Closes #12805 from mengxr/SPARK-14850.
Diffstat (limited to 'mllib/src')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala | 11 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala | 9 |
2 files changed, 9 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala index 53f4d55971..521a216c67 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala @@ -18,8 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -53,9 +52,9 @@ private[ml] class MatrixUDT extends UserDefinedType[Matrix] { row.setByte(0, 0) row.setInt(1, sm.numRows) row.setInt(2, sm.numCols) - row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any]))) - row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any]))) - row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any]))) + row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs)) + row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices)) + row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values)) row.setBoolean(6, sm.isTransposed) case dm: DenseMatrix => @@ -64,7 +63,7 @@ private[ml] class MatrixUDT extends UserDefinedType[Matrix] { row.setInt(2, dm.numCols) row.setNullAt(3) row.setNullAt(4) - row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any]))) + row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values)) row.setBoolean(6, dm.isTransposed) } row diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala index fe93a12d06..c29f7f86e9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala @@ -18,8 +18,7 @@ package org.apache.spark.ml.linalg import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.GenericMutableRow -import org.apache.spark.sql.catalyst.util.GenericArrayData +import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData} import org.apache.spark.sql.types._ /** @@ -46,15 +45,15 @@ private[ml] class VectorUDT extends UserDefinedType[Vector] { val row = new GenericMutableRow(4) row.setByte(0, 0) row.setInt(1, size) - row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any]))) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(2, UnsafeArrayData.fromPrimitiveArray(indices)) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row case DenseVector(values) => val row = new GenericMutableRow(4) row.setByte(0, 1) row.setNullAt(1) row.setNullAt(2) - row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any]))) + row.update(3, UnsafeArrayData.fromPrimitiveArray(values)) row } } |