aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2016-04-29 23:51:01 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 23:51:01 -0700
commit3d09ceeef9212d4f3a8cd286ce369ace47242358 (patch)
tree30d4592df6c79f0c00862bb17f5a54d8a22420c6
parent73c20bf32524c2232febc8c4b12d5fa228347163 (diff)
downloadspark-3d09ceeef9212d4f3a8cd286ce369ace47242358.tar.gz
spark-3d09ceeef9212d4f3a8cd286ce369ace47242358.tar.bz2
spark-3d09ceeef9212d4f3a8cd286ce369ace47242358.zip
[SPARK-14850][.2][ML] use UnsafeArrayData.fromPrimitiveArray in ml.VectorUDT/MatrixUDT
## What changes were proposed in this pull request? This PR uses `UnsafeArrayData.fromPrimitiveArray` to implement `ml.VectorUDT/MatrixUDT` to avoid boxing/unboxing. ## How was this patch tested? Exiting unit tests. cc: cloud-fan Author: Xiangrui Meng <meng@databricks.com> Closes #12805 from mengxr/SPARK-14850.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala9
2 files changed, 9 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
index 53f4d55971..521a216c67 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/MatrixUDT.scala
@@ -18,8 +18,7 @@
package org.apache.spark.ml.linalg
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.types._
/**
@@ -53,9 +52,9 @@ private[ml] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
- row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
- row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
- row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
+ row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
+ row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
+ row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
row.setBoolean(6, sm.isTransposed)
case dm: DenseMatrix =>
@@ -64,7 +63,7 @@ private[ml] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
- row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
+ row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
row.setBoolean(6, dm.isTransposed)
}
row
diff --git a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala
index fe93a12d06..c29f7f86e9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/linalg/VectorUDT.scala
@@ -18,8 +18,7 @@
package org.apache.spark.ml.linalg
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.types._
/**
@@ -46,15 +45,15 @@ private[ml] class VectorUDT extends UserDefinedType[Vector] {
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
- row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
- row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
+ row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
+ row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
- row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
+ row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
row
}
}