aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala156
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala8
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala18
3 files changed, 132 insertions, 50 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index a6c049e517..7c65b0d475 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
import org.apache.spark.rdd.RDD
@@ -31,56 +31,112 @@ import org.apache.spark.rdd.RDD
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
+ *
+ * See python/pyspark/mllib/_common.py for the mutually agreed upon data format.
*/
@DeveloperApi
class PythonMLLibAPI extends Serializable {
- private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = {
- val packetLength = bytes.length
- if (packetLength < 16) {
- throw new IllegalArgumentException("Byte array too short.")
- }
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.getLong()
- if (magic != 1) {
+ private val DENSE_VECTOR_MAGIC: Byte = 1
+ private val SPARSE_VECTOR_MAGIC: Byte = 2
+ private val DENSE_MATRIX_MAGIC: Byte = 3
+ private val LABELED_POINT_MAGIC: Byte = 4
+
+ private def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
+ require(bytes.length - offset >= 5, "Byte array too short")
+ val magic = bytes(offset)
+ if (magic == DENSE_VECTOR_MAGIC) {
+ deserializeDenseVector(bytes, offset)
+ } else if (magic == SPARSE_VECTOR_MAGIC) {
+ deserializeSparseVector(bytes, offset)
+ } else {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
- val length = bb.getLong()
- if (packetLength != 16 + 8 * length) {
- throw new IllegalArgumentException("Length " + length + " is wrong.")
- }
+ }
+
+ private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
+ val packetLength = bytes.length - offset
+ require(packetLength >= 5, "Byte array too short")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.get()
+ require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
+ val length = bb.getInt()
+ require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength)
val db = bb.asDoubleBuffer()
val ans = new Array[Double](length.toInt)
db.get(ans)
- ans
+ Vectors.dense(ans)
}
- private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = {
+ private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
+ val packetLength = bytes.length - offset
+ require(packetLength >= 9, "Byte array too short")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.get()
+ require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
+ val size = bb.getInt()
+ val nonZeros = bb.getInt()
+ require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength)
+ val ib = bb.asIntBuffer()
+ val indices = new Array[Int](nonZeros)
+ ib.get(indices)
+ bb.position(bb.position() + 4 * nonZeros)
+ val db = bb.asDoubleBuffer()
+ val values = new Array[Double](nonZeros)
+ db.get(values)
+ Vectors.sparse(size, indices, values)
+ }
+
+ private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
val len = doubles.length
- val bytes = new Array[Byte](16 + 8 * len)
+ val bytes = new Array[Byte](5 + 8 * len)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
- bb.putLong(1)
- bb.putLong(len)
+ bb.put(DENSE_VECTOR_MAGIC)
+ bb.putInt(len)
val db = bb.asDoubleBuffer()
db.put(doubles)
bytes
}
+ private def serializeSparseVector(vector: SparseVector): Array[Byte] = {
+ val nonZeros = vector.indices.length
+ val bytes = new Array[Byte](9 + 12 * nonZeros)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.put(SPARSE_VECTOR_MAGIC)
+ bb.putInt(vector.size)
+ bb.putInt(nonZeros)
+ val ib = bb.asIntBuffer()
+ ib.put(vector.indices)
+ bb.position(bb.position() + 4 * nonZeros)
+ val db = bb.asDoubleBuffer()
+ db.put(vector.values)
+ bytes
+ }
+
+ private def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
+ case s: SparseVector =>
+ serializeSparseVector(s)
+ case _ =>
+ serializeDenseVector(vector.toArray)
+ }
+
private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
val packetLength = bytes.length
- if (packetLength < 24) {
+ if (packetLength < 9) {
throw new IllegalArgumentException("Byte array too short.")
}
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
- val magic = bb.getLong()
- if (magic != 2) {
+ val magic = bb.get()
+ if (magic != DENSE_MATRIX_MAGIC) {
throw new IllegalArgumentException("Magic " + magic + " is wrong.")
}
- val rows = bb.getLong()
- val cols = bb.getLong()
- if (packetLength != 24 + 8 * rows * cols) {
+ val rows = bb.getInt()
+ val cols = bb.getInt()
+ if (packetLength != 9 + 8 * rows * cols) {
throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
}
val db = bb.asDoubleBuffer()
@@ -98,12 +154,12 @@ class PythonMLLibAPI extends Serializable {
if (rows > 0) {
cols = doubles(0).length
}
- val bytes = new Array[Byte](24 + 8 * rows * cols)
+ val bytes = new Array[Byte](9 + 8 * rows * cols)
val bb = ByteBuffer.wrap(bytes)
bb.order(ByteOrder.nativeOrder())
- bb.putLong(2)
- bb.putLong(rows)
- bb.putLong(cols)
+ bb.put(DENSE_MATRIX_MAGIC)
+ bb.putInt(rows)
+ bb.putInt(cols)
val db = bb.asDoubleBuffer()
for (i <- 0 until rows) {
db.put(doubles(i))
@@ -111,18 +167,27 @@ class PythonMLLibAPI extends Serializable {
bytes
}
+ private def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
+ require(bytes.length >= 9, "Byte array too short")
+ val magic = bytes(0)
+ if (magic != LABELED_POINT_MAGIC) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
+ val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
+ labelBytes.order(ByteOrder.nativeOrder())
+ val label = labelBytes.asDoubleBuffer().get(0)
+ LabeledPoint(label, deserializeDoubleVector(bytes, 9))
+ }
+
private def trainRegressionModel(
- trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel,
+ trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(xBytes => {
- val x = deserializeDoubleVector(xBytes)
- LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
- })
+ val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val initialWeights = deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(serializeDoubleVector(model.weights.toArray))
+ ret.add(serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
ret
}
@@ -143,7 +208,7 @@ class PythonMLLibAPI extends Serializable {
numIterations,
stepSize,
miniBatchFraction,
- Vectors.dense(initialWeights)),
+ initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -166,7 +231,7 @@ class PythonMLLibAPI extends Serializable {
stepSize,
regParam,
miniBatchFraction,
- Vectors.dense(initialWeights)),
+ initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -189,7 +254,7 @@ class PythonMLLibAPI extends Serializable {
stepSize,
regParam,
miniBatchFraction,
- Vectors.dense(initialWeights)),
+ initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -212,7 +277,7 @@ class PythonMLLibAPI extends Serializable {
stepSize,
regParam,
miniBatchFraction,
- Vectors.dense(initialWeights)),
+ initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -233,7 +298,7 @@ class PythonMLLibAPI extends Serializable {
numIterations,
stepSize,
miniBatchFraction,
- Vectors.dense(initialWeights)),
+ initialWeights),
dataBytesJRDD,
initialWeightsBA)
}
@@ -244,14 +309,11 @@ class PythonMLLibAPI extends Serializable {
def trainNaiveBayes(
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(xBytes => {
- val x = deserializeDoubleVector(xBytes)
- LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length)))
- })
+ val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(serializeDoubleVector(model.labels))
- ret.add(serializeDoubleVector(model.pi))
+ ret.add(serializeDoubleVector(Vectors.dense(model.labels)))
+ ret.add(serializeDoubleVector(Vectors.dense(model.pi)))
ret.add(serializeDoubleMatrix(model.theta))
ret
}
@@ -265,7 +327,7 @@ class PythonMLLibAPI extends Serializable {
maxIterations: Int,
runs: Int,
initializationMode: String): java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes)))
+ val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 99a849f1c6..7cdf6bd56a 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -130,9 +130,11 @@ object Vectors {
private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = {
breezeVector match {
case v: BDV[Double] =>
- require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.")
- require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.")
- new DenseVector(v.data)
+ if (v.offset == 0 && v.stride == 1) {
+ new DenseVector(v.data)
+ } else {
+ new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one
+ }
case v: BSV[Double] =>
new SparseVector(v.length, v.index, v.data)
case v: BV[_] =>
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
index 8a200310e0..cfe8a27fcb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala
@@ -82,4 +82,22 @@ class VectorsSuite extends FunSuite {
assert(v.## != another.##)
}
}
+
+ test("indexing dense vectors") {
+ val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0)
+ assert(vec(0) === 1.0)
+ assert(vec(3) === 4.0)
+ }
+
+ test("indexing sparse vectors") {
+ val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
+ assert(vec(0) === 1.0)
+ assert(vec(1) === 0.0)
+ assert(vec(2) === 2.0)
+ assert(vec(3) === 0.0)
+ assert(vec(6) === 4.0)
+ val vec2 = Vectors.sparse(8, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0))
+ assert(vec2(6) === 4.0)
+ assert(vec2(7) === 0.0)
+ }
}