aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala26
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala110
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala361
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala27
-rw-r--r--python/pyspark/rdd.py66
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala13
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala18
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala9
-rw-r--r--streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala16
-rw-r--r--streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala4
16 files changed, 497 insertions, 205 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 84bd0f7ffd..4d6a97e255 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -729,6 +729,26 @@ class SparkContext(
}
/**
+ * Support function for API backtraces.
+ */
+ def setCallSite(site: String) {
+ setLocalProperty("externalCallSite", site)
+ }
+
+ /**
+ * Support function for API backtraces.
+ */
+ def clearCallSite() {
+ setLocalProperty("externalCallSite", null)
+ }
+
+ private[spark] def getCallSite(): String = {
+ val callSite = getLocalProperty("externalCallSite")
+ if (callSite == null) return Utils.formatSparkCallSite
+ callSite
+ }
+
+ /**
* Run a function on a given set of partitions in an RDD and pass the results to the given
* handler function. This is the main entry point for all actions in Spark. The allowLocal
* flag specifies whether the scheduler can run the computation on the driver rather than
@@ -740,7 +760,7 @@ class SparkContext(
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite)
val start = System.nanoTime
@@ -824,7 +844,7 @@ class SparkContext(
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
logInfo("Starting job: " + callSite)
val start = System.nanoTime
val result = dagScheduler.runApproximateJob(rdd, func, evaluator, callSite, timeout,
@@ -844,7 +864,7 @@ class SparkContext(
resultFunc: => R): SimpleFutureAction[R] =
{
val cleanF = clean(processPartition)
- val callSite = Utils.formatSparkCallSite
+ val callSite = getCallSite
val waiter = dagScheduler.submitJob(
rdd,
(context: TaskContext, iter: Iterator[T]) => cleanF(iter),
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 0680a065a3..5be5317f40 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -411,6 +411,20 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* changed at runtime.
*/
def getConf: SparkConf = sc.getConf
+
+ /**
+ * Pass-through to SparkContext.setCallSite. For API support only.
+ */
+ def setCallSite(site: String) {
+ sc.setCallSite(site)
+ }
+
+ /**
+ * Pass-through to SparkContext.setCallSite. For API support only.
+ */
+ def clearCallSite() {
+ sc.clearCallSite()
+ }
}
object JavaSparkContext {
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 3c92c205ea..e51d274d33 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -141,11 +141,6 @@ private[spark] class Executor(
val tr = runningTasks.get(taskId)
if (tr != null) {
tr.kill()
- // We remove the task also in the finally block in TaskRunner.run.
- // The reason we need to remove it here is because killTask might be called before the task
- // is even launched, and never reaching that finally block. ConcurrentHashMap's remove is
- // idempotent.
- runningTasks.remove(taskId)
}
}
@@ -167,6 +162,8 @@ private[spark] class Executor(
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
extends Runnable {
+ object TaskKilledException extends Exception
+
@volatile private var killed = false
@volatile private var task: Task[Any] = _
@@ -200,9 +197,11 @@ private[spark] class Executor(
// If this task has been killed before we deserialized it, let's quit now. Otherwise,
// continue executing the task.
if (killed) {
- logInfo("Executor killed task " + taskId)
- execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
- return
+ // Throw an exception rather than returning, because returning within a try{} block
+ // causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
+ // exception will be caught by the catch block, leading to an incorrect ExceptionFailure
+ // for the task.
+ throw TaskKilledException
}
attemptedTask = Some(task)
@@ -216,9 +215,7 @@ private[spark] class Executor(
// If the task has been killed, let's fail it.
if (task.killed) {
- logInfo("Executor killed task " + taskId)
- execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
- return
+ throw TaskKilledException
}
val resultSer = SparkEnv.get.serializer.newInstance()
@@ -260,6 +257,11 @@ private[spark] class Executor(
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
}
+ case TaskKilledException => {
+ logInfo("Executor killed task " + taskId)
+ execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
+ }
+
case t: Throwable => {
val serviceTime = (System.currentTimeMillis() - taskStart).toInt
val metrics = attemptedTask.flatMap(t => t.metrics)
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
new file mode 100644
index 0000000000..4c625d062e
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+import java.io.{ObjectOutputStream, IOException}
+import org.apache.spark.{TaskContext, OneToOneDependency, SparkContext, Partition}
+
+
+/**
+ * Class representing partitions of PartitionerAwareUnionRDD, which maintains the list of corresponding partitions
+ * of parent RDDs.
+ */
+private[spark]
+class PartitionerAwareUnionRDDPartition(
+ @transient val rdds: Seq[RDD[_]],
+ val idx: Int
+ ) extends Partition {
+ var parents = rdds.map(_.partitions(idx)).toArray
+
+ override val index = idx
+ override def hashCode(): Int = idx
+
+ @throws(classOf[IOException])
+ private def writeObject(oos: ObjectOutputStream) {
+ // Update the reference to parent partition at the time of task serialization
+ parents = rdds.map(_.partitions(index)).toArray
+ oos.defaultWriteObject()
+ }
+}
+
+/**
+ * Class representing an RDD that can take multiple RDDs partitioned by the same partitioner and
+ * unify them into a single RDD while preserving the partitioner. So m RDDs with p partitions each
+ * will be unified to a single RDD with p partitions and the same partitioner. The preferred
+ * location for each partition of the unified RDD will be the most common preferred location
+ * of the corresponding partitions of the parent RDDs. For example, location of partition 0
+ * of the unified RDD will be where most of partition 0 of the parent RDDs are located.
+ */
+private[spark]
+class PartitionerAwareUnionRDD[T: ClassTag](
+ sc: SparkContext,
+ var rdds: Seq[RDD[T]]
+ ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) {
+ require(rdds.length > 0)
+ require(rdds.flatMap(_.partitioner).toSet.size == 1,
+ "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner))
+
+ override val partitioner = rdds.head.partitioner
+
+ override def getPartitions: Array[Partition] = {
+ val numPartitions = partitioner.get.numPartitions
+ (0 until numPartitions).map(index => {
+ new PartitionerAwareUnionRDDPartition(rdds, index)
+ }).toArray
+ }
+
+ // Get the location where most of the partitions of parent RDDs are located
+ override def getPreferredLocations(s: Partition): Seq[String] = {
+ logDebug("Finding preferred location for " + this + ", partition " + s.index)
+ val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
+ val locations = rdds.zip(parentPartitions).flatMap {
+ case (rdd, part) => {
+ val parentLocations = currPrefLocs(rdd, part)
+ logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations)
+ parentLocations
+ }
+ }
+ val location = if (locations.isEmpty) {
+ None
+ } else {
+ // Find the location that maximum number of parent partitions prefer
+ Some(locations.groupBy(x => x).maxBy(_._2.length)._1)
+ }
+ logDebug("Selected location for " + this + ", partition " + s.index + " = " + location)
+ location.toSeq
+ }
+
+ override def compute(s: Partition, context: TaskContext): Iterator[T] = {
+ val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].parents
+ rdds.zip(parentPartitions).iterator.flatMap {
+ case (rdd, p) => rdd.iterator(p, context)
+ }
+ }
+
+ override def clearDependencies() {
+ super.clearDependencies()
+ rdds = null
+ }
+
+ // Get the *current* preferred locations from the DAGScheduler (as opposed to the static ones)
+ private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = {
+ rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6a7b0f8a86..3f41b66279 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -953,7 +953,7 @@ abstract class RDD[T: ClassTag](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- @transient private[spark] val origin = Utils.formatSparkCallSite
+ @transient private[spark] val origin = sc.getCallSite
private[spark] def elementClassTag: ClassTag[T] = classTag[T]
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 642dabaad5..bc688110f4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -40,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration {
* manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
* of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
+private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
extends Logging with Serializable {
import CheckpointState._
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
index d94b706854..0c8ed62759 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala
@@ -285,7 +285,8 @@ private[spark] class TaskSchedulerImpl(
}
}
case None =>
- logInfo("Ignoring update from TID " + tid + " because its task set is gone")
+ logInfo("Ignoring update with state %s from TID %s because its task set is gone"
+ .format(state, tid))
}
} catch {
case e: Exception => logError("Exception in statusUpdate", e)
@@ -328,7 +329,7 @@ private[spark] class TaskSchedulerImpl(
// Have each task set throw a SparkException with the error
for ((taskSetId, manager) <- activeTaskSets) {
try {
- manager.error(message)
+ manager.abort(message)
} catch {
case e: Exception => logError("Exception in error callback", e)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index 67ad99a4d7..6dd1469d8f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -548,11 +548,6 @@ private[spark] class TaskSetManager(
}
}
- def error(message: String) {
- // Save the error message
- abort("Error: " + message)
- }
-
def abort(message: String) {
// TODO: Kill running tasks if we were not terminated due to a Mesos error
sched.dagScheduler.taskSetFailed(taskSet, message)
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index 70bfb81661..ec13b329b2 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -55,15 +55,15 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("RDDs with one-to-one dependencies") {
- testCheckpointing(_.map(x => x.toString))
- testCheckpointing(_.flatMap(x => 1 to x))
- testCheckpointing(_.filter(_ % 2 == 0))
- testCheckpointing(_.sample(false, 0.5, 0))
- testCheckpointing(_.glom())
- testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
- testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
- testCheckpointing(_.pipe(Seq("cat")))
+ testRDD(_.map(x => x.toString))
+ testRDD(_.flatMap(x => 1 to x))
+ testRDD(_.filter(_ % 2 == 0))
+ testRDD(_.sample(false, 0.5, 0))
+ testRDD(_.glom())
+ testRDD(_.mapPartitions(_.map(_.toString)))
+ testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
+ testRDD(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
+ testRDD(_.pipe(Seq("cat")))
}
test("ParallelCollection") {
@@ -95,7 +95,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
test("ShuffledRDD") {
- testCheckpointing(rdd => {
+ testRDD(rdd => {
// Creating ShuffledRDD directly as PairRDDFunctions.combineByKey produces a MapPartitionedRDD
new ShuffledRDD[Int, Int, (Int, Int)](rdd.map(x => (x % 2, 1)), partitioner)
})
@@ -103,25 +103,17 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
test("UnionRDD") {
def otherRDD = sc.makeRDD(1 to 10, 1)
-
- // Test whether the size of UnionRDDPartitions reduce in size after parent RDD is checkpointed.
- // Current implementation of UnionRDD has transient reference to parent RDDs,
- // so only the partitions will reduce in serialized size, not the RDD.
- testCheckpointing(_.union(otherRDD), false, true)
- testParentCheckpointing(_.union(otherRDD), false, true)
+ testRDD(_.union(otherRDD))
+ testRDDPartitions(_.union(otherRDD))
}
test("CartesianRDD") {
def otherRDD = sc.makeRDD(1 to 10, 1)
- testCheckpointing(new CartesianRDD(sc, _, otherRDD))
-
- // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
- // so only the RDD will reduce in serialized size, not the partitions.
- testParentCheckpointing(new CartesianRDD(sc, _, otherRDD), true, false)
+ testRDD(new CartesianRDD(sc, _, otherRDD))
+ testRDDPartitions(new CartesianRDD(sc, _, otherRDD))
// Test that the CartesianRDD updates parent partitions (CartesianRDD.s1/s2) after
- // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
+ // the parent RDD has been checkpointed and parent partitions have been changed.
// Note that this test is very specific to the current implementation of CartesianRDD.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint() // checkpoint that MappedRDD
@@ -132,23 +124,20 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val splitAfterCheckpoint =
serializeDeserialize(cartesian.partitions.head.asInstanceOf[CartesianPartition])
assert(
- (splitAfterCheckpoint.s1 != splitBeforeCheckpoint.s1) &&
- (splitAfterCheckpoint.s2 != splitBeforeCheckpoint.s2),
- "CartesianRDD.parents not updated after parent RDD checkpointed"
+ (splitAfterCheckpoint.s1.getClass != splitBeforeCheckpoint.s1.getClass) &&
+ (splitAfterCheckpoint.s2.getClass != splitBeforeCheckpoint.s2.getClass),
+ "CartesianRDD.s1 and CartesianRDD.s2 not updated after parent RDD is checkpointed"
)
}
test("CoalescedRDD") {
- testCheckpointing(_.coalesce(2))
+ testRDD(_.coalesce(2))
+ testRDDPartitions(_.coalesce(2))
- // Test whether size of CoalescedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of CoalescedRDDPartition has transient reference to parent RDD,
- // so only the RDD will reduce in serialized size, not the partitions.
- testParentCheckpointing(_.coalesce(2), true, false)
-
- // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents) after
- // the parent RDD has been checkpointed and parent partitions have been changed to HadoopPartitions.
- // Note that this test is very specific to the current implementation of CoalescedRDDPartitions
+ // Test that the CoalescedRDDPartition updates parent partitions (CoalescedRDDPartition.parents)
+ // after the parent RDD has been checkpointed and parent partitions have been changed.
+ // Note that this test is very specific to the current implementation of
+ // CoalescedRDDPartitions.
val ones = sc.makeRDD(1 to 100, 10).map(x => x)
ones.checkpoint() // checkpoint that MappedRDD
val coalesced = new CoalescedRDD(ones, 2)
@@ -158,33 +147,78 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val splitAfterCheckpoint =
serializeDeserialize(coalesced.partitions.head.asInstanceOf[CoalescedRDDPartition])
assert(
- splitAfterCheckpoint.parents.head != splitBeforeCheckpoint.parents.head,
- "CoalescedRDDPartition.parents not updated after parent RDD checkpointed"
+ splitAfterCheckpoint.parents.head.getClass != splitBeforeCheckpoint.parents.head.getClass,
+ "CoalescedRDDPartition.parents not updated after parent RDD is checkpointed"
)
}
test("CoGroupedRDD") {
- val longLineageRDD1 = generateLongLineageRDDForCoGroupedRDD()
- testCheckpointing(rdd => {
+ val longLineageRDD1 = generateFatPairRDD()
+ testRDD(rdd => {
CheckpointSuite.cogroup(longLineageRDD1, rdd.map(x => (x % 2, 1)), partitioner)
- }, false, true)
+ })
- val longLineageRDD2 = generateLongLineageRDDForCoGroupedRDD()
- testParentCheckpointing(rdd => {
+ val longLineageRDD2 = generateFatPairRDD()
+ testRDDPartitions(rdd => {
CheckpointSuite.cogroup(
longLineageRDD2, sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)), partitioner)
- }, false, true)
+ })
}
test("ZippedRDD") {
- testCheckpointing(
- rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
-
- // Test whether size of ZippedRDD reduce in size after parent RDD is checkpointed
- // Current implementation of ZippedRDDPartitions has transient references to parent RDDs,
- // so only the RDD will reduce in serialized size, not the partitions.
- testParentCheckpointing(
- rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)), true, false)
+ testRDD(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
+ testRDDPartitions(rdd => new ZippedRDD(sc, rdd, rdd.map(x => x)))
+
+ // Test that the ZippedPartition updates parent partitions
+ // after the parent RDD has been checkpointed and parent partitions have been changed.
+ // Note that this test is very specific to the current implementation of ZippedRDD.
+ val rdd = generateFatRDD()
+ val zippedRDD = new ZippedRDD(sc, rdd, rdd.map(x => x))
+ zippedRDD.rdd1.checkpoint()
+ zippedRDD.rdd2.checkpoint()
+ val partitionBeforeCheckpoint =
+ serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+ zippedRDD.count()
+ val partitionAfterCheckpoint =
+ serializeDeserialize(zippedRDD.partitions.head.asInstanceOf[ZippedPartition[_, _]])
+ assert(
+ partitionAfterCheckpoint.partition1.getClass != partitionBeforeCheckpoint.partition1.getClass &&
+ partitionAfterCheckpoint.partition2.getClass != partitionBeforeCheckpoint.partition2.getClass,
+ "ZippedRDD.partition1 and ZippedRDD.partition2 not updated after parent RDD is checkpointed"
+ )
+ }
+
+ test("PartitionerAwareUnionRDD") {
+ testRDD(rdd => {
+ new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
+ generateFatPairRDD(),
+ rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+ ))
+ })
+
+ testRDDPartitions(rdd => {
+ new PartitionerAwareUnionRDD[(Int, Int)](sc, Array(
+ generateFatPairRDD(),
+ rdd.map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
+ ))
+ })
+
+ // Test that the PartitionerAwareUnionRDD updates parent partitions
+ // (PartitionerAwareUnionRDD.parents) after the parent RDD has been checkpointed and parent
+ // partitions have been changed. Note that this test is very specific to the current
+ // implementation of PartitionerAwareUnionRDD.
+ val pairRDD = generateFatPairRDD()
+ pairRDD.checkpoint()
+ val unionRDD = new PartitionerAwareUnionRDD(sc, Array(pairRDD))
+ val partitionBeforeCheckpoint = serializeDeserialize(
+ unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
+ pairRDD.count()
+ val partitionAfterCheckpoint = serializeDeserialize(
+ unionRDD.partitions.head.asInstanceOf[PartitionerAwareUnionRDDPartition])
+ assert(
+ partitionBeforeCheckpoint.parents.head.getClass != partitionAfterCheckpoint.parents.head.getClass,
+ "PartitionerAwareUnionRDDPartition.parents not updated after parent RDD is checkpointed"
+ )
}
test("CheckpointRDD with zero partitions") {
@@ -198,29 +232,32 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
}
/**
- * Test checkpointing of the final RDD generated by the given operation. By default,
- * this method tests whether the size of serialized RDD has reduced after checkpointing or not.
- * It can also test whether the size of serialized RDD partitions has reduced after checkpointing or
- * not, but this is not done by default as usually the partitions do not refer to any RDD and
- * therefore never store the lineage.
+ * Test checkpointing of the RDD generated by the given operation. It tests whether the
+ * serialized size of the RDD is reduce after checkpointing or not. This function should be called
+ * on all RDDs that have a parent RDD (i.e., do not call on ParallelCollection, BlockRDD, etc.).
*/
- def testCheckpointing[U: ClassTag](
- op: (RDD[Int]) => RDD[U],
- testRDDSize: Boolean = true,
- testRDDPartitionSize: Boolean = false
- ) {
+ def testRDD[U: ClassTag](op: (RDD[Int]) => RDD[U]) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD()
+ val baseRDD = generateFatRDD()
val operatedRDD = op(baseRDD)
val parentRDD = operatedRDD.dependencies.headOption.orNull
val rddType = operatedRDD.getClass.getSimpleName
val numPartitions = operatedRDD.partitions.length
+ // Force initialization of all the data structures in RDDs
+ // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
+ initializeRdd(operatedRDD)
+
+ val partitionsBeforeCheckpoint = operatedRDD.partitions
+
// Find serialized sizes before and after the checkpoint
- val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+ val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
operatedRDD.checkpoint()
val result = operatedRDD.collect()
- val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+ operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
+ val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
// Test whether the checkpoint file has been created
assert(sc.checkpointFile[U](operatedRDD.getCheckpointFile.get).collect() === result)
@@ -228,6 +265,9 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// Test whether dependencies have been changed from its earlier parent RDD
assert(operatedRDD.dependencies.head.rdd != parentRDD)
+ // Test whether the partitions have been changed from its earlier partitions
+ assert(operatedRDD.partitions.toList != partitionsBeforeCheckpoint.toList)
+
// Test whether the partitions have been changed to the new Hadoop partitions
assert(operatedRDD.partitions.toList === operatedRDD.checkpointData.get.getPartitions.toList)
@@ -237,122 +277,72 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
// Test whether the data in the checkpointed RDD is same as original
assert(operatedRDD.collect() === result)
- // Test whether serialized size of the RDD has reduced. If the RDD
- // does not have any dependency to another RDD (e.g., ParallelCollection,
- // ShuffleRDD with ShuffleDependency), it may not reduce in size after checkpointing.
- if (testRDDSize) {
- logInfo("Size of " + rddType +
- "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
- assert(
- rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
- "Size of " + rddType + " did not reduce after checkpointing " +
- "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
- )
- }
+ // Test whether serialized size of the RDD has reduced.
+ logInfo("Size of " + rddType +
+ " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]")
+ assert(
+ rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
+ "Size of " + rddType + " did not reduce after checkpointing " +
+ " [" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
+ )
- // Test whether serialized size of the partitions has reduced. If the partitions
- // do not have any non-transient reference to another RDD or another RDD's partitions, it
- // does not refer to a lineage and therefore may not reduce in size after checkpointing.
- // However, if the original partitions before checkpointing do refer to a parent RDD, the partitions
- // must be forgotten after checkpointing (to remove all reference to parent RDDs) and
- // replaced with the HadooPartitions of the checkpointed RDD.
- if (testRDDPartitionSize) {
- logInfo("Size of " + rddType + " partitions "
- + "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]")
- assert(
- splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
- "Size of " + rddType + " partitions did not reduce after checkpointing " +
- "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
- )
- }
}
/**
* Test whether checkpointing of the parent of the generated RDD also
* truncates the lineage or not. Some RDDs like CoGroupedRDD hold on to its parent
* RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
- * this RDD will remember the partitions and therefore potentially the whole lineage.
+ * the generated RDD will remember the partitions and therefore potentially the whole lineage.
+ * This function should be called only those RDD whose partitions refer to parent RDD's
+ * partitions (i.e., do not call it on simple RDD like MappedRDD).
+ *
*/
- def testParentCheckpointing[U: ClassTag](
- op: (RDD[Int]) => RDD[U],
- testRDDSize: Boolean,
- testRDDPartitionSize: Boolean
- ) {
+ def testRDDPartitions[U: ClassTag](op: (RDD[Int]) => RDD[U]) {
// Generate the final RDD using given RDD operation
- val baseRDD = generateLongLineageRDD()
+ val baseRDD = generateFatRDD()
val operatedRDD = op(baseRDD)
- val parentRDD = operatedRDD.dependencies.head.rdd
+ val parentRDDs = operatedRDD.dependencies.map(_.rdd)
val rddType = operatedRDD.getClass.getSimpleName
- val parentRDDType = parentRDD.getClass.getSimpleName
- // Get the partitions and dependencies of the parent in case they're lazily computed
- parentRDD.dependencies
- parentRDD.partitions
+ // Force initialization of all the data structures in RDDs
+ // Without this, serializing the RDD will give a wrong estimate of the size of the RDD
+ initializeRdd(operatedRDD)
// Find serialized sizes before and after the checkpoint
- val (rddSizeBeforeCheckpoint, splitSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
- parentRDD.checkpoint() // checkpoint the parent RDD, not the generated one
- val result = operatedRDD.collect()
- val (rddSizeAfterCheckpoint, splitSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
+ val (rddSizeBeforeCheckpoint, partitionSizeBeforeCheckpoint) = getSerializedSizes(operatedRDD)
+ parentRDDs.foreach(_.checkpoint()) // checkpoint the parent RDD, not the generated one
+ val result = operatedRDD.collect() // force checkpointing
+ operatedRDD.collect() // force re-initialization of post-checkpoint lazy variables
+ val (rddSizeAfterCheckpoint, partitionSizeAfterCheckpoint) = getSerializedSizes(operatedRDD)
+ logInfo("RDD after checkpoint: " + operatedRDD + "\n" + operatedRDD.toDebugString)
// Test whether the data in the checkpointed RDD is same as original
assert(operatedRDD.collect() === result)
- // Test whether serialized size of the RDD has reduced because of its parent being
- // checkpointed. If this RDD or its parent RDD do not have any dependency
- // to another RDD (e.g., ParallelCollection, ShuffleRDD with ShuffleDependency), it may
- // not reduce in size after checkpointing.
- if (testRDDSize) {
- assert(
- rddSizeAfterCheckpoint < rddSizeBeforeCheckpoint,
- "Size of " + rddType + " did not reduce after checkpointing parent " + parentRDDType +
- "[" + rddSizeBeforeCheckpoint + " --> " + rddSizeAfterCheckpoint + "]"
- )
- }
-
- // Test whether serialized size of the partitions has reduced because of its parent being
- // checkpointed. If the partitions do not have any non-transient reference to another RDD
- // or another RDD's partitions, it does not refer to a lineage and therefore may not reduce
- // in size after checkpointing. However, if the partitions do refer to the *partitions* of a parent
- // RDD, then these partitions must update reference to the parent RDD partitions as the parent RDD's
- // partitions must have changed after checkpointing.
- if (testRDDPartitionSize) {
- assert(
- splitSizeAfterCheckpoint < splitSizeBeforeCheckpoint,
- "Size of " + rddType + " partitions did not reduce after checkpointing parent " + parentRDDType +
- "[" + splitSizeBeforeCheckpoint + " --> " + splitSizeAfterCheckpoint + "]"
- )
- }
-
+ // Test whether serialized size of the partitions has reduced
+ logInfo("Size of partitions of " + rddType +
+ " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]")
+ assert(
+ partitionSizeAfterCheckpoint < partitionSizeBeforeCheckpoint,
+ "Size of " + rddType + " partitions did not reduce after checkpointing parent RDDs" +
+ " [" + partitionSizeBeforeCheckpoint + " --> " + partitionSizeAfterCheckpoint + "]"
+ )
}
/**
- * Generate an RDD with a long lineage of one-to-one dependencies.
+ * Generate an RDD such that both the RDD and its partitions have large size.
*/
- def generateLongLineageRDD(): RDD[Int] = {
- var rdd = sc.makeRDD(1 to 100, 4)
- for (i <- 1 to 50) {
- rdd = rdd.map(x => x + 1)
- }
- rdd
+ def generateFatRDD(): RDD[Int] = {
+ new FatRDD(sc.makeRDD(1 to 100, 4)).map(x => x)
}
/**
- * Generate an RDD with a long lineage specifically for CoGroupedRDD.
- * A CoGroupedRDD can have a long lineage only one of its parents have a long lineage
- * and narrow dependency with this RDD. This method generate such an RDD by a sequence
- * of cogroups and mapValues which creates a long lineage of narrow dependencies.
+ * Generate an pair RDD (with partitioner) such that both the RDD and its partitions
+ * have large size.
*/
- def generateLongLineageRDDForCoGroupedRDD() = {
- val add = (x: (Seq[Int], Seq[Int])) => (x._1 ++ x._2).reduce(_ + _)
-
- def ones: RDD[(Int, Int)] = sc.makeRDD(1 to 2, 2).map(x => (x % 2, 1)).reduceByKey(partitioner, _ + _)
-
- var cogrouped: RDD[(Int, (Seq[Int], Seq[Int]))] = ones.cogroup(ones)
- for(i <- 1 to 10) {
- cogrouped = cogrouped.mapValues(add).cogroup(ones)
- }
- cogrouped.mapValues(add)
+ def generateFatPairRDD() = {
+ new FatPairRDD(sc.makeRDD(1 to 100, 4), partitioner).mapValues(x => x)
}
/**
@@ -360,8 +350,26 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* upon checkpointing. Ignores the checkpointData field, which may grow when we checkpoint.
*/
def getSerializedSizes(rdd: RDD[_]): (Int, Int) = {
- (Utils.serialize(rdd).length - Utils.serialize(rdd.checkpointData).length,
- Utils.serialize(rdd.partitions).length)
+ val rddSize = Utils.serialize(rdd).size
+ val rddCpDataSize = Utils.serialize(rdd.checkpointData).size
+ val rddPartitionSize = Utils.serialize(rdd.partitions).size
+ val rddDependenciesSize = Utils.serialize(rdd.dependencies).size
+
+ // Print detailed size, helps in debugging
+ logInfo("Serialized sizes of " + rdd +
+ ": RDD = " + rddSize +
+ ", RDD checkpoint data = " + rddCpDataSize +
+ ", RDD partitions = " + rddPartitionSize +
+ ", RDD dependencies = " + rddDependenciesSize
+ )
+ // this makes sure that serializing the RDD's checkpoint data does not
+ // serialize the whole RDD as well
+ assert(
+ rddSize > rddCpDataSize,
+ "RDD's checkpoint data (" + rddCpDataSize + ") is equal or larger than the " +
+ "whole RDD with checkpoint data (" + rddSize + ")"
+ )
+ (rddSize - rddCpDataSize, rddPartitionSize)
}
/**
@@ -373,8 +381,49 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
val bytes = Utils.serialize(obj)
Utils.deserialize[T](bytes)
}
+
+ /**
+ * Recursively force the initialization of the all members of an RDD and it parents.
+ */
+ def initializeRdd(rdd: RDD[_]) {
+ rdd.partitions // forces the
+ rdd.dependencies.map(_.rdd).foreach(initializeRdd(_))
+ }
}
+/** RDD partition that has large serialized size. */
+class FatPartition(val partition: Partition) extends Partition {
+ val bigData = new Array[Byte](10000)
+ def index: Int = partition.index
+}
+
+/** RDD that has large serialized size. */
+class FatRDD(parent: RDD[Int]) extends RDD[Int](parent) {
+ val bigData = new Array[Byte](100000)
+
+ protected def getPartitions: Array[Partition] = {
+ parent.partitions.map(p => new FatPartition(p))
+ }
+
+ def compute(split: Partition, context: TaskContext): Iterator[Int] = {
+ parent.compute(split.asInstanceOf[FatPartition].partition, context)
+ }
+}
+
+/** Pair RDD that has large serialized size. */
+class FatPairRDD(parent: RDD[Int], _partitioner: Partitioner) extends RDD[(Int, Int)](parent) {
+ val bigData = new Array[Byte](100000)
+
+ protected def getPartitions: Array[Partition] = {
+ parent.partitions.map(p => new FatPartition(p))
+ }
+
+ @transient override val partitioner = Some(_partitioner)
+
+ def compute(split: Partition, context: TaskContext): Iterator[(Int, Int)] = {
+ parent.compute(split.asInstanceOf[FatPartition].partition, context).map(x => (x, x))
+ }
+}
object CheckpointSuite {
// This is a custom cogroup function that does not use mapValues like
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 1383359f85..559ea051d3 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -84,6 +84,33 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
}
+ test("partitioner aware union") {
+ import SparkContext._
+ def makeRDDWithPartitioner(seq: Seq[Int]) = {
+ sc.makeRDD(seq, 1)
+ .map(x => (x, null))
+ .partitionBy(new HashPartitioner(2))
+ .mapPartitions(_.map(_._1), true)
+ }
+
+ val nums1 = makeRDDWithPartitioner(1 to 4)
+ val nums2 = makeRDDWithPartitioner(5 to 8)
+ assert(nums1.partitioner == nums2.partitioner)
+ assert(new PartitionerAwareUnionRDD(sc, Seq(nums1)).collect().toSet === Set(1, 2, 3, 4))
+
+ val union = new PartitionerAwareUnionRDD(sc, Seq(nums1, nums2))
+ assert(union.collect().toSet === Set(1, 2, 3, 4, 5, 6, 7, 8))
+ val nums1Parts = nums1.collectPartitions()
+ val nums2Parts = nums2.collectPartitions()
+ val unionParts = union.collectPartitions()
+ assert(nums1Parts.length === 2)
+ assert(nums2Parts.length === 2)
+ assert(unionParts.length === 2)
+ assert((nums1Parts(0) ++ nums2Parts(0)).toList === unionParts(0).toList)
+ assert((nums1Parts(1) ++ nums2Parts(1)).toList === unionParts(1).toList)
+ assert(union.partitioner === nums1.partitioner)
+ }
+
test("aggregate") {
val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3)))
type StringMap = HashMap[String, Int]
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index f87923e6fa..6fb4a7b3be 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -23,6 +23,7 @@ import operator
import os
import sys
import shlex
+import traceback
from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
@@ -39,6 +40,46 @@ from py4j.java_collections import ListConverter, MapConverter
__all__ = ["RDD"]
+def _extract_concise_traceback():
+ tb = traceback.extract_stack()
+ if len(tb) == 0:
+ return "I'm lost!"
+ # HACK: This function is in a file called 'rdd.py' in the top level of
+ # everything PySpark. Just trim off the directory name and assume
+ # everything in that tree is PySpark guts.
+ file, line, module, what = tb[len(tb) - 1]
+ sparkpath = os.path.dirname(file)
+ first_spark_frame = len(tb) - 1
+ for i in range(0, len(tb)):
+ file, line, fun, what = tb[i]
+ if file.startswith(sparkpath):
+ first_spark_frame = i
+ break
+ if first_spark_frame == 0:
+ file, line, fun, what = tb[0]
+ return "%s at %s:%d" % (fun, file, line)
+ sfile, sline, sfun, swhat = tb[first_spark_frame]
+ ufile, uline, ufun, uwhat = tb[first_spark_frame-1]
+ return "%s at %s:%d" % (sfun, ufile, uline)
+
+_spark_stack_depth = 0
+
+class _JavaStackTrace(object):
+ def __init__(self, sc):
+ self._traceback = _extract_concise_traceback()
+ self._context = sc
+
+ def __enter__(self):
+ global _spark_stack_depth
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(self._traceback)
+ _spark_stack_depth += 1
+
+ def __exit__(self, type, value, tb):
+ global _spark_stack_depth
+ _spark_stack_depth -= 1
+ if _spark_stack_depth == 0:
+ self._context._jsc.setCallSite(None)
class RDD(object):
"""
@@ -401,7 +442,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- bytesInJava = self._jrdd.collect().iterator()
+ with _JavaStackTrace(self.context) as st:
+ bytesInJava = self._jrdd.collect().iterator()
return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
@@ -582,13 +624,14 @@ class RDD(object):
# TODO(shivaram): Similar to the scala implementation, update the take
# method to scan multiple splits based on an estimate of how many elements
# we have per-split.
- for partition in range(mapped._jrdd.splits().size()):
- partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
- partitionsToTake[0] = partition
- iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
- items.extend(mapped._collect_iterator_through_file(iterator))
- if len(items) >= num:
- break
+ with _JavaStackTrace(self.context) as st:
+ for partition in range(mapped._jrdd.splits().size()):
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
+ items.extend(mapped._collect_iterator_through_file(iterator))
+ if len(items) >= num:
+ break
return items[:num]
def first(self):
@@ -765,9 +808,10 @@ class RDD(object):
yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
- pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
- partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
- id(partitionFunc))
+ with _JavaStackTrace(self.context) as st:
+ pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
+ partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
+ id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
index 80af96c060..56dbcbda23 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/PairDStreamFunctions.scala
@@ -108,8 +108,9 @@ extends Serializable {
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiner: (C, C) => C,
- partitioner: Partitioner) : DStream[(K, C)] = {
- new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner)
+ partitioner: Partitioner,
+ mapSideCombine: Boolean = true): DStream[(K, C)] = {
+ new ShuffledDStream[K, V, C](self, createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine)
}
/**
@@ -173,7 +174,13 @@ extends Serializable {
slideDuration: Duration,
partitioner: Partitioner
): DStream[(K, Seq[V])] = {
- self.window(windowDuration, slideDuration).groupByKey(partitioner)
+ val createCombiner = (v: Seq[V]) => new ArrayBuffer[V] ++= v
+ val mergeValue = (buf: ArrayBuffer[V], v: Seq[V]) => buf ++= v
+ val mergeCombiner = (buf1: ArrayBuffer[V], buf2: ArrayBuffer[V]) => buf1 ++= buf2
+ self.groupByKey(partitioner)
+ .window(windowDuration, slideDuration)
+ .combineByKey[ArrayBuffer[V]](createCombiner, mergeValue, mergeCombiner, partitioner)
+ .asInstanceOf[DStream[(K, Seq[V])]]
}
/**
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
index dfd6e27c3e..6c3467d405 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaPairDStream.scala
@@ -155,7 +155,7 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
/**
* Combine elements of each key in DStream's RDDs using custom function. This is similar to the
- * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.PairRDDFunctions]] for more
+ * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.rdd.PairRDDFunctions]] for more
* information.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
@@ -169,6 +169,22 @@ class JavaPairDStream[K, V](val dstream: DStream[(K, V)])(
}
/**
+ * Combine elements of each key in DStream's RDDs using custom function. This is similar to the
+ * combineByKey for RDDs. Please refer to combineByKey in [[org.apache.spark.rdd.PairRDDFunctions]] for more
+ * information.
+ */
+ def combineByKey[C](createCombiner: JFunction[V, C],
+ mergeValue: JFunction2[C, V, C],
+ mergeCombiners: JFunction2[C, C, C],
+ partitioner: Partitioner,
+ mapSideCombine: Boolean
+ ): JavaPairDStream[K, C] = {
+ implicit val cm: ClassTag[C] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
+ dstream.combineByKey(createCombiner, mergeValue, mergeCombiners, partitioner, mapSideCombine)
+ }
+
+ /**
* Return a new DStream by applying `groupByKey` over a sliding window. This is similar to
* `DStream.groupByKey()` but applies it over a sliding window. The new DStream generates RDDs
* with the same interval as this DStream. Hash partitioning is used to generate the RDDs with
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
index e6e0022097..84e69f277b 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/ShuffledDStream.scala
@@ -29,8 +29,9 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiner: (C, C) => C,
- partitioner: Partitioner
- ) extends DStream [(K,C)] (parent.ssc) {
+ partitioner: Partitioner,
+ mapSideCombine: Boolean = true
+ ) extends DStream[(K,C)] (parent.ssc) {
override def dependencies = List(parent)
@@ -38,8 +39,8 @@ class ShuffledDStream[K: ClassTag, V: ClassTag, C: ClassTag](
override def compute(validTime: Time): Option[RDD[(K,C)]] = {
parent.getOrCompute(validTime) match {
- case Some(rdd) =>
- Some(rdd.combineByKey[C](createCombiner, mergeValue, mergeCombiner, partitioner))
+ case Some(rdd) => Some(rdd.combineByKey[C](
+ createCombiner, mergeValue, mergeCombiner, partitioner, mapSideCombine))
case None => None
}
}
diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
index 73d959331a..89c43ff935 100644
--- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
+++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
@@ -17,10 +17,10 @@
package org.apache.spark.streaming.dstream
-import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.UnionRDD
+import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD}
import org.apache.spark.storage.StorageLevel
-import org.apache.spark.streaming.{Duration, Interval, Time, DStream}
+import org.apache.spark.streaming._
+import org.apache.spark.streaming.Duration
import scala.reflect.ClassTag
@@ -51,6 +51,14 @@ class WindowedDStream[T: ClassTag](
override def compute(validTime: Time): Option[RDD[T]] = {
val currentWindow = new Interval(validTime - windowDuration + parent.slideDuration, validTime)
- Some(new UnionRDD(ssc.sc, parent.slice(currentWindow)))
+ val rddsInWindow = parent.slice(currentWindow)
+ val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) {
+ logDebug("Using partition aware union for windowing at " + validTime)
+ new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow)
+ } else {
+ logDebug("Using normal union for windowing at " + validTime)
+ new UnionRDD(ssc.sc,rddsInWindow)
+ }
+ Some(windowRDD)
}
}
diff --git a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
index c92c34d49b..c39abfc21b 100644
--- a/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
+++ b/streaming/src/test/scala/org/apache/spark/streaming/WindowOperationsSuite.scala
@@ -224,9 +224,7 @@ class WindowOperationsSuite extends TestSuiteBase {
val slideDuration = Seconds(1)
val numBatches = expectedOutput.size * (slideDuration / batchDuration).toInt
val operation = (s: DStream[(String, Int)]) => {
- s.groupByKeyAndWindow(windowDuration, slideDuration)
- .map(x => (x._1, x._2.toSet))
- .persist()
+ s.groupByKeyAndWindow(windowDuration, slideDuration).map(x => (x._1, x._2.toSet))
}
testOperation(input, operation, expectedOutput, numBatches, true)
}