aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/rdd.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-08-27 13:18:33 -0700
committerJosh Rosen <joshrosen@apache.org>2014-08-27 13:18:33 -0700
commit4fa2fda88fc7beebb579ba808e400113b512533b (patch)
treed872443f0281f52ee6ab6f19e34c5d1437d8e640 /python/pyspark/rdd.py
parent48f42781dedecd38ddcb2dcf67dead92bb4318f5 (diff)
downloadspark-4fa2fda88fc7beebb579ba808e400113b512533b.tar.gz
spark-4fa2fda88fc7beebb579ba808e400113b512533b.tar.bz2
spark-4fa2fda88fc7beebb579ba808e400113b512533b.zip
[SPARK-2871] [PySpark] add RDD.lookup(key)
RDD.lookup(key) Return the list of values in the RDD for key `key`. This operation is done efficiently if the RDD has a known partitioner by only searching the partition that the key maps to. >>> l = range(1000) >>> rdd = sc.parallelize(zip(l, l), 10) >>> rdd.lookup(42) # slow [42] >>> sorted = rdd.sortByKey() >>> sorted.lookup(42) # fast [42] It also clean up the code in RDD.py, and fix several bugs (related to preservesPartitioning). Author: Davies Liu <davies.liu@gmail.com> Closes #2093 from davies/lookup and squashes the following commits: 1789cd4 [Davies Liu] `f` in foreach could be generator or not. 2871b80 [Davies Liu] Merge branch 'master' into lookup c6390ea [Davies Liu] address all comments 0f1bce8 [Davies Liu] add test case for lookup() be0e8ba [Davies Liu] fix preservesPartitioning eb1305d [Davies Liu] add RDD.lookup(key)
Diffstat (limited to 'python/pyspark/rdd.py')
-rw-r--r--python/pyspark/rdd.py211
1 files changed, 79 insertions, 132 deletions
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 31919741e9..2d80fad796 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -147,76 +147,6 @@ class BoundedFloat(float):
return obj
-class MaxHeapQ(object):
-
- """
- An implementation of MaxHeap.
-
- >>> import pyspark.rdd
- >>> heap = pyspark.rdd.MaxHeapQ(5)
- >>> [heap.insert(i) for i in range(10)]
- [None, None, None, None, None, None, None, None, None, None]
- >>> sorted(heap.getElements())
- [0, 1, 2, 3, 4]
- >>> heap = pyspark.rdd.MaxHeapQ(5)
- >>> [heap.insert(i) for i in range(9, -1, -1)]
- [None, None, None, None, None, None, None, None, None, None]
- >>> sorted(heap.getElements())
- [0, 1, 2, 3, 4]
- >>> heap = pyspark.rdd.MaxHeapQ(1)
- >>> [heap.insert(i) for i in range(9, -1, -1)]
- [None, None, None, None, None, None, None, None, None, None]
- >>> heap.getElements()
- [0]
- """
-
- def __init__(self, maxsize):
- # We start from q[1], so its children are always 2 * k
- self.q = [0]
- self.maxsize = maxsize
-
- def _swim(self, k):
- while (k > 1) and (self.q[k / 2] < self.q[k]):
- self._swap(k, k / 2)
- k = k / 2
-
- def _swap(self, i, j):
- t = self.q[i]
- self.q[i] = self.q[j]
- self.q[j] = t
-
- def _sink(self, k):
- N = self.size()
- while 2 * k <= N:
- j = 2 * k
- # Here we test if both children are greater than parent
- # if not swap with larger one.
- if j < N and self.q[j] < self.q[j + 1]:
- j = j + 1
- if(self.q[k] > self.q[j]):
- break
- self._swap(k, j)
- k = j
-
- def size(self):
- return len(self.q) - 1
-
- def insert(self, value):
- if (self.size()) < self.maxsize:
- self.q.append(value)
- self._swim(self.size())
- else:
- self._replaceRoot(value)
-
- def getElements(self):
- return self.q[1:]
-
- def _replaceRoot(self, value):
- if(self.q[1] > value):
- self.q[1] = value
- self._sink(1)
-
-
def _parse_memory(s):
"""
Parse a memory string in the format supported by Java (e.g. 1g, 200m) and
@@ -248,6 +178,7 @@ class RDD(object):
self.ctx = ctx
self._jrdd_deserializer = jrdd_deserializer
self._id = jrdd.id()
+ self._partitionFunc = None
def _toPickleSerialization(self):
if (self._jrdd_deserializer == PickleSerializer() or
@@ -325,8 +256,6 @@ class RDD(object):
checkpointFile = self._jrdd.rdd().getCheckpointFile()
if checkpointFile.isDefined():
return checkpointFile.get()
- else:
- return None
def map(self, f, preservesPartitioning=False):
"""
@@ -366,7 +295,7 @@ class RDD(object):
"""
def func(s, iterator):
return f(iterator)
- return self.mapPartitionsWithIndex(func)
+ return self.mapPartitionsWithIndex(func, preservesPartitioning)
def mapPartitionsWithIndex(self, f, preservesPartitioning=False):
"""
@@ -416,7 +345,7 @@ class RDD(object):
"""
def func(iterator):
return ifilter(f, iterator)
- return self.mapPartitions(func)
+ return self.mapPartitions(func, True)
def distinct(self):
"""
@@ -561,7 +490,7 @@ class RDD(object):
"""
return self.map(lambda v: (v, None)) \
.cogroup(other.map(lambda v: (v, None))) \
- .filter(lambda x: (len(x[1][0]) != 0) and (len(x[1][1]) != 0)) \
+ .filter(lambda (k, vs): all(vs)) \
.keys()
def _reserialize(self, serializer=None):
@@ -616,7 +545,7 @@ class RDD(object):
if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)
- return self.mapPartitions(sortPartition)
+ return self.mapPartitions(sortPartition, True)
# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
@@ -721,8 +650,8 @@ class RDD(object):
def processPartition(iterator):
for x in iterator:
f(x)
- yield None
- self.mapPartitions(processPartition).collect() # Force evaluation
+ return iter([])
+ self.mapPartitions(processPartition).count() # Force evaluation
def foreachPartition(self, f):
"""
@@ -731,10 +660,15 @@ class RDD(object):
>>> def f(iterator):
... for x in iterator:
... print x
- ... yield None
>>> sc.parallelize([1, 2, 3, 4, 5]).foreachPartition(f)
"""
- self.mapPartitions(f).collect() # Force evaluation
+ def func(it):
+ r = f(it)
+ try:
+ return iter(r)
+ except TypeError:
+ return iter([])
+ self.mapPartitions(func).count() # Force evaluation
def collect(self):
"""
@@ -767,18 +701,23 @@ class RDD(object):
15
>>> sc.parallelize((2 for _ in range(10))).map(lambda x: 1).cache().reduce(add)
10
+ >>> sc.parallelize([]).reduce(add)
+ Traceback (most recent call last):
+ ...
+ ValueError: Can not reduce() empty RDD
"""
def func(iterator):
- acc = None
- for obj in iterator:
- if acc is None:
- acc = obj
- else:
- acc = f(obj, acc)
- if acc is not None:
- yield acc
+ iterator = iter(iterator)
+ try:
+ initial = next(iterator)
+ except StopIteration:
+ return
+ yield reduce(f, iterator, initial)
+
vals = self.mapPartitions(func).collect()
- return reduce(f, vals)
+ if vals:
+ return reduce(f, vals)
+ raise ValueError("Can not reduce() empty RDD")
def fold(self, zeroValue, op):
"""
@@ -1081,7 +1020,7 @@ class RDD(object):
yield counts
def mergeMaps(m1, m2):
- for (k, v) in m2.iteritems():
+ for k, v in m2.iteritems():
m1[k] += v
return m1
return self.mapPartitions(countPartition).reduce(mergeMaps)
@@ -1117,24 +1056,10 @@ class RDD(object):
[10, 9, 7, 6, 5, 4]
"""
- def topNKeyedElems(iterator, key_=None):
- q = MaxHeapQ(num)
- for k in iterator:
- if key_ is not None:
- k = (key_(k), k)
- q.insert(k)
- yield q.getElements()
-
- def unKey(x, key_=None):
- if key_ is not None:
- x = [i[1] for i in x]
- return x
-
def merge(a, b):
- return next(topNKeyedElems(a + b))
- result = self.mapPartitions(
- lambda i: topNKeyedElems(i, key)).reduce(merge)
- return sorted(unKey(result, key), key=key)
+ return heapq.nsmallest(num, a + b, key)
+
+ return self.mapPartitions(lambda it: [heapq.nsmallest(num, it, key)]).reduce(merge)
def take(self, num):
"""
@@ -1174,13 +1099,13 @@ class RDD(object):
left = num - len(items)
def takeUpToNumLeft(iterator):
+ iterator = iter(iterator)
taken = 0
while taken < left:
yield next(iterator)
taken += 1
- p = range(
- partsScanned, min(partsScanned + numPartsToTry, totalParts))
+ p = range(partsScanned, min(partsScanned + numPartsToTry, totalParts))
res = self.context.runJob(self, takeUpToNumLeft, p, True)
items += res
@@ -1194,8 +1119,15 @@ class RDD(object):
>>> sc.parallelize([2, 3, 4]).first()
2
+ >>> sc.parallelize([]).first()
+ Traceback (most recent call last):
+ ...
+ ValueError: RDD is empty
"""
- return self.take(1)[0]
+ rs = self.take(1)
+ if rs:
+ return rs[0]
+ raise ValueError("RDD is empty")
def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
"""
@@ -1420,13 +1352,13 @@ class RDD(object):
"""
def reducePartition(iterator):
m = {}
- for (k, v) in iterator:
- m[k] = v if k not in m else func(m[k], v)
+ for k, v in iterator:
+ m[k] = func(m[k], v) if k in m else v
yield m
def mergeMaps(m1, m2):
- for (k, v) in m2.iteritems():
- m1[k] = v if k not in m1 else func(m1[k], v)
+ for k, v in m2.iteritems():
+ m1[k] = func(m1[k], v) if k in m1 else v
return m1
return self.mapPartitions(reducePartition).reduce(mergeMaps)
@@ -1523,7 +1455,7 @@ class RDD(object):
buckets = defaultdict(list)
c, batch = 0, min(10 * numPartitions, 1000)
- for (k, v) in iterator:
+ for k, v in iterator:
buckets[partitionFunc(k) % numPartitions].append((k, v))
c += 1
@@ -1546,7 +1478,7 @@ class RDD(object):
batch = max(batch / 1.5, 1)
c = 0
- for (split, items) in buckets.iteritems():
+ for split, items in buckets.iteritems():
yield pack_long(split)
yield outputSerializer.dumps(items)
@@ -1616,7 +1548,7 @@ class RDD(object):
merger.mergeCombiners(iterator)
return merger.iteritems()
- return shuffled.mapPartitions(_mergeCombiners)
+ return shuffled.mapPartitions(_mergeCombiners, True)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
@@ -1680,7 +1612,6 @@ class RDD(object):
return self.combineByKey(createCombiner, mergeValue, mergeCombiners,
numPartitions).mapValues(lambda x: ResultIterable(x))
- # TODO: add tests
def flatMapValues(self, f):
"""
Pass each value in the key-value pair RDD through a flatMap function
@@ -1770,9 +1701,8 @@ class RDD(object):
[('b', 4), ('b', 5)]
"""
def filter_func((key, vals)):
- return len(vals[0]) > 0 and len(vals[1]) == 0
- map_func = lambda (key, vals): [(key, val) for val in vals[0]]
- return self.cogroup(other, numPartitions).filter(filter_func).flatMap(map_func)
+ return vals[0] and not vals[1]
+ return self.cogroup(other, numPartitions).filter(filter_func).flatMapValues(lambda x: x[0])
def subtract(self, other, numPartitions=None):
"""
@@ -1785,7 +1715,7 @@ class RDD(object):
"""
# note: here 'True' is just a placeholder
rdd = other.map(lambda x: (x, True))
- return self.map(lambda x: (x, True)).subtractByKey(rdd).map(lambda tpl: tpl[0])
+ return self.map(lambda x: (x, True)).subtractByKey(rdd, numPartitions).keys()
def keyBy(self, f):
"""
@@ -1925,9 +1855,8 @@ class RDD(object):
Return the name of this RDD.
"""
name_ = self._jrdd.name()
- if not name_:
- return None
- return name_.encode('utf-8')
+ if name_:
+ return name_.encode('utf-8')
def setName(self, name):
"""
@@ -1945,9 +1874,8 @@ class RDD(object):
A description of this RDD and its recursive dependencies for debugging.
"""
debug_string = self._jrdd.toDebugString()
- if not debug_string:
- return None
- return debug_string.encode('utf-8')
+ if debug_string:
+ return debug_string.encode('utf-8')
def getStorageLevel(self):
"""
@@ -1982,10 +1910,28 @@ class RDD(object):
else:
return self.getNumPartitions()
- # TODO: `lookup` is disabled because we can't make direct comparisons based
- # on the key; we need to compare the hash of the key to the hash of the
- # keys in the pairs. This could be an expensive operation, since those
- # hashes aren't retained.
+ def lookup(self, key):
+ """
+ Return the list of values in the RDD for key `key`. This operation
+ is done efficiently if the RDD has a known partitioner by only
+ searching the partition that the key maps to.
+
+ >>> l = range(1000)
+ >>> rdd = sc.parallelize(zip(l, l), 10)
+ >>> rdd.lookup(42) # slow
+ [42]
+ >>> sorted = rdd.sortByKey()
+ >>> sorted.lookup(42) # fast
+ [42]
+ >>> sorted.lookup(1024)
+ []
+ """
+ values = self.filter(lambda (k, v): k == key).values()
+
+ if self._partitionFunc is not None:
+ return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False)
+
+ return values.collect()
def _is_pickled(self):
""" Return this RDD is serialized by Pickle or not. """
@@ -2096,6 +2042,7 @@ class PipelinedRDD(RDD):
self._jrdd_val = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
+ self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
@property
def _jrdd(self):