aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-17 16:54:57 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-02-17 16:54:57 -0800
commitc3d2b90bde2e11823909605d518167548df66bd8 (patch)
treeeab646a984d8c91b533789fc07fea1221cfe6460
parent117121a4ecaadda156a82255333670775e7727db (diff)
downloadspark-c3d2b90bde2e11823909605d518167548df66bd8.tar.gz
spark-c3d2b90bde2e11823909605d518167548df66bd8.tar.bz2
spark-c3d2b90bde2e11823909605d518167548df66bd8.zip
[SPARK-5785] [PySpark] narrow dependency for cogroup/join in PySpark
Currently, PySpark does not support narrow dependency during cogroup/join when the two RDDs have the partitioner, another unnecessary shuffle stage will come in. The Python implementation of cogroup/join is different than Scala one, it depends on union() and partitionBy(). This patch will try to use PartitionerAwareUnionRDD() in union(), when all the RDDs have the same partitioner. It also fix `reservePartitioner` in all the map() or mapPartitions(), then partitionBy() can skip the unnecessary shuffle stage. Author: Davies Liu <davies@databricks.com> Closes #4629 from davies/narrow and squashes the following commits: dffe34e [Davies Liu] improve test, check number of stages for join/cogroup 1ed3ba2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into narrow 4d29932 [Davies Liu] address comment cc28d97 [Davies Liu] add unit tests 940245e [Davies Liu] address comments ff5a0a6 [Davies Liu] skip the partitionBy() on Python side eb26c62 [Davies Liu] narrow dependency in PySpark
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala8
-rw-r--r--python/pyspark/join.py8
-rw-r--r--python/pyspark/rdd.py49
-rw-r--r--python/pyspark/streaming/dstream.py2
-rw-r--r--python/pyspark/tests.py38
7 files changed, 101 insertions, 25 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index fd8fac6df0..d59b466830 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -961,11 +961,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/** Build the union of a list of RDDs. */
- def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = {
+ val partitioners = rdds.flatMap(_.partitioner).toSet
+ if (partitioners.size == 1) {
+ new PartitionerAwareUnionRDD(this, rdds)
+ } else {
+ new UnionRDD(this, rdds)
+ }
+ }
/** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
- new UnionRDD(this, Seq(first) ++ rest)
+ union(Seq(first) ++ rest)
/** Get an RDD that has no partitions or elements. */
def emptyRDD[T: ClassTag] = new EmptyRDD[T](this)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 2527211929..dcb6e6313a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -303,6 +303,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Long, Array[Byte])](prev) {
override def getPartitions = prev.partitions
+ override val partitioner = prev.partitioner
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
@@ -330,6 +331,15 @@ private[spark] object PythonRDD extends Logging {
}
/**
+ * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true
+ *
+ * This is useful for PySpark to have the partitioner after partitionBy()
+ */
+ def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = {
+ pair.rdd.mapPartitions(it => it.map(_._2), true)
+ }
+
+ /**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index fe55a5124f..3ab9e54f0e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -462,7 +462,13 @@ abstract class RDD[T: ClassTag](
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
- def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
+ def union(other: RDD[T]): RDD[T] = {
+ if (partitioner.isDefined && other.partitioner == partitioner) {
+ new PartitionerAwareUnionRDD(sc, Array(this, other))
+ } else {
+ new UnionRDD(sc, Array(this, other))
+ }
+ }
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple
diff --git a/python/pyspark/join.py b/python/pyspark/join.py
index b4a8447137..efc1ef9396 100644
--- a/python/pyspark/join.py
+++ b/python/pyspark/join.py
@@ -35,8 +35,8 @@ from pyspark.resultiterable import ResultIterable
def _do_python_join(rdd, other, numPartitions, dispatch):
- vs = rdd.map(lambda (k, v): (k, (1, v)))
- ws = other.map(lambda (k, v): (k, (2, v)))
+ vs = rdd.mapValues(lambda v: (1, v))
+ ws = other.mapValues(lambda v: (2, v))
return vs.union(ws).groupByKey(numPartitions).flatMapValues(lambda x: dispatch(x.__iter__()))
@@ -98,8 +98,8 @@ def python_full_outer_join(rdd, other, numPartitions):
def python_cogroup(rdds, numPartitions):
def make_mapper(i):
- return lambda (k, v): (k, (i, v))
- vrdds = [rdd.map(make_mapper(i)) for i, rdd in enumerate(rdds)]
+ return lambda v: (i, v)
+ vrdds = [rdd.mapValues(make_mapper(i)) for i, rdd in enumerate(rdds)]
union_vrdds = reduce(lambda acc, other: acc.union(other), vrdds)
rdd_len = len(vrdds)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index bd4f16e058..ba2347ae76 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -111,6 +111,19 @@ def _parse_memory(s):
return int(float(s[:-1]) * units[s[-1].lower()])
+class Partitioner(object):
+ def __init__(self, numPartitions, partitionFunc):
+ self.numPartitions = numPartitions
+ self.partitionFunc = partitionFunc
+
+ def __eq__(self, other):
+ return (isinstance(other, Partitioner) and self.numPartitions == other.numPartitions
+ and self.partitionFunc == other.partitionFunc)
+
+ def __call__(self, k):
+ return self.partitionFunc(k) % self.numPartitions
+
+
class RDD(object):
"""
@@ -126,7 +139,7 @@ class RDD(object):
self.ctx = ctx
self._jrdd_deserializer = jrdd_deserializer
self._id = jrdd.id()
- self._partitionFunc = None
+ self.partitioner = None
def _pickled(self):
return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
@@ -450,14 +463,17 @@ class RDD(object):
if self._jrdd_deserializer == other._jrdd_deserializer:
rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
self._jrdd_deserializer)
- return rdd
else:
# These RDDs contain data in different serialized formats, so we
# must normalize them to the default serializer.
self_copy = self._reserialize()
other_copy = other._reserialize()
- return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
- self.ctx.serializer)
+ rdd = RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+ if (self.partitioner == other.partitioner and
+ self.getNumPartitions() == rdd.getNumPartitions()):
+ rdd.partitioner = self.partitioner
+ return rdd
def intersection(self, other):
"""
@@ -1588,6 +1604,9 @@ class RDD(object):
"""
if numPartitions is None:
numPartitions = self._defaultReducePartitions()
+ partitioner = Partitioner(numPartitions, partitionFunc)
+ if self.partitioner == partitioner:
+ return self
# Transferring O(n) objects to Java is too expensive.
# Instead, we'll form the hash buckets in Python,
@@ -1632,18 +1651,16 @@ class RDD(object):
yield pack_long(split)
yield outputSerializer.dumps(items)
- keyed = self.mapPartitionsWithIndex(add_shuffle_key)
+ keyed = self.mapPartitionsWithIndex(add_shuffle_key, preservesPartitioning=True)
keyed._bypass_serializer = True
with SCCallSiteSync(self.context) as css:
pairRDD = self.ctx._jvm.PairwiseRDD(
keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
- jrdd = pairRDD.partitionBy(partitioner).values()
+ jpartitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
+ jrdd = self.ctx._jvm.PythonRDD.valueOfPair(pairRDD.partitionBy(jpartitioner))
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
- # This is required so that id(partitionFunc) remains unique,
- # even if partitionFunc is a lambda:
- rdd._partitionFunc = partitionFunc
+ rdd.partitioner = partitioner
return rdd
# TODO: add control over map-side aggregation
@@ -1689,7 +1706,7 @@ class RDD(object):
merger.mergeValues(iterator)
return merger.iteritems()
- locally_combined = self.mapPartitions(combineLocally)
+ locally_combined = self.mapPartitions(combineLocally, preservesPartitioning=True)
shuffled = locally_combined.partitionBy(numPartitions)
def _mergeCombiners(iterator):
@@ -1698,7 +1715,7 @@ class RDD(object):
merger.mergeCombiners(iterator)
return merger.iteritems()
- return shuffled.mapPartitions(_mergeCombiners, True)
+ return shuffled.mapPartitions(_mergeCombiners, preservesPartitioning=True)
def aggregateByKey(self, zeroValue, seqFunc, combFunc, numPartitions=None):
"""
@@ -2077,8 +2094,8 @@ class RDD(object):
"""
values = self.filter(lambda (k, v): k == key).values()
- if self._partitionFunc is not None:
- return self.ctx.runJob(values, lambda x: x, [self._partitionFunc(key)], False)
+ if self.partitioner is not None:
+ return self.ctx.runJob(values, lambda x: x, [self.partitioner(key)], False)
return values.collect()
@@ -2243,7 +2260,7 @@ class PipelinedRDD(RDD):
self._id = None
self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
- self._partitionFunc = prev._partitionFunc if self.preservesPartitioning else None
+ self.partitioner = prev.partitioner if self.preservesPartitioning else None
self._broadcast = None
def __del__(self):
diff --git a/python/pyspark/streaming/dstream.py b/python/pyspark/streaming/dstream.py
index 2fe39392ff..3fa4244423 100644
--- a/python/pyspark/streaming/dstream.py
+++ b/python/pyspark/streaming/dstream.py
@@ -578,7 +578,7 @@ class DStream(object):
if a is None:
g = b.groupByKey(numPartitions).mapValues(lambda vs: (list(vs), None))
else:
- g = a.cogroup(b, numPartitions)
+ g = a.cogroup(b.partitionBy(numPartitions), numPartitions)
g = g.mapValues(lambda (va, vb): (list(vb), list(va)[0] if len(va) else None))
state = g.mapValues(lambda (vs, s): updateFunc(vs, s))
return state.filter(lambda (k, v): v is not None)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index d6afc1cdaa..f64e25c607 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -727,7 +727,6 @@ class RDDTests(ReusedPySparkTestCase):
(u'1', {u'director': u'David Lean'}),
(u'2', {u'director': u'Andrew Dominik'})
]
- from pyspark.rdd import RDD
data_rdd = self.sc.parallelize(data)
data_java_rdd = data_rdd._to_java_object_rdd()
data_python_rdd = self.sc._jvm.SerDe.javaToPython(data_java_rdd)
@@ -740,6 +739,43 @@ class RDDTests(ReusedPySparkTestCase):
converted_rdd = RDD(data_python_rdd, self.sc)
self.assertEqual(2, converted_rdd.count())
+ def test_narrow_dependency_in_join(self):
+ rdd = self.sc.parallelize(range(10)).map(lambda x: (x, x))
+ parted = rdd.partitionBy(2)
+ self.assertEqual(2, parted.union(parted).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, parted.union(rdd).getNumPartitions())
+ self.assertEqual(rdd.getNumPartitions() + 2, rdd.union(parted).getNumPartitions())
+
+ self.sc.setJobGroup("test1", "test", True)
+ tracker = self.sc.statusTracker()
+
+ d = sorted(parted.join(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test1")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test2", "test", True)
+ d = sorted(parted.join(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual((0, (0, 0)), d[0])
+ jobId = tracker.getJobIdsForGroup("test2")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test3", "test", True)
+ d = sorted(parted.cogroup(parted).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], map(list, d[0][1]))
+ jobId = tracker.getJobIdsForGroup("test3")[0]
+ self.assertEqual(2, len(tracker.getJobInfo(jobId).stageIds))
+
+ self.sc.setJobGroup("test4", "test", True)
+ d = sorted(parted.cogroup(rdd).collect())
+ self.assertEqual(10, len(d))
+ self.assertEqual([[0], [0]], map(list, d[0][1]))
+ jobId = tracker.getJobIdsForGroup("test4")[0]
+ self.assertEqual(3, len(tracker.getJobInfo(jobId).stageIds))
+
class ProfilerTests(PySparkTestCase):