diff options
author | Matei Zaharia <matei@databricks.com> | 2013-12-29 15:08:08 -0500 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2013-12-29 15:08:08 -0500 |
commit | b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a (patch) | |
tree | 486226041e35962c1543902d8ffc10a81f4223a5 /python/pyspark/mllib/recommendation.py | |
parent | 58c6fa2041b99160f254b17c2b71de9d82c53f8c (diff) | |
parent | ad3dfd153196497fefe6c1925e0a495a4373f1c5 (diff) | |
download | spark-b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a.tar.gz spark-b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a.tar.bz2 spark-b4ceed40d6e511a1d475b3f4fbcdd2ad24c02b5a.zip |
Merge remote-tracking branch 'origin/master' into conf2
Conflicts:
core/src/main/scala/org/apache/spark/SparkContext.scala
core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala
core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala
new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala
streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala
streaming/src/main/scala/org/apache/spark/streaming/scheduler/JobGenerator.scala
streaming/src/test/scala/org/apache/spark/streaming/BasicOperationsSuite.scala
streaming/src/test/scala/org/apache/spark/streaming/CheckpointSuite.scala
streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala
streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala
streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
Diffstat (limited to 'python/pyspark/mllib/recommendation.py')
-rw-r--r-- | python/pyspark/mllib/recommendation.py | 74 |
1 files changed, 74 insertions, 0 deletions
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py new file mode 100644 index 0000000000..14d06cba21 --- /dev/null +++ b/python/pyspark/mllib/recommendation.py @@ -0,0 +1,74 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from pyspark import SparkContext +from pyspark.mllib._common import \ + _get_unmangled_rdd, _get_unmangled_double_vector_rdd, \ + _serialize_double_matrix, _deserialize_double_matrix, \ + _serialize_double_vector, _deserialize_double_vector, \ + _get_initial_weights, _serialize_rating, _regression_train_wrapper + +class MatrixFactorizationModel(object): + """A matrix factorisation model trained by regularized alternating + least-squares. + + >>> r1 = (1, 1, 1.0) + >>> r2 = (1, 2, 2.0) + >>> r3 = (2, 1, 2.0) + >>> ratings = sc.parallelize([r1, r2, r3]) + >>> model = ALS.trainImplicit(sc, ratings, 1) + >>> model.predict(2,2) is not None + True + """ + + def __init__(self, sc, java_model): + self._context = sc + self._java_model = java_model + + def __del__(self): + self._context._gateway.detach(self._java_model) + + def predict(self, user, product): + return self._java_model.predict(user, product) + +class ALS(object): + @classmethod + def train(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks) + return MatrixFactorizationModel(sc, mod) + + @classmethod + def trainImplicit(cls, sc, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01): + ratingBytes = _get_unmangled_rdd(ratings, _serialize_rating) + mod = sc._jvm.PythonMLLibAPI().trainImplicitALSModel(ratingBytes._jrdd, + rank, iterations, lambda_, blocks, alpha) + return MatrixFactorizationModel(sc, mod) + +def _test(): + import doctest + globs = globals().copy() + globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2) + (failure_count, test_count) = doctest.testmod(globs=globs, + optionflags=doctest.ELLIPSIS) + globs['sc'].stop() + if failure_count: + exit(-1) + +if __name__ == "__main__": + _test() |