aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-19 03:40:57 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-19 03:40:57 -0500
commitbf20591a006b9d2fdd9a674d637f5e929fd065a2 (patch)
tree90be0c9c61b009e2379d1c05b0abc4d038f358c4 /python
parentbf491bb3c0a9008caa4ac112672a4760b3d1c7b8 (diff)
downloadspark-bf20591a006b9d2fdd9a674d637f5e929fd065a2.tar.gz
spark-bf20591a006b9d2fdd9a674d637f5e929fd065a2.tar.bz2
spark-bf20591a006b9d2fdd9a674d637f5e929fd065a2.zip
Incorporate most of Josh's style suggestions. I don't want to deal with the type and length checking errors until we've got at least one working stub that we're all happy with.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/__init__.py4
-rw-r--r--python/pyspark/mllib.py185
2 files changed, 91 insertions, 98 deletions
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 949406c57b..9f71db397d 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -42,7 +42,7 @@ from pyspark.context import SparkContext
from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel
-from pyspark.mllib import train_linear_regression_model
+from pyspark.mllib import LinearRegressionModel
-__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "train_linear_regression_model"]
+__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "LinearRegressionModel"];
diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py
index 8237f66d67..0dfc4909c7 100644
--- a/python/pyspark/mllib.py
+++ b/python/pyspark/mllib.py
@@ -1,8 +1,4 @@
-from numpy import *;
-from pyspark.serializers import NoOpSerializer, FramedSerializer, \
- BatchedSerializer, CloudPickleSerializer, pack_long
-
-#__all__ = ["train_linear_regression_model"];
+from numpy import *
# Double vector format:
#
@@ -15,100 +11,97 @@ from pyspark.serializers import NoOpSerializer, FramedSerializer, \
# This is all in machine-endian. That means that the Java interpreter and the
# Python interpreter must agree on what endian the machine is.
-def deserialize_byte_array(shape, ba, offset):
- """Implementation detail. Do not use directly."""
- ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64", \
- order='C');
- return ar.copy();
-
-def serialize_double_vector(v):
- """Implementation detail. Do not use directly."""
- if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1):
- length = v.shape[0];
- ba = bytearray(16 + 8*length);
- header = ndarray(shape=[2], buffer=ba, dtype="int64");
- header[0] = 1;
- header[1] = length;
- copyto(ndarray(shape=[length], buffer=ba, offset=16, dtype="float64"), v);
- return ba;
- else:
- raise TypeError("serialize_double_vector called on a non-double-vector");
+def _deserialize_byte_array(shape, ba, offset):
+ ar = ndarray(shape=shape, buffer=ba, offset=offset, dtype="float64",
+ order='C')
+ return ar.copy()
-def deserialize_double_vector(ba):
- """Implementation detail. Do not use directly."""
- if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)):
- header = ndarray(shape=[2], buffer=ba, dtype="int64");
- if (header[0] != 1):
- raise TypeError("deserialize_double_vector called on bytearray with " \
- "wrong magic");
- length = header[1];
- if (len(ba) != 8*length + 16):
- raise TypeError("deserialize_double_vector called on bytearray with " \
- "wrong length");
- return deserialize_byte_array([length], ba, 16);
- else:
- raise TypeError("deserialize_double_vector called on a non-bytearray");
+def _serialize_double_vector(v):
+ if (type(v) == ndarray and v.dtype == float64 and v.ndim == 1):
+ length = v.shape[0]
+ ba = bytearray(16 + 8*length)
+ header = ndarray(shape=[2], buffer=ba, dtype="int64")
+ header[0] = 1
+ header[1] = length
+ copyto(ndarray(shape=[length], buffer=ba, offset=16,
+ dtype="float64"), v)
+ return ba
+ else:
+ raise TypeError("_serialize_double_vector called on a "
+ "non-double-vector")
-def serialize_double_matrix(m):
- """Implementation detail. Do not use directly."""
- if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2):
- rows = m.shape[0];
- cols = m.shape[1];
- ba = bytearray(24 + 8 * rows * cols);
- header = ndarray(shape=[3], buffer=ba, dtype="int64");
- header[0] = 2;
- header[1] = rows;
- header[2] = cols;
- copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24, dtype="float64", \
- order='C'), m);
- return ba;
- else:
- print type(m);
- print m.dtype;
- print m.ndim;
- raise TypeError("serialize_double_matrix called on a non-double-matrix");
+def _deserialize_double_vector(ba):
+ if (type(ba) == bytearray and len(ba) >= 16 and (len(ba) & 7 == 0)):
+ header = ndarray(shape=[2], buffer=ba, dtype="int64")
+ if (header[0] != 1):
+ raise TypeError("_deserialize_double_vector called on bytearray "
+ "with wrong magic")
+ length = header[1]
+ if (len(ba) != 8*length + 16):
+ raise TypeError("_deserialize_double_vector called on bytearray "
+ "with wrong length")
+ return _deserialize_byte_array([length], ba, 16)
+ else:
+ raise TypeError("_deserialize_double_vector called on a non-bytearray")
-def deserialize_double_matrix(ba):
- """Implementation detail. Do not use directly."""
- if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)):
- header = ndarray(shape=[3], buffer=ba, dtype="int64");
- if (header[0] != 2):
- raise TypeError("deserialize_double_matrix called on bytearray with " \
- "wrong magic");
- rows = header[1];
- cols = header[2];
- if (len(ba) != 8*rows*cols + 24):
- raise TypeError("deserialize_double_matrix called on bytearray with " \
- "wrong length");
- return deserialize_byte_array([rows, cols], ba, 24);
- else:
- raise TypeError("deserialize_double_matrix called on a non-bytearray");
+def _serialize_double_matrix(m):
+ if (type(m) == ndarray and m.dtype == float64 and m.ndim == 2):
+ rows = m.shape[0]
+ cols = m.shape[1]
+ ba = bytearray(24 + 8 * rows * cols)
+ header = ndarray(shape=[3], buffer=ba, dtype="int64")
+ header[0] = 2
+ header[1] = rows
+ header[2] = cols
+ copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24,
+ dtype="float64", order='C'), m)
+ return ba
+ else:
+ raise TypeError("_serialize_double_matrix called on a "
+ "non-double-matrix")
-class LinearRegressionModel:
- _coeff = None;
- _intercept = None;
- def __init__(self, coeff, intercept):
- self._coeff = coeff;
- self._intercept = intercept;
- def predict(self, x):
- if (type(x) == ndarray):
- if (x.ndim == 1):
- return dot(_coeff, x) - _intercept;
- else:
- raise RuntimeError("Bulk predict not yet supported.");
- elif (type(x) == RDD):
- raise RuntimeError("Bulk predict not yet supported.");
+def _deserialize_double_matrix(ba):
+ if (type(ba) == bytearray and len(ba) >= 24 and (len(ba) & 7 == 0)):
+ header = ndarray(shape=[3], buffer=ba, dtype="int64")
+ if (header[0] != 2):
+ raise TypeError("_deserialize_double_matrix called on bytearray "
+ "with wrong magic")
+ rows = header[1]
+ cols = header[2]
+ if (len(ba) != 8*rows*cols + 24):
+ raise TypeError("_deserialize_double_matrix called on bytearray "
+ "with wrong length")
+ return _deserialize_byte_array([rows, cols], ba, 24)
else:
- raise TypeError("Bad type argument to LinearRegressionModel::predict");
+ raise TypeError("_deserialize_double_matrix called on a non-bytearray")
+
+class LinearRegressionModel(object):
+ def __init__(self, coeff, intercept):
+ self._coeff = coeff
+ self._intercept = intercept
+
+ def predict(self, x):
+ if (type(x) == ndarray):
+ if (x.ndim == 1):
+ return dot(_coeff, x) - _intercept
+ else:
+ raise RuntimeError("Bulk predict not yet supported.")
+ elif (type(x) == RDD):
+ raise RuntimeError("Bulk predict not yet supported.")
+ else:
+ raise TypeError("Bad type argument to "
+ "LinearRegressionModel::predict")
-def train_linear_regression_model(sc, data):
- """Train a linear regression model on the given data."""
- dataBytes = data.map(serialize_double_vector);
- sc.serializer = NoOpSerializer();
- dataBytes.cache();
- api = sc._jvm.PythonMLLibAPI();
- ans = api.trainLinearRegressionModel(dataBytes._jrdd);
- if (len(ans) != 2 or type(ans[0]) != bytearray or type(ans[1]) != float):
- raise RuntimeError("train_linear_regression_model received garbage " \
- "from JVM");
- return LinearRegressionModel(deserialize_double_vector(ans[0]), ans[1]);
+ @classmethod
+ def train(cls, sc, data):
+ """Train a linear regression model on the given data."""
+ dataBytes = data.map(_serialize_double_vector)
+ dataBytes._bypass_serializer = True
+ dataBytes.cache()
+ api = sc._jvm.PythonMLLibAPI()
+ ans = api.trainLinearRegressionModel(dataBytes._jrdd)
+ if (len(ans) != 2 or type(ans[0]) != bytearray
+ or type(ans[1]) != float):
+ raise RuntimeError("train_linear_regression_model received "
+ "garbage from JVM")
+ return LinearRegressionModel(_deserialize_double_vector(ans[0]), ans[1])