aboutsummaryrefslogblamecommitdiff
path: root/pyspark/pyspark/rdd.py
blob: ff9c4830322e13d0141e35f4cc05f42a34fd20ae (plain) (tree)
1
2
3
4
5
6
7
8
9
10
                                               
 

                                                





                                                               
                                  


                              














                                                                  

                                                        
 
                         



                                                                
                                                                     

                                                        
                                                                                 






                                                      

                                                      







                                                                           




                                                                     
                                            




                                                                 
                                  


                                                                
                                                               






                                              
                                                           












                                                
                                                               




                                                    

                                                           

                                        
                                                                  












                                                      

                                                                   
 
                        
           

                                                       
          

                                                                                       
           

                                                                                          


















                                             

                                                                   





                                             
                                                                                   




                            
                    
 








                                                               


                                                
                                                              












                                                                    


                                                    
                                      
           
                                                  


                                                   

                                                    


                                                
                                                             


                                                    

                                                    


                                                 



                                                              
                                                

                                                   
                                                                            





                                                                     


                                  

                                                                      
                                     
           
                                                              




                                                           

                                                   



                                                                


                                         
                                                              
















                                                                            
                               
                                                           
                                        
 
                           
                                                

                                                                  





                                                    

                                                    


                                             
                                                     








                                                                               

















                                                                               
                                 






























                                                                          


                                                                    
                                    
             








                                             

                                                              
                                     



                              
                              



                              

                                                                     







                                                                              


                                            
                            
                                                     





                                
from base64 import standard_b64encode as b64enc

from pyspark import cloudpickle
from pyspark.serializers import PickleSerializer
from pyspark.join import python_join, python_left_outer_join, \
    python_right_outer_join, python_cogroup


class RDD(object):

    def __init__(self, jrdd, ctx):
        self._jrdd = jrdd
        self.is_cached = False
        self.ctx = ctx

    @classmethod
    def _get_pipe_command(cls, command, functions):
        if functions and not isinstance(functions, (list, tuple)):
            functions = [functions]
        worker_args = [command]
        for f in functions:
            worker_args.append(b64enc(cloudpickle.dumps(f)))
        return " ".join(worker_args)

    def cache(self):
        self.is_cached = True
        self._jrdd.cache()
        return self

    def map(self, f, preservesPartitioning=False):
        return MappedRDD(self, f, preservesPartitioning)

    def flatMap(self, f):
        """
        >>> rdd = sc.parallelize([2, 3, 4])
        >>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
        [1, 1, 1, 2, 2, 3]
        >>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
        [(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
        """
        return MappedRDD(self, f, preservesPartitioning=False, command='flatmap')

    def filter(self, f):
        """
        >>> rdd = sc.parallelize([1, 2, 3, 4, 5])
        >>> rdd.filter(lambda x: x % 2 == 0).collect()
        [2, 4]
        """
        def filter_func(x): return x if f(x) else None
        return RDD(self._pipe(filter_func), self.ctx)

    def _pipe(self, functions, command="map"):
        class_manifest = self._jrdd.classManifest()
        pipe_command = RDD._get_pipe_command(command, functions)
        python_rdd = self.ctx.jvm.PythonRDD(self._jrdd.rdd(), pipe_command,
            False, self.ctx.pythonExec, class_manifest)
        return python_rdd.asJavaRDD()

    def distinct(self):
        """
        >>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
        [1, 2, 3]
        """
        return self.map(lambda x: (x, "")) \
                   .reduceByKey(lambda x, _: x) \
                   .map(lambda (x, _): x)

    def sample(self, withReplacement, fraction, seed):
        jrdd = self._jrdd.sample(withReplacement, fraction, seed)
        return RDD(jrdd, self.ctx)

    def takeSample(self, withReplacement, num, seed):
        vals = self._jrdd.takeSample(withReplacement, num, seed)
        return [PickleSerializer.loads(bytes(x)) for x in vals]

    def union(self, other):
        """
        >>> rdd = sc.parallelize([1, 1, 2, 3])
        >>> rdd.union(rdd).collect()
        [1, 1, 2, 3, 1, 1, 2, 3]
        """
        return RDD(self._jrdd.union(other._jrdd), self.ctx)

    # TODO: sort

    # TODO: Overload __add___?

    # TODO: glom

    def cartesian(self, other):
        """
        >>> rdd = sc.parallelize([1, 2])
        >>> sorted(rdd.cartesian(rdd).collect())
        [(1, 1), (1, 2), (2, 1), (2, 2)]
        """
        return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)

    # numsplits
    def groupBy(self, f, numSplits=None):
        """
        >>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
        >>> result = rdd.groupBy(lambda x: x % 2).collect()
        >>> sorted([(x, sorted(y)) for (x, y) in result])
        [(0, [2, 8]), (1, [1, 1, 3, 5])]
        """
        return self.map(lambda x: (f(x), x)).groupByKey(numSplits)

    # TODO: pipe

    # TODO: mapPartitions

    def foreach(self, f):
        """
        >>> def f(x): print x
        >>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
        """
        self.map(f).collect()  # Force evaluation

    def collect(self):
        pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect())
        return PickleSerializer.loads(bytes(pickle))

    def reduce(self, f):
        """
        >>> from operator import add
        >>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
        15
        >>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
        10
        """
        vals = MappedRDD(self, f, command="reduce", preservesPartitioning=False).collect()
        return reduce(f, vals)

    # TODO: fold

    # TODO: aggregate

    def count(self):
        """
        >>> sc.parallelize([2, 3, 4]).count()
        3L
        """
        return self._jrdd.count()

    # TODO: count approx methods

    def take(self, num):
        """
        >>> sc.parallelize([2, 3, 4]).take(2)
        [2, 3]
        """
        pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
        return PickleSerializer.loads(bytes(pickle))

    def first(self):
        """
        >>> sc.parallelize([2, 3, 4]).first()
        2
        """
        return PickleSerializer.loads(bytes(self.ctx.asPickle(self._jrdd.first())))

    # TODO: saveAsTextFile

    # TODO: saveAsObjectFile

    # Pair functions

    def collectAsMap(self):
        """
        >>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
        >>> m[1]
        2
        >>> m[3]
        4
        """
        return dict(self.collect())

    def reduceByKey(self, func, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(x.reduceByKey(lambda a, b: a + b).collect())
        [('a', 2), ('b', 1)]
        """
        return self.combineByKey(lambda x: x, func, func, numSplits)

    # TODO: reduceByKeyLocally()

    # TODO: countByKey()

    # TODO: partitionBy

    def join(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2), ("a", 3)])
        >>> sorted(x.join(y).collect())
        [('a', (1, 2)), ('a', (1, 3))]
        """
        return python_join(self, other, numSplits)

    def leftOuterJoin(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2)])
        >>> sorted(x.leftOuterJoin(y).collect())
        [('a', (1, 2)), ('b', (4, None))]
        """
        return python_left_outer_join(self, other, numSplits)

    def rightOuterJoin(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2)])
        >>> sorted(y.rightOuterJoin(x).collect())
        [('a', (2, 1)), ('b', (None, 4))]
        """
        return python_right_outer_join(self, other, numSplits)

    # TODO: pipelining
    # TODO: optimizations
    def shuffle(self, numSplits, hashFunc=hash):
        if numSplits is None:
            numSplits = self.ctx.defaultParallelism
        pipe_command = RDD._get_pipe_command('shuffle_map_step', [hashFunc])
        class_manifest = self._jrdd.classManifest()
        python_rdd = self.ctx.jvm.PythonPairRDD(self._jrdd.rdd(),
            pipe_command, False, self.ctx.pythonExec, class_manifest)
        partitioner = self.ctx.jvm.spark.HashPartitioner(numSplits)
        jrdd = python_rdd.asJavaPairRDD().partitionBy(partitioner)
        jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
        return RDD(jrdd, self.ctx)



    def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
                     numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> def f(x): return x
        >>> def add(a, b): return a + str(b)
        >>> sorted(x.combineByKey(str, add, add).collect())
        [('a', '11'), ('b', '1')]
        """
        if numSplits is None:
            numSplits = self.ctx.defaultParallelism
        shuffled = self.shuffle(numSplits)
        functions = [createCombiner, mergeValue, mergeCombiners]
        jpairs = shuffled._pipe(functions, "combine_by_key")
        return RDD(jpairs, self.ctx)

    def groupByKey(self, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
        >>> sorted(x.groupByKey().collect())
        [('a', [1, 1]), ('b', [1])]
        """

        def createCombiner(x):
            return [x]

        def mergeValue(xs, x):
            xs.append(x)
            return xs

        def mergeCombiners(a, b):
            return a + b

        return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
                numSplits)

    def flatMapValues(self, f):
        flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
        return self.flatMap(flat_map_fn)

    def mapValues(self, f):
        map_values_fn = lambda (k, v): (k, f(v))
        return self.map(map_values_fn, preservesPartitioning=True)

    # TODO: support varargs cogroup of several RDDs.
    def groupWith(self, other):
        return self.cogroup(other)

    def cogroup(self, other, numSplits=None):
        """
        >>> x = sc.parallelize([("a", 1), ("b", 4)])
        >>> y = sc.parallelize([("a", 2)])
        >>> x.cogroup(y).collect()
        [('a', ([1], [2])), ('b', ([4], []))]
        """
        return python_cogroup(self, other, numSplits)

    # TODO: `lookup` is disabled because we can't make direct comparisons based
    # on the key; we need to compare the hash of the key to the hash of the
    # keys in the pairs.  This could be an expensive operation, since those
    # hashes aren't retained.

    # TODO: file saving


class MappedRDD(RDD):
    """
    Pipelined maps:
    >>> rdd = sc.parallelize([1, 2, 3, 4])
    >>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
    [4, 8, 12, 16]
    >>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
    [4, 8, 12, 16]

    Pipelined reduces:
    >>> from operator import add
    >>> rdd.map(lambda x: 2 * x).reduce(add)
    20
    >>> rdd.flatMap(lambda x: [x, x]).reduce(add)
    20
    """
    def __init__(self, prev, func, preservesPartitioning=False, command='map'):
        if isinstance(prev, MappedRDD) and not prev.is_cached:
            prev_func = prev.func
            if command == 'reduce':
                if prev.command == 'flatmap':
                    def flatmap_reduce_func(x, acc):
                        values = prev_func(x)
                        if values is None:
                            return acc
                        if not acc:
                            if len(values) == 1:
                                return values[0]
                            else:
                                return reduce(func, values[1:], values[0])
                        else:
                            return reduce(func, values, acc)
                    self.func = flatmap_reduce_func
                else:
                    def reduce_func(x, acc):
                        val = prev_func(x)
                        if not val:
                            return acc
                        if acc is None:
                            return val
                        else:
                            return func(val, acc)
                    self.func = reduce_func
            else:
                if prev.command == 'flatmap':
                    command = 'flatmap'
                    self.func = lambda x: (func(y) for y in prev_func(x))
                else:
                    self.func = lambda x: func(prev_func(x))

            self.preservesPartitioning = \
                prev.preservesPartitioning and preservesPartitioning
            self._prev_jrdd = prev._prev_jrdd
            self.is_pipelined = True
        else:
            if command == 'reduce':
                def reduce_func(val, acc):
                    if acc is None:
                        return val
                    else:
                        return func(val, acc)
                self.func = reduce_func
            else:
                self.func = func
            self.preservesPartitioning = preservesPartitioning
            self._prev_jrdd = prev._jrdd
            self.is_pipelined = False
        self.is_cached = False
        self.ctx = prev.ctx
        self.prev = prev
        self._jrdd_val = None
        self.command = command

    @property
    def _jrdd(self):
        if not self._jrdd_val:
            funcs = [self.func]
            pipe_command = RDD._get_pipe_command(self.command, funcs)
            class_manifest = self._prev_jrdd.classManifest()
            python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
                pipe_command, self.preservesPartitioning, self.ctx.pythonExec,
                class_manifest)
            self._jrdd_val = python_rdd.asJavaRDD()
        return self._jrdd_val


def _test():
    import doctest
    from pyspark.context import SparkContext
    globs = globals().copy()
    globs['sc'] = SparkContext('local', 'PythonTest')
    doctest.testmod(globs=globs)
    globs['sc'].stop()


if __name__ == "__main__":
    _test()