aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala21
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala33
-rw-r--r--docs/streaming-programming-guide.md4
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/DStream.scala7
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala7
5 files changed, 72 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 0355618e43..e2652f13c4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -266,6 +266,27 @@ abstract class RDD[T: ClassManifest](
def distinct(): RDD[T] = distinct(partitions.size)
/**
+ * Return a new RDD that has exactly numPartitions partitions.
+ *
+ * Used to increase or decrease the level of parallelism in this RDD. By default, this will use
+ * a shuffle to redistribute data. If you are shrinking the RDD into fewer partitions, you can
+ * set skipShuffle = false to avoid a shuffle. Skipping shuffles is not supported when
+ * increasing the number of partitions.
+ *
+ * Similar to `coalesce`, but shuffles by default, allowing you to call this safely even
+ * if you don't know the number of partitions.
+ */
+ def repartition(numPartitions: Int, skipShuffle: Boolean = false): RDD[T] = {
+ if (skipShuffle && numPartitions > this.partitions.size) {
+ val msg = "repartition must grow %s from %s to %s partitions, cannot skip shuffle.".format(
+ this.name, this.partitions.size, numPartitions
+ )
+ throw new IllegalArgumentException(msg)
+ }
+ coalesce(numPartitions, !skipShuffle)
+ }
+
+ /**
* Return a new RDD that is reduced into `numPartitions` partitions.
*
* This results in a narrow dependency, e.g. if you go from 1000 partitions
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 6d1bc5e296..fd00183668 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -139,6 +139,39 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(rdd.union(emptyKv).collect().size === 2)
}
+ test("repartitioned RDDs") {
+ val data = sc.parallelize(1 to 1000, 10)
+
+ // Coalesce partitions
+ val repartitioned1 = data.repartition(2)
+ assert(repartitioned1.partitions.size == 2)
+ val partitions1 = repartitioned1.glom().collect()
+ assert(partitions1(0).length > 0)
+ assert(partitions1(1).length > 0)
+ assert(repartitioned1.collect().toSet === (1 to 1000).toSet)
+
+ // Split partitions
+ val repartitioned2 = data.repartition(20)
+ assert(repartitioned2.partitions.size == 20)
+ val partitions2 = repartitioned2.glom().collect()
+ assert(partitions2(0).length > 0)
+ assert(partitions2(19).length > 0)
+ assert(repartitioned2.collect().toSet === (1 to 1000).toSet)
+
+ // Coalesce partitions - no shuffle
+ val repartitioned3 = data.repartition(2, skipShuffle = true)
+ assert(repartitioned3.partitions.size == 2)
+ val partitions3 = repartitioned3.glom().collect()
+ assert(partitions3(0).toList === (1 to 500).toList)
+ assert(partitions3(1).toList === (501 to 1000).toList)
+ assert(repartitioned3.collect().toSet === (1 to 1000).toSet)
+
+ // Split partitions - no shuffle (should throw exn)
+ intercept[IllegalArgumentException] {
+ data.repartition(20, skipShuffle = true)
+ }
+ }
+
test("coalesced RDDs") {
val data = sc.parallelize(1 to 10, 10)
diff --git a/docs/streaming-programming-guide.md b/docs/streaming-programming-guide.md
index 835b257238..851e30fe76 100644
--- a/docs/streaming-programming-guide.md
+++ b/docs/streaming-programming-guide.md
@@ -73,6 +73,10 @@ DStreams support many of the transformations available on normal Spark RDD's:
Iterator[T] => Iterator[U] when running on an DStream of type T. </td>
</tr>
<tr>
+ <td> <b>repartition</b>(<i>numPartitions</i>) </td>
+ <td> Changes the level of parallelism in this DStream by creating more or fewer partitions. </td>
+</tr>
+<tr>
<td> <b>union</b>(<i>otherStream</i>) </td>
<td> Return a new DStream that contains the union of the elements in the source DStream and the argument DStream. </td>
</tr>
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
index 80da6bd30b..6da2261f06 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala
@@ -438,6 +438,13 @@ abstract class DStream[T: ClassManifest] (
*/
def glom(): DStream[Array[T]] = new GlommedDStream(this)
+
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): DStream[T] = this.transform(_.repartition(numPartitions))
+
/**
* Return a new DStream in which each RDD is generated by applying mapPartitions() to each RDDs
* of this DStream. Applying mapPartitions() to an RDD applies a function to each partition
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
index 459695b7ca..eae517cff0 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaDStreamLike.scala
@@ -123,6 +123,13 @@ trait JavaDStreamLike[T, This <: JavaDStreamLike[T, This, R], R <: JavaRDDLike[T
def glom(): JavaDStream[JList[T]] =
new JavaDStream(dstream.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+ /**
+ * Return a new DStream with an increased or decreased level of parallelism. Each RDD in the
+ * returned DStream has exactly numPartitions partitions.
+ */
+ def repartition(numPartitions: Int): JavaDStream[T] =
+ new JavaDStream(dstream.repartition(numPartitions))
+
/** Return the StreamingContext associated with this DStream */
def context(): StreamingContext = dstream.context()