diff options
Diffstat (limited to 'mllib')
3 files changed, 132 insertions, 50 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala index a6c049e517..7c65b0d475 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala @@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.classification._ import org.apache.spark.mllib.clustering._ -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{SparseVector, Vector, Vectors} import org.apache.spark.mllib.recommendation._ import org.apache.spark.mllib.regression._ import org.apache.spark.rdd.RDD @@ -31,56 +31,112 @@ import org.apache.spark.rdd.RDD /** * :: DeveloperApi :: * The Java stubs necessary for the Python mllib bindings. + * + * See python/pyspark/mllib/_common.py for the mutually agreed upon data format. */ @DeveloperApi class PythonMLLibAPI extends Serializable { - private def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { - val packetLength = bytes.length - if (packetLength < 16) { - throw new IllegalArgumentException("Byte array too short.") - } - val bb = ByteBuffer.wrap(bytes) - bb.order(ByteOrder.nativeOrder()) - val magic = bb.getLong() - if (magic != 1) { + private val DENSE_VECTOR_MAGIC: Byte = 1 + private val SPARSE_VECTOR_MAGIC: Byte = 2 + private val DENSE_MATRIX_MAGIC: Byte = 3 + private val LABELED_POINT_MAGIC: Byte = 4 + + private def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = { + require(bytes.length - offset >= 5, "Byte array too short") + val magic = bytes(offset) + if (magic == DENSE_VECTOR_MAGIC) { + deserializeDenseVector(bytes, offset) + } else if (magic == SPARSE_VECTOR_MAGIC) { + deserializeSparseVector(bytes, offset) + } else { throw new IllegalArgumentException("Magic " + magic + " is wrong.") } - val length = bb.getLong() - if (packetLength != 16 + 8 * length) { - throw new IllegalArgumentException("Length " + length + " is wrong.") - } + } + + private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 5, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val length = bb.getInt() + require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength) val db = bb.asDoubleBuffer() val ans = new Array[Double](length.toInt) db.get(ans) - ans + Vectors.dense(ans) } - private def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = { + val packetLength = bytes.length - offset + require(packetLength >= 9, "Byte array too short") + val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset) + bb.order(ByteOrder.nativeOrder()) + val magic = bb.get() + require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic) + val size = bb.getInt() + val nonZeros = bb.getInt() + require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength) + val ib = bb.asIntBuffer() + val indices = new Array[Int](nonZeros) + ib.get(indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + val values = new Array[Double](nonZeros) + db.get(values) + Vectors.sparse(size, indices, values) + } + + private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = { val len = doubles.length - val bytes = new Array[Byte](16 + 8 * len) + val bytes = new Array[Byte](5 + 8 * len) val bb = ByteBuffer.wrap(bytes) bb.order(ByteOrder.nativeOrder()) - bb.putLong(1) - bb.putLong(len) + bb.put(DENSE_VECTOR_MAGIC) + bb.putInt(len) val db = bb.asDoubleBuffer() db.put(doubles) bytes } + private def serializeSparseVector(vector: SparseVector): Array[Byte] = { + val nonZeros = vector.indices.length + val bytes = new Array[Byte](9 + 12 * nonZeros) + val bb = ByteBuffer.wrap(bytes) + bb.order(ByteOrder.nativeOrder()) + bb.put(SPARSE_VECTOR_MAGIC) + bb.putInt(vector.size) + bb.putInt(nonZeros) + val ib = bb.asIntBuffer() + ib.put(vector.indices) + bb.position(bb.position() + 4 * nonZeros) + val db = bb.asDoubleBuffer() + db.put(vector.values) + bytes + } + + private def serializeDoubleVector(vector: Vector): Array[Byte] = vector match { + case s: SparseVector => + serializeSparseVector(s) + case _ => + serializeDenseVector(vector.toArray) + } + private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = { val packetLength = bytes.length - if (packetLength < 24) { + if (packetLength < 9) { throw new IllegalArgumentException("Byte array too short.") } val bb = ByteBuffer.wrap(bytes) bb.order(ByteOrder.nativeOrder()) - val magic = bb.getLong() - if (magic != 2) { + val magic = bb.get() + if (magic != DENSE_MATRIX_MAGIC) { throw new IllegalArgumentException("Magic " + magic + " is wrong.") } - val rows = bb.getLong() - val cols = bb.getLong() - if (packetLength != 24 + 8 * rows * cols) { + val rows = bb.getInt() + val cols = bb.getInt() + if (packetLength != 9 + 8 * rows * cols) { throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.") } val db = bb.asDoubleBuffer() @@ -98,12 +154,12 @@ class PythonMLLibAPI extends Serializable { if (rows > 0) { cols = doubles(0).length } - val bytes = new Array[Byte](24 + 8 * rows * cols) + val bytes = new Array[Byte](9 + 8 * rows * cols) val bb = ByteBuffer.wrap(bytes) bb.order(ByteOrder.nativeOrder()) - bb.putLong(2) - bb.putLong(rows) - bb.putLong(cols) + bb.put(DENSE_MATRIX_MAGIC) + bb.putInt(rows) + bb.putInt(cols) val db = bb.asDoubleBuffer() for (i <- 0 until rows) { db.put(doubles(i)) @@ -111,18 +167,27 @@ class PythonMLLibAPI extends Serializable { bytes } + private def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = { + require(bytes.length >= 9, "Byte array too short") + val magic = bytes(0) + if (magic != LABELED_POINT_MAGIC) { + throw new IllegalArgumentException("Magic " + magic + " is wrong.") + } + val labelBytes = ByteBuffer.wrap(bytes, 1, 8) + labelBytes.order(ByteOrder.nativeOrder()) + val label = labelBytes.asDoubleBuffer().get(0) + LabeledPoint(label, deserializeDoubleVector(bytes, 9)) + } + private def trainRegressionModel( - trainFunc: (RDD[LabeledPoint], Array[Double]) => GeneralizedLinearModel, + trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel, dataBytesJRDD: JavaRDD[Array[Byte]], initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(xBytes => { - val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) - }) + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) val initialWeights = deserializeDoubleVector(initialWeightsBA) val model = trainFunc(data, initialWeights) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.weights.toArray)) + ret.add(serializeDoubleVector(model.weights)) ret.add(model.intercept: java.lang.Double) ret } @@ -143,7 +208,7 @@ class PythonMLLibAPI extends Serializable { numIterations, stepSize, miniBatchFraction, - Vectors.dense(initialWeights)), + initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -166,7 +231,7 @@ class PythonMLLibAPI extends Serializable { stepSize, regParam, miniBatchFraction, - Vectors.dense(initialWeights)), + initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -189,7 +254,7 @@ class PythonMLLibAPI extends Serializable { stepSize, regParam, miniBatchFraction, - Vectors.dense(initialWeights)), + initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -212,7 +277,7 @@ class PythonMLLibAPI extends Serializable { stepSize, regParam, miniBatchFraction, - Vectors.dense(initialWeights)), + initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -233,7 +298,7 @@ class PythonMLLibAPI extends Serializable { numIterations, stepSize, miniBatchFraction, - Vectors.dense(initialWeights)), + initialWeights), dataBytesJRDD, initialWeightsBA) } @@ -244,14 +309,11 @@ class PythonMLLibAPI extends Serializable { def trainNaiveBayes( dataBytesJRDD: JavaRDD[Array[Byte]], lambda: Double): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(xBytes => { - val x = deserializeDoubleVector(xBytes) - LabeledPoint(x(0), Vectors.dense(x.slice(1, x.length))) - }) + val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint) val model = NaiveBayes.train(data, lambda) val ret = new java.util.LinkedList[java.lang.Object]() - ret.add(serializeDoubleVector(model.labels)) - ret.add(serializeDoubleVector(model.pi)) + ret.add(serializeDoubleVector(Vectors.dense(model.labels))) + ret.add(serializeDoubleVector(Vectors.dense(model.pi))) ret.add(serializeDoubleMatrix(model.theta)) ret } @@ -265,7 +327,7 @@ class PythonMLLibAPI extends Serializable { maxIterations: Int, runs: Int, initializationMode: String): java.util.List[java.lang.Object] = { - val data = dataBytesJRDD.rdd.map(xBytes => Vectors.dense(deserializeDoubleVector(xBytes))) + val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes)) val model = KMeans.train(data, k, maxIterations, runs, initializationMode) val ret = new java.util.LinkedList[java.lang.Object]() ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray))) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala index 99a849f1c6..7cdf6bd56a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala @@ -130,9 +130,11 @@ object Vectors { private[mllib] def fromBreeze(breezeVector: BV[Double]): Vector = { breezeVector match { case v: BDV[Double] => - require(v.offset == 0, s"Do not support non-zero offset ${v.offset}.") - require(v.stride == 1, s"Do not support stride other than 1, but got ${v.stride}.") - new DenseVector(v.data) + if (v.offset == 0 && v.stride == 1) { + new DenseVector(v.data) + } else { + new DenseVector(v.toArray) // Can't use underlying array directly, so make a new one + } case v: BSV[Double] => new SparseVector(v.length, v.index, v.data) case v: BV[_] => diff --git a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala index 8a200310e0..cfe8a27fcb 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/linalg/VectorsSuite.scala @@ -82,4 +82,22 @@ class VectorsSuite extends FunSuite { assert(v.## != another.##) } } + + test("indexing dense vectors") { + val vec = Vectors.dense(1.0, 2.0, 3.0, 4.0) + assert(vec(0) === 1.0) + assert(vec(3) === 4.0) + } + + test("indexing sparse vectors") { + val vec = Vectors.sparse(7, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0)) + assert(vec(0) === 1.0) + assert(vec(1) === 0.0) + assert(vec(2) === 2.0) + assert(vec(3) === 0.0) + assert(vec(6) === 4.0) + val vec2 = Vectors.sparse(8, Array(0, 2, 4, 6), Array(1.0, 2.0, 3.0, 4.0)) + assert(vec2(6) === 4.0) + assert(vec2(7) === 0.0) + } } |