aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-17 16:54:57 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-02-17 16:54:57 -0800
commitc3d2b90bde2e11823909605d518167548df66bd8 (patch)
treeeab646a984d8c91b533789fc07fea1221cfe6460 /core
parent117121a4ecaadda156a82255333670775e7727db (diff)
downloadspark-c3d2b90bde2e11823909605d518167548df66bd8.tar.gz
spark-c3d2b90bde2e11823909605d518167548df66bd8.tar.bz2
spark-c3d2b90bde2e11823909605d518167548df66bd8.zip
[SPARK-5785] [PySpark] narrow dependency for cogroup/join in PySpark
Currently, PySpark does not support narrow dependency during cogroup/join when the two RDDs have the partitioner, another unnecessary shuffle stage will come in. The Python implementation of cogroup/join is different than Scala one, it depends on union() and partitionBy(). This patch will try to use PartitionerAwareUnionRDD() in union(), when all the RDDs have the same partitioner. It also fix `reservePartitioner` in all the map() or mapPartitions(), then partitionBy() can skip the unnecessary shuffle stage. Author: Davies Liu <davies@databricks.com> Closes #4629 from davies/narrow and squashes the following commits: dffe34e [Davies Liu] improve test, check number of stages for join/cogroup 1ed3ba2 [Davies Liu] Merge branch 'master' of github.com:apache/spark into narrow 4d29932 [Davies Liu] address comment cc28d97 [Davies Liu] add unit tests 940245e [Davies Liu] address comments ff5a0a6 [Davies Liu] skip the partitionBy() on Python side eb26c62 [Davies Liu] narrow dependency in PySpark
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala8
3 files changed, 26 insertions, 3 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index fd8fac6df0..d59b466830 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -961,11 +961,18 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli
}
/** Build the union of a list of RDDs. */
- def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = {
+ val partitioners = rdds.flatMap(_.partitioner).toSet
+ if (partitioners.size == 1) {
+ new PartitionerAwareUnionRDD(this, rdds)
+ } else {
+ new UnionRDD(this, rdds)
+ }
+ }
/** Build the union of a list of RDDs passed as variable-length arguments. */
def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
- new UnionRDD(this, Seq(first) ++ rest)
+ union(Seq(first) ++ rest)
/** Get an RDD that has no partitions or elements. */
def emptyRDD[T: ClassTag] = new EmptyRDD[T](this)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 2527211929..dcb6e6313a 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -303,6 +303,7 @@ private class PythonException(msg: String, cause: Exception) extends RuntimeExce
private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
RDD[(Long, Array[Byte])](prev) {
override def getPartitions = prev.partitions
+ override val partitioner = prev.partitioner
override def compute(split: Partition, context: TaskContext) =
prev.iterator(split, context).grouped(2).map {
case Seq(a, b) => (Utils.deserializeLongValue(a), b)
@@ -330,6 +331,15 @@ private[spark] object PythonRDD extends Logging {
}
/**
+ * Return an RDD of values from an RDD of (Long, Array[Byte]), with preservePartitions=true
+ *
+ * This is useful for PySpark to have the partitioner after partitionBy()
+ */
+ def valueOfPair(pair: JavaPairRDD[Long, Array[Byte]]): JavaRDD[Array[Byte]] = {
+ pair.rdd.mapPartitions(it => it.map(_._2), true)
+ }
+
+ /**
* Adapter for calling SparkContext#runJob from Python.
*
* This method will return an iterator of an array that contains all elements in the RDD
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index fe55a5124f..3ab9e54f0e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -462,7 +462,13 @@ abstract class RDD[T: ClassTag](
* Return the union of this RDD and another one. Any identical elements will appear multiple
* times (use `.distinct()` to eliminate them).
*/
- def union(other: RDD[T]): RDD[T] = new UnionRDD(sc, Array(this, other))
+ def union(other: RDD[T]): RDD[T] = {
+ if (partitioner.isDefined && other.partitioner == partitioner) {
+ new PartitionerAwareUnionRDD(sc, Array(this, other))
+ } else {
+ new UnionRDD(sc, Array(this, other))
+ }
+ }
/**
* Return the union of this RDD and another one. Any identical elements will appear multiple