diff options
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 10 |
1 files changed, 6 insertions, 4 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py index 0eeb5bb66b..f4a83f0209 100644 --- a/python/pyspark/mllib/recommendation.py +++ b/python/pyspark/mllib/recommendation.py @@ -32,11 +32,11 @@ class MatrixFactorizationModel(object): >>> r2 = (1, 2, 2.0) >>> r3 = (2, 1, 2.0) >>> ratings = sc.parallelize([r1, r2, r3]) - >>> model = ALS.trainImplicit(sc, ratings, 1) + >>> model = ALS.trainImplicit(ratings, 1) >>> model.predict(2,2) is not None True >>> testset = sc.parallelize([(1, 2), (1, 1)]) - >>> model.predictAll(testset).count == 2 + >>> model.predictAll(testset).count() == 2 True """ @@ -57,14 +57,16 @@ class MatrixFactorizationModel(object): class ALS(object): @classmethod - def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + sc = ratings.context ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, rank, iterations, lambda_, blocks) return MatrixFactorizationModel(sc, mod) @classmethod - def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + sc = ratings.context ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, rank, iterations, lambda_, blocks, alpha) |