aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-04-29 23:04:51 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 23:04:51 -0700
commit43b149fb885a27f9467aab28e5195f6f03aadcf0 (patch)
treec8620d5d0f42e9f3238020e3bce8f8ea527182eb /mllib/src/main/scala
parent4bac703eb9dcc286d6b89630cf433f95b63a4a1f (diff)
downloadspark-43b149fb885a27f9467aab28e5195f6f03aadcf0.tar.gz
spark-43b149fb885a27f9467aab28e5195f6f03aadcf0.tar.bz2
spark-43b149fb885a27f9467aab28e5195f6f03aadcf0.zip
[SPARK-14850][ML] convert primitive array from/to unsafe array directly in VectorUDT/MatrixUDT
## What changes were proposed in this pull request? This PR adds `fromPrimitiveArray` and `toPrimitiveArray` in `UnsafeArrayData`, so that we can do the conversion much faster in VectorUDT/MatrixUDT. ## How was this patch tested? existing tests and new test suite `UnsafeArraySuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #12640 from cloud-fan/ml.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala9
2 files changed, 9 insertions, 11 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 90fa4fbbc6..076cca6016 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -27,8 +27,7 @@ import com.github.fommil.netlib.BLAS.{getInstance => blas}
import org.apache.spark.annotation.Since
import org.apache.spark.ml.{linalg => newlinalg}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.types._
/**
@@ -194,9 +193,9 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setByte(0, 0)
row.setInt(1, sm.numRows)
row.setInt(2, sm.numCols)
- row.update(3, new GenericArrayData(sm.colPtrs.map(_.asInstanceOf[Any])))
- row.update(4, new GenericArrayData(sm.rowIndices.map(_.asInstanceOf[Any])))
- row.update(5, new GenericArrayData(sm.values.map(_.asInstanceOf[Any])))
+ row.update(3, UnsafeArrayData.fromPrimitiveArray(sm.colPtrs))
+ row.update(4, UnsafeArrayData.fromPrimitiveArray(sm.rowIndices))
+ row.update(5, UnsafeArrayData.fromPrimitiveArray(sm.values))
row.setBoolean(6, sm.isTransposed)
case dm: DenseMatrix =>
@@ -205,7 +204,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
row.setInt(2, dm.numCols)
row.setNullAt(3)
row.setNullAt(4)
- row.update(5, new GenericArrayData(dm.values.map(_.asInstanceOf[Any])))
+ row.update(5, UnsafeArrayData.fromPrimitiveArray(dm.values))
row.setBoolean(6, dm.isTransposed)
}
row
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 6e3da6b701..132e54a8c3 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -33,8 +33,7 @@ import org.apache.spark.annotation.{AlphaComponent, Since}
import org.apache.spark.ml.{linalg => newlinalg}
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, UnsafeArrayData}
import org.apache.spark.sql.types._
/**
@@ -216,15 +215,15 @@ class VectorUDT extends UserDefinedType[Vector] {
val row = new GenericMutableRow(4)
row.setByte(0, 0)
row.setInt(1, size)
- row.update(2, new GenericArrayData(indices.map(_.asInstanceOf[Any])))
- row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
+ row.update(2, UnsafeArrayData.fromPrimitiveArray(indices))
+ row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
row
case DenseVector(values) =>
val row = new GenericMutableRow(4)
row.setByte(0, 1)
row.setNullAt(1)
row.setNullAt(2)
- row.update(3, new GenericArrayData(values.map(_.asInstanceOf[Any])))
+ row.update(3, UnsafeArrayData.fromPrimitiveArray(values))
row
}
}