aboutsummaryrefslogblamecommitdiff
path: root/pyspark/pyspark/rdd.py
blob: 708ea6eb559c5f062721d7ffbc2992c015b15586 (plain) (tree)
1
2
3
4
5
6
7
8
9
                                               
                                   
                                          
         


                                  
 
                               
                                                        


                                                               
                                                             
 


                  
                                  


                              
 




                             
                                                  

                                                              
 
                         



                                                                
                                                                     

                                                        
                                                                         
                                       
 








                                                 





                                                      
                                                       
                                       
 




                                                                     
                                            




                                                                 
                                  


                                                                
                                                    






                                              
                                                           
 








                                              
 
                
 






                                                 
                                       






                                                
                                                               
 


                                                    

                                                           

                                        
                                                                  
 













                                                                                
 







                                                      



                                                                                
 
                        
           

                                                       
          

                                                                                       
           








                                     
                                                 
                              
 
















                                                                              
                                                 
                                          









                                             

                           

                                                                             

                                     



                                     
                              


                                         
                                                                   





                                             
                                                                   
                                         





                                             
                                                                        
 





                                                                          
 
                    
 








                                                               


                                                


                                                                



                                                                    
















                                                                    
 


                                                                
                                            


                                                      


                                          


                                                    
                                      
           
                                                  


                                                   

                                                    


                                                
                                                             


                                                    

                                                    


                                                 

                                                              
                                                    





                                                                               

                                                   
                                      
                                       
                                   



                                                               


                                                                             
                                                                                


                                                                             
                                               
                                                    

                                                               
 
                                                                      
                                     
           
                                                              




                                                           

                                                   


















                                                                  


                                         
                                                              
















                                                                            
                               
                                                           
                                        
 
                           
                                                

                                                                  





                                                    

                                                    
                                          

                                             
                                                     





                                                                               
 
                        














                                                                       

                                                                 
                                 


                                                


                                                                    
             
                            

                                                              



                              
                                       


                    








                                                                            


                                                              
                                                                  
                                                                               

                                               


                             


                                            
                            
                                                        





                                
from base64 import standard_b64encode as b64enc
from collections import defaultdict
from itertools import chain, ifilter, imap
import os
import shlex
from subprocess import Popen, PIPE
from threading import Thread

from pyspark import cloudpickle
from pyspark.serializers import dump_pickle, load_pickle
from pyspark.join import python_join, python_left_outer_join, \
    python_right_outer_join, python_cogroup

from py4j.java_collections import ListConverter, MapConverter


class RDD(object):

    def __init__(self, jrdd, ctx):
        self._jrdd = jrdd
        self.is_cached = False
        self.ctx = ctx

    def cache(self):
        self.is_cached = True
        self._jrdd.cache()
        return self

    def map(self, f, preservesPartitioning=False):
        def func(iterator): return imap(f, iterator)
        return PipelinedRDD(self, func, preservesPartitioning)

    def flatMap(self, f):
        """
        >>> rdd = sc.parallelize([2, 3, 4])
        >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
        [1, 1, 1, 2, 2, 3]
        >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
        [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
        """
        def func(iterator): return chain.from_iterable(imap(f, iterator))
        return self.mapPartitions(func)

    def mapPartitions(self, f):
        """
        >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
        >>> def f(iterator): yield sum(iterator)
        >>> rdd.mapPartitions(f).collect()
        [3, 7]
        """
        return PipelinedRDD(self, f)

    def filter(self, f):
        """
        >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
        >>> rdd.filter(lambda x: x % 2 == 0).collect()
        [2, 4]
        """
        def func(iterator): return ifilter(f, iterator)
        return self.mapPartitions(func)

    def distinct(self):
        """
        >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
        [1, 2, 3]
        """
        return self.map(lambda x: (x, "")) \
                   .reduceByKey(lambda x, _: x) \
                   .map(lambda (x, _): x)

    def sample(self, withReplacement, fraction, seed):
        jrdd = self._jrdd.sample(withReplacement, fraction, seed)
        return RDD(jrdd, self.ctx)

    def takeSample(self, withReplacement, num, seed):
        vals = self._jrdd.takeSample(withReplacement, num, seed)
        return [load_pickle(bytes(x)) for x in vals]

    def union(self, other):
        """
        >>> rdd = sc.parallelize([1, 1, 2, 3])
        >>> rdd.union(rdd).collect()
        [1, 1, 2, 3, 1, 1, 2, 3]
        """
        return RDD(self._jrdd.union(other._jrdd), self.ctx)

    def __add__(self, other):
        """
        >>> rdd = sc.parallelize([1, 1, 2, 3])
        >>> (rdd + rdd).collect()
        [1, 1, 2, 3, 1, 1, 2, 3]
        """
        if not isinstance(other, RDD):
            raise TypeError
        return self.union(other)

    # TODO: sort

    def glom(self):
        """
        >>> rdd = sc.parallelize([1, 2, 3, 4], 2)
        >>> rdd.glom().first()
        [1, 2]
        """
        def func(iterator): yield list(iterator)
        return self.mapPartitions(func)

    def cartesian(self, other):
        """
        >>> rdd = sc.parallelize([1, 2])
        >>> sorted(rdd.cartesian(rdd).collect())
        [(1, 1), (1, 2), (2, 1), (2, 2)]
        """
        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)

    def groupBy(self, f, numSplits=None):
        """
        >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
        >>> result = rdd.groupBy(lambda x: x % 2).collect()
        >>> sorted([(x, sorted(y)) for (x, y) in result])
        [(0, [2, 8]), (1, [1, 1, 3, 5])]
        """
        return self.map(lambda x: (f(x), x)).groupByKey(numSplits)

    def pipe(self, command, env={}):
        """
        >>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
        ['1', '2', '3']
        """
        def func(iterator):
            pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
            def pipe_objs(out):
                for obj in iterator:
                    out.write(str(obj).rstrip('\n') + '\n')
                out.close()
            Thread(target=pipe_objs, args=[pipe.stdin]).start()
            return (x.rstrip('\n') for x in pipe.stdout)
        return self.mapPartitions(func)

    def foreach(self, f):
        """
        >>> def f(x): print x
        >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
        """
        self.map(f).collect()  # Force evaluation

    def collect(self):
        def asList(iterator):
            yield list(iterator)
        pickles = self.mapPartitions(asList)._jrdd.rdd().collect()
        return list(chain.from_iterable(load_pickle(bytes(p)) for p in pickles))

    def reduce(self, f):
        """
        >>> from operator import add
        >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
        15
        >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
        10
        """
        def func(iterator):
            acc = None
            for obj in iterator:
                if acc is None:
                    acc = obj
                else:
                    acc = f(obj, acc)
            if acc is not None:
                yield acc
        vals = self.mapPartitions(func).collect()
        return reduce(f, vals)

    def fold(self, zeroValue, op):
        """
        Aggregate the elements of each partition, and then the results for all
        the partitions, using a given associative function and a neutral "zero
        value." The function op(t1, t2) is allowed to modify t1 and return it
        as its result value to avoid object allocation; however, it should not
        modify t2.

        >>> from operator import add
        >>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
        15
        """
        def func(iterator):
            acc = zeroValue
            for obj in iterator:
                acc = op(obj, acc)
            yield acc
        vals = self.mapPartitions(func).collect()
        return reduce(op, vals, zeroValue)

    # TODO: aggregate

    def count(self):
        """
        >>> sc.parallelize([2, 3, 4]).count()
        3L
        """
        return self._jrdd.count()

    def countByValue(self):
        """
        >>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
        [(1, 2), (2, 3)]
        """
        def countPartition(iterator):
            counts = defaultdict(int)
            for obj in iterator:
                counts[obj] += 1
            yield counts
        def mergeMaps(m1, m2):
            for (k, v) in m2.iteritems():
                m1[k] += v
            return m1
        return self.mapPartitions(countPartition).reduce(mergeMaps)

    def take(self, num):
        """
        >>> sc.parallelize([2, 3, 4]).take(2)
        [2, 3]
        """
        pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
        return load_pickle(bytes(pickle))

    def first(self):
        """
        >>> sc.parallelize([2, 3, 4]).first()
        2
        """
        return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))

    def saveAsTextFile(self, path):
        def func(iterator):
            return (str(x).encode("utf-8") for x in iterator)
        keyed = PipelinedRDD(self, func)
        keyed._bypass_serializer = True
        keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)

    # Pair functions

    def collectAsMap(self):
        """
        >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
        >>> m[1]
        2
        >>> m[3]
        4
        """
        return dict(self.collect())

    def reduceByKey(self, func, numSplits=None):
        """
        >>> from operator import add
        >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(rdd.reduceByKey(add).collect())
        [('a', 2), ('b', 1)]
        """
        return self.combineByKey(lambda x: x, func, func, numSplits)

    def reduceByKeyLocally(self, func):
        """
        >>> from operator import add
        >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(rdd.reduceByKeyLocally(add).items())
        [('a', 2), ('b', 1)]
        """
        def reducePartition(iterator):
            m = {}
            for (k, v) in iterator:
                m[k] = v if k not in m else func(m[k], v)
            yield m
        def mergeMaps(m1, m2):
            for (k, v) in m2.iteritems():
                m1[k] = v if k not in m1 else func(m1[k], v)
            return m1
        return self.mapPartitions(reducePartition).reduce(mergeMaps)

    def countByKey(self):
        """
        >>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(rdd.countByKey().items())
        [('a', 2), ('b', 1)]
        """
        return self.map(lambda x: x[0]).countByValue()

    def join(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2), ("a", 3)])
        >>> sorted(x.join(y).collect())
        [('a', (1, 2)), ('a', (1, 3))]
        """
        return python_join(self, other, numSplits)

    def leftOuterJoin(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2)])
        >>> sorted(x.leftOuterJoin(y).collect())
        [('a', (1, 2)), ('b', (4, None))]
        """
        return python_left_outer_join(self, other, numSplits)

    def rightOuterJoin(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2)])
        >>> sorted(y.rightOuterJoin(x).collect())
        [('a', (2, 1)), ('b', (None, 4))]
        """
        return python_right_outer_join(self, other, numSplits)

    def partitionBy(self, numSplits, hashFunc=hash):
        """
        >>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
        >>> sets = pairs.partitionBy(2).glom().collect()
        >>> set(sets[0]).intersection(set(sets[1]))
        set([])
        """
        if numSplits is None:
            numSplits = self.ctx.defaultParallelism
        def add_shuffle_key(iterator):
            buckets = defaultdict(list)
            for (k, v) in iterator:
                buckets[hashFunc(k) % numSplits].append((k, v))
            for (split, items) in buckets.iteritems():
                yield str(split)
                yield dump_pickle(items)
        keyed = PipelinedRDD(self, add_shuffle_key)
        keyed._bypass_serializer = True
        pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
        partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
        # Transferring O(n) objects to Java is too expensive.  Instead, we'll
        # form the hash buckets in Python, transferring O(numSplits) objects
        # to Java.  Each object is a (splitNumber, [objects]) pair.
        jrdd = pairRDD.partitionBy(partitioner)
        jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
        # Flatten the resulting RDD:
        return RDD(jrdd, self.ctx).flatMap(lambda items: items)

    def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
                     numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> def f(x): return x
        >>> def add(a, b): return a + str(b)
        >>> sorted(x.combineByKey(str, add, add).collect())
        [('a', '11'), ('b', '1')]
        """
        if numSplits is None:
            numSplits = self.ctx.defaultParallelism
        def combineLocally(iterator):
            combiners = {}
            for (k, v) in iterator:
                if k not in combiners:
                    combiners[k] = createCombiner(v)
                else:
                    combiners[k] = mergeValue(combiners[k], v)
            return combiners.iteritems()
        locally_combined = self.mapPartitions(combineLocally)
        shuffled = locally_combined.partitionBy(numSplits)
        def _mergeCombiners(iterator):
            combiners = {}
            for (k, v) in iterator:
                if not k in combiners:
                    combiners[k] = v
                else:
                    combiners[k] = mergeCombiners(combiners[k], v)
            return combiners.iteritems()
        return shuffled.mapPartitions(_mergeCombiners)

    def groupByKey(self, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(x.groupByKey().collect())
        [('a', [1, 1]), ('b', [1])]
        """

        def createCombiner(x):
            return [x]

        def mergeValue(xs, x):
            xs.append(x)
            return xs

        def mergeCombiners(a, b):
            return a + b

        return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
                numSplits)

    def flatMapValues(self, f):
        flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
        return self.flatMap(flat_map_fn)

    def mapValues(self, f):
        map_values_fn = lambda (k, v): (k, f(v))
        return self.map(map_values_fn, preservesPartitioning=True)

    # TODO: support varargs cogroup of several RDDs.
    def groupWith(self, other):
        return self.cogroup(other)

    def cogroup(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2)])
        >>> sorted(x.cogroup(y).collect())
        [('a', ([1], [2])), ('b', ([4], []))]
        """
        return python_cogroup(self, other, numSplits)

    # TODO: `lookup` is disabled because we can't make direct comparisons based
    # on the key; we need to compare the hash of the key to the hash of the
    # keys in the pairs.  This could be an expensive operation, since those
    # hashes aren't retained.


class PipelinedRDD(RDD):
    """
    Pipelined maps:
    >>> rdd = sc.parallelize([1, 2, 3, 4])
    >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
    [4, 8, 12, 16]
    >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
    [4, 8, 12, 16]

    Pipelined reduces:
    >>> from operator import add
    >>> rdd.map(lambda x: 2 * x).reduce(add)
    20
    >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
    20
    """
    def __init__(self, prev, func, preservesPartitioning=False):
        if isinstance(prev, PipelinedRDD) and not prev.is_cached:
            prev_func = prev.func
            def pipeline_func(iterator):
                return func(prev_func(iterator))
            self.func = pipeline_func
            self.preservesPartitioning = \
                prev.preservesPartitioning and preservesPartitioning
            self._prev_jrdd = prev._prev_jrdd
        else:
            self.func = func
            self.preservesPartitioning = preservesPartitioning
            self._prev_jrdd = prev._jrdd
        self.is_cached = False
        self.ctx = prev.ctx
        self.prev = prev
        self._jrdd_val = None
        self._bypass_serializer = False

    @property
    def _jrdd(self):
        if self._jrdd_val:
            return self._jrdd_val
        funcs = [self.func, self._bypass_serializer]
        pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs)
        broadcast_vars = ListConverter().convert(
            [x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
            self.ctx.gateway._gateway_client)
        self.ctx._pickled_broadcast_vars.clear()
        class_manifest = self._prev_jrdd.classManifest()
        env = MapConverter().convert(
            {'PYTHONPATH' : os.environ.get("PYTHONPATH", "")},
            self.ctx.gateway._gateway_client)
        python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
            pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
            broadcast_vars, class_manifest)
        self._jrdd_val = python_rdd.asJavaRDD()
        return self._jrdd_val


def _test():
    import doctest
    from pyspark.context import SparkContext
    globs = globals().copy()
    globs['sc'] = SparkContext('local[4]', 'PythonTest')
    doctest.testmod(globs=globs)
    globs['sc'].stop()


if __name__ == "__main__":
    _test()