From 076fc1622190d342e20592c00ca19f8c0a56997f Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Sat, 21 Dec 2013 14:54:01 -0500 Subject: Python stubs for ALSModel. --- python/pyspark/__init__.py | 5 ++-- python/pyspark/mllib.py | 59 +++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 8 deletions(-) (limited to 'python') diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py index 8b5bb79a18..3d73d95909 100644 --- a/python/pyspark/__init__.py +++ b/python/pyspark/__init__.py @@ -43,9 +43,10 @@ from pyspark.rdd import RDD from pyspark.files import SparkFiles from pyspark.storagelevel import StorageLevel from pyspark.mllib import LinearRegressionModel, LassoModel, \ - RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel + RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel, \ + ALSModel __all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel", "LinearRegressionModel", "LassoModel", "RidgeRegressionModel", - "LogisticRegressionModel", "SVMModel", "KMeansModel"]; + "LogisticRegressionModel", "SVMModel", "KMeansModel", "ALSModel"]; diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py index 8848284a5e..22187eb4dd 100644 --- a/python/pyspark/mllib.py +++ b/python/pyspark/mllib.py @@ -164,14 +164,17 @@ class LinearRegressionModelBase(LinearModel): _linear_predictor_typecheck(x, self._coeff) return dot(self._coeff, x) + self._intercept -# Map a pickled Python RDD of numpy double vectors to a Java RDD of -# _serialized_double_vectors -def _get_unmangled_double_vector_rdd(data): - dataBytes = data.map(_serialize_double_vector) +def _get_unmangled_rdd(data, serializer): + dataBytes = data.map(serializer) dataBytes._bypass_serializer = True dataBytes.cache() return dataBytes +# Map a pickled Python RDD of numpy double vectors to a Java RDD of +# _serialized_double_vectors +def _get_unmangled_double_vector_rdd(data): + return _get_unmangled_rdd(data, _serialize_double_vector) + # If we weren't given initial weights, take a zero vector of the appropriate # length. def _get_initial_weights(initial_weights, data): @@ -317,7 +320,7 @@ class KMeansModel(object): return best @classmethod - def train(cls, sc, data, k, maxIterations = 100, runs = 1, + def train(cls, sc, data, k, maxIterations=100, runs=1, initialization_mode="k-means||"): """Train a k-means clustering model.""" dataBytes = _get_unmangled_double_vector_rdd(data) @@ -330,12 +333,56 @@ class KMeansModel(object): + type(ans[0]) + " which is not bytearray") return KMeansModel(_deserialize_double_matrix(ans[0])) +def _serialize_rating(r): + ba = bytearray(16) + intpart = ndarray(shape=[2], buffer=ba, dtype=int32) + doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8) + intpart[0], intpart[1], doublepart[0] = r + return ba + +class ALSModel(object): + """A matrix factorisation model trained by regularized alternating + least-squares. + + >>> r1 = (1, 1, 1.0) + >>> r2 = (1, 2, 2.0) + >>> r3 = (2, 1, 2.0) + >>> ratings = sc.parallelize([r1, r2, r3]) + >>> model = ALSModel.trainImplicit(sc, ratings, 1) + >>> model.predict(2,2) is not None + True + """ + + def __init__(self, sc, java_model): + self._context = sc + self._java_model = java_model + + #def __del__(self): + #self._gateway.detach(self._java_model) + + def predict(self, user, product): + return self._java_model.predict(user, product) + + @classmethod + def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks) + return ALSModel(sc, mod) + + @classmethod + def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks, alpha) + return ALSModel(sc, mod) + def _test(): import doctest globs = globals().copy() globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) (failure_count, test_count) = doctest.testmod(globs=globs, - optionflags=doctest.ELLIPSIS) + optionflags=doctest.ELLIPSIS) globs['sc'].stop() print failure_count,"failures among",test_count,"tests" if failure_count: -- cgit v1.2.3