aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-19 15:01:11 -0700
committerXiangrui Meng <meng@databricks.com>2014-09-19 15:01:11 -0700
commitfce5e251d636c788cda91345867e0294280c074d (patch)
tree4bded23a826bcfeb02deef73bd735cf0a05d4ee7 /mllib/src/main
parenta03e5b81e91d9d792b6a2e01d1505394ea303dd8 (diff)
downloadspark-fce5e251d636c788cda91345867e0294280c074d.tar.gz
spark-fce5e251d636c788cda91345867e0294280c074d.tar.bz2
spark-fce5e251d636c788cda91345867e0294280c074d.zip
[SPARK-3491] [MLlib] [PySpark] use pickle to serialize data in MLlib
Currently, we serialize the data between JVM and Python case by case manually, this cannot scale to support so many APIs in MLlib. This patch will try to address this problem by serialize the data using pickle protocol, using Pyrolite library to serialize/deserialize in JVM. Pickle protocol can be easily extended to support customized class. All the modules are refactored to use this protocol. Known issues: There will be some performance regression (both CPU and memory, the serialized data increased) Author: Davies Liu <davies.liu@gmail.com> Closes #2378 from davies/pickle_mllib and squashes the following commits: dffbba2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into pickle_mllib 810f97f [Davies Liu] fix equal of matrix 032cd62 [Davies Liu] add more type check and conversion for user_product bd738ab [Davies Liu] address comments e431377 [Davies Liu] fix cache of rdd, refactor 19d0967 [Davies Liu] refactor Picklers 2511e76 [Davies Liu] cleanup 1fccf1a [Davies Liu] address comments a2cc855 [Davies Liu] fix tests 9ceff73 [Davies Liu] test size of serialized Rating 44e0551 [Davies Liu] fix cache a379a81 [Davies Liu] fix pickle array in python2.7 df625c7 [Davies Liu] Merge commit '154d141' into pickle_mllib 154d141 [Davies Liu] fix autobatchedpickler 44736d7 [Davies Liu] speed up pickling array in Python 2.7 e1d1bfc [Davies Liu] refactor 708dc02 [Davies Liu] fix tests 9dcfb63 [Davies Liu] fix style 88034f0 [Davies Liu] rafactor, address comments 46a501e [Davies Liu] choose batch size automatically df19464 [Davies Liu] memorize the module and class name during pickleing f3506c5 [Davies Liu] Merge branch 'master' into pickle_mllib 722dd96 [Davies Liu] cleanup _common.py 0ee1525 [Davies Liu] remove outdated tests b02e34f [Davies Liu] remove _common.py 84c721d [Davies Liu] Merge branch 'master' into pickle_mllib 4d7963e [Davies Liu] remove muanlly serialization 6d26b03 [Davies Liu] fix tests c383544 [Davies Liu] classification f2a0856 [Davies Liu] mllib/regression d9f691f [Davies Liu] mllib/util cccb8b1 [Davies Liu] mllib/tree 8fe166a [Davies Liu] Merge branch 'pickle' into pickle_mllib aa2287e [Davies Liu] random f1544c4 [Davies Liu] refactor clustering 52d1350 [Davies Liu] use new protocol in mllib/stat b30ef35 [Davies Liu] use pickle to serialize data for mllib/recommendation f44f771 [Davies Liu] enable tests about array 3908f5c [Davies Liu] Merge branch 'master' into pickle c77c87b [Davies Liu] cleanup debugging code 60e4e2f [Davies Liu] support unpickle array.array for Python 2.6
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala487
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala15
3 files changed, 186 insertions, 326 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index fa0fa69f38..9164c294ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -17,16 +17,20 @@
package org.apache.spark.mllib.api.python
-import java.nio.{ByteBuffer, ByteOrder}
+import java.io.OutputStream
import scala.collection.JavaConverters._
+import scala.language.existentials
+import scala.reflect.ClassTag
+
+import net.razorvine.pickle._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
import org.apache.spark.mllib.optimization._
-import org.apache.spark.mllib.linalg.{Matrix, SparseVector, Vector, Vectors}
+import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.random.{RandomRDDs => RG}
import org.apache.spark.mllib.recommendation._
import org.apache.spark.mllib.regression._
@@ -40,11 +44,10 @@ import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
import org.apache.spark.util.Utils
+
/**
* :: DeveloperApi ::
* The Java stubs necessary for the Python mllib bindings.
- *
- * See python/pyspark/mllib/_common.py for the mutually agreed upon data format.
*/
@DeveloperApi
class PythonMLLibAPI extends Serializable {
@@ -60,18 +63,17 @@ class PythonMLLibAPI extends Serializable {
def loadLabeledPoints(
jsc: JavaSparkContext,
path: String,
- minPartitions: Int): JavaRDD[Array[Byte]] =
- MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions).map(SerDe.serializeLabeledPoint)
+ minPartitions: Int): JavaRDD[LabeledPoint] =
+ MLUtils.loadLabeledPoints(jsc.sc, path, minPartitions)
private def trainRegressionModel(
trainFunc: (RDD[LabeledPoint], Vector) => GeneralizedLinearModel,
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
initialWeightsBA: Array[Byte]): java.util.LinkedList[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
- val initialWeights = SerDe.deserializeDoubleVector(initialWeightsBA)
- val model = trainFunc(data, initialWeights)
+ val initialWeights = SerDe.loads(initialWeightsBA).asInstanceOf[Vector]
+ val model = trainFunc(data.rdd, initialWeights)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.serializeDoubleVector(model.weights))
+ ret.add(SerDe.dumps(model.weights))
ret.add(model.intercept: java.lang.Double)
ret
}
@@ -80,7 +82,7 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib LinearRegressionWithSGD.train()
*/
def trainLinearRegressionModelWithSGD(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
@@ -106,7 +108,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
(data, initialWeights) =>
lrAlg.run(data, initialWeights),
- dataBytesJRDD,
+ data,
initialWeightsBA)
}
@@ -114,7 +116,7 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib LassoWithSGD.train()
*/
def trainLassoModelWithSGD(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
@@ -129,7 +131,7 @@ class PythonMLLibAPI extends Serializable {
regParam,
miniBatchFraction,
initialWeights),
- dataBytesJRDD,
+ data,
initialWeightsBA)
}
@@ -137,7 +139,7 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib RidgeRegressionWithSGD.train()
*/
def trainRidgeModelWithSGD(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
@@ -152,7 +154,7 @@ class PythonMLLibAPI extends Serializable {
regParam,
miniBatchFraction,
initialWeights),
- dataBytesJRDD,
+ data,
initialWeightsBA)
}
@@ -160,7 +162,7 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib SVMWithSGD.train()
*/
def trainSVMModelWithSGD(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
regParam: Double,
@@ -186,7 +188,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
(data, initialWeights) =>
SVMAlg.run(data, initialWeights),
- dataBytesJRDD,
+ data,
initialWeightsBA)
}
@@ -194,7 +196,7 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib LogisticRegressionWithSGD.train()
*/
def trainLogisticRegressionModelWithSGD(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
numIterations: Int,
stepSize: Double,
miniBatchFraction: Double,
@@ -220,7 +222,7 @@ class PythonMLLibAPI extends Serializable {
trainRegressionModel(
(data, initialWeights) =>
LogRegAlg.run(data, initialWeights),
- dataBytesJRDD,
+ data,
initialWeightsBA)
}
@@ -228,14 +230,13 @@ class PythonMLLibAPI extends Serializable {
* Java stub for NaiveBayes.train()
*/
def trainNaiveBayes(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
lambda: Double): java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
- val model = NaiveBayes.train(data, lambda)
+ val model = NaiveBayes.train(data.rdd, lambda)
val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.labels)))
- ret.add(SerDe.serializeDoubleVector(Vectors.dense(model.pi)))
- ret.add(SerDe.serializeDoubleMatrix(model.theta))
+ ret.add(Vectors.dense(model.labels))
+ ret.add(Vectors.dense(model.pi))
+ ret.add(model.theta)
ret
}
@@ -243,16 +244,12 @@ class PythonMLLibAPI extends Serializable {
* Java stub for Python mllib KMeans.train()
*/
def trainKMeansModel(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[Vector],
k: Int,
maxIterations: Int,
runs: Int,
- initializationMode: String): java.util.List[java.lang.Object] = {
- val data = dataBytesJRDD.rdd.map(bytes => SerDe.deserializeDoubleVector(bytes))
- val model = KMeans.train(data, k, maxIterations, runs, initializationMode)
- val ret = new java.util.LinkedList[java.lang.Object]()
- ret.add(SerDe.serializeDoubleMatrix(model.clusterCenters.map(_.toArray)))
- ret
+ initializationMode: String): KMeansModel = {
+ KMeans.train(data.rdd, k, maxIterations, runs, initializationMode)
}
/**
@@ -262,13 +259,12 @@ class PythonMLLibAPI extends Serializable {
* the Py4J documentation.
*/
def trainALSModel(
- ratingsBytesJRDD: JavaRDD[Array[Byte]],
+ ratings: JavaRDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
blocks: Int): MatrixFactorizationModel = {
- val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating)
- ALS.train(ratings, rank, iterations, lambda, blocks)
+ ALS.train(ratings.rdd, rank, iterations, lambda, blocks)
}
/**
@@ -278,14 +274,13 @@ class PythonMLLibAPI extends Serializable {
* exit; see the Py4J documentation.
*/
def trainImplicitALSModel(
- ratingsBytesJRDD: JavaRDD[Array[Byte]],
+ ratingsJRDD: JavaRDD[Rating],
rank: Int,
iterations: Int,
lambda: Double,
blocks: Int,
alpha: Double): MatrixFactorizationModel = {
- val ratings = ratingsBytesJRDD.rdd.map(SerDe.unpackRating)
- ALS.trainImplicit(ratings, rank, iterations, lambda, blocks, alpha)
+ ALS.trainImplicit(ratingsJRDD.rdd, rank, iterations, lambda, blocks, alpha)
}
/**
@@ -293,11 +288,11 @@ class PythonMLLibAPI extends Serializable {
* This stub returns a handle to the Java object instead of the content of the Java object.
* Extra care needs to be taken in the Python code to ensure it gets freed on exit;
* see the Py4J documentation.
- * @param dataBytesJRDD Training data
+ * @param data Training data
* @param categoricalFeaturesInfoJMap Categorical features info, as Java map
*/
def trainDecisionTreeModel(
- dataBytesJRDD: JavaRDD[Array[Byte]],
+ data: JavaRDD[LabeledPoint],
algoStr: String,
numClasses: Int,
categoricalFeaturesInfoJMap: java.util.Map[Int, Int],
@@ -307,8 +302,6 @@ class PythonMLLibAPI extends Serializable {
minInstancesPerNode: Int,
minInfoGain: Double): DecisionTreeModel = {
- val data = dataBytesJRDD.rdd.map(SerDe.deserializeLabeledPoint)
-
val algo = Algo.fromString(algoStr)
val impurity = Impurities.fromString(impurityStr)
@@ -322,44 +315,15 @@ class PythonMLLibAPI extends Serializable {
minInstancesPerNode = minInstancesPerNode,
minInfoGain = minInfoGain)
- DecisionTree.train(data, strategy)
- }
-
- /**
- * Predict the label of the given data point.
- * This is a Java stub for python DecisionTreeModel.predict()
- *
- * @param featuresBytes Serialized feature vector for data point
- * @return predicted label
- */
- def predictDecisionTreeModel(
- model: DecisionTreeModel,
- featuresBytes: Array[Byte]): Double = {
- val features: Vector = SerDe.deserializeDoubleVector(featuresBytes)
- model.predict(features)
- }
-
- /**
- * Predict the labels of the given data points.
- * This is a Java stub for python DecisionTreeModel.predict()
- *
- * @param dataJRDD A JavaRDD with serialized feature vectors
- * @return JavaRDD of serialized predictions
- */
- def predictDecisionTreeModel(
- model: DecisionTreeModel,
- dataJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
- val data = dataJRDD.rdd.map(xBytes => SerDe.deserializeDoubleVector(xBytes))
- model.predict(data).map(SerDe.serializeDouble)
+ DecisionTree.train(data.rdd, strategy)
}
/**
* Java stub for mllib Statistics.colStats(X: RDD[Vector]).
* TODO figure out return type.
*/
- def colStats(X: JavaRDD[Array[Byte]]): MultivariateStatisticalSummarySerialized = {
- val cStats = Statistics.colStats(X.rdd.map(SerDe.deserializeDoubleVector(_)))
- new MultivariateStatisticalSummarySerialized(cStats)
+ def colStats(rdd: JavaRDD[Vector]): MultivariateStatisticalSummary = {
+ Statistics.colStats(rdd.rdd)
}
/**
@@ -367,19 +331,15 @@ class PythonMLLibAPI extends Serializable {
* Returns the correlation matrix serialized into a byte array understood by deserializers in
* pyspark.
*/
- def corr(X: JavaRDD[Array[Byte]], method: String): Array[Byte] = {
- val inputMatrix = X.rdd.map(SerDe.deserializeDoubleVector(_))
- val result = Statistics.corr(inputMatrix, getCorrNameOrDefault(method))
- SerDe.serializeDoubleMatrix(SerDe.to2dArray(result))
+ def corr(x: JavaRDD[Vector], method: String): Matrix = {
+ Statistics.corr(x.rdd, getCorrNameOrDefault(method))
}
/**
* Java stub for mllib Statistics.corr(x: RDD[Double], y: RDD[Double], method: String).
*/
- def corr(x: JavaRDD[Array[Byte]], y: JavaRDD[Array[Byte]], method: String): Double = {
- val xDeser = x.rdd.map(SerDe.deserializeDouble(_))
- val yDeser = y.rdd.map(SerDe.deserializeDouble(_))
- Statistics.corr(xDeser, yDeser, getCorrNameOrDefault(method))
+ def corr(x: JavaRDD[Double], y: JavaRDD[Double], method: String): Double = {
+ Statistics.corr(x.rdd, y.rdd, getCorrNameOrDefault(method))
}
// used by the corr methods to retrieve the name of the correlation method passed in via pyspark
@@ -411,10 +371,10 @@ class PythonMLLibAPI extends Serializable {
def uniformRDD(jsc: JavaSparkContext,
size: Long,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Double] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.uniformRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble)
+ RG.uniformRDD(jsc.sc, size, parts, s)
}
/**
@@ -423,10 +383,10 @@ class PythonMLLibAPI extends Serializable {
def normalRDD(jsc: JavaSparkContext,
size: Long,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Double] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.normalRDD(jsc.sc, size, parts, s).map(SerDe.serializeDouble)
+ RG.normalRDD(jsc.sc, size, parts, s)
}
/**
@@ -436,10 +396,10 @@ class PythonMLLibAPI extends Serializable {
mean: Double,
size: Long,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Double] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.poissonRDD(jsc.sc, mean, size, parts, s).map(SerDe.serializeDouble)
+ RG.poissonRDD(jsc.sc, mean, size, parts, s)
}
/**
@@ -449,10 +409,10 @@ class PythonMLLibAPI extends Serializable {
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Vector] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
+ RG.uniformVectorRDD(jsc.sc, numRows, numCols, parts, s)
}
/**
@@ -462,10 +422,10 @@ class PythonMLLibAPI extends Serializable {
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Vector] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
+ RG.normalVectorRDD(jsc.sc, numRows, numCols, parts, s)
}
/**
@@ -476,259 +436,168 @@ class PythonMLLibAPI extends Serializable {
numRows: Long,
numCols: Int,
numPartitions: java.lang.Integer,
- seed: java.lang.Long): JavaRDD[Array[Byte]] = {
+ seed: java.lang.Long): JavaRDD[Vector] = {
val parts = getNumPartitionsOrDefault(numPartitions, jsc)
val s = getSeedOrDefault(seed)
- RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s).map(SerDe.serializeDoubleVector)
+ RG.poissonVectorRDD(jsc.sc, mean, numRows, numCols, parts, s)
}
}
/**
- * :: DeveloperApi ::
- * MultivariateStatisticalSummary with Vector fields serialized.
+ * SerDe utility functions for PythonMLLibAPI.
*/
-@DeveloperApi
-class MultivariateStatisticalSummarySerialized(val summary: MultivariateStatisticalSummary)
- extends Serializable {
+private[spark] object SerDe extends Serializable {
- def mean: Array[Byte] = SerDe.serializeDoubleVector(summary.mean)
+ val PYSPARK_PACKAGE = "pyspark.mllib"
- def variance: Array[Byte] = SerDe.serializeDoubleVector(summary.variance)
+ /**
+ * Base class used for pickle
+ */
+ private[python] abstract class BasePickler[T: ClassTag]
+ extends IObjectPickler with IObjectConstructor {
+
+ private val cls = implicitly[ClassTag[T]].runtimeClass
+ private val module = PYSPARK_PACKAGE + "." + cls.getName.split('.')(4)
+ private val name = cls.getSimpleName
+
+ // register this to Pickler and Unpickler
+ def register(): Unit = {
+ Pickler.registerCustomPickler(this.getClass, this)
+ Pickler.registerCustomPickler(cls, this)
+ Unpickler.registerConstructor(module, name, this)
+ }
- def count: Long = summary.count
+ def pickle(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
+ if (obj == this) {
+ out.write(Opcodes.GLOBAL)
+ out.write((module + "\n" + name + "\n").getBytes())
+ } else {
+ pickler.save(this) // it will be memorized by Pickler
+ saveState(obj, out, pickler)
+ out.write(Opcodes.REDUCE)
+ }
+ }
+
+ private[python] def saveObjects(out: OutputStream, pickler: Pickler, objects: Any*) = {
+ if (objects.length == 0 || objects.length > 3) {
+ out.write(Opcodes.MARK)
+ }
+ objects.foreach(pickler.save(_))
+ val code = objects.length match {
+ case 1 => Opcodes.TUPLE1
+ case 2 => Opcodes.TUPLE2
+ case 3 => Opcodes.TUPLE3
+ case _ => Opcodes.TUPLE
+ }
+ out.write(code)
+ }
- def numNonzeros: Array[Byte] = SerDe.serializeDoubleVector(summary.numNonzeros)
+ private[python] def saveState(obj: Object, out: OutputStream, pickler: Pickler)
+ }
- def max: Array[Byte] = SerDe.serializeDoubleVector(summary.max)
+ // Pickler for DenseVector
+ private[python] class DenseVectorPickler extends BasePickler[DenseVector] {
- def min: Array[Byte] = SerDe.serializeDoubleVector(summary.min)
-}
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val vector: DenseVector = obj.asInstanceOf[DenseVector]
+ saveObjects(out, pickler, vector.toArray)
+ }
-/**
- * SerDe utility functions for PythonMLLibAPI.
- */
-private[spark] object SerDe extends Serializable {
- private val DENSE_VECTOR_MAGIC: Byte = 1
- private val SPARSE_VECTOR_MAGIC: Byte = 2
- private val DENSE_MATRIX_MAGIC: Byte = 3
- private val LABELED_POINT_MAGIC: Byte = 4
-
- private[python] def deserializeDoubleVector(bytes: Array[Byte], offset: Int = 0): Vector = {
- require(bytes.length - offset >= 5, "Byte array too short")
- val magic = bytes(offset)
- if (magic == DENSE_VECTOR_MAGIC) {
- deserializeDenseVector(bytes, offset)
- } else if (magic == SPARSE_VECTOR_MAGIC) {
- deserializeSparseVector(bytes, offset)
- } else {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ def construct(args: Array[Object]): Object = {
+ require(args.length == 1)
+ if (args.length != 1) {
+ throw new PickleException("should be 1")
+ }
+ new DenseVector(args(0).asInstanceOf[Array[Double]])
}
}
- private[python] def deserializeDouble(bytes: Array[Byte], offset: Int = 0): Double = {
- require(bytes.length - offset == 8, "Wrong size byte array for Double")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- bb.getDouble
- }
-
- private[python] def deserializeDenseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 5, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == DENSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val length = bb.getInt()
- require (packetLength == 5 + 8 * length, "Invalid packet length: " + packetLength)
- val db = bb.asDoubleBuffer()
- val ans = new Array[Double](length.toInt)
- db.get(ans)
- Vectors.dense(ans)
- }
-
- private[python] def deserializeSparseVector(bytes: Array[Byte], offset: Int = 0): Vector = {
- val packetLength = bytes.length - offset
- require(packetLength >= 9, "Byte array too short")
- val bb = ByteBuffer.wrap(bytes, offset, bytes.length - offset)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- require(magic == SPARSE_VECTOR_MAGIC, "Invalid magic: " + magic)
- val size = bb.getInt()
- val nonZeros = bb.getInt()
- require (packetLength == 9 + 12 * nonZeros, "Invalid packet length: " + packetLength)
- val ib = bb.asIntBuffer()
- val indices = new Array[Int](nonZeros)
- ib.get(indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- val values = new Array[Double](nonZeros)
- db.get(values)
- Vectors.sparse(size, indices, values)
- }
+ // Pickler for DenseMatrix
+ private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {
- /**
- * Returns an 8-byte array for the input Double.
- *
- * Note: we currently do not use a magic byte for double for storage efficiency.
- * This should be reconsidered when we add Ser/De for other 8-byte types (e.g. Long), for safety.
- * The corresponding deserializer, deserializeDouble, needs to be modified as well if the
- * serialization scheme changes.
- */
- private[python] def serializeDouble(double: Double): Array[Byte] = {
- val bytes = new Array[Byte](8)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putDouble(double)
- bytes
- }
-
- private[python] def serializeDenseVector(doubles: Array[Double]): Array[Byte] = {
- val len = doubles.length
- val bytes = new Array[Byte](5 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_VECTOR_MAGIC)
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(doubles)
- bytes
- }
-
- private[python] def serializeSparseVector(vector: SparseVector): Array[Byte] = {
- val nonZeros = vector.indices.length
- val bytes = new Array[Byte](9 + 12 * nonZeros)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(SPARSE_VECTOR_MAGIC)
- bb.putInt(vector.size)
- bb.putInt(nonZeros)
- val ib = bb.asIntBuffer()
- ib.put(vector.indices)
- bb.position(bb.position() + 4 * nonZeros)
- val db = bb.asDoubleBuffer()
- db.put(vector.values)
- bytes
- }
-
- private[python] def serializeDoubleVector(vector: Vector): Array[Byte] = vector match {
- case s: SparseVector =>
- serializeSparseVector(s)
- case _ =>
- serializeDenseVector(vector.toArray)
- }
-
- private[python] def deserializeDoubleMatrix(bytes: Array[Byte]): Array[Array[Double]] = {
- val packetLength = bytes.length
- if (packetLength < 9) {
- throw new IllegalArgumentException("Byte array too short.")
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
+ saveObjects(out, pickler, m.numRows, m.numCols, m.values)
}
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- val magic = bb.get()
- if (magic != DENSE_MATRIX_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 3) {
+ throw new PickleException("should be 3")
+ }
+ new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ args(2).asInstanceOf[Array[Double]])
}
- val rows = bb.getInt()
- val cols = bb.getInt()
- if (packetLength != 9 + 8 * rows * cols) {
- throw new IllegalArgumentException("Size " + rows + "x" + cols + " is wrong.")
+ }
+
+ // Pickler for SparseVector
+ private[python] class SparseVectorPickler extends BasePickler[SparseVector] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val v: SparseVector = obj.asInstanceOf[SparseVector]
+ saveObjects(out, pickler, v.size, v.indices, v.values)
}
- val db = bb.asDoubleBuffer()
- val ans = new Array[Array[Double]](rows.toInt)
- for (i <- 0 until rows.toInt) {
- ans(i) = new Array[Double](cols.toInt)
- db.get(ans(i))
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 3) {
+ throw new PickleException("should be 3")
+ }
+ new SparseVector(args(0).asInstanceOf[Int], args(1).asInstanceOf[Array[Int]],
+ args(2).asInstanceOf[Array[Double]])
}
- ans
}
- private[python] def serializeDoubleMatrix(doubles: Array[Array[Double]]): Array[Byte] = {
- val rows = doubles.length
- var cols = 0
- if (rows > 0) {
- cols = doubles(0).length
+ // Pickler for LabeledPoint
+ private[python] class LabeledPointPickler extends BasePickler[LabeledPoint] {
+
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val point: LabeledPoint = obj.asInstanceOf[LabeledPoint]
+ saveObjects(out, pickler, point.label, point.features)
}
- val bytes = new Array[Byte](9 + 8 * rows * cols)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(DENSE_MATRIX_MAGIC)
- bb.putInt(rows)
- bb.putInt(cols)
- val db = bb.asDoubleBuffer()
- for (i <- 0 until rows) {
- db.put(doubles(i))
+
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 2) {
+ throw new PickleException("should be 2")
+ }
+ new LabeledPoint(args(0).asInstanceOf[Double], args(1).asInstanceOf[Vector])
}
- bytes
}
- private[python] def serializeLabeledPoint(p: LabeledPoint): Array[Byte] = {
- val fb = serializeDoubleVector(p.features)
- val bytes = new Array[Byte](1 + 8 + fb.length)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.put(LABELED_POINT_MAGIC)
- bb.putDouble(p.label)
- bb.put(fb)
- bytes
- }
+ // Pickler for Rating
+ private[python] class RatingPickler extends BasePickler[Rating] {
- private[python] def deserializeLabeledPoint(bytes: Array[Byte]): LabeledPoint = {
- require(bytes.length >= 9, "Byte array too short")
- val magic = bytes(0)
- if (magic != LABELED_POINT_MAGIC) {
- throw new IllegalArgumentException("Magic " + magic + " is wrong.")
+ def saveState(obj: Object, out: OutputStream, pickler: Pickler) = {
+ val rating: Rating = obj.asInstanceOf[Rating]
+ saveObjects(out, pickler, rating.user, rating.product, rating.rating)
}
- val labelBytes = ByteBuffer.wrap(bytes, 1, 8)
- labelBytes.order(ByteOrder.nativeOrder())
- val label = labelBytes.asDoubleBuffer().get(0)
- LabeledPoint(label, deserializeDoubleVector(bytes, 9))
- }
- // Reformat a Matrix into Array[Array[Double]] for serialization
- private[python] def to2dArray(matrix: Matrix): Array[Array[Double]] = {
- val values = matrix.toArray
- Array.tabulate(matrix.numRows, matrix.numCols)((i, j) => values(i + j * matrix.numRows))
+ def construct(args: Array[Object]): Object = {
+ if (args.length != 3) {
+ throw new PickleException("should be 3")
+ }
+ new Rating(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int],
+ args(2).asInstanceOf[Double])
+ }
}
+ def initialize(): Unit = {
+ new DenseVectorPickler().register()
+ new DenseMatrixPickler().register()
+ new SparseVectorPickler().register()
+ new LabeledPointPickler().register()
+ new RatingPickler().register()
+ }
- /** Unpack a Rating object from an array of bytes */
- private[python] def unpackRating(ratingBytes: Array[Byte]): Rating = {
- val bb = ByteBuffer.wrap(ratingBytes)
- bb.order(ByteOrder.nativeOrder())
- val user = bb.getInt()
- val product = bb.getInt()
- val rating = bb.getDouble()
- new Rating(user, product, rating)
+ def dumps(obj: AnyRef): Array[Byte] = {
+ new Pickler().dumps(obj)
}
- /** Unpack a tuple of Ints from an array of bytes */
- def unpackTuple(tupleBytes: Array[Byte]): (Int, Int) = {
- val bb = ByteBuffer.wrap(tupleBytes)
- bb.order(ByteOrder.nativeOrder())
- val v1 = bb.getInt()
- val v2 = bb.getInt()
- (v1, v2)
+ def loads(bytes: Array[Byte]): AnyRef = {
+ new Unpickler().loads(bytes)
}
- /**
- * Serialize a Rating object into an array of bytes.
- * It can be deserialized using RatingDeserializer().
- *
- * @param rate the Rating object to serialize
- * @return
- */
- def serializeRating(rate: Rating): Array[Byte] = {
- val len = 3
- val bytes = new Array[Byte](4 + 8 * len)
- val bb = ByteBuffer.wrap(bytes)
- bb.order(ByteOrder.nativeOrder())
- bb.putInt(len)
- val db = bb.asDoubleBuffer()
- db.put(rate.user.toDouble)
- db.put(rate.product.toDouble)
- db.put(rate.rating)
- bytes
+ /* convert object into Tuple */
+ def asTupleRDD(rdd: RDD[Array[Any]]): RDD[(Int, Int)] = {
+ rdd.map(x => (x(0).asInstanceOf[Int], x(1).asInstanceOf[Int]))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 5711532abc..4e87fe088e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -17,12 +17,12 @@
package org.apache.spark.mllib.linalg
+import java.util.Arrays
+
import breeze.linalg.{Matrix => BM, DenseMatrix => BDM, CSCMatrix => BSM}
import org.apache.spark.util.random.XORShiftRandom
-import java.util.Arrays
-
/**
* Trait for a local matrix.
*/
@@ -106,6 +106,12 @@ class DenseMatrix(val numRows: Int, val numCols: Int, val values: Array[Double])
override def toArray: Array[Double] = values
+ override def equals(o: Any) = o match {
+ case m: DenseMatrix =>
+ m.numRows == numRows && m.numCols == numCols && Arrays.equals(toArray, m.toArray)
+ case _ => false
+ }
+
private[mllib] def toBreeze: BM[Double] = new BDM[Double](numRows, numCols, values)
private[mllib] def apply(i: Int): Double = values(i)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
index 478c648505..66b58ba770 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/MatrixFactorizationModel.scala
@@ -106,19 +106,4 @@ class MatrixFactorizationModel private[mllib] (
}
scored.top(num)(Ordering.by(_._2))
}
-
- /**
- * :: DeveloperApi ::
- * Predict the rating of many users for many products.
- * This is a Java stub for python predictAll()
- *
- * @param usersProductsJRDD A JavaRDD with serialized tuples (user, product)
- * @return JavaRDD of serialized Rating objects.
- */
- @DeveloperApi
- def predict(usersProductsJRDD: JavaRDD[Array[Byte]]): JavaRDD[Array[Byte]] = {
- val usersProducts = usersProductsJRDD.rdd.map(xBytes => SerDe.unpackTuple(xBytes))
- predict(usersProducts).map(rate => SerDe.serializeRating(rate))
- }
-
}