from base64 import standard_b64encode as b64enc
from collections import defaultdict
from itertools import chain, ifilter, imap
import os
import shlex
from subprocess import Popen, PIPE
from threading import Thread
from pyspark import cloudpickle
from pyspark.serializers import dump_pickle, load_pickle
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from py4j.java_collections import ListConverter, MapConverter
class RDD(object):
def __init__(self, jrdd, ctx):
self._jrdd = jrdd
self.is_cached = False
self.ctx = ctx
def cache(self):
self.is_cached = True
self._jrdd.cache()
return self
def map(self, f, preservesPartitioning=False):
def func(iterator): return imap(f, iterator)
return PipelinedRDD(self, func, preservesPartitioning)
def flatMap(self, f):
"""
>>> rdd = sc.parallelize([2, 3, 4])
>>> sorted(rdd.flatMap(lambda x: range(1, x)).collect())
[1, 1, 1, 2, 2, 3]
>>> sorted(rdd.flatMap(lambda x: [(x, x), (x, x)]).collect())
[(2, 2), (2, 2), (3, 3), (3, 3), (4, 4), (4, 4)]
"""
def func(iterator): return chain.from_iterable(imap(f, iterator))
return self.mapPartitions(func)
def mapPartitions(self, f):
"""
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> def f(iterator): yield sum(iterator)
>>> rdd.mapPartitions(f).collect()
[3, 7]
"""
return PipelinedRDD(self, f)
def filter(self, f):
"""
>>> rdd = sc.parallelize([1, 2, 3, 4, 5])
>>> rdd.filter(lambda x: x % 2 == 0).collect()
[2, 4]
"""
def func(iterator): return ifilter(f, iterator)
return self.mapPartitions(func)
def distinct(self):
"""
>>> sorted(sc.parallelize([1, 1, 2, 3]).distinct().collect())
[1, 2, 3]
"""
return self.map(lambda x: (x, "")) \
.reduceByKey(lambda x, _: x) \
.map(lambda (x, _): x)
def sample(self, withReplacement, fraction, seed):
jrdd = self._jrdd.sample(withReplacement, fraction, seed)
return RDD(jrdd, self.ctx)
def takeSample(self, withReplacement, num, seed):
vals = self._jrdd.takeSample(withReplacement, num, seed)
return [load_pickle(bytes(x)) for x in vals]
def union(self, other):
"""
>>> rdd = sc.parallelize([1, 1, 2, 3])
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
return RDD(self._jrdd.union(other._jrdd), self.ctx)
def __add__(self, other):
"""
>>> rdd = sc.parallelize([1, 1, 2, 3])
>>> (rdd + rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
if not isinstance(other, RDD):
raise TypeError
return self.union(other)
# TODO: sort
def glom(self):
"""
>>> rdd = sc.parallelize([1, 2, 3, 4], 2)
>>> rdd.glom().first()
[1, 2]
"""
def func(iterator): yield list(iterator)
return self.mapPartitions(func)
def cartesian(self, other):
"""
>>> rdd = sc.parallelize([1, 2])
>>> sorted(rdd.cartesian(rdd).collect())
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
return RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
def groupBy(self, f, numSplits=None):
"""
>>> rdd = sc.parallelize([1, 1, 2, 3, 5, 8])
>>> result = rdd.groupBy(lambda x: x % 2).collect()
>>> sorted([(x, sorted(y)) for (x, y) in result])
[(0, [2, 8]), (1, [1, 1, 3, 5])]
"""
return self.map(lambda x: (f(x), x)).groupByKey(numSplits)
def pipe(self, command, env={}):
"""
>>> sc.parallelize([1, 2, 3]).pipe('cat').collect()
['1', '2', '3']
"""
def func(iterator):
pipe = Popen(shlex.split(command), env=env, stdin=PIPE, stdout=PIPE)
def pipe_objs(out):
for obj in iterator:
out.write(str(obj).rstrip('\n') + '\n')
out.close()
Thread(target=pipe_objs, args=[pipe.stdin]).start()
return (x.rstrip('\n') for x in pipe.stdout)
return self.mapPartitions(func)
def foreach(self, f):
"""
>>> def f(x): print x
>>> sc.parallelize([1, 2, 3, 4, 5]).foreach(f)
"""
self.map(f).collect() # Force evaluation
def collect(self):
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().collect())
return load_pickle(bytes(pickle))
def reduce(self, f):
"""
>>> from operator import add
>>> sc.parallelize([1, 2, 3, 4, 5]).reduce(add)
15
>>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
10
"""
def func(iterator):
acc = None
for obj in iterator:
if acc is None:
acc = obj
else:
acc = f(obj, acc)
if acc is not None:
yield acc
vals = self.mapPartitions(func).collect()
return reduce(f, vals)
def fold(self, zeroValue, op):
"""
Aggregate the elements of each partition, and then the results for all
the partitions, using a given associative function and a neutral "zero
value." The function op(t1, t2) is allowed to modify t1 and return it
as its result value to avoid object allocation; however, it should not
modify t2.
>>> from operator import add
>>> sc.parallelize([1, 2, 3, 4, 5]).fold(0, add)
15
"""
def func(iterator):
acc = zeroValue
for obj in iterator:
acc = op(obj, acc)
yield acc
vals = self.mapPartitions(func).collect()
return reduce(op, vals, zeroValue)
# TODO: aggregate
def count(self):
"""
>>> sc.parallelize([2, 3, 4]).count()
3L
"""
return self._jrdd.count()
def countByValue(self):
"""
>>> sorted(sc.parallelize([1, 2, 1, 2, 2], 2).countByValue().items())
[(1, 2), (2, 3)]
"""
def countPartition(iterator):
counts = defaultdict(int)
for obj in iterator:
counts[obj] += 1
yield counts
def mergeMaps(m1, m2):
for (k, v) in m2.iteritems():
m1[k] += v
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
def take(self, num):
"""
>>> sc.parallelize([2, 3, 4]).take(2)
[2, 3]
"""
pickle = self.ctx.arrayAsPickle(self._jrdd.rdd().take(num))
return load_pickle(bytes(pickle))
def first(self):
"""
>>> sc.parallelize([2, 3, 4]).first()
2
"""
return load_pickle(bytes(self.ctx.asPickle(self._jrdd.first())))
def saveAsTextFile(self, path):
def func(iterator):
return (str(x).encode("utf-8") for x in iterator)
keyed = PipelinedRDD(self, func)
keyed._bypass_serializer = True
keyed._jrdd.map(self.ctx.jvm.BytesToString()).saveAsTextFile(path)
# Pair functions
def collectAsMap(self):
"""
>>> m = sc.parallelize([(1, 2), (3, 4)]).collectAsMap()
>>> m[1]
2
>>> m[3]
4
"""
return dict(self.collect())
def reduceByKey(self, func, numSplits=None):
"""
>>> from operator import add
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.reduceByKey(add).collect())
[('a', 2), ('b', 1)]
"""
return self.combineByKey(lambda x: x, func, func, numSplits)
def reduceByKeyLocally(self, func):
"""
>>> from operator import add
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.reduceByKeyLocally(add).items())
[('a', 2), ('b', 1)]
"""
def reducePartition(iterator):
m = {}
for (k, v) in iterator:
m[k] = v if k not in m else func(m[k], v)
yield m
def mergeMaps(m1, m2):
for (k, v) in m2.iteritems():
m1[k] = v if k not in m1 else func(m1[k], v)
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)
def countByKey(self):
"""
>>> rdd = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(rdd.countByKey().items())
[('a', 2), ('b', 1)]
"""
return self.map(lambda x: x[0]).countByValue()
def join(self, other, numSplits=None):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2), ("a", 3)])
>>> sorted(x.join(y).collect())
[('a', (1, 2)), ('a', (1, 3))]
"""
return python_join(self, other, numSplits)
def leftOuterJoin(self, other, numSplits=None):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(x.leftOuterJoin(y).collect())
[('a', (1, 2)), ('b', (4, None))]
"""
return python_left_outer_join(self, other, numSplits)
def rightOuterJoin(self, other, numSplits=None):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(y.rightOuterJoin(x).collect())
[('a', (2, 1)), ('b', (None, 4))]
"""
return python_right_outer_join(self, other, numSplits)
def partitionBy(self, numSplits, hashFunc=hash):
"""
>>> pairs = sc.parallelize([1, 2, 3, 4, 2, 4, 1]).map(lambda x: (x, x))
>>> sets = pairs.partitionBy(2).glom().collect()
>>> set(sets[0]).intersection(set(sets[1]))
set([])
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def add_shuffle_key(iterator):
for (k, v) in iterator:
yield str(hashFunc(k))
yield dump_pickle((k, v))
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx.jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx.jvm.spark.api.python.PythonPartitioner(numSplits)
jrdd = pairRDD.partitionBy(partitioner)
jrdd = jrdd.map(self.ctx.jvm.ExtractValue())
return RDD(jrdd, self.ctx)
def combineByKey(self, createCombiner, mergeValue, mergeCombiners,
numSplits=None):
"""
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> def f(x): return x
>>> def add(a, b): return a + str(b)
>>> sorted(x.combineByKey(str, add, add).collect())
[('a', '11'), ('b', '1')]
"""
if numSplits is None:
numSplits = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
for (k, v) in iterator:
if k not in combiners:
combiners[k] = createCombiner(v)
else:
combiners[k] = mergeValue(combiners[k], v)
return combiners.iteritems()
locally_combined = self.mapPartitions(combineLocally)
shuffled = locally_combined.partitionBy(numSplits)
def _mergeCombiners(iterator):
combiners = {}
for (k, v) in iterator:
if not k in combiners:
combiners[k] = v
else:
combiners[k] = mergeCombiners(combiners[k], v)
return combiners.iteritems()
return shuffled.mapPartitions(_mergeCombiners)
def groupByKey(self, numSplits=None):
"""
>>> x = sc.parallelize([("a", 1), ("b", 1), ("a", 1)])
>>> sorted(x.groupByKey().collect())
[('a', [1, 1]), ('b', [1])]
"""
def createCombiner(x):
return [x]
def mergeValue(xs, x):
xs.append(x)
return xs
def mergeCombiners(a, b):
return a + b
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numSplits)
def flatMapValues(self, f):
flat_map_fn = lambda (k, v): ((k, x) for x in f(v))
return self.flatMap(flat_map_fn)
def mapValues(self, f):
map_values_fn = lambda (k, v): (k, f(v))
return self.map(map_values_fn, preservesPartitioning=True)
# TODO: support varargs cogroup of several RDDs.
def groupWith(self, other):
return self.cogroup(other)
def cogroup(self, other, numSplits=None):
"""
>>> x = sc.parallelize([("a", 1), ("b", 4)])
>>> y = sc.parallelize([("a", 2)])
>>> sorted(x.cogroup(y).collect())
[('a', ([1], [2])), ('b', ([4], []))]
"""
return python_cogroup(self, other, numSplits)
# TODO: `lookup` is disabled because we can't make direct comparisons based
# on the key; we need to compare the hash of the key to the hash of the
# keys in the pairs. This could be an expensive operation, since those
# hashes aren't retained.
class PipelinedRDD(RDD):
"""
Pipelined maps:
>>> rdd = sc.parallelize([1, 2, 3, 4])
>>> rdd.map(lambda x: 2 * x).cache().map(lambda x: 2 * x).collect()
[4, 8, 12, 16]
>>> rdd.map(lambda x: 2 * x).map(lambda x: 2 * x).collect()
[4, 8, 12, 16]
Pipelined reduces:
>>> from operator import add
>>> rdd.map(lambda x: 2 * x).reduce(add)
20
>>> rdd.flatMap(lambda x: [x, x]).reduce(add)
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
if isinstance(prev, PipelinedRDD) and not prev.is_cached:
prev_func = prev.func
def pipeline_func(iterator):
return func(prev_func(iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
self._prev_jrdd = prev._prev_jrdd
else:
self.func = func
self.preservesPartitioning = preservesPartitioning
self._prev_jrdd = prev._jrdd
self.is_cached = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
funcs = [self.func, self._bypass_serializer]
pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in funcs)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx.gateway._gateway_client)
self.ctx._pickled_broadcast_vars.clear()
class_manifest = self._prev_jrdd.classManifest()
env = MapConverter().convert(
{'PYTHONPATH' : os.environ.get("PYTHONPATH", "")},
self.ctx.gateway._gateway_client)
python_rdd = self.ctx.jvm.PythonRDD(self._prev_jrdd.rdd(),
pipe_command, env, self.preservesPartitioning, self.ctx.pythonExec,
broadcast_vars, class_manifest)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
def _test():
import doctest
from pyspark.context import SparkContext
globs = globals().copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest')
doctest.testmod(globs=globs)
globs['sc'].stop()
if __name__ == "__main__":
_test()