aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-11-24 16:37:14 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-24 16:37:14 -0800
commitb660de7a9cbdea3df4a37fbcf60c1c33c71782b8 (patch)
tree34bec7b4e6789d63846f7519e1c9d51351033b61 /mllib
parentcb0e9b0980f38befe88bf52aa037fe33262730f7 (diff)
downloadspark-b660de7a9cbdea3df4a37fbcf60c1c33c71782b8.tar.gz
spark-b660de7a9cbdea3df4a37fbcf60c1c33c71782b8.tar.bz2
spark-b660de7a9cbdea3df4a37fbcf60c1c33c71782b8.zip
[SPARK-4562] [MLlib] speedup vector
This PR change the underline array of DenseVector to numpy.ndarray to avoid the conversion, because most of the users will using numpy.array. It also improve the serialization of DenseVector. Before this change: trial | trainingTime | testTime -------|--------|-------- 0 | 5.126 | 1.786 1 |2.698 |1.693 After the change: trial | trainingTime | testTime -------|--------|-------- 0 |4.692 |0.554 1 |2.307 |0.525 This could partially fix the performance regression during test. Author: Davies Liu <davies@databricks.com> Closes #3420 from davies/ser2 and squashes the following commits: 0e1e6f3 [Davies Liu] fix tests 426f5db [Davies Liu] impove toArray() 44707ec [Davies Liu] add name for ISO-8859-1 fa7d791 [Davies Liu] address comments 1cfb137 [Davies Liu] handle zero sparse vector 2548ee2 [Davies Liu] fix tests 9e6389d [Davies Liu] bugfix 470f702 [Davies Liu] speed up DenseMatrix f0d3c40 [Davies Liu] speedup SparseVector ef6ce70 [Davies Liu] speed up dense vector
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala73
1 files changed, 65 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f04df1c156..9f20cd5d00 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
+import java.nio.{ByteBuffer, ByteOrder}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
@@ -684,6 +685,7 @@ class PythonMLLibAPI extends Serializable {
private[spark] object SerDe extends Serializable {
val PYSPARK_PACKAGE = "pyspark.mllib"
+ val LATIN1 = "ISO-8859-1"
/**
* Base class used for pickle
@@ -735,7 +737,16 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
- saveObjects(out, pickler, vector.toArray)
+ val bytes = new Array[Byte](8 * vector.size)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ db.put(vector.values)
+
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE1)
}
def construct(args: Array[Object]): Object = {
@@ -743,7 +754,13 @@ private[spark] object SerDe extends Serializable {
if (args.length != 1) {
throw new PickleException("should be 1")
}
- new DenseVector(args(0).asInstanceOf[Array[Double]])
+ val bytes = args(0).asInstanceOf[String].getBytes(LATIN1)
+ val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Double](bytes.length / 8)
+ db.get(ans)
+ Vectors.dense(ans)
}
}
@@ -752,15 +769,30 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
- saveObjects(out, pickler, m.numRows, m.numCols, m.values)
+ val bytes = new Array[Byte](8 * m.values.size)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE3)
}
def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
- args(2).asInstanceOf[Array[Double]])
+ val bytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val n = bytes.length / 8
+ val values = new Array[Double](n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values)
}
}
@@ -769,15 +801,40 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
- saveObjects(out, pickler, v.size, v.indices, v.values)
+ val n = v.indices.size
+ val indiceBytes = new Array[Byte](4 * n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
+ val valueBytes = new Array[Byte](8 * n)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(v.size))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
+ out.write(indiceBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valueBytes.length))
+ out.write(valueBytes)
+ out.write(Opcodes.TUPLE3)
}
def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]],
- args(2).asInstanceOf[Array[Double]])
+ val size = args(0).asInstanceOf[Int]
+ val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1)
+ val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val n = indiceBytes.length / 4
+ val indices = new Array[Int](n)
+ val values = new Array[Double](n)
+ if (n > 0) {
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
+ }
+ new SparseVector(size, indices, values)
}
}