aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-21 14:54:01 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-21 14:54:01 -0500
commit076fc1622190d342e20592c00ca19f8c0a56997f (patch)
tree36fc493073a72ed856ed8e43213ed615c57084cc /python
parentb454fdc2ebc495e4d13162f4bea8cf3e33909463 (diff)
downloadspark-076fc1622190d342e20592c00ca19f8c0a56997f.tar.gz
spark-076fc1622190d342e20592c00ca19f8c0a56997f.tar.bz2
spark-076fc1622190d342e20592c00ca19f8c0a56997f.zip
Python stubs for ALSModel.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/__init__.py5
-rw-r--r--python/pyspark/mllib.py59
2 files changed, 56 insertions, 8 deletions
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index 8b5bb79a18..3d73d95909 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -43,9 +43,10 @@ from pyspark.rdd import RDD
from pyspark.files import SparkFiles
from pyspark.storagelevel import StorageLevel
from pyspark.mllib import LinearRegressionModel, LassoModel, \
- RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel
+ RidgeRegressionModel, LogisticRegressionModel, SVMModel, KMeansModel, \
+ ALSModel
__all__ = ["SparkContext", "RDD", "SparkFiles", "StorageLevel",
"LinearRegressionModel", "LassoModel", "RidgeRegressionModel",
- "LogisticRegressionModel", "SVMModel", "KMeansModel"];
+ "LogisticRegressionModel", "SVMModel", "KMeansModel", "ALSModel"];
diff --git a/python/pyspark/mllib.py b/python/pyspark/mllib.py
index 8848284a5e..22187eb4dd 100644
--- a/python/pyspark/mllib.py
+++ b/python/pyspark/mllib.py
@@ -164,14 +164,17 @@ class LinearRegressionModelBase(LinearModel):
_linear_predictor_typecheck(x, self._coeff)
return dot(self._coeff, x) + self._intercept
-# Map a pickled Python RDD of numpy double vectors to a Java RDD of
-# _serialized_double_vectors
-def _get_unmangled_double_vector_rdd(data):
- dataBytes = data.map(_serialize_double_vector)
+def _get_unmangled_rdd(data, serializer):
+ dataBytes = data.map(serializer)
dataBytes._bypass_serializer = True
dataBytes.cache()
return dataBytes
+# Map a pickled Python RDD of numpy double vectors to a Java RDD of
+# _serialized_double_vectors
+def _get_unmangled_double_vector_rdd(data):
+ return _get_unmangled_rdd(data, _serialize_double_vector)
+
# If we weren't given initial weights, take a zero vector of the appropriate
# length.
def _get_initial_weights(initial_weights, data):
@@ -317,7 +320,7 @@ class KMeansModel(object):
return best
@classmethod
- def train(cls, sc, data, k, maxIterations = 100, runs = 1,
+ def train(cls, sc, data, k, maxIterations=100, runs=1,
initialization_mode="k-means||"):
"""Train a k-means clustering model."""
dataBytes = _get_unmangled_double_vector_rdd(data)
@@ -330,12 +333,56 @@ class KMeansModel(object):
+ type(ans[0]) + " which is not bytearray")
return KMeansModel(_deserialize_double_matrix(ans[0]))
+def _serialize_rating(r):
+ ba = bytearray(16)
+ intpart = ndarray(shape=[2], buffer=ba, dtype=int32)
+ doublepart = ndarray(shape=[1], buffer=ba, dtype=float64, offset=8)
+ intpart[0], intpart[1], doublepart[0] = r
+ return ba
+
+class ALSModel(object):
+ """A matrix factorisation model trained by regularized alternating
+ least-squares.
+
+ >>> r1 = (1, 1, 1.0)
+ >>> r2 = (1, 2, 2.0)
+ >>> r3 = (2, 1, 2.0)
+ >>> ratings = sc.parallelize([r1, r2, r3])
+ >>> model = ALSModel.trainImplicit(sc, ratings, 1)
+ >>> model.predict(2,2) is not None
+ True
+ """
+
+ def __init__(self, sc, java_model):
+ self._context = sc
+ self._java_model = java_model
+
+ #def __del__(self):
+ #self._gateway.detach(self._java_model)
+
+ def predict(self, user, product):
+ return self._java_model.predict(user, product)
+
+ @classmethod
+ def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+ ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+ mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd,
+ rank, iterations, lambda_, blocks)
+ return ALSModel(sc, mod)
+
+ @classmethod
+ def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+ ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating)
+ mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd,
+ rank, iterations, lambda_, blocks, alpha)
+ return ALSModel(sc, mod)
+
def _test():
import doctest
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs,
- optionflags=doctest.ELLIPSIS)
+ optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
print failure_count,"failures among",test_count,"tests"
if failure_count: