aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/fpm.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-04-22 17:22:26 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-22 17:22:26 -0700
commitf4f39981f4f5e88c30eec7d0b107e2c3cdc268c9 (patch)
treed26235eae02cab27c9cdd537d53d41b4978fecfd /python/pyspark/mllib/fpm.py
parentbaf865ddc2cff9b99d6aeab9861e030da511257f (diff)
downloadspark-f4f39981f4f5e88c30eec7d0b107e2c3cdc268c9.tar.gz
spark-f4f39981f4f5e88c30eec7d0b107e2c3cdc268c9.tar.bz2
spark-f4f39981f4f5e88c30eec7d0b107e2c3cdc268c9.zip
[SPARK-6827] [MLLIB] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API
Make PySpark ```FPGrowthModel.freqItemsets``` consistent with Java/Scala API like ```MatrixFactorizationModel.userFeatures``` It return a RDD with each tuple is composed of an array and a long value. I think it's difficult to implement namedtuples to wrap the output because items of freqItemsets can be any type with arbitrary length which is tedious to impelement corresponding SerDe function. Author: Yanbo Liang <ybliang8@gmail.com> Closes #5614 from yanboliang/spark-6827 and squashes the following commits: da8c404 [Yanbo Liang] use namedtuple 5532e78 [Yanbo Liang] Wrap FPGrowthModel.freqItemsets and make it consistent with Java API
Diffstat (limited to 'python/pyspark/mllib/fpm.py')
-rw-r--r--python/pyspark/mllib/fpm.py15
1 files changed, 12 insertions, 3 deletions
diff --git a/python/pyspark/mllib/fpm.py b/python/pyspark/mllib/fpm.py
index 628ccc01cf..d8df02bdba 100644
--- a/python/pyspark/mllib/fpm.py
+++ b/python/pyspark/mllib/fpm.py
@@ -15,6 +15,10 @@
# limitations under the License.
#
+import numpy
+from numpy import array
+from collections import namedtuple
+
from pyspark import SparkContext
from pyspark.rdd import ignore_unicode_prefix
from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, inherit_doc
@@ -36,14 +40,14 @@ class FPGrowthModel(JavaModelWrapper):
>>> rdd = sc.parallelize(data, 2)
>>> model = FPGrowth.train(rdd, 0.6, 2)
>>> sorted(model.freqItemsets().collect())
- [([u'a'], 4), ([u'c'], 3), ([u'c', u'a'], 3)]
+ [FreqItemset(items=[u'a'], freq=4), FreqItemset(items=[u'c'], freq=3), ...
"""
def freqItemsets(self):
"""
- Get the frequent itemsets of this model
+ Returns the frequent itemsets of this model.
"""
- return self.call("getFreqItemsets")
+ return self.call("getFreqItemsets").map(lambda x: (FPGrowth.FreqItemset(x[0], x[1])))
class FPGrowth(object):
@@ -67,6 +71,11 @@ class FPGrowth(object):
model = callMLlibFunc("trainFPGrowthModel", data, float(minSupport), int(numPartitions))
return FPGrowthModel(model)
+ class FreqItemset(namedtuple("FreqItemset", ["items", "freq"])):
+ """
+ Represents an (items, freq) tuple.
+ """
+
def _test():
import doctest