aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDoris Xin <doris.s.xin@gmail.com>2014-08-12 23:47:42 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-12 23:47:42 -0700
commitfe4735958e62b1b32a01960503876000f3d2e520 (patch)
tree4d17db757f96e2d70017cb2990b2b020d5cdc4b1 /mllib
parent2bd812639c3d8c62a725fb7577365ef0816f2898 (diff)
downloadspark-fe4735958e62b1b32a01960503876000f3d2e520.tar.gz
spark-fe4735958e62b1b32a01960503876000f3d2e520.tar.bz2
spark-fe4735958e62b1b32a01960503876000f3d2e520.zip
[SPARK-2993] [MLLib] colStats (wrapper around MultivariateStatisticalSummary) in Statistics
For both Scala and Python. The ser/de util functions were moved out of `PythonMLLibAPI` and into their own object to avoid creating the `PythonMLLibAPI` object inside of `MultivariateStatisticalSummarySerialized`, which is then referenced inside of a method in `PythonMLLibAPI`. `MultivariateStatisticalSummarySerialized` was created to serialize the `Vector` fields in `MultivariateStatisticalSummary`. Author: Doris Xin <doris.s.xin@gmail.com> Closes #1911 from dorx/colStats and squashes the following commits: 77b9924 [Doris Xin] developerAPI tag de9cbbe [Doris Xin] reviewer comments and moved more ser/de 459faba [Doris Xin] colStats in Statistics for both Scala and Python
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala532
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala13
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala17
4 files changed, 309 insertions, 260 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index ba7ccd8ce4..18dc087856 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -34,7 +34,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
import org.apache.spark.mllib.tree.DecisionTree
import org.apache.spark.mllib.tree.impurity._
import org.apache.spark.mllib.tree.model.DecisionTreeModel
-import org.apache.spark.mllib.stat.Statistics
+import org.apache.spark.mllib.stat.{MultivariateStatisticalSummary, Statistics}
import org.apache.spark.mllib.stat.correlation.CorrelationNames
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
@@ -48,182 +48,7 @@ import org.apache.spark.util.Utils
*/
@DeveloperApi
class PythonMLLibAPI extends Serializable {
- private val DENSE_VECTOR_MAGIC: Byte = 1
- private val SPARSE_VECTOR_MAGIC: Byte = 2
- private val DENSE_MATRIX_MAGIC: Byte = 3
- private val LABELED_POINT_MAGIC: Byte = 4
-
- private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
- require(bytes.length - offset >= 5, "Byte array too short")
- val magic = bytes(offset)
- if (magic == DENSE_VECTOR_MAGIC) {
- deserializeDenseVector(bytes, offset)
- } else if (magic == SPARSE_VECTOR_MAGIC) {
- deserializeSparseVector(bytes, offset)
- } else {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
- }
- }
-
- private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
- require(bytes.length - offset == 8, "Wrong size byte array for Double")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- bb.getDouble
- }
- private def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 5, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val length = bb.getInt()
- require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength)
- val db = bb.asDoubleBuffer()
- val ans = new Array[Double](length.toInt)
- db.get(ans)
- Vectors.dense(ans)
- }
-
- private def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 9, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val size = bb.getInt()
- val nonZeros = bb.getInt()
- require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength)
- val ib = bb.asIntBuffer()
- val indices = new Array[Int](nonZeros)
- ib.get(indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- val values = new Array[Double](nonZeros)
- db.get(values)
- Vectors.sparse(size, indices, values)
- }
-
- /**
- * Returns an 8-byte array for the input Double.
- *
- * Note: we currently do not use a magic byte for double for storage efficiency.
- * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
- * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
- * serialization scheme changes.
- */
- private[python] def serializeDouble(double: Double): Array[Byte] = {
- val bytes = new Array[Byte](8)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putDouble(double)
- bytes
- }
-
- private def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
- val len = doubles.length
- val bytes = new Array[Byte](5 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_VECTOR_MAGIC)
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(doubles)
- bytes
- }
-
- private def serializeSparseVector(vector: SparseVector): Array[Byte] = {
- val nonZeros = vector.indices.length
- val bytes = new Array[Byte](9 + 12 * nonZeros)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(SPARSE_VECTOR_MAGIC)
- bb.putInt(vector.size)
- bb.putInt(nonZeros)
- val ib = bb.asIntBuffer()
- ib.put(vector.indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- db.put(vector.values)
- bytes
- }
-
- private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
- case s: SparseVector =>
- serializeSparseVector(s)
- case _ =>
- serializeDenseVector(vector.toArray)
- }
-
- private def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
- val packetLength = bytes.length
- if (packetLength < 9) {
- throw new IllegalArgumentException("Byte array too short.")
- }
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- if (magic != DENSE_MATRIX_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
- }
- val rows = bb.getInt()
- val cols = bb.getInt()
- if (packetLength != 9 + 8 * rows * cols) {
- throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
- }
- val db = bb.asDoubleBuffer()
- val ans = new Array[Array[Double]](rows.toInt)
- for (i <- 0 until rows.toInt) {
- ans(i) = new Array[Double](cols.toInt)
- db.get(ans(i))
- }
- ans
- }
-
- private def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
- val rows = doubles.length
- var cols = 0
- if (rows > 0) {
- cols = doubles(0).length
- }
- val bytes = new Array[Byte](9 + 8 * rows * cols)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_MATRIX_MAGIC)
- bb.putInt(rows)
- bb.putInt(cols)
- val db = bb.asDoubleBuffer()
- for (i <- 0 until rows) {
- db.put(doubles(i))
- }
- bytes
- }
-
- private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
- val fb = serializeDoubleVector(p.features)
- val bytes = new Array[Byte](1 + 8 + fb.length)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(LABELED_POINT_MAGIC)
- bb.putDouble(p.label)
- bb.put(fb)
- bytes
- }
-
- private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
- require(bytes.length >= 9, "Byte array too short")
- val magic = bytes(0)
- if (magic != LABELED_POINT_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
- }
- val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
- labelBytes.order(ByteOrder.nativeOrder())
- val label = labelBytes.asDoubleBuffer().get(0)
- LabeledPoint(label, deserializeDoubleVector(bytes, 9))
- }
/**
* Loads and serializes labeled points saved with `RDD#saveAsTextFile`.
@@ -236,17 +61,17 @@ class PythonMLLibAPI extends Serializable {
jsc: JavaSparkContext,
path: String,
minPartitions: Int): JavaRDD[Array[Byte]] =
- MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(serializeLabeledPoint)
+ MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint)
private def trainRegressionModel(
trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
dataBytesJRDD: JavaRDD[Array[Byte]],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
- val initialWeights = deserializeDoubleVector(initialWeightsBA)
+ val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
+ val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA)
val model = trainFunc(data, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(serializeDoubleVector(model.weights))
+ ret.add(SerDe.serializeDoubleVector(model.weights))
ret.add(model.intercept: java.lang.Double)
ret
}
@@ -405,12 +230,12 @@ class PythonMLLibAPI extends Serializable {
def trainNaiveBayes(
dataBytesJRDD: JavaRDD[Array[Byte]],
lambda: Double): java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
+ val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
val model = NaiveBayes.train(data, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(serializeDoubleVector(Vectors.dense(model.labels)))
- ret.add(serializeDoubleVector(Vectors.dense(model.pi)))
- ret.add(serializeDoubleMatrix(model.theta))
+ ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels)))
+ ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi)))
+ ret.add(SerDe.serializeDoubleMatrix(model.theta))
ret
}
@@ -423,52 +248,13 @@ class PythonMLLibAPI extends Serializable {
maxIterations: Int,
runs: Int,
initializationMode: String): java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(bytes => deserializeDoubleVector(bytes))
+ val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes))
val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
+ ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
ret
}
- /** Unpack a Rating object from an array of bytes */
- private def unpackRating(ratingBytes: Array[Byte]): Rating = {
- val bb = ByteBuffer.wrap(ratingBytes)
- bb.order(ByteOrder.nativeOrder())
- val user = bb.getInt()
- val product = bb.getInt()
- val rating = bb.getDouble()
- new Rating(user, product, rating)
- }
-
- /** Unpack a tuple of Ints from an array of bytes */
- private[spark] def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
- val bb = ByteBuffer.wrap(tupleBytes)
- bb.order(ByteOrder.nativeOrder())
- val v1 = bb.getInt()
- val v2 = bb.getInt()
- (v1, v2)
- }
-
- /**
- * Serialize a Rating object into an array of bytes.
- * It can be deserialized using RatingDeserializer().
- *
- * @param rate the Rating object to serialize
- * @return
- */
- private[spark] def serializeRating(rate: Rating): Array[Byte] = {
- val len = 3
- val bytes = new Array[Byte](4 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(rate.user.toDouble)
- db.put(rate.product.toDouble)
- db.put(rate.rating)
- bytes
- }
-
/**
* Java stub for Python mllib ALS.train(). This stub returns a handle
* to the Java object instead of the content of the Java object. Extra care
@@ -481,7 +267,7 @@ class PythonMLLibAPI extends Serializable {
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
- val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+ val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating)
ALS.train(ratings, rank, iterations, lambda, blocks)
}
@@ -498,7 +284,7 @@ class PythonMLLibAPI extends Serializable {
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
- val ratings = ratingsBytesJRDD.rdd.map(unpackRating)
+ val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating)
ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
}
@@ -519,7 +305,7 @@ class PythonMLLibAPI extends Serializable {
maxDepth: Int,
maxBins: Int): DecisionTreeModel = {
- val data = dataBytesJRDD.rdd.map(deserializeLabeledPoint)
+ val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
val algo = Algo.fromString(algoStr)
val impurity = Impurities.fromString(impurityStr)
@@ -545,7 +331,7 @@ class PythonMLLibAPI extends Serializable {
def predictDecisionTreeModel(
model: DecisionTreeModel,
featuresBytes: Array[Byte]): Double = {
- val features: Vector = deserializeDoubleVector(featuresBytes)
+ val features: Vector = SerDe.deserializeDoubleVector(featuresBytes)
model.predict(features)
}
@@ -559,8 +345,17 @@ class PythonMLLibAPI extends Serializable {
def predictDecisionTreeModel(
model: DecisionTreeModel,
dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
- val data = dataJRDD.rdd.map(xBytes => deserializeDoubleVector(xBytes))
- model.predict(data).map(serializeDouble)
+ val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes))
+ model.predict(data).map(SerDe.serializeDouble)
+ }
+
+ /**
+ * Java stub for mllib Statistics.colStats(X: RDD[Vector]).
+ * TODO figure out return type.
+ */
+ def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = {
+ val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_)))
+ new MultivariateStatisticalSummarySerialized(cStats)
}
/**
@@ -569,17 +364,17 @@ class PythonMLLibAPI extends Serializable {
* pyspark.
*/
def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = {
- val inputMatrix = X.rdd.map(deserializeDoubleVector(_))
+ val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_))
val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method))
- serializeDoubleMatrix(to2dArray(result))
+ SerDe.serializeDoubleMatrix(SerDe.to2dArray(result))
}
/**
* Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String).
*/
def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = {
- val xDeser = x.rdd.map(deserializeDouble(_))
- val yDeser = y.rdd.map(deserializeDouble(_))
+ val xDeser = x.rdd.map(SerDe.deserializeDouble(_))
+ val yDeser = y.rdd.map(SerDe.deserializeDouble(_))
Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method))
}
@@ -588,12 +383,6 @@ class PythonMLLibAPI extends Serializable {
if (method == null) CorrelationNames.defaultCorrName else method
}
- // Reformat a Matrix into Array[Array[Double]] for serialization
- private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = {
- val values = matrix.toArray
- Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows))
- }
-
// Used by the *RDD methods to get default seed if not passed in from pyspark
private def getSeedOrDefault(seed: java.lang.Long): Long = {
if (seed == null) Utils.random.nextLong else seed
@@ -621,7 +410,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.uniformRDD(jsc.sc, size, parts, s).map(serializeDouble)
+ RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble)
}
/**
@@ -633,7 +422,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.normalRDD(jsc.sc, size, parts, s).map(serializeDouble)
+ RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble)
}
/**
@@ -646,7 +435,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.poissonRDD(jsc.sc, mean, size, parts, s).map(serializeDouble)
+ RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble)
}
/**
@@ -659,7 +448,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector)
+ RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
}
/**
@@ -672,7 +461,7 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(serializeDoubleVector)
+ RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
}
/**
@@ -686,7 +475,256 @@ class PythonMLLibAPI extends Serializable {
seed: java.lang.Long): JavaRDD[Array[Byte]] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(serializeDoubleVector)
+ RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
+ }
+
+}
+
+/**
+ * :: DeveloperApi ::
+ * MultivariateStatisticalSummary with Vector fields serialized.
+ */
+@DeveloperApi
+class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary)
+ extends Serializable {
+
+ def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean)
+
+ def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance)
+
+ def count: Long = summary.count
+
+ def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros)
+
+ def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max)
+
+ def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min)
+}
+
+/**
+ * SerDe utility functions for PythonMLLibAPI.
+ */
+private[spark] object SerDe extends Serializable {
+ private val DENSE_VECTOR_MAGIC: Byte = 1
+ private val SPARSE_VECTOR_MAGIC: Byte = 2
+ private val DENSE_MATRIX_MAGIC: Byte = 3
+ private val LABELED_POINT_MAGIC: Byte = 4
+
+ private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
+ require(bytes.length - offset >= 5, "Byte array too short")
+ val magic = bytes(offset)
+ if (magic == DENSE_VECTOR_MAGIC) {
+ deserializeDenseVector(bytes, offset)
+ } else if (magic == SPARSE_VECTOR_MAGIC) {
+ deserializeSparseVector(bytes, offset)
+ } else {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
}
+ private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
+ require(bytes.length - offset == 8, "Wrong size byte array for Double")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ bb.getDouble
+ }
+
+ private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
+ val packetLength = bytes.length - offset
+ require(packetLength >= 5, "Byte array too short")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.get()
+ require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
+ val length = bb.getInt()
+ require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength)
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Double](length.toInt)
+ db.get(ans)
+ Vectors.dense(ans)
+ }
+
+ private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
+ val packetLength = bytes.length - offset
+ require(packetLength >= 9, "Byte array too short")
+ val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.get()
+ require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
+ val size = bb.getInt()
+ val nonZeros = bb.getInt()
+ require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength)
+ val ib = bb.asIntBuffer()
+ val indices = new Array[Int](nonZeros)
+ ib.get(indices)
+ bb.position(bb.position() + 4 * nonZeros)
+ val db = bb.asDoubleBuffer()
+ val values = new Array[Double](nonZeros)
+ db.get(values)
+ Vectors.sparse(size, indices, values)
+ }
+
+ /**
+ * Returns an 8-byte array for the input Double.
+ *
+ * Note: we currently do not use a magic byte for double for storage efficiency.
+ * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
+ * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
+ * serialization scheme changes.
+ */
+ private[python] def serializeDouble(double: Double): Array[Byte] = {
+ val bytes = new Array[Byte](8)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putDouble(double)
+ bytes
+ }
+
+ private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
+ val len = doubles.length
+ val bytes = new Array[Byte](5 + 8 * len)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.put(DENSE_VECTOR_MAGIC)
+ bb.putInt(len)
+ val db = bb.asDoubleBuffer()
+ db.put(doubles)
+ bytes
+ }
+
+ private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = {
+ val nonZeros = vector.indices.length
+ val bytes = new Array[Byte](9 + 12 * nonZeros)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.put(SPARSE_VECTOR_MAGIC)
+ bb.putInt(vector.size)
+ bb.putInt(nonZeros)
+ val ib = bb.asIntBuffer()
+ ib.put(vector.indices)
+ bb.position(bb.position() + 4 * nonZeros)
+ val db = bb.asDoubleBuffer()
+ db.put(vector.values)
+ bytes
+ }
+
+ private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
+ case s: SparseVector =>
+ serializeSparseVector(s)
+ case _ =>
+ serializeDenseVector(vector.toArray)
+ }
+
+ private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
+ val packetLength = bytes.length
+ if (packetLength < 9) {
+ throw new IllegalArgumentException("Byte array too short.")
+ }
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ val magic = bb.get()
+ if (magic != DENSE_MATRIX_MAGIC) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
+ val rows = bb.getInt()
+ val cols = bb.getInt()
+ if (packetLength != 9 + 8 * rows * cols) {
+ throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
+ }
+ val db = bb.asDoubleBuffer()
+ val ans = new Array[Array[Double]](rows.toInt)
+ for (i <- 0 until rows.toInt) {
+ ans(i) = new Array[Double](cols.toInt)
+ db.get(ans(i))
+ }
+ ans
+ }
+
+ private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
+ val rows = doubles.length
+ var cols = 0
+ if (rows > 0) {
+ cols = doubles(0).length
+ }
+ val bytes = new Array[Byte](9 + 8 * rows * cols)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.put(DENSE_MATRIX_MAGIC)
+ bb.putInt(rows)
+ bb.putInt(cols)
+ val db = bb.asDoubleBuffer()
+ for (i <- 0 until rows) {
+ db.put(doubles(i))
+ }
+ bytes
+ }
+
+ private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
+ val fb = serializeDoubleVector(p.features)
+ val bytes = new Array[Byte](1 + 8 + fb.length)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.put(LABELED_POINT_MAGIC)
+ bb.putDouble(p.label)
+ bb.put(fb)
+ bytes
+ }
+
+ private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
+ require(bytes.length >= 9, "Byte array too short")
+ val magic = bytes(0)
+ if (magic != LABELED_POINT_MAGIC) {
+ throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ }
+ val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
+ labelBytes.order(ByteOrder.nativeOrder())
+ val label = labelBytes.asDoubleBuffer().get(0)
+ LabeledPoint(label, deserializeDoubleVector(bytes, 9))
+ }
+
+ // Reformat a Matrix into Array[Array[Double]] for serialization
+ private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = {
+ val values = matrix.toArray
+ Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows))
+ }
+
+
+ /** Unpack a Rating object from an array of bytes */
+ private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = {
+ val bb = ByteBuffer.wrap(ratingBytes)
+ bb.order(ByteOrder.nativeOrder())
+ val user = bb.getInt()
+ val product = bb.getInt()
+ val rating = bb.getDouble()
+ new Rating(user, product, rating)
+ }
+
+ /** Unpack a tuple of Ints from an array of bytes */
+ def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
+ val bb = ByteBuffer.wrap(tupleBytes)
+ bb.order(ByteOrder.nativeOrder())
+ val v1 = bb.getInt()
+ val v2 = bb.getInt()
+ (v1, v2)
+ }
+
+ /**
+ * Serialize a Rating object into an array of bytes.
+ * It can be deserialized using RatingDeserializer().
+ *
+ * @param rate the Rating object to serialize
+ * @return
+ */
+ def serializeRating(rate: Rating): Array[Byte] = {
+ val len = 3
+ val bytes = new Array[Byte](4 + 8 * len)
+ val bb = ByteBuffer.wrap(bytes)
+ bb.order(ByteOrder.nativeOrder())
+ bb.putInt(len)
+ val db = bb.asDoubleBuffer()
+ db.put(rate.user.toDouble)
+ db.put(rate.product.toDouble)
+ db.put(rate.rating)
+ bytes
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index a1a76fcbe9..478c648505 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -23,7 +23,7 @@ import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
-import org.apache.spark.mllib.api.python.PythonMLLibAPI
+import org.apache.spark.mllib.api.python.SerDe
/**
* Model representing the result of matrix factorization.
@@ -117,9 +117,8 @@ class MatrixFactorizationModel private[mllib] (
*/
@DeveloperApi
def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
- val pythonAPI = new PythonMLLibAPI()
- val usersProducts = usersProductsJRDD.rdd.map(xBytes => pythonAPI.unpackTuple(xBytes))
- predict(usersProducts).map(rate => pythonAPI.serializeRating(rate))
+ val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes))
+ predict(usersProducts).map(rate => SerDe.serializeRating(rate))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
index cf8679610e..3cf1028fbc 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/stat/Statistics.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.stat
import org.apache.spark.annotation.Experimental
+import org.apache.spark.mllib.linalg.distributed.RowMatrix
import org.apache.spark.mllib.linalg.{Matrix, Vector}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.stat.correlation.Correlations
@@ -32,6 +33,18 @@ object Statistics {
/**
* :: Experimental ::
+ * Computes column-wise summary statistics for the input RDD[Vector].
+ *
+ * @param X an RDD[Vector] for which column-wise summary statistics are to be computed.
+ * @return [[MultivariateStatisticalSummary]] object containing column-wise summary statistics.
+ */
+ @Experimental
+ def colStats(X: RDD[Vector]): MultivariateStatisticalSummary = {
+ new RowMatrix(X).computeColumnSummaryStatistics()
+ }
+
+ /**
+ * :: Experimental ::
* Compute the Pearson correlation matrix for the input RDD of Vectors.
* Columns with 0 covariance produce NaN entries in the correlation matrix.
*
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
index bd413a80f5..092d67bbc5 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/api/python/PythonMLLibAPISuite.scala
@@ -23,7 +23,6 @@ import org.apache.spark.mllib.linalg.{Matrices, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
class PythonMLLibAPISuite extends FunSuite {
- val py = new PythonMLLibAPI
test("vector serialization") {
val vectors = Seq(
@@ -34,8 +33,8 @@ class PythonMLLibAPISuite extends FunSuite {
Vectors.sparse(1, Array.empty[Int], Array.empty[Double]),
Vectors.sparse(2, Array(1), Array(-2.0)))
vectors.foreach { v =>
- val bytes = py.serializeDoubleVector(v)
- val u = py.deserializeDoubleVector(bytes)
+ val bytes = SerDe.serializeDoubleVector(v)
+ val u = SerDe.deserializeDoubleVector(bytes)
assert(u.getClass === v.getClass)
assert(u === v)
}
@@ -50,8 +49,8 @@ class PythonMLLibAPISuite extends FunSuite {
LabeledPoint(1.0, Vectors.sparse(1, Array.empty[Int], Array.empty[Double])),
LabeledPoint(-0.5, Vectors.sparse(2, Array(1), Array(-2.0))))
points.foreach { p =>
- val bytes = py.serializeLabeledPoint(p)
- val q = py.deserializeLabeledPoint(bytes)
+ val bytes = SerDe.serializeLabeledPoint(p)
+ val q = SerDe.deserializeLabeledPoint(bytes)
assert(q.label === p.label)
assert(q.features.getClass === p.features.getClass)
assert(q.features === p.features)
@@ -60,8 +59,8 @@ class PythonMLLibAPISuite extends FunSuite {
test("double serialization") {
for (x <- List(123.0, -10.0, 0.0, Double.MaxValue, Double.MinValue, Double.NaN)) {
- val bytes = py.serializeDouble(x)
- val deser = py.deserializeDouble(bytes)
+ val bytes = SerDe.serializeDouble(x)
+ val deser = SerDe.deserializeDouble(bytes)
// We use `equals` here for comparison because we cannot use `==` for NaN
assert(x.equals(deser))
}
@@ -70,14 +69,14 @@ class PythonMLLibAPISuite extends FunSuite {
test("matrix to 2D array") {
val values = Array[Double](0, 1.2, 3, 4.56, 7, 8)
val matrix = Matrices.dense(2, 3, values)
- val arr = py.to2dArray(matrix)
+ val arr = SerDe.to2dArray(matrix)
val expected = Array(Array[Double](0, 3, 7), Array[Double](1.2, 4.56, 8))
assert(arr === expected)
// Test conversion for empty matrix
val empty = Array[Double]()
val emptyMatrix = Matrices.dense(0, 0, empty)
- val empty2D = py.to2dArray(emptyMatrix)
+ val empty2D = SerDe.to2dArray(emptyMatrix)
assert(empty2D === Array[Array[Double]]())
}
}