aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/join.py8
-rw-r--r--python/pyspark/rdd.py49
-rw-r--r--python/pyspark/streaming/dstream.py2
-rw-r--r--python/pyspark/tests.py38
4 files changed, 75 insertions, 22 deletions
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index b4a8447137..efc1ef9396 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -35,8 +35,8 @@ from pyspark.resultiterable import ResultIterable
def _do_python_join(rdd, other, numPartitions, dispatch):
- vs = rdd.map(lambda (k, v): (k, (1, v)))
- ws = other.map(lambda (k, v): (k, (2, v)))
+ vs = rdd.mapValues(lambda v: (1, v))
+ ws = other.mapValues(lambda v: (2, v))
return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__()))
@@ -98,8 +98,8 @@ def python_full_outer_join(rdd, other, numPartitions):
def python_cogroup(rdds, numPartitions):
def make_mapper(i):
- return lambda (k, v): (k, (i, v))
- vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
+ return lambda v: (i, v)
+ vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bd4f16e058..ba2347ae76 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -111,6 +111,19 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+class Partitioner(object):
+ def __init__(self, numPartitions, partitionFunc):
+ self.numPartitions = numPartitions
+ self.partitionFunc = partitionFunc
+
+ def __eq__(self, other):
+ return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
+ and self.partitionFunc == other.partitionFunc)
+
+ def __call__(self, k):
+ return self.partitionFunc(k) % self.numPartitions
+
+
class RDD(object):
"""
@@ -126,7 +139,7 @@ class RDD(object):
self.ctx = ctx
self._jrdd_deserializer = jrdd_deserializer
self._id = jrdd.id()
- self._partitionFunc = None
+ self.partitioner = None
def _pickled(self):
return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
@@ -450,14 +463,17 @@ class RDD(object):
if self._jrdd_deserializer == other._jrdd_deserializer:
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
self._jrdd_deserializer)
- return rdd
else:
# These RDDs contain data in different serialized formats, so we
# must normalize them to the default serializer.
self_copy = self._reserialize()
other_copy = other._reserialize()
- return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
- self.ctx.serializer)
+ rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+ if (self.partitioner == other.partitioner and
+ self.getNumPartitions() == rdd.getNumPartitions()):
+ rdd.partitioner = self.partitioner
+ return rdd
def intersection(self, other):
"""
@@ -1588,6 +1604,9 @@ class RDD(object):
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ partitioner = Partitioner(numPartitions, partitionFunc)
+ if self.partitioner == partitioner:
+ return self
# Transferring O(n) objects to Java is too expensive.
# Instead, we'll form the hash buckets in Python,
@@ -1632,18 +1651,16 @@ class RDD(object):
yield pack_long(split)
yield outputSerializer.dumps(items)
- keyed = self.mapPartitionsWithIndex(add_shuffle_key)
+ keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True)
keyed._bypass_serializer = True
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
- jrdd = pairRDD.partitionBy(partitioner).values()
+ jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
+ jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique,
- # even if partitionFunc is a lambda:
- rdd._partitionFunc = partitionFunc
+ rdd.partitioner = partitioner
return rdd
# TODO: add control over map-side aggregation
@@ -1689,7 +1706,7 @@ class RDD(object):
merger.mergeValues(iterator)
return merger.iteritems()
- locally_combined = self.mapPartitions(combineLocally)
+ locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
@@ -1698,7 +1715,7 @@ class RDD(object):
merger.mergeCombiners(iterator)
return merger.iteritems()
- return shuffled.mapPartitions(_mergeCombiners, True)
+ return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
@@ -2077,8 +2094,8 @@ class RDD(object):
"""
values = self.filter(lambda (k, v): k == key).values()
- if self._partitionFunc is not None:
- return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False)
+ if self.partitioner is not None:
+ return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
return values.collect()
@@ -2243,7 +2260,7 @@ class PipelinedRDD(RDD):
self._id = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
- self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self.partitioner = prev.partitioner if self.preservesPartitioning else None
self._broadcast = None
def __del__(self):
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index 2fe39392ff..3fa4244423 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -578,7 +578,7 @@ class DStream(object):
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
- g = a.cogroup(b, numPartitions)
+ g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index d6afc1cdaa..f64e25c607 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -727,7 +727,6 @@ class RDDTests(ReusedPySparkTestCase):
(u'1', {u'director': u'David Lean'}),
(u'2', {u'director': u'Andrew Dominik'})
]
- from pyspark.rdd import RDD
data_rdd = self.sc.parallelize(data)
data_java_rdd = data_rdd._to_java_object_rdd()
data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd)
@@ -740,6 +739,43 @@ class RDDTests(ReusedPySparkTestCase):
converted_rdd = RDD(data_python_rdd, self.sc)
self.assertEqual(2, converted_rdd.count())
+ def test_narrow_dependency_in_join(self):
+ rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
+ parted = rdd.partitionBy(2)
+ self.assertEqual(2, parted.union(parted).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
+
+ self.sc.setJobGroup("test1", "test", True)
+ tracker = self.sc.statusTracker()
+
+ d = sorted(parted.join(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test1")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test2", "test", True)
+ d = sorted(parted.join(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test2")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test3", "test", True)
+ d = sorted(parted.cogroup(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], map(list, d[0][1]))
+ jobId = tracker.getJobIdsForGroup("test3")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test4", "test", True)
+ d = sorted(parted.cogroup(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], map(list, d[0][1]))
+ jobId = tracker.getJobIdsForGroup("test4")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
class ProfilerTests(PySparkTestCase):