aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala73
1 files changed, 65 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index f04df1c156..9f20cd5d00 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.api.python
import java.io.OutputStream
+import java.nio.{ByteBuffer, ByteOrder}
import java.util.{ArrayList => JArrayList, List => JList, Map => JMap}
import scala.collection.JavaConverters._
@@ -684,6 +685,7 @@ class PythonMLLibAPI extends Serializable {
private[spark] object SerDe extends Serializable {
val PYSPARK_PACKAGE = "pyspark.mllib"
+ val LATIN1 = "ISO-8859-1"
/**
* Base class used for pickle
@@ -735,7 +737,16 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val vector: DenseVector = obj.asInstanceOf[DenseVector]
- saveObjects(out, pickler, vector.toArray)
+ val bytes = new Array[Byte](8 * vector.size)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ db.put(vector.values)
+
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE1)
}
def construct(args: Array[Object]): Object = {
@@ -743,7 +754,13 @@ private[spark] object SerDe extends Serializable {
if (args.length != 1) {
throw new PickleException("should be 1")
}
- new DenseVector(args(0).asInstanceOf[Array[Double]])
+ val bytes = args(0).asInstanceOf[String].getBytes(LATIN1)
+ val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
+ bb.order(ByteOrder.nativeOrder())
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Double](bytes.length / 8)
+ db.get(ans)
+ Vectors.dense(ans)
}
}
@@ -752,15 +769,30 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
- saveObjects(out, pickler, m.numRows, m.numCols, m.values)
+ val bytes = new Array[Byte](8 * m.values.size)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numRows))
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(m.numCols))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(bytes.length))
+ out.write(bytes)
+ out.write(Opcodes.TUPLE3)
}
def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
- args(2).asInstanceOf[Array[Double]])
+ val bytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val n = bytes.length / 8
+ val values = new Array[Double](n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values)
}
}
@@ -769,15 +801,40 @@ private[spark] object SerDe extends Serializable {
def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
val v: SparseVector = obj.asInstanceOf[SparseVector]
- saveObjects(out, pickler, v.size, v.indices, v.values)
+ val n = v.indices.size
+ val indiceBytes = new Array[Byte](4 * n)
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
+ val valueBytes = new Array[Byte](8 * n)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)
+
+ out.write(Opcodes.BININT)
+ out.write(PickleUtils.integer_to_bytes(v.size))
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
+ out.write(indiceBytes)
+ out.write(Opcodes.BINSTRING)
+ out.write(PickleUtils.integer_to_bytes(valueBytes.length))
+ out.write(valueBytes)
+ out.write(Opcodes.TUPLE3)
}
def construct(args: Array[Object]): Object = {
if (args.length != 3) {
throw new PickleException("should be 3")
}
- new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]],
- args(2).asInstanceOf[Array[Double]])
+ val size = args(0).asInstanceOf[Int]
+ val indiceBytes = args(1).asInstanceOf[String].getBytes(LATIN1)
+ val valueBytes = args(2).asInstanceOf[String].getBytes(LATIN1)
+ val n = indiceBytes.length / 4
+ val indices = new Array[Int](n)
+ val values = new Array[Double](n)
+ if (n > 0) {
+ val order = ByteOrder.nativeOrder()
+ ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
+ ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
+ }
+ new SparseVector(size, indices, values)
}
}