diff options
author | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-19 01:22:18 -0500 |
---|---|---|
committer | Tor Myklebust <tmyklebu@gmail.com> | 2013-12-19 01:29:09 -0500 |
commit | 95915f8b3b6d07a9dddb09a637aa23c8622bff9b (patch) | |
tree | bc22a7cb8758079a0c8896d022dca0b418e66ec8 | |
parent | d3b1af4b6c7766bbf7a09ee6d5c1b13eda6b098f (diff) | |
download | spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.tar.gz spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.tar.bz2 spark-95915f8b3b6d07a9dddb09a637aa23c8622bff9b.zip |
First cut at python mllib bindings. Only LinearRegression is supported.
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala | 51 | ||||
-rw-r--r-- | python/pyspark/mllib.py | 114 |
2 files changed, 165 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala new file mode 100644 index 0000000000..19d2e9a773 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/api/PythonMLLibAPI.scala @@ -0,0 +1,51 @@ +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.mllib.regression._ +import java.nio.ByteBuffer +import java.nio.ByteOrder +import java.nio.DoubleBuffer + +class PythonMLLibAPI extends Serializable { + def deserializeDoubleVector(bytes: Array[Byte]): Array[Double] = { + val packetLength = bytes.length; + if (packetLength < 16) { + throw new IllegalArgumentException("Byte array too short."); + } + val bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.nativeOrder()); + val magic = bb.getLong(); + if (magic != 1) { + throw new IllegalArgumentException("Magic " + magic + " is wrong."); + } + val length = bb.getLong(); + if (packetLength != 16 + 8 * length) { + throw new IllegalArgumentException("Length " + length + "is wrong."); + } + val db = bb.asDoubleBuffer(); + val ans = new Array[Double](length.toInt); + db.get(ans); + return ans; + } + + def serializeDoubleVector(doubles: Array[Double]): Array[Byte] = { + val len = doubles.length; + val bytes = new Array[Byte](16 + 8 * len); + val bb = ByteBuffer.wrap(bytes); + bb.order(ByteOrder.nativeOrder()); + bb.putLong(1); + bb.putLong(len); + val db = bb.asDoubleBuffer(); + db.put(doubles); + return bytes; + } + + def trainLinearRegressionModel(dataBytesJRDD: JavaRDD[Array[Byte]]): + java.util.List[java.lang.Object] = { + val data = dataBytesJRDD.rdd.map(x => deserializeDoubleVector(x)) + .map(v => LabeledPoint(v(0), v.slice(1, v.length))); + val model = LinearRegressionWithSGD.train(data, 222); + val ret = new java.util.LinkedList[java.lang.Object](); + ret.add(serializeDoubleVector(model.weights)); + ret.add(model.intercept: java.lang.Double); + return ret; + } +} diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py new file mode 100644 index 0000000000..8237f66d67 --- /dev/null +++ b/python/pyspark/mllib.py @@ -0,0 +1,114 @@ +from numpy import *; +from pyspark.serializers import NoOpSerializer, FramedSerializer, \ + BatchedSerializer, CloudPickleSerializer, pack_long + +#__all__ = ["train_linear_regression_model"]; + +# Double vector format: +# +# [8-byte 1] [8-byte length] [length*8 bytes of data] +# +# Double matrix format: +# +# [8-byte 2] [8-byte rows] [8-byte cols] [rows*cols*8 bytes of data] +# +# This is all in machine-endian. That means that the Java interpreter and the +# Python interpreter must agree on what endian the machine is. + +def deserialize_byte_array(shape, ba, offset): + """Implementation detail. Do not use directly.""" + ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", \ + order='C'); + return ar.copy(); + +def serialize_double_vector(v): + """Implementation detail. Do not use directly.""" + if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1): + length = v.shape[0]; + ba = bytearray(16 + 8*length); + header = ndarray(shape=[2], buffer=ba, dtype="int64"); + header[0] = 1; + header[1] = length; + copyto(ndarray(shape=[length], buffer=ba, offset=16, dtype="float64"), v); + return ba; + else: + raise TypeError("serialize_double_vector called on a non-double-vector"); + +def deserialize_double_vector(ba): + """Implementation detail. Do not use directly.""" + if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)): + header = ndarray(shape=[2], buffer=ba, dtype="int64"); + if (header[0] != 1): + raise TypeError("deserialize_double_vector called on bytearray with " \ + "wrong magic"); + length = header[1]; + if (len(ba) != 8*length + 16): + raise TypeError("deserialize_double_vector called on bytearray with " \ + "wrong length"); + return deserialize_byte_array([length], ba, 16); + else: + raise TypeError("deserialize_double_vector called on a non-bytearray"); + +def serialize_double_matrix(m): + """Implementation detail. Do not use directly.""" + if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2): + rows = m.shape[0]; + cols = m.shape[1]; + ba = bytearray(24 + 8 * rows * cols); + header = ndarray(shape=[3], buffer=ba, dtype="int64"); + header[0] = 2; + header[1] = rows; + header[2] = cols; + copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, dtype="float64", \ + order='C'), m); + return ba; + else: + print type(m); + print m.dtype; + print m.ndim; + raise TypeError("serialize_double_matrix called on a non-double-matrix"); + +def deserialize_double_matrix(ba): + """Implementation detail. Do not use directly.""" + if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)): + header = ndarray(shape=[3], buffer=ba, dtype="int64"); + if (header[0] != 2): + raise TypeError("deserialize_double_matrix called on bytearray with " \ + "wrong magic"); + rows = header[1]; + cols = header[2]; + if (len(ba) != 8*rows*cols + 24): + raise TypeError("deserialize_double_matrix called on bytearray with " \ + "wrong length"); + return deserialize_byte_array([rows, cols], ba, 24); + else: + raise TypeError("deserialize_double_matrix called on a non-bytearray"); + +class LinearRegressionModel: + _coeff = None; + _intercept = None; + def __init__(self, coeff, intercept): + self._coeff = coeff; + self._intercept = intercept; + def predict(self, x): + if (type(x) == ndarray): + if (x.ndim == 1): + return dot(_coeff, x) - _intercept; + else: + raise RuntimeError("Bulk predict not yet supported."); + elif (type(x) == RDD): + raise RuntimeError("Bulk predict not yet supported."); + else: + raise TypeError("Bad type argument to LinearRegressionModel::predict"); + +def train_linear_regression_model(sc, data): + """Train a linear regression model on the given data.""" + dataBytes = data.map(serialize_double_vector); + sc.serializer = NoOpSerializer(); + dataBytes.cache(); + api = sc._jvm.PythonMLLibAPI(); + ans = api.trainLinearRegressionModel(dataBytes._jrdd); + if (len(ans) != 2 or type(ans[0]) != bytearray or type(ans[1]) != float): + raise RuntimeError("train_linear_regression_model received garbage " \ + "from JVM"); + return LinearRegressionModel(deserialize_double_vector(ans[0]), ans[1]); |